from __future__ import annotations import time from typing import Any, Generator, Literal, Optional, cast from neollm.myllm.abstract_myllm import AbstractMyLLM from neollm.myllm.myllm import MyLLM from neollm.myllm.print_utils import TITLE_COLOR from neollm.types import ( InputType, OutputType, PriceInfo, StreamOutputType, TimeInfo, TokenInfo, ) from neollm.utils.utils import cprint class MyL3M2(AbstractMyLLM[InputType, OutputType]): """LLMの複数リクエストをまとめるクラス""" do_stream: bool = False # stream_verboseがないため、__call__ではstreamを使わない def __init__( self, parent: Optional["MyL3M2[Any, Any]"] = None, verbose: bool = False, silent_list: list[Literal["inputs", "outputs", "metadata", "all_myllm"]] | None = None, ) -> None: """ MyL3M2の初期化 Args: parent (MyL3M2, optional): 親のMyL3M2のインスタンス(self or None) verbose (bool, optional): 出力をするかどうかのフラグ. Defaults to False. sileznt_list (list[Literal["inputs", "outputs", "metadata", "all_myllm"]], optional): サイレントモードのリスト。出力を抑制する要素を指定する。. Defaults to None(=[]). """ self.parent = parent self.verbose = verbose self.silent_set = set(silent_list or []) self.myllm_list: list["MyL3M2[Any, Any]" | MyLLM[Any, Any]] = [] self.inputs: InputType | None = None self.outputs: OutputType | None = None self.called: bool = False def _link(self, inputs: InputType) -> OutputType: """複数のLLMの処理を行う Args: inputs (InputType): 入力データを保持する辞書 Returns: OutputType: 処理結果の出力データ """ raise NotImplementedError("_link(self, inputs: InputType) -> OutputType:を実装してください") def _stream_link(self, inputs: InputType) -> Generator[StreamOutputType, None, OutputType]: """複数のLLMの処理を行う(stream処理) Args: inputs (InputType): 入力データを保持する辞書 Yields: Generator[StreamOutputType, None, OutputType]: 処理結果の出力データ(stream) Returns: self.outputsに入れたいもの """ raise NotImplementedError( "_stream_link(self, inputs: InputType) -> Generator[StreamOutputType, None, None]を実装してください" ) def _call(self, inputs: InputType, stream: bool = False) -> Generator[StreamOutputType, None, OutputType]: if self.called: raise RuntimeError("MyLLMは1回しか呼び出せない") self._print_start(sep="=") # main ----------------------------------------------------------- t_start = time.time() self.inputs = inputs # [stream] if stream: it = self._stream_link(inputs) while True: try: yield next(it) except StopIteration as e: self.outputs = cast(OutputType, e.value) break except Exception as e: raise e # [non-stream] else: self.outputs = self._link(inputs) self._print_inputs() self._print_outputs() self._print_all_myllm() self.time = time.time() - t_start self.time_detail = TimeInfo(total=self.time, main=self.time) # metadata ----------------------------------------------------------- self._print_metadata() self._print_end(sep="=") # 親MyL3M2にAppend ----------------------------------------------------------- if self.parent is not None: self.parent.myllm_list.append(self) self.called = True return self.outputs @property def token(self) -> TokenInfo: token = TokenInfo(input=0, output=0, total=0) for myllm in self.myllm_list: # TODO: token += myllm.token token.input += myllm.token.input token.output += myllm.token.output token.total += myllm.token.total return token @property def price(self) -> PriceInfo: price = PriceInfo(input=0.0, output=0.0, total=0.0) for myllm in self.myllm_list: # TODO: price += myllm.price price.input += myllm.price.input price.output += myllm.price.output price.total += myllm.price.total return price @property def logs(self) -> list[Any]: logs: list[Any] = [] for myllm in self.myllm_list: if isinstance(myllm, MyLLM): logs.append(myllm.log) elif isinstance(myllm, MyL3M2): logs.extend(myllm.logs) return logs def _print_all_myllm(self, prefix: str = "", title: bool = True) -> None: if not ("all_myllm" not in self.silent_set and self.verbose): return try: if title: cprint("[all_myllm]", color=TITLE_COLOR) print(" ", end="") cprint(f"{self}", color="magenta", bold=True, underline=True) for myllm in self.myllm_list: if isinstance(myllm, MyLLM): cprint(f" {prefix}- {myllm}", color="cyan") elif isinstance(myllm, MyL3M2): cprint(f" {prefix}- {myllm}", color="magenta") myllm._print_all_myllm(prefix=prefix + " ", title=False) except Exception as e: cprint(e, color="red", background=True) def __repr__(self) -> str: return f"MyL3M2({self.__class__.__name__})"