Spaces:
Configuration error
Configuration error
from abc import ABC, abstractmethod | |
from typing import Any | |
from neollm.llm.utils import get_entity | |
from neollm.types import ( | |
APIPricing, | |
ChatCompletion, | |
ChatCompletionMessage, | |
ChatCompletionMessageToolCall, | |
Choice, | |
ChoiceDeltaToolCall, | |
Chunk, | |
ClientSettings, | |
CompletionUsage, | |
Function, | |
FunctionCall, | |
LLMSettings, | |
Messages, | |
Response, | |
StreamResponse, | |
) | |
from neollm.utils.utils import cprint | |
# 現状、Azure, OpenAIに対応 | |
class AbstractLLM(ABC): | |
dollar_per_ktoken: APIPricing | |
model: str | |
context_window: int | |
_custom_price_calculation: bool = False # self.tokenではなく、self.custom_tokenを使う場合にTrue | |
def __init__(self, client_settings: ClientSettings): | |
"""LLMクラスの初期化 | |
Args: | |
client_settings (ClientSettings): クライアント設定 | |
""" | |
self.client_settings = client_settings | |
def calculate_price(self, num_input_tokens: int = 0, num_output_tokens: int = 0) -> float: | |
""" | |
費用の計測 | |
Args: | |
num_input_tokens (int, optional): 入力のトークン数. Defaults to 0. | |
num_output_tokens (int, optional): 出力のトークン数. Defaults to 0. | |
Returns: | |
float: API利用料(USD) | |
""" | |
price = ( | |
self.dollar_per_ktoken.input * num_input_tokens + self.dollar_per_ktoken.output * num_output_tokens | |
) / 1000 | |
return price | |
def count_tokens(self, messages: Messages | None = None, only_response: bool = False) -> int: ... | |
def encode(self, text: str) -> list[int]: ... | |
def decode(self, encoded: list[int]) -> str: ... | |
def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response: | |
"""生成 | |
Args: | |
messages (Messages): OpenAI仕様のMessages(list[dict]) | |
Returns: | |
Response: OpenAI likeなResponse | |
""" | |
def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse: ... | |
def __repr__(self) -> str: | |
return f"{self.__class__}()" | |
def convert_nonstream_response( | |
self, chunk_list: list[Chunk], messages: Messages, functions: Any = None | |
) -> Response: | |
# messagesとfunctionsはトークン数計測に必要 | |
_chunk_choices = [chunk.choices[0] for chunk in chunk_list if len(chunk.choices) > 0] | |
# TODO: n=2以上の場合にwarningを出したい | |
# FunctionCall -------------------------------------------------- | |
function_call: FunctionCall | None | |
if all([_c.delta.function_call is None for _c in _chunk_choices]): | |
function_call = None | |
else: | |
function_call = FunctionCall( | |
arguments="".join( | |
[ | |
_c.delta.function_call.arguments | |
for _c in _chunk_choices | |
if _c.delta.function_call is not None and _c.delta.function_call.arguments is not None | |
] | |
), | |
name=get_entity( | |
[_c.delta.function_call.name for _c in _chunk_choices if _c.delta.function_call is not None], | |
default="", | |
), | |
) | |
# ToolCalls -------------------------------------------------- | |
_tool_calls_dict: dict[int, list[ChoiceDeltaToolCall]] = {} # key=index | |
for _chunk in _chunk_choices: | |
if _chunk.delta.tool_calls is None: | |
continue | |
for _tool_call in _chunk.delta.tool_calls: | |
_tool_calls_dict.setdefault(_tool_call.index, []).append(_tool_call) | |
tool_calls: list[ChatCompletionMessageToolCall] | None | |
if sum(len(_tool_calls) for _tool_calls in _tool_calls_dict.values()) == 0: | |
tool_calls = None | |
else: | |
tool_calls = [] | |
for _tool_calls in _tool_calls_dict.values(): | |
tool_calls.append( | |
ChatCompletionMessageToolCall( | |
id=get_entity([_tc.id for _tc in _tool_calls], default=""), | |
function=Function( | |
arguments="".join( | |
[ | |
_tc.function.arguments | |
for _tc in _tool_calls | |
if _tc.function is not None and _tc.function.arguments is not None | |
] | |
), | |
name=get_entity( | |
[_tc.function.name for _tc in _tool_calls if _tc.function is not None], default="" | |
), | |
), | |
type=get_entity([_tc.type for _tc in _tool_calls], default="function"), | |
) | |
) | |
message = ChatCompletionMessage( | |
content="".join([_c.delta.content for _c in _chunk_choices if _c.delta.content is not None]), | |
# TODO: ChoiceDeltaのroleなんで、assistant以外も許されてるの? | |
role=get_entity([_c.delta.role for _c in _chunk_choices], default="assistant"), # type: ignore | |
function_call=function_call, | |
tool_calls=tool_calls, | |
) | |
choice = Choice( | |
index=get_entity([_c.index for _c in _chunk_choices], default=0), | |
message=message, | |
finish_reason=get_entity([_c.finish_reason for _c in _chunk_choices], default=None), | |
) | |
# Usage -------------------------------------------------- | |
try: | |
for chunk in chunk_list: | |
if getattr(chunk, "tokens"): | |
prompt_tokens = int(getattr(chunk, "tokens")["input_tokens"]) | |
completion_tokens = int(getattr(chunk, "tokens")["output_tokens"]) | |
assert prompt_tokens | |
assert completion_tokens | |
except Exception: | |
prompt_tokens = self.count_tokens(messages) # TODO: fcなど | |
completion_tokens = self.count_tokens([message.to_typeddict_message()], only_response=True) | |
usages = CompletionUsage( | |
completion_tokens=completion_tokens, | |
prompt_tokens=prompt_tokens, | |
total_tokens=prompt_tokens + completion_tokens, | |
) | |
# ChatCompletion ------------------------------------------ | |
response = ChatCompletion( | |
id=get_entity([chunk.id for chunk in chunk_list], default=""), | |
object="chat.completion", | |
created=get_entity([getattr(chunk, "created", 0) for chunk in chunk_list], default=0), | |
model=get_entity([getattr(chunk, "model", "") for chunk in chunk_list], default=""), | |
choices=[choice], | |
system_fingerprint=get_entity( | |
[getattr(chunk, "system_fingerprint", None) for chunk in chunk_list], default=None | |
), | |
usage=usages, | |
) | |
return response | |
def max_tokens(self) -> int: | |
cprint("max_tokensは非推奨です。context_windowを使用してください。") | |
return self.context_window | |