Kpenciler's picture
Upload 53 files
88435ed verified
raw
history blame
5.91 kB
from __future__ import annotations
import time
from typing import Any, Generator, Literal, Optional, cast
from neollm.myllm.abstract_myllm import AbstractMyLLM
from neollm.myllm.myllm import MyLLM
from neollm.myllm.print_utils import TITLE_COLOR
from neollm.types import (
InputType,
OutputType,
PriceInfo,
StreamOutputType,
TimeInfo,
TokenInfo,
)
from neollm.utils.utils import cprint
class MyL3M2(AbstractMyLLM[InputType, OutputType]):
"""LLMの複数リクエストをまとめるクラス"""
do_stream: bool = False # stream_verboseがないため、__call__ではstreamを使わない
def __init__(
self,
parent: Optional["MyL3M2[Any, Any]"] = None,
verbose: bool = False,
silent_list: list[Literal["inputs", "outputs", "metadata", "all_myllm"]] | None = None,
) -> None:
"""
MyL3M2の初期化
Args:
parent (MyL3M2, optional):
親のMyL3M2のインスタンス(self or None)
verbose (bool, optional):
出力をするかどうかのフラグ. Defaults to False.
sileznt_list (list[Literal["inputs", "outputs", "metadata", "all_myllm"]], optional):
サイレントモードのリスト。出力を抑制する要素を指定する。. Defaults to None(=[]).
"""
self.parent = parent
self.verbose = verbose
self.silent_set = set(silent_list or [])
self.myllm_list: list["MyL3M2[Any, Any]" | MyLLM[Any, Any]] = []
self.inputs: InputType | None = None
self.outputs: OutputType | None = None
self.called: bool = False
def _link(self, inputs: InputType) -> OutputType:
"""複数のLLMの処理を行う
Args:
inputs (InputType): 入力データを保持する辞書
Returns:
OutputType: 処理結果の出力データ
"""
raise NotImplementedError("_link(self, inputs: InputType) -> OutputType:を実装してください")
def _stream_link(self, inputs: InputType) -> Generator[StreamOutputType, None, OutputType]:
"""複数のLLMの処理を行う(stream処理)
Args:
inputs (InputType): 入力データを保持する辞書
Yields:
Generator[StreamOutputType, None, OutputType]: 処理結果の出力データ(stream)
Returns:
self.outputsに入れたいもの
"""
raise NotImplementedError(
"_stream_link(self, inputs: InputType) -> Generator[StreamOutputType, None, None]を実装してください"
)
def _call(self, inputs: InputType, stream: bool = False) -> Generator[StreamOutputType, None, OutputType]:
if self.called:
raise RuntimeError("MyLLMは1回しか呼び出せない")
self._print_start(sep="=")
# main -----------------------------------------------------------
t_start = time.time()
self.inputs = inputs
# [stream]
if stream:
it = self._stream_link(inputs)
while True:
try:
yield next(it)
except StopIteration as e:
self.outputs = cast(OutputType, e.value)
break
except Exception as e:
raise e
# [non-stream]
else:
self.outputs = self._link(inputs)
self._print_inputs()
self._print_outputs()
self._print_all_myllm()
self.time = time.time() - t_start
self.time_detail = TimeInfo(total=self.time, main=self.time)
# metadata -----------------------------------------------------------
self._print_metadata()
self._print_end(sep="=")
# 親MyL3M2にAppend -----------------------------------------------------------
if self.parent is not None:
self.parent.myllm_list.append(self)
self.called = True
return self.outputs
@property
def token(self) -> TokenInfo:
token = TokenInfo(input=0, output=0, total=0)
for myllm in self.myllm_list:
# TODO: token += myllm.token
token.input += myllm.token.input
token.output += myllm.token.output
token.total += myllm.token.total
return token
@property
def price(self) -> PriceInfo:
price = PriceInfo(input=0.0, output=0.0, total=0.0)
for myllm in self.myllm_list:
# TODO: price += myllm.price
price.input += myllm.price.input
price.output += myllm.price.output
price.total += myllm.price.total
return price
@property
def logs(self) -> list[Any]:
logs: list[Any] = []
for myllm in self.myllm_list:
if isinstance(myllm, MyLLM):
logs.append(myllm.log)
elif isinstance(myllm, MyL3M2):
logs.extend(myllm.logs)
return logs
def _print_all_myllm(self, prefix: str = "", title: bool = True) -> None:
if not ("all_myllm" not in self.silent_set and self.verbose):
return
try:
if title:
cprint("[all_myllm]", color=TITLE_COLOR)
print(" ", end="")
cprint(f"{self}", color="magenta", bold=True, underline=True)
for myllm in self.myllm_list:
if isinstance(myllm, MyLLM):
cprint(f" {prefix}- {myllm}", color="cyan")
elif isinstance(myllm, MyL3M2):
cprint(f" {prefix}- {myllm}", color="magenta")
myllm._print_all_myllm(prefix=prefix + " ", title=False)
except Exception as e:
cprint(e, color="red", background=True)
def __repr__(self) -> str:
return f"MyL3M2({self.__class__.__name__})"