from abc import ABC, abstractmethod from typing import Any from neollm.llm.utils import get_entity from neollm.types import ( APIPricing, ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, ChoiceDeltaToolCall, Chunk, ClientSettings, CompletionUsage, Function, FunctionCall, LLMSettings, Messages, Response, StreamResponse, ) from neollm.utils.utils import cprint # 現状、Azure, OpenAIに対応 class AbstractLLM(ABC): dollar_per_ktoken: APIPricing model: str context_window: int _custom_price_calculation: bool = False # self.tokenではなく、self.custom_tokenを使う場合にTrue def __init__(self, client_settings: ClientSettings): """LLMクラスの初期化 Args: client_settings (ClientSettings): クライアント設定 """ self.client_settings = client_settings def calculate_price(self, num_input_tokens: int = 0, num_output_tokens: int = 0) -> float: """ 費用の計測 Args: num_input_tokens (int, optional): 入力のトークン数. Defaults to 0. num_output_tokens (int, optional): 出力のトークン数. Defaults to 0. Returns: float: API利用料(USD) """ price = ( self.dollar_per_ktoken.input * num_input_tokens + self.dollar_per_ktoken.output * num_output_tokens ) / 1000 return price @abstractmethod def count_tokens(self, messages: Messages | None = None, only_response: bool = False) -> int: ... @abstractmethod def encode(self, text: str) -> list[int]: ... @abstractmethod def decode(self, encoded: list[int]) -> str: ... @abstractmethod def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response: """生成 Args: messages (Messages): OpenAI仕様のMessages(list[dict]) Returns: Response: OpenAI likeなResponse """ @abstractmethod def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse: ... def __repr__(self) -> str: return f"{self.__class__}()" def convert_nonstream_response( self, chunk_list: list[Chunk], messages: Messages, functions: Any = None ) -> Response: # messagesとfunctionsはトークン数計測に必要 _chunk_choices = [chunk.choices[0] for chunk in chunk_list if len(chunk.choices) > 0] # TODO: n=2以上の場合にwarningを出したい # FunctionCall -------------------------------------------------- function_call: FunctionCall | None if all([_c.delta.function_call is None for _c in _chunk_choices]): function_call = None else: function_call = FunctionCall( arguments="".join( [ _c.delta.function_call.arguments for _c in _chunk_choices if _c.delta.function_call is not None and _c.delta.function_call.arguments is not None ] ), name=get_entity( [_c.delta.function_call.name for _c in _chunk_choices if _c.delta.function_call is not None], default="", ), ) # ToolCalls -------------------------------------------------- _tool_calls_dict: dict[int, list[ChoiceDeltaToolCall]] = {} # key=index for _chunk in _chunk_choices: if _chunk.delta.tool_calls is None: continue for _tool_call in _chunk.delta.tool_calls: _tool_calls_dict.setdefault(_tool_call.index, []).append(_tool_call) tool_calls: list[ChatCompletionMessageToolCall] | None if sum(len(_tool_calls) for _tool_calls in _tool_calls_dict.values()) == 0: tool_calls = None else: tool_calls = [] for _tool_calls in _tool_calls_dict.values(): tool_calls.append( ChatCompletionMessageToolCall( id=get_entity([_tc.id for _tc in _tool_calls], default=""), function=Function( arguments="".join( [ _tc.function.arguments for _tc in _tool_calls if _tc.function is not None and _tc.function.arguments is not None ] ), name=get_entity( [_tc.function.name for _tc in _tool_calls if _tc.function is not None], default="" ), ), type=get_entity([_tc.type for _tc in _tool_calls], default="function"), ) ) message = ChatCompletionMessage( content="".join([_c.delta.content for _c in _chunk_choices if _c.delta.content is not None]), # TODO: ChoiceDeltaのroleなんで、assistant以外も許されてるの? role=get_entity([_c.delta.role for _c in _chunk_choices], default="assistant"), # type: ignore function_call=function_call, tool_calls=tool_calls, ) choice = Choice( index=get_entity([_c.index for _c in _chunk_choices], default=0), message=message, finish_reason=get_entity([_c.finish_reason for _c in _chunk_choices], default=None), ) # Usage -------------------------------------------------- try: for chunk in chunk_list: if getattr(chunk, "tokens"): prompt_tokens = int(getattr(chunk, "tokens")["input_tokens"]) completion_tokens = int(getattr(chunk, "tokens")["output_tokens"]) assert prompt_tokens assert completion_tokens except Exception: prompt_tokens = self.count_tokens(messages) # TODO: fcなど completion_tokens = self.count_tokens([message.to_typeddict_message()], only_response=True) usages = CompletionUsage( completion_tokens=completion_tokens, prompt_tokens=prompt_tokens, total_tokens=prompt_tokens + completion_tokens, ) # ChatCompletion ------------------------------------------ response = ChatCompletion( id=get_entity([chunk.id for chunk in chunk_list], default=""), object="chat.completion", created=get_entity([getattr(chunk, "created", 0) for chunk in chunk_list], default=0), model=get_entity([getattr(chunk, "model", "") for chunk in chunk_list], default=""), choices=[choice], system_fingerprint=get_entity( [getattr(chunk, "system_fingerprint", None) for chunk in chunk_list], default=None ), usage=usages, ) return response @property def max_tokens(self) -> int: cprint("max_tokensは非推奨です。context_windowを使用してください。") return self.context_window