Spaces:
Configuration error
Configuration error
File size: 7,319 Bytes
88435ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from abc import ABC, abstractmethod
from typing import Any
from neollm.llm.utils import get_entity
from neollm.types import (
APIPricing,
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
Choice,
ChoiceDeltaToolCall,
Chunk,
ClientSettings,
CompletionUsage,
Function,
FunctionCall,
LLMSettings,
Messages,
Response,
StreamResponse,
)
from neollm.utils.utils import cprint
# 現状、Azure, OpenAIに対応
class AbstractLLM(ABC):
dollar_per_ktoken: APIPricing
model: str
context_window: int
_custom_price_calculation: bool = False # self.tokenではなく、self.custom_tokenを使う場合にTrue
def __init__(self, client_settings: ClientSettings):
"""LLMクラスの初期化
Args:
client_settings (ClientSettings): クライアント設定
"""
self.client_settings = client_settings
def calculate_price(self, num_input_tokens: int = 0, num_output_tokens: int = 0) -> float:
"""
費用の計測
Args:
num_input_tokens (int, optional): 入力のトークン数. Defaults to 0.
num_output_tokens (int, optional): 出力のトークン数. Defaults to 0.
Returns:
float: API利用料(USD)
"""
price = (
self.dollar_per_ktoken.input * num_input_tokens + self.dollar_per_ktoken.output * num_output_tokens
) / 1000
return price
@abstractmethod
def count_tokens(self, messages: Messages | None = None, only_response: bool = False) -> int: ...
@abstractmethod
def encode(self, text: str) -> list[int]: ...
@abstractmethod
def decode(self, encoded: list[int]) -> str: ...
@abstractmethod
def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response:
"""生成
Args:
messages (Messages): OpenAI仕様のMessages(list[dict])
Returns:
Response: OpenAI likeなResponse
"""
@abstractmethod
def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse: ...
def __repr__(self) -> str:
return f"{self.__class__}()"
def convert_nonstream_response(
self, chunk_list: list[Chunk], messages: Messages, functions: Any = None
) -> Response:
# messagesとfunctionsはトークン数計測に必要
_chunk_choices = [chunk.choices[0] for chunk in chunk_list if len(chunk.choices) > 0]
# TODO: n=2以上の場合にwarningを出したい
# FunctionCall --------------------------------------------------
function_call: FunctionCall | None
if all([_c.delta.function_call is None for _c in _chunk_choices]):
function_call = None
else:
function_call = FunctionCall(
arguments="".join(
[
_c.delta.function_call.arguments
for _c in _chunk_choices
if _c.delta.function_call is not None and _c.delta.function_call.arguments is not None
]
),
name=get_entity(
[_c.delta.function_call.name for _c in _chunk_choices if _c.delta.function_call is not None],
default="",
),
)
# ToolCalls --------------------------------------------------
_tool_calls_dict: dict[int, list[ChoiceDeltaToolCall]] = {} # key=index
for _chunk in _chunk_choices:
if _chunk.delta.tool_calls is None:
continue
for _tool_call in _chunk.delta.tool_calls:
_tool_calls_dict.setdefault(_tool_call.index, []).append(_tool_call)
tool_calls: list[ChatCompletionMessageToolCall] | None
if sum(len(_tool_calls) for _tool_calls in _tool_calls_dict.values()) == 0:
tool_calls = None
else:
tool_calls = []
for _tool_calls in _tool_calls_dict.values():
tool_calls.append(
ChatCompletionMessageToolCall(
id=get_entity([_tc.id for _tc in _tool_calls], default=""),
function=Function(
arguments="".join(
[
_tc.function.arguments
for _tc in _tool_calls
if _tc.function is not None and _tc.function.arguments is not None
]
),
name=get_entity(
[_tc.function.name for _tc in _tool_calls if _tc.function is not None], default=""
),
),
type=get_entity([_tc.type for _tc in _tool_calls], default="function"),
)
)
message = ChatCompletionMessage(
content="".join([_c.delta.content for _c in _chunk_choices if _c.delta.content is not None]),
# TODO: ChoiceDeltaのroleなんで、assistant以外も許されてるの?
role=get_entity([_c.delta.role for _c in _chunk_choices], default="assistant"), # type: ignore
function_call=function_call,
tool_calls=tool_calls,
)
choice = Choice(
index=get_entity([_c.index for _c in _chunk_choices], default=0),
message=message,
finish_reason=get_entity([_c.finish_reason for _c in _chunk_choices], default=None),
)
# Usage --------------------------------------------------
try:
for chunk in chunk_list:
if getattr(chunk, "tokens"):
prompt_tokens = int(getattr(chunk, "tokens")["input_tokens"])
completion_tokens = int(getattr(chunk, "tokens")["output_tokens"])
assert prompt_tokens
assert completion_tokens
except Exception:
prompt_tokens = self.count_tokens(messages) # TODO: fcなど
completion_tokens = self.count_tokens([message.to_typeddict_message()], only_response=True)
usages = CompletionUsage(
completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
# ChatCompletion ------------------------------------------
response = ChatCompletion(
id=get_entity([chunk.id for chunk in chunk_list], default=""),
object="chat.completion",
created=get_entity([getattr(chunk, "created", 0) for chunk in chunk_list], default=0),
model=get_entity([getattr(chunk, "model", "") for chunk in chunk_list], default=""),
choices=[choice],
system_fingerprint=get_entity(
[getattr(chunk, "system_fingerprint", None) for chunk in chunk_list], default=None
),
usage=usages,
)
return response
@property
def max_tokens(self) -> int:
cprint("max_tokensは非推奨です。context_windowを使用してください。")
return self.context_window
|