Spaces:
Configuration error
Configuration error
from __future__ import annotations | |
import time | |
from typing import Any, Generator, Literal, Optional, cast | |
from neollm.myllm.abstract_myllm import AbstractMyLLM | |
from neollm.myllm.myllm import MyLLM | |
from neollm.myllm.print_utils import TITLE_COLOR | |
from neollm.types import ( | |
InputType, | |
OutputType, | |
PriceInfo, | |
StreamOutputType, | |
TimeInfo, | |
TokenInfo, | |
) | |
from neollm.utils.utils import cprint | |
class MyL3M2(AbstractMyLLM[InputType, OutputType]): | |
"""LLMの複数リクエストをまとめるクラス""" | |
do_stream: bool = False # stream_verboseがないため、__call__ではstreamを使わない | |
def __init__( | |
self, | |
parent: Optional["MyL3M2[Any, Any]"] = None, | |
verbose: bool = False, | |
silent_list: list[Literal["inputs", "outputs", "metadata", "all_myllm"]] | None = None, | |
) -> None: | |
""" | |
MyL3M2の初期化 | |
Args: | |
parent (MyL3M2, optional): | |
親のMyL3M2のインスタンス(self or None) | |
verbose (bool, optional): | |
出力をするかどうかのフラグ. Defaults to False. | |
sileznt_list (list[Literal["inputs", "outputs", "metadata", "all_myllm"]], optional): | |
サイレントモードのリスト。出力を抑制する要素を指定する。. Defaults to None(=[]). | |
""" | |
self.parent = parent | |
self.verbose = verbose | |
self.silent_set = set(silent_list or []) | |
self.myllm_list: list["MyL3M2[Any, Any]" | MyLLM[Any, Any]] = [] | |
self.inputs: InputType | None = None | |
self.outputs: OutputType | None = None | |
self.called: bool = False | |
def _link(self, inputs: InputType) -> OutputType: | |
"""複数のLLMの処理を行う | |
Args: | |
inputs (InputType): 入力データを保持する辞書 | |
Returns: | |
OutputType: 処理結果の出力データ | |
""" | |
raise NotImplementedError("_link(self, inputs: InputType) -> OutputType:を実装してください") | |
def _stream_link(self, inputs: InputType) -> Generator[StreamOutputType, None, OutputType]: | |
"""複数のLLMの処理を行う(stream処理) | |
Args: | |
inputs (InputType): 入力データを保持する辞書 | |
Yields: | |
Generator[StreamOutputType, None, OutputType]: 処理結果の出力データ(stream) | |
Returns: | |
self.outputsに入れたいもの | |
""" | |
raise NotImplementedError( | |
"_stream_link(self, inputs: InputType) -> Generator[StreamOutputType, None, None]を実装してください" | |
) | |
def _call(self, inputs: InputType, stream: bool = False) -> Generator[StreamOutputType, None, OutputType]: | |
if self.called: | |
raise RuntimeError("MyLLMは1回しか呼び出せない") | |
self._print_start(sep="=") | |
# main ----------------------------------------------------------- | |
t_start = time.time() | |
self.inputs = inputs | |
# [stream] | |
if stream: | |
it = self._stream_link(inputs) | |
while True: | |
try: | |
yield next(it) | |
except StopIteration as e: | |
self.outputs = cast(OutputType, e.value) | |
break | |
except Exception as e: | |
raise e | |
# [non-stream] | |
else: | |
self.outputs = self._link(inputs) | |
self._print_inputs() | |
self._print_outputs() | |
self._print_all_myllm() | |
self.time = time.time() - t_start | |
self.time_detail = TimeInfo(total=self.time, main=self.time) | |
# metadata ----------------------------------------------------------- | |
self._print_metadata() | |
self._print_end(sep="=") | |
# 親MyL3M2にAppend ----------------------------------------------------------- | |
if self.parent is not None: | |
self.parent.myllm_list.append(self) | |
self.called = True | |
return self.outputs | |
def token(self) -> TokenInfo: | |
token = TokenInfo(input=0, output=0, total=0) | |
for myllm in self.myllm_list: | |
# TODO: token += myllm.token | |
token.input += myllm.token.input | |
token.output += myllm.token.output | |
token.total += myllm.token.total | |
return token | |
def price(self) -> PriceInfo: | |
price = PriceInfo(input=0.0, output=0.0, total=0.0) | |
for myllm in self.myllm_list: | |
# TODO: price += myllm.price | |
price.input += myllm.price.input | |
price.output += myllm.price.output | |
price.total += myllm.price.total | |
return price | |
def logs(self) -> list[Any]: | |
logs: list[Any] = [] | |
for myllm in self.myllm_list: | |
if isinstance(myllm, MyLLM): | |
logs.append(myllm.log) | |
elif isinstance(myllm, MyL3M2): | |
logs.extend(myllm.logs) | |
return logs | |
def _print_all_myllm(self, prefix: str = "", title: bool = True) -> None: | |
if not ("all_myllm" not in self.silent_set and self.verbose): | |
return | |
try: | |
if title: | |
cprint("[all_myllm]", color=TITLE_COLOR) | |
print(" ", end="") | |
cprint(f"{self}", color="magenta", bold=True, underline=True) | |
for myllm in self.myllm_list: | |
if isinstance(myllm, MyLLM): | |
cprint(f" {prefix}- {myllm}", color="cyan") | |
elif isinstance(myllm, MyL3M2): | |
cprint(f" {prefix}- {myllm}", color="magenta") | |
myllm._print_all_myllm(prefix=prefix + " ", title=False) | |
except Exception as e: | |
cprint(e, color="red", background=True) | |
def __repr__(self) -> str: | |
return f"MyL3M2({self.__class__.__name__})" | |