from typing import Literal, cast, get_args from anthropic import AnthropicVertex from neollm.llm.abstract_llm import AbstractLLM from neollm.llm.claude.abstract_claude import AbstractClaude from neollm.types import APIPricing, ClientSettings from neollm.utils.utils import cprint # price: https://www.anthropic.com/api # models: https://docs.anthropic.com/claude/docs/models-overview SUPPORTED_MODELS = Literal[ "claude-3-opus@20240229", "claude-3-sonnet@20240229", "claude-3-haiku@20240307", ] # TODO! google 動かしたいね def get_gcp_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM: # Add 日付 replace_map_for_nodate: dict[str, SUPPORTED_MODELS] = { "claude-3-opus": "claude-3-opus@20240229", "claude-3-sonnet": "claude-3-sonnet@20240229", "claude-3-haiku": "claude-3-haiku@20240307", } if model_name in replace_map_for_nodate: cprint("WARNING: model_nameに日付を指定してください", color="yellow", background=True) print(f"model_name: {model_name} -> {replace_map_for_nodate[model_name]}") model_name = replace_map_for_nodate[model_name] # map to LLM supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = { "claude-3-opus@20240229": GCPClaude3Opus20240229(client_settings), "claude-3-sonnet@20240229": GCPClaude3Sonnet20240229(client_settings), "claude-3-haiku@20240307": GCPClaude3Haiku20240229(client_settings), } if model_name in supported_model_map: model_name = cast(SUPPORTED_MODELS, model_name) return supported_model_map[model_name] raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.") class GoogleLLM(AbstractClaude): @property def client(self) -> AnthropicVertex: client = AnthropicVertex(**self.client_settings) return client class GCPClaude3Opus20240229(GoogleLLM): dollar_per_ktoken = APIPricing(input=15 / 1000, output=75 / 1000) model: str = "claude-3-opus@20240229" context_window: int = 200_000 class GCPClaude3Sonnet20240229(GoogleLLM): dollar_per_ktoken = APIPricing(input=3 / 1000, output=15 / 1000) model: str = "claude-3-sonnet@20240229" context_window: int = 200_000 class GCPClaude3Haiku20240229(GoogleLLM): dollar_per_ktoken = APIPricing(input=0.25 / 1000, output=1.25 / 1000) model: str = "claude-3-haiku@20240307" context_window: int = 200_000