Spaces:
Configuration error
Configuration error
import time | |
from abc import abstractmethod | |
from typing import Iterable, cast | |
from google.cloud.aiplatform_v1beta1.types import CountTokensResponse | |
from google.cloud.aiplatform_v1beta1.types.content import Candidate | |
from vertexai.generative_models import ( | |
Content, | |
GenerationConfig, | |
GenerationResponse, | |
GenerativeModel, | |
Part, | |
) | |
from vertexai.generative_models._generative_models import ContentsType | |
from neollm.llm.abstract_llm import AbstractLLM | |
from neollm.types import ( | |
ChatCompletion, | |
CompletionUsageForCustomPriceCalculation, | |
LLMSettings, | |
Message, | |
Messages, | |
Response, | |
StreamResponse, | |
) | |
from neollm.types.openai.chat_completion import ( | |
ChatCompletionMessage, | |
Choice, | |
CompletionUsage, | |
) | |
from neollm.types.openai.chat_completion import FinishReason as FinishReasonVertex | |
from neollm.types.openai.chat_completion_chunk import ( | |
ChatCompletionChunk, | |
ChoiceDelta, | |
ChunkChoice, | |
) | |
from neollm.utils.utils import cprint | |
class AbstractGemini(AbstractLLM): | |
def generate_config(self, llm_settings: LLMSettings) -> GenerationConfig: ... | |
# 使っていない | |
def encode(self, text: str) -> list[int]: | |
return [ord(char) for char in text] | |
# 使っていない | |
def decode(self, decoded: list[int]) -> str: | |
return "".join([chr(number) for number in decoded]) | |
def _count_tokens_vertex(self, contents: ContentsType) -> CountTokensResponse: | |
model = GenerativeModel(model_name=self.model) | |
return cast(CountTokensResponse, model.count_tokens(contents)) | |
def count_tokens(self, messages: list[Message] | None = None, only_response: bool = False) -> int: | |
""" | |
トークン数の計測 | |
Args: | |
messages (Messages): messages | |
Returns: | |
int: トークン数 | |
""" | |
if messages is None: | |
return 0 | |
_system, _message = self._convert_to_platform_messages(messages) | |
total_tokens = 0 | |
if _system: | |
total_tokens += int(self._count_tokens_vertex(_system).total_tokens) | |
if _message: | |
total_tokens = int(self._count_tokens_vertex(_message).total_tokens) | |
return total_tokens | |
def _convert_to_platform_messages(self, messages: Messages) -> tuple[str | None, list[Content]]: | |
_system = None | |
_message: list[Content] = [] | |
for message in messages: | |
if message["role"] == "system": | |
_system = "\n" + message["content"] | |
elif message["role"] == "user": | |
if isinstance(message["content"], str): | |
_message.append(Content(role="user", parts=[Part.from_text(message["content"])])) | |
else: | |
try: | |
if isinstance(message["content"], list) and message["content"][1]["type"] == "image_url": | |
encoded_image = message["content"][1]["image_url"]["url"].split(",")[-1] | |
_message.append( | |
Content( | |
role="user", | |
parts=[ | |
Part.from_text(message["content"][0]["text"]), | |
Part.from_data(data=encoded_image, mime_type="image/jpeg"), | |
], | |
) | |
) | |
except KeyError: | |
cprint("WARNING: 未対応です", color="yellow", background=True) | |
except IndexError: | |
cprint("WARNING: 未対応です", color="yellow", background=True) | |
except Exception as e: | |
cprint(e, color="red", background=True) | |
elif message["role"] == "assistant": | |
if isinstance(message["content"], str): | |
_message.append(Content(role="model", parts=[Part.from_text(message["content"])])) | |
else: | |
cprint("WARNING: 未対応です", color="yellow", background=True) | |
return _system, _message | |
def _convert_finish_reason(self, stop_reason: Candidate.FinishReason) -> FinishReasonVertex | None: | |
""" | |
参考記事 : https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason | |
0: FINISH_REASON_UNSPECIFIED | |
Default value. This value is unused. | |
1: STOP | |
Natural stop point of the model or provided stop sequence. | |
2: MAX_TOKENS | |
The maximum number of tokens as specified in the request was reached. | |
3: SAFETY | |
The candidate content was flagged for safety reasons. | |
4: RECITATION | |
The candidate content was flagged for recitation reasons. | |
5: OTHER | |
Unknown reason. | |
""" | |
if stop_reason.value in [0, 3, 4, 5]: | |
return "stop" | |
if stop_reason.value in [2]: | |
return "length" | |
return None | |
def _convert_to_response( | |
self, platform_response: GenerationResponse, system: str | None, message: list[Content] | |
) -> Response: | |
# input 請求用文字数 | |
input_billable_characters = 0 | |
if system: | |
input_billable_characters += self._count_tokens_vertex(system).total_billable_characters | |
if message: | |
input_billable_characters += self._count_tokens_vertex(message).total_billable_characters | |
# output 請求用文字数 | |
output_billable_characters = 0 | |
if platform_response.text: | |
output_billable_characters += self._count_tokens_vertex(platform_response.text).total_billable_characters | |
return ChatCompletion( # type: ignore [call-arg] | |
id="", | |
choices=[ | |
Choice( | |
index=0, | |
message=ChatCompletionMessage( | |
content=platform_response.text, | |
role="assistant", | |
), | |
finish_reason=self._convert_finish_reason(platform_response.candidates[0].finish_reason), | |
) | |
], | |
created=int(time.time()), | |
model=self.model, | |
object="messages.create", | |
system_fingerprint=None, | |
usage=CompletionUsage( | |
prompt_tokens=platform_response.usage_metadata.prompt_token_count, | |
completion_tokens=platform_response.usage_metadata.candidates_token_count, | |
total_tokens=platform_response.usage_metadata.prompt_token_count | |
+ platform_response.usage_metadata.candidates_token_count, | |
), | |
usage_for_price=CompletionUsageForCustomPriceCalculation( | |
prompt_tokens=input_billable_characters, | |
completion_tokens=output_billable_characters, | |
total_tokens=input_billable_characters + output_billable_characters, | |
), | |
) | |
def _convert_to_streamresponse(self, platform_streamresponse: Iterable[GenerationResponse]) -> StreamResponse: | |
created = int(time.time()) | |
content: str | None = None | |
for chunk in platform_streamresponse: | |
content = chunk.text | |
yield ChatCompletionChunk( | |
id="", | |
choices=[ | |
ChunkChoice( | |
delta=ChoiceDelta( | |
content=content, | |
role="assistant", | |
), | |
finish_reason=self._convert_finish_reason(chunk.candidates[0].finish_reason), | |
index=0, # 0-indexedじゃないかもしれないので0に塗り替え | |
) | |
], | |
created=created, | |
model=self.model, | |
object="chat.completion.chunk", | |
) | |
def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response: | |
_system, _message = self._convert_to_platform_messages(messages) | |
model = GenerativeModel( | |
model_name=self.model, | |
system_instruction=_system, | |
) | |
response = model.generate_content( | |
contents=_message, | |
stream=False, | |
generation_config=self.generate_config(llm_settings), | |
) | |
return self._convert_to_response(platform_response=response, system=_system, message=_message) | |
def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse: | |
_system, _message = self._convert_to_platform_messages(messages) | |
model = GenerativeModel( | |
model_name=self.model, | |
system_instruction=_system, | |
) | |
response = model.generate_content( | |
contents=_message, | |
stream=True, | |
generation_config=self.generate_config(llm_settings), | |
) | |
return self._convert_to_streamresponse(platform_streamresponse=response) | |