Kpenciler's picture
Upload 53 files
88435ed verified
raw
history blame
2.49 kB
from typing import Literal, cast, get_args
from anthropic import Anthropic
from neollm.llm.abstract_llm import AbstractLLM
from neollm.llm.claude.abstract_claude import AbstractClaude
from neollm.types import APIPricing, ClientSettings
from neollm.utils.utils import cprint
# price: https://www.anthropic.com/api
# models: https://docs.anthropic.com/claude/docs/models-overview
SUPPORTED_MODELS = Literal[
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
]
def get_anthoropic_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM:
# Add 日付
replace_map_for_nodate: dict[str, SUPPORTED_MODELS] = {
"claude-3-opus": "claude-3-opus-20240229",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-haiku": "claude-3-haiku-20240307",
}
if model_name in replace_map_for_nodate:
cprint("WARNING: model_nameに日付を指定してください", color="yellow", background=True)
print(f"model_name: {model_name} -> {replace_map_for_nodate[model_name]}")
model_name = replace_map_for_nodate[model_name]
# map to LLM
supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = {
"claude-3-opus-20240229": AnthropicClaude3Opus20240229(client_settings),
"claude-3-sonnet-20240229": AnthropicClaude3Sonnet20240229(client_settings),
"claude-3-haiku-20240307": AnthropicClaude3Haiku20240229(client_settings),
}
if model_name in supported_model_map:
model_name = cast(SUPPORTED_MODELS, model_name)
return supported_model_map[model_name]
raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.")
class AnthoropicLLM(AbstractClaude):
@property
def client(self) -> Anthropic:
client = Anthropic(**self.client_settings)
return client
class AnthropicClaude3Opus20240229(AnthoropicLLM):
dollar_per_ktoken = APIPricing(input=15 / 1000, output=75 / 1000)
model: str = "claude-3-opus-20240229"
context_window: int = 200_000
class AnthropicClaude3Sonnet20240229(AnthoropicLLM):
dollar_per_ktoken = APIPricing(input=3 / 1000, output=15 / 1000)
model: str = "claude-3-sonnet-20240229"
context_window: int = 200_000
class AnthropicClaude3Haiku20240229(AnthoropicLLM):
dollar_per_ktoken = APIPricing(input=0.25 / 1000, output=1.25 / 1000)
model: str = "claude-3-haiku-20240307"
context_window: int = 200_000