Spaces:
Configuration error
Configuration error
File size: 5,913 Bytes
88435ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from __future__ import annotations
import time
from typing import Any, Generator, Literal, Optional, cast
from neollm.myllm.abstract_myllm import AbstractMyLLM
from neollm.myllm.myllm import MyLLM
from neollm.myllm.print_utils import TITLE_COLOR
from neollm.types import (
InputType,
OutputType,
PriceInfo,
StreamOutputType,
TimeInfo,
TokenInfo,
)
from neollm.utils.utils import cprint
class MyL3M2(AbstractMyLLM[InputType, OutputType]):
"""LLMの複数リクエストをまとめるクラス"""
do_stream: bool = False # stream_verboseがないため、__call__ではstreamを使わない
def __init__(
self,
parent: Optional["MyL3M2[Any, Any]"] = None,
verbose: bool = False,
silent_list: list[Literal["inputs", "outputs", "metadata", "all_myllm"]] | None = None,
) -> None:
"""
MyL3M2の初期化
Args:
parent (MyL3M2, optional):
親のMyL3M2のインスタンス(self or None)
verbose (bool, optional):
出力をするかどうかのフラグ. Defaults to False.
sileznt_list (list[Literal["inputs", "outputs", "metadata", "all_myllm"]], optional):
サイレントモードのリスト。出力を抑制する要素を指定する。. Defaults to None(=[]).
"""
self.parent = parent
self.verbose = verbose
self.silent_set = set(silent_list or [])
self.myllm_list: list["MyL3M2[Any, Any]" | MyLLM[Any, Any]] = []
self.inputs: InputType | None = None
self.outputs: OutputType | None = None
self.called: bool = False
def _link(self, inputs: InputType) -> OutputType:
"""複数のLLMの処理を行う
Args:
inputs (InputType): 入力データを保持する辞書
Returns:
OutputType: 処理結果の出力データ
"""
raise NotImplementedError("_link(self, inputs: InputType) -> OutputType:を実装してください")
def _stream_link(self, inputs: InputType) -> Generator[StreamOutputType, None, OutputType]:
"""複数のLLMの処理を行う(stream処理)
Args:
inputs (InputType): 入力データを保持する辞書
Yields:
Generator[StreamOutputType, None, OutputType]: 処理結果の出力データ(stream)
Returns:
self.outputsに入れたいもの
"""
raise NotImplementedError(
"_stream_link(self, inputs: InputType) -> Generator[StreamOutputType, None, None]を実装してください"
)
def _call(self, inputs: InputType, stream: bool = False) -> Generator[StreamOutputType, None, OutputType]:
if self.called:
raise RuntimeError("MyLLMは1回しか呼び出せない")
self._print_start(sep="=")
# main -----------------------------------------------------------
t_start = time.time()
self.inputs = inputs
# [stream]
if stream:
it = self._stream_link(inputs)
while True:
try:
yield next(it)
except StopIteration as e:
self.outputs = cast(OutputType, e.value)
break
except Exception as e:
raise e
# [non-stream]
else:
self.outputs = self._link(inputs)
self._print_inputs()
self._print_outputs()
self._print_all_myllm()
self.time = time.time() - t_start
self.time_detail = TimeInfo(total=self.time, main=self.time)
# metadata -----------------------------------------------------------
self._print_metadata()
self._print_end(sep="=")
# 親MyL3M2にAppend -----------------------------------------------------------
if self.parent is not None:
self.parent.myllm_list.append(self)
self.called = True
return self.outputs
@property
def token(self) -> TokenInfo:
token = TokenInfo(input=0, output=0, total=0)
for myllm in self.myllm_list:
# TODO: token += myllm.token
token.input += myllm.token.input
token.output += myllm.token.output
token.total += myllm.token.total
return token
@property
def price(self) -> PriceInfo:
price = PriceInfo(input=0.0, output=0.0, total=0.0)
for myllm in self.myllm_list:
# TODO: price += myllm.price
price.input += myllm.price.input
price.output += myllm.price.output
price.total += myllm.price.total
return price
@property
def logs(self) -> list[Any]:
logs: list[Any] = []
for myllm in self.myllm_list:
if isinstance(myllm, MyLLM):
logs.append(myllm.log)
elif isinstance(myllm, MyL3M2):
logs.extend(myllm.logs)
return logs
def _print_all_myllm(self, prefix: str = "", title: bool = True) -> None:
if not ("all_myllm" not in self.silent_set and self.verbose):
return
try:
if title:
cprint("[all_myllm]", color=TITLE_COLOR)
print(" ", end="")
cprint(f"{self}", color="magenta", bold=True, underline=True)
for myllm in self.myllm_list:
if isinstance(myllm, MyLLM):
cprint(f" {prefix}- {myllm}", color="cyan")
elif isinstance(myllm, MyL3M2):
cprint(f" {prefix}- {myllm}", color="magenta")
myllm._print_all_myllm(prefix=prefix + " ", title=False)
except Exception as e:
cprint(e, color="red", background=True)
def __repr__(self) -> str:
return f"MyL3M2({self.__class__.__name__})"
|