Spaces:
Configuration error
Configuration error
import time | |
from abc import abstractmethod | |
from typing import Any, Literal, cast | |
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, Stream | |
from anthropic.types import MessageParam as AnthropicMessageParam | |
from anthropic.types import MessageStreamEvent as AnthropicMessageStreamEvent | |
from anthropic.types.message import Message as AnthropicMessage | |
from neollm.llm.abstract_llm import AbstractLLM | |
from neollm.types import ( | |
ChatCompletion, | |
LLMSettings, | |
Message, | |
Messages, | |
Response, | |
StreamResponse, | |
) | |
from neollm.types.openai.chat_completion import ( | |
ChatCompletionMessage, | |
Choice, | |
CompletionUsage, | |
FinishReason, | |
) | |
from neollm.types.openai.chat_completion_chunk import ( | |
ChatCompletionChunk, | |
ChoiceDelta, | |
ChunkChoice, | |
) | |
from neollm.utils.utils import cprint | |
DEFAULT_MAX_TOKENS = 4_096 | |
class AbstractClaude(AbstractLLM): | |
def client(self) -> Anthropic | AnthropicVertex | AnthropicBedrock: ... | |
def _client_for_token(self) -> Anthropic: | |
"""トークンカウント用のAnthropicクライアント取得 | |
(AnthropicBedrock, AnthropicVertexがmethodを持っていないため) | |
Returns: | |
Anthropic: Anthropicクライアント | |
""" | |
return Anthropic() | |
def encode(self, text: str) -> list[int]: | |
tokenizer = self._client_for_token.get_tokenizer() | |
encoded = cast(list[int], tokenizer.encode(text).ids) | |
return encoded | |
def decode(self, decoded: list[int]) -> str: | |
tokenizer = self._client_for_token.get_tokenizer() | |
text = cast(str, tokenizer.decode(decoded)) | |
return text | |
def count_tokens(self, messages: list[Message] | None = None, only_response: bool = False) -> int: | |
""" | |
トークン数の計測 | |
Args: | |
messages (Messages): messages | |
Returns: | |
int: トークン数 | |
""" | |
if messages is None: | |
return 0 | |
tokens = 0 | |
for message in messages: | |
content = message["content"] | |
if isinstance(content, str): | |
tokens += self._client_for_token.count_tokens(content) | |
continue | |
if isinstance(content, list): | |
for content_i in content: | |
if content_i["type"] == "text": | |
tokens += self._client_for_token.count_tokens(content_i["text"]) | |
continue | |
return tokens | |
def _convert_finish_reason( | |
self, stop_reason: Literal["end_turn", "max_tokens", "stop_sequence"] | None | |
) -> FinishReason | None: | |
if stop_reason == "max_tokens": | |
return "length" | |
if stop_reason == "stop_sequence": | |
return "stop" | |
return None | |
def _convert_to_response(self, platform_response: AnthropicMessage) -> Response: | |
return ChatCompletion( | |
id=platform_response.id, | |
choices=[ | |
Choice( | |
index=0, | |
message=ChatCompletionMessage( | |
content=platform_response.content[0].text if len(platform_response.content) > 0 else "", | |
role="assistant", | |
), | |
finish_reason=self._convert_finish_reason(platform_response.stop_reason), | |
) | |
], | |
created=int(time.time()), | |
model=self.model, | |
object="messages.create", | |
system_fingerprint=None, | |
usage=CompletionUsage( | |
prompt_tokens=platform_response.usage.input_tokens, | |
completion_tokens=platform_response.usage.output_tokens, | |
total_tokens=platform_response.usage.input_tokens + platform_response.usage.output_tokens, | |
), | |
) | |
def _convert_to_platform_messages(self, messages: Messages) -> tuple[str, list[AnthropicMessageParam]]: | |
_system = "" | |
_message: list[AnthropicMessageParam] = [] | |
for message in messages: | |
if message["role"] == "system": | |
_system += "\n" + message["content"] | |
elif message["role"] == "user": | |
if isinstance(message["content"], str): | |
_message.append({"role": "user", "content": message["content"]}) | |
else: | |
cprint("WARNING: 未対応です", color="yellow", background=True) | |
elif message["role"] == "assistant": | |
if isinstance(message["content"], str): | |
_message.append({"role": "assistant", "content": message["content"]}) | |
else: | |
cprint("WARNING: 未対応です", color="yellow", background=True) | |
else: | |
cprint("WARNING: 未対応です", color="yellow", background=True) | |
return _system, _message | |
def _convert_to_streamresponse( | |
self, platform_streamresponse: Stream[AnthropicMessageStreamEvent] | |
) -> StreamResponse: | |
created = int(time.time()) | |
model = "" | |
id_ = "" | |
content: str | None = None | |
for chunk in platform_streamresponse: | |
input_tokens = 0 | |
output_tokens = 0 | |
if chunk.type == "message_stop" or chunk.type == "content_block_stop": | |
continue | |
if chunk.type == "message_start": | |
model = model or chunk.message.model | |
id_ = id_ or chunk.message.id | |
input_tokens = chunk.message.usage.input_tokens | |
output_tokens = chunk.message.usage.output_tokens | |
content = "".join([content_block.text for content_block in chunk.message.content]) | |
finish_reason = self._convert_finish_reason(chunk.message.stop_reason) | |
elif chunk.type == "message_delta": | |
content = "" | |
finish_reason = self._convert_finish_reason(chunk.delta.stop_reason) | |
output_tokens = chunk.usage.output_tokens | |
elif chunk.type == "content_block_start": | |
content = chunk.content_block.text | |
finish_reason = None | |
elif chunk.type == "content_block_delta": | |
content = chunk.delta.text | |
finish_reason = None | |
yield ChatCompletionChunk( | |
id=id_, | |
choices=[ | |
ChunkChoice( | |
delta=ChoiceDelta( | |
content=content, | |
role="assistant", | |
), | |
finish_reason=finish_reason, | |
index=0, # 0-indexedじゃないかもしれないので0に塗り替え | |
) | |
], | |
created=created, | |
model=model, | |
object="chat.completion.chunk", | |
tokens={"input_tokens": input_tokens, "output_tokens": output_tokens}, # type: ignore | |
) | |
def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response: | |
_system, _message = self._convert_to_platform_messages(messages) | |
llm_settings = self._set_max_tokens(llm_settings) | |
response = self.client.messages.create( | |
model=self.model, | |
system=_system, | |
messages=_message, | |
stream=False, | |
**llm_settings, | |
) | |
return self._convert_to_response(platform_response=response) | |
def generate_stream(self, messages: Any, llm_settings: LLMSettings) -> StreamResponse: | |
_system, _message = self._convert_to_platform_messages(messages) | |
llm_settings = self._set_max_tokens(llm_settings) | |
response = self.client.messages.create( | |
model=self.model, | |
system=_system, | |
messages=_message, | |
stream=True, | |
**llm_settings, | |
) | |
return self._convert_to_streamresponse(platform_streamresponse=response) | |
def _set_max_tokens(self, llm_settings: LLMSettings) -> LLMSettings: | |
# claudeはmax_tokensが必須 | |
if not llm_settings.get("max_tokens"): | |
cprint(f"max_tokens is not set. Set to {DEFAULT_MAX_TOKENS}.", color="yellow") | |
llm_settings["max_tokens"] = DEFAULT_MAX_TOKENS | |
return llm_settings | |