Spaces:
Configuration error
Configuration error
from typing import Literal, cast, get_args | |
from anthropic import AnthropicVertex | |
from neollm.llm.abstract_llm import AbstractLLM | |
from neollm.llm.claude.abstract_claude import AbstractClaude | |
from neollm.types import APIPricing, ClientSettings | |
from neollm.utils.utils import cprint | |
# price: https://www.anthropic.com/api | |
# models: https://docs.anthropic.com/claude/docs/models-overview | |
SUPPORTED_MODELS = Literal[ | |
"claude-3-opus@20240229", | |
"claude-3-sonnet@20240229", | |
"claude-3-haiku@20240307", | |
] | |
# TODO! google 動かしたいね | |
def get_gcp_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM: | |
# Add 日付 | |
replace_map_for_nodate: dict[str, SUPPORTED_MODELS] = { | |
"claude-3-opus": "claude-3-opus@20240229", | |
"claude-3-sonnet": "claude-3-sonnet@20240229", | |
"claude-3-haiku": "claude-3-haiku@20240307", | |
} | |
if model_name in replace_map_for_nodate: | |
cprint("WARNING: model_nameに日付を指定してください", color="yellow", background=True) | |
print(f"model_name: {model_name} -> {replace_map_for_nodate[model_name]}") | |
model_name = replace_map_for_nodate[model_name] | |
# map to LLM | |
supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = { | |
"claude-3-opus@20240229": GCPClaude3Opus20240229(client_settings), | |
"claude-3-sonnet@20240229": GCPClaude3Sonnet20240229(client_settings), | |
"claude-3-haiku@20240307": GCPClaude3Haiku20240229(client_settings), | |
} | |
if model_name in supported_model_map: | |
model_name = cast(SUPPORTED_MODELS, model_name) | |
return supported_model_map[model_name] | |
raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.") | |
class GoogleLLM(AbstractClaude): | |
def client(self) -> AnthropicVertex: | |
client = AnthropicVertex(**self.client_settings) | |
return client | |
class GCPClaude3Opus20240229(GoogleLLM): | |
dollar_per_ktoken = APIPricing(input=15 / 1000, output=75 / 1000) | |
model: str = "claude-3-opus@20240229" | |
context_window: int = 200_000 | |
class GCPClaude3Sonnet20240229(GoogleLLM): | |
dollar_per_ktoken = APIPricing(input=3 / 1000, output=15 / 1000) | |
model: str = "claude-3-sonnet@20240229" | |
context_window: int = 200_000 | |
class GCPClaude3Haiku20240229(GoogleLLM): | |
dollar_per_ktoken = APIPricing(input=0.25 / 1000, output=1.25 / 1000) | |
model: str = "claude-3-haiku@20240307" | |
context_window: int = 200_000 | |