File size: 2,494 Bytes
88435ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Literal, cast, get_args

from anthropic import Anthropic

from neollm.llm.abstract_llm import AbstractLLM
from neollm.llm.claude.abstract_claude import AbstractClaude
from neollm.types import APIPricing, ClientSettings
from neollm.utils.utils import cprint

# price: https://www.anthropic.com/api
# models: https://docs.anthropic.com/claude/docs/models-overview

SUPPORTED_MODELS = Literal[
    "claude-3-opus-20240229",
    "claude-3-sonnet-20240229",
    "claude-3-haiku-20240307",
]


def get_anthoropic_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM:
    # Add 日付
    replace_map_for_nodate: dict[str, SUPPORTED_MODELS] = {
        "claude-3-opus": "claude-3-opus-20240229",
        "claude-3-sonnet": "claude-3-sonnet-20240229",
        "claude-3-haiku": "claude-3-haiku-20240307",
    }
    if model_name in replace_map_for_nodate:
        cprint("WARNING: model_nameに日付を指定してください", color="yellow", background=True)
        print(f"model_name: {model_name} -> {replace_map_for_nodate[model_name]}")
        model_name = replace_map_for_nodate[model_name]

    # map to LLM
    supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = {
        "claude-3-opus-20240229": AnthropicClaude3Opus20240229(client_settings),
        "claude-3-sonnet-20240229": AnthropicClaude3Sonnet20240229(client_settings),
        "claude-3-haiku-20240307": AnthropicClaude3Haiku20240229(client_settings),
    }
    if model_name in supported_model_map:
        model_name = cast(SUPPORTED_MODELS, model_name)
        return supported_model_map[model_name]
    raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.")


class AnthoropicLLM(AbstractClaude):
    @property
    def client(self) -> Anthropic:
        client = Anthropic(**self.client_settings)
        return client


class AnthropicClaude3Opus20240229(AnthoropicLLM):
    dollar_per_ktoken = APIPricing(input=15 / 1000, output=75 / 1000)
    model: str = "claude-3-opus-20240229"
    context_window: int = 200_000


class AnthropicClaude3Sonnet20240229(AnthoropicLLM):
    dollar_per_ktoken = APIPricing(input=3 / 1000, output=15 / 1000)
    model: str = "claude-3-sonnet-20240229"
    context_window: int = 200_000


class AnthropicClaude3Haiku20240229(AnthoropicLLM):
    dollar_per_ktoken = APIPricing(input=0.25 / 1000, output=1.25 / 1000)
    model: str = "claude-3-haiku-20240307"
    context_window: int = 200_000