Kpenciler's picture
Upload 53 files
88435ed verified
raw
history blame
7.32 kB
from abc import ABC, abstractmethod
from typing import Any
from neollm.llm.utils import get_entity
from neollm.types import (
APIPricing,
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
Choice,
ChoiceDeltaToolCall,
Chunk,
ClientSettings,
CompletionUsage,
Function,
FunctionCall,
LLMSettings,
Messages,
Response,
StreamResponse,
)
from neollm.utils.utils import cprint
# 現状、Azure, OpenAIに対応
class AbstractLLM(ABC):
dollar_per_ktoken: APIPricing
model: str
context_window: int
_custom_price_calculation: bool = False # self.tokenではなく、self.custom_tokenを使う場合にTrue
def __init__(self, client_settings: ClientSettings):
"""LLMクラスの初期化
Args:
client_settings (ClientSettings): クライアント設定
"""
self.client_settings = client_settings
def calculate_price(self, num_input_tokens: int = 0, num_output_tokens: int = 0) -> float:
"""
費用の計測
Args:
num_input_tokens (int, optional): 入力のトークン数. Defaults to 0.
num_output_tokens (int, optional): 出力のトークン数. Defaults to 0.
Returns:
float: API利用料(USD)
"""
price = (
self.dollar_per_ktoken.input * num_input_tokens + self.dollar_per_ktoken.output * num_output_tokens
) / 1000
return price
@abstractmethod
def count_tokens(self, messages: Messages | None = None, only_response: bool = False) -> int: ...
@abstractmethod
def encode(self, text: str) -> list[int]: ...
@abstractmethod
def decode(self, encoded: list[int]) -> str: ...
@abstractmethod
def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response:
"""生成
Args:
messages (Messages): OpenAI仕様のMessages(list[dict])
Returns:
Response: OpenAI likeなResponse
"""
@abstractmethod
def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse: ...
def __repr__(self) -> str:
return f"{self.__class__}()"
def convert_nonstream_response(
self, chunk_list: list[Chunk], messages: Messages, functions: Any = None
) -> Response:
# messagesとfunctionsはトークン数計測に必要
_chunk_choices = [chunk.choices[0] for chunk in chunk_list if len(chunk.choices) > 0]
# TODO: n=2以上の場合にwarningを出したい
# FunctionCall --------------------------------------------------
function_call: FunctionCall | None
if all([_c.delta.function_call is None for _c in _chunk_choices]):
function_call = None
else:
function_call = FunctionCall(
arguments="".join(
[
_c.delta.function_call.arguments
for _c in _chunk_choices
if _c.delta.function_call is not None and _c.delta.function_call.arguments is not None
]
),
name=get_entity(
[_c.delta.function_call.name for _c in _chunk_choices if _c.delta.function_call is not None],
default="",
),
)
# ToolCalls --------------------------------------------------
_tool_calls_dict: dict[int, list[ChoiceDeltaToolCall]] = {} # key=index
for _chunk in _chunk_choices:
if _chunk.delta.tool_calls is None:
continue
for _tool_call in _chunk.delta.tool_calls:
_tool_calls_dict.setdefault(_tool_call.index, []).append(_tool_call)
tool_calls: list[ChatCompletionMessageToolCall] | None
if sum(len(_tool_calls) for _tool_calls in _tool_calls_dict.values()) == 0:
tool_calls = None
else:
tool_calls = []
for _tool_calls in _tool_calls_dict.values():
tool_calls.append(
ChatCompletionMessageToolCall(
id=get_entity([_tc.id for _tc in _tool_calls], default=""),
function=Function(
arguments="".join(
[
_tc.function.arguments
for _tc in _tool_calls
if _tc.function is not None and _tc.function.arguments is not None
]
),
name=get_entity(
[_tc.function.name for _tc in _tool_calls if _tc.function is not None], default=""
),
),
type=get_entity([_tc.type for _tc in _tool_calls], default="function"),
)
)
message = ChatCompletionMessage(
content="".join([_c.delta.content for _c in _chunk_choices if _c.delta.content is not None]),
# TODO: ChoiceDeltaのroleなんで、assistant以外も許されてるの?
role=get_entity([_c.delta.role for _c in _chunk_choices], default="assistant"), # type: ignore
function_call=function_call,
tool_calls=tool_calls,
)
choice = Choice(
index=get_entity([_c.index for _c in _chunk_choices], default=0),
message=message,
finish_reason=get_entity([_c.finish_reason for _c in _chunk_choices], default=None),
)
# Usage --------------------------------------------------
try:
for chunk in chunk_list:
if getattr(chunk, "tokens"):
prompt_tokens = int(getattr(chunk, "tokens")["input_tokens"])
completion_tokens = int(getattr(chunk, "tokens")["output_tokens"])
assert prompt_tokens
assert completion_tokens
except Exception:
prompt_tokens = self.count_tokens(messages) # TODO: fcなど
completion_tokens = self.count_tokens([message.to_typeddict_message()], only_response=True)
usages = CompletionUsage(
completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
# ChatCompletion ------------------------------------------
response = ChatCompletion(
id=get_entity([chunk.id for chunk in chunk_list], default=""),
object="chat.completion",
created=get_entity([getattr(chunk, "created", 0) for chunk in chunk_list], default=0),
model=get_entity([getattr(chunk, "model", "") for chunk in chunk_list], default=""),
choices=[choice],
system_fingerprint=get_entity(
[getattr(chunk, "system_fingerprint", None) for chunk in chunk_list], default=None
),
usage=usages,
)
return response
@property
def max_tokens(self) -> int:
cprint("max_tokensは非推奨です。context_windowを使用してください。")
return self.context_window