Kpenciler commited on
Commit
88435ed
1 Parent(s): 8a87a53

Upload 53 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +26 -10
  3. asset/external_view.png +0 -0
  4. asset/external_view.pptx +3 -0
  5. makefile +14 -0
  6. neollm.code-workspace +63 -0
  7. neollm/__init__.py +5 -0
  8. neollm/exceptions.py +2 -0
  9. neollm/llm/__init__.py +4 -0
  10. neollm/llm/abstract_llm.py +188 -0
  11. neollm/llm/claude/abstract_claude.py +214 -0
  12. neollm/llm/claude/anthropic_llm.py +66 -0
  13. neollm/llm/claude/gcp_llm.py +67 -0
  14. neollm/llm/gemini/abstract_gemini.py +229 -0
  15. neollm/llm/gemini/gcp_llm.py +114 -0
  16. neollm/llm/get_llm.py +47 -0
  17. neollm/llm/gpt/abstract_gpt.py +81 -0
  18. neollm/llm/gpt/azure_llm.py +215 -0
  19. neollm/llm/gpt/openai_llm.py +222 -0
  20. neollm/llm/gpt/token.py +247 -0
  21. neollm/llm/platform.py +16 -0
  22. neollm/llm/utils.py +72 -0
  23. neollm/myllm/abstract_myllm.py +148 -0
  24. neollm/myllm/myl3m2.py +165 -0
  25. neollm/myllm/myllm.py +449 -0
  26. neollm/myllm/print_utils.py +235 -0
  27. neollm/types/__init__.py +4 -0
  28. neollm/types/_model.py +8 -0
  29. neollm/types/info.py +82 -0
  30. neollm/types/mytypes.py +31 -0
  31. neollm/types/openai/__init__.py +2 -0
  32. neollm/types/openai/chat_completion.py +170 -0
  33. neollm/types/openai/chat_completion_chunk.py +109 -0
  34. neollm/utils/inference.py +70 -0
  35. neollm/utils/postprocess.py +120 -0
  36. neollm/utils/preprocess.py +107 -0
  37. neollm/utils/prompt_checker.py +110 -0
  38. neollm/utils/tokens.py +229 -0
  39. neollm/utils/utils.py +98 -0
  40. poetry.lock +0 -0
  41. project/.env.template +24 -0
  42. project/ex_module/ex_profile_extractor.py +113 -0
  43. project/ex_module/ex_translated_profile_extractor.py +49 -0
  44. project/ex_module/ex_translator.py +62 -0
  45. project/neollm-tutorial.ipynb +713 -0
  46. pyproject.toml +81 -0
  47. test/llm/claude/test_claude_llm.py +37 -0
  48. test/llm/gpt/test_azure_llm.py +92 -0
  49. test/llm/gpt/test_openai_llm.py +37 -0
  50. test/llm/platform.py +32 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ asset/external_view.pptx filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,10 +1,26 @@
1
- ---
2
- title: Neo Llm Module V1.3.5
3
- emoji: 🐢
4
- colorFrom: pink
5
- colorTo: blue
6
- sdk: static
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # neoLLM Module
2
+
3
+ neoAIのLLMソリューションの基盤モジュール
4
+ [neoLLM Module Document](https://www.notion.so/neoLLM-Module-Document-64399d1d1db24d92bce8f9b88472833f)
5
+
6
+ ## 準備
7
+ [neoLLM インストール方法](https://www.notion.so/c760d96f1b4240e6880a32bee96bba35)
8
+ 1. install neoLLM Module ※ Python 3.10
9
+ ```bash
10
+ $ pip install git+https://github.com/neoAI-inc/neo-llm-module.git@v1.x.x
11
+ ```
12
+
13
+ 2. APIキーの設定
14
+ `.env`ファイルの配置
15
+ - 環境変数を`.env`ファイルで定義し,実行するバスに配置
16
+ - `project/example_env.txt`を`.env`に名前を変えて, 必要事項を記入
17
+
18
+ ## 使用方法
19
+ ### 概要
20
+ 灰色背景の部分を開発するだけでOK
21
+ - MyLLM: 1回のLLMへのリクエストをラップできる
22
+ - MyL3M2: 複数のLLMへのリクエストをラップできる
23
+
24
+ 詳しくは、`project/neollm-tutorial.ipynb`, `project/ex_module`
25
+ ![外観図](asset/external_view.png)
26
+
asset/external_view.png ADDED
asset/external_view.pptx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b3e9d7dbbb6f9ca5750edd9eaad8fe7ce5fcb5797e8027ae11dea90a0a47a2c
3
+ size 8728033
makefile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: lint
2
+ lint: ## run tests with poetry (isort, black, pflake8, mypy)
3
+ poetry run black neollm
4
+ poetry run isort neollm
5
+ poetry run pflake8 neollm
6
+ poetry run mypy neollm --explicit-package-bases
7
+
8
+ .PHONY: test
9
+ test:
10
+ poetry run pytest
11
+
12
+ .PHONY: unit-test
13
+ unit-test:
14
+ poetry run pytest -k "not test_neollm"
neollm.code-workspace ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "folders": [
3
+ {
4
+ "name": "neo-llm-module",
5
+ "path": "."
6
+ }
7
+ ],
8
+ "settings": {
9
+ "editor.codeActionsOnSave": {
10
+ "source.fixAll.eslint": "explicit",
11
+ "source.fixAll.stylelint": "explicit"
12
+ },
13
+ "editor.formatOnSave": true,
14
+ "editor.formatOnPaste": true,
15
+ "editor.formatOnType": true,
16
+ "json.format.keepLines": true,
17
+ "[javascript]": {
18
+ "editor.defaultFormatter": "esbenp.prettier-vscode"
19
+ },
20
+ "[typescript]": {
21
+ "editor.defaultFormatter": "esbenp.prettier-vscode"
22
+ },
23
+ "[typescriptreact]": {
24
+ "editor.defaultFormatter": "esbenp.prettier-vscode"
25
+ },
26
+ "[css]": {
27
+ "editor.defaultFormatter": "esbenp.prettier-vscode"
28
+ },
29
+ "[json]": {
30
+ "editor.defaultFormatter": "vscode.json-language-features"
31
+ },
32
+ "search.exclude": {
33
+ "**/node_modules": true,
34
+ "static": true
35
+ },
36
+ "[python]": {
37
+ "editor.defaultFormatter": "ms-python.black-formatter",
38
+ "editor.codeActionsOnSave": {
39
+ "source.organizeImports": "explicit"
40
+ }
41
+ },
42
+ "flake8.args": [
43
+ "--max-line-length=119",
44
+ "--max-complexity=15",
45
+ "--ignore=E203,E501,E704,W503",
46
+ "--exclude=.venv,.git,__pycache__,.mypy_cache,.hg"
47
+ ],
48
+ "isort.args": ["--settings-path=pyproject.toml"],
49
+ "black-formatter.args": ["--config=pyproject.toml"],
50
+ "mypy-type-checker.args": ["--config-file=pyproject.toml"],
51
+ "python.analysis.extraPaths": ["./backend"]
52
+ },
53
+ "extensions": {
54
+ "recommendations": [
55
+ "esbenp.prettier-vscode",
56
+ "dbaeumer.vscode-eslint",
57
+ "ms-python.flake8",
58
+ "ms-python.isort",
59
+ "ms-python.black-formatter",
60
+ "ms-python.mypy-type-checker"
61
+ ]
62
+ }
63
+ }
neollm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from neollm.myllm.abstract_myllm import AbstractMyLLM
2
+ from neollm.myllm.myl3m2 import MyL3M2
3
+ from neollm.myllm.myllm import MyLLM
4
+
5
+ __all__ = ["AbstractMyLLM", "MyLLM", "MyL3M2"]
neollm/exceptions.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ class ContentFilterError(Exception):
2
+ pass
neollm/llm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from neollm.llm.abstract_llm import AbstractLLM
2
+ from neollm.llm.get_llm import get_llm
3
+
4
+ __all__ = ["AbstractLLM", "get_llm"]
neollm/llm/abstract_llm.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from neollm.llm.utils import get_entity
5
+ from neollm.types import (
6
+ APIPricing,
7
+ ChatCompletion,
8
+ ChatCompletionMessage,
9
+ ChatCompletionMessageToolCall,
10
+ Choice,
11
+ ChoiceDeltaToolCall,
12
+ Chunk,
13
+ ClientSettings,
14
+ CompletionUsage,
15
+ Function,
16
+ FunctionCall,
17
+ LLMSettings,
18
+ Messages,
19
+ Response,
20
+ StreamResponse,
21
+ )
22
+ from neollm.utils.utils import cprint
23
+
24
+
25
+ # 現状、Azure, OpenAIに対応
26
+ class AbstractLLM(ABC):
27
+ dollar_per_ktoken: APIPricing
28
+ model: str
29
+ context_window: int
30
+ _custom_price_calculation: bool = False # self.tokenではなく、self.custom_tokenを使う場合にTrue
31
+
32
+ def __init__(self, client_settings: ClientSettings):
33
+ """LLMクラスの初期化
34
+
35
+ Args:
36
+ client_settings (ClientSettings): クライアント設定
37
+ """
38
+ self.client_settings = client_settings
39
+
40
+ def calculate_price(self, num_input_tokens: int = 0, num_output_tokens: int = 0) -> float:
41
+ """
42
+ 費用の計測
43
+
44
+ Args:
45
+ num_input_tokens (int, optional): 入力のトークン数. Defaults to 0.
46
+ num_output_tokens (int, optional): 出力のトークン数. Defaults to 0.
47
+
48
+ Returns:
49
+ float: API利用料(USD)
50
+ """
51
+ price = (
52
+ self.dollar_per_ktoken.input * num_input_tokens + self.dollar_per_ktoken.output * num_output_tokens
53
+ ) / 1000
54
+ return price
55
+
56
+ @abstractmethod
57
+ def count_tokens(self, messages: Messages | None = None, only_response: bool = False) -> int: ...
58
+
59
+ @abstractmethod
60
+ def encode(self, text: str) -> list[int]: ...
61
+
62
+ @abstractmethod
63
+ def decode(self, encoded: list[int]) -> str: ...
64
+
65
+ @abstractmethod
66
+ def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response:
67
+ """生成
68
+
69
+ Args:
70
+ messages (Messages): OpenAI仕様のMessages(list[dict])
71
+
72
+ Returns:
73
+ Response: OpenAI likeなResponse
74
+ """
75
+
76
+ @abstractmethod
77
+ def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse: ...
78
+
79
+ def __repr__(self) -> str:
80
+ return f"{self.__class__}()"
81
+
82
+ def convert_nonstream_response(
83
+ self, chunk_list: list[Chunk], messages: Messages, functions: Any = None
84
+ ) -> Response:
85
+ # messagesとfunctionsはトークン数計測に必要
86
+ _chunk_choices = [chunk.choices[0] for chunk in chunk_list if len(chunk.choices) > 0]
87
+ # TODO: n=2以上の場合にwarningを出したい
88
+
89
+ # FunctionCall --------------------------------------------------
90
+ function_call: FunctionCall | None
91
+ if all([_c.delta.function_call is None for _c in _chunk_choices]):
92
+ function_call = None
93
+ else:
94
+ function_call = FunctionCall(
95
+ arguments="".join(
96
+ [
97
+ _c.delta.function_call.arguments
98
+ for _c in _chunk_choices
99
+ if _c.delta.function_call is not None and _c.delta.function_call.arguments is not None
100
+ ]
101
+ ),
102
+ name=get_entity(
103
+ [_c.delta.function_call.name for _c in _chunk_choices if _c.delta.function_call is not None],
104
+ default="",
105
+ ),
106
+ )
107
+
108
+ # ToolCalls --------------------------------------------------
109
+ _tool_calls_dict: dict[int, list[ChoiceDeltaToolCall]] = {} # key=index
110
+ for _chunk in _chunk_choices:
111
+ if _chunk.delta.tool_calls is None:
112
+ continue
113
+ for _tool_call in _chunk.delta.tool_calls:
114
+ _tool_calls_dict.setdefault(_tool_call.index, []).append(_tool_call)
115
+
116
+ tool_calls: list[ChatCompletionMessageToolCall] | None
117
+ if sum(len(_tool_calls) for _tool_calls in _tool_calls_dict.values()) == 0:
118
+ tool_calls = None
119
+ else:
120
+ tool_calls = []
121
+ for _tool_calls in _tool_calls_dict.values():
122
+ tool_calls.append(
123
+ ChatCompletionMessageToolCall(
124
+ id=get_entity([_tc.id for _tc in _tool_calls], default=""),
125
+ function=Function(
126
+ arguments="".join(
127
+ [
128
+ _tc.function.arguments
129
+ for _tc in _tool_calls
130
+ if _tc.function is not None and _tc.function.arguments is not None
131
+ ]
132
+ ),
133
+ name=get_entity(
134
+ [_tc.function.name for _tc in _tool_calls if _tc.function is not None], default=""
135
+ ),
136
+ ),
137
+ type=get_entity([_tc.type for _tc in _tool_calls], default="function"),
138
+ )
139
+ )
140
+ message = ChatCompletionMessage(
141
+ content="".join([_c.delta.content for _c in _chunk_choices if _c.delta.content is not None]),
142
+ # TODO: ChoiceDeltaのroleなんで、assistant以外も許されてるの?
143
+ role=get_entity([_c.delta.role for _c in _chunk_choices], default="assistant"), # type: ignore
144
+ function_call=function_call,
145
+ tool_calls=tool_calls,
146
+ )
147
+ choice = Choice(
148
+ index=get_entity([_c.index for _c in _chunk_choices], default=0),
149
+ message=message,
150
+ finish_reason=get_entity([_c.finish_reason for _c in _chunk_choices], default=None),
151
+ )
152
+
153
+ # Usage --------------------------------------------------
154
+ try:
155
+ for chunk in chunk_list:
156
+ if getattr(chunk, "tokens"):
157
+ prompt_tokens = int(getattr(chunk, "tokens")["input_tokens"])
158
+ completion_tokens = int(getattr(chunk, "tokens")["output_tokens"])
159
+ assert prompt_tokens
160
+ assert completion_tokens
161
+ except Exception:
162
+ prompt_tokens = self.count_tokens(messages) # TODO: fcなど
163
+ completion_tokens = self.count_tokens([message.to_typeddict_message()], only_response=True)
164
+ usages = CompletionUsage(
165
+ completion_tokens=completion_tokens,
166
+ prompt_tokens=prompt_tokens,
167
+ total_tokens=prompt_tokens + completion_tokens,
168
+ )
169
+
170
+ # ChatCompletion ------------------------------------------
171
+ response = ChatCompletion(
172
+ id=get_entity([chunk.id for chunk in chunk_list], default=""),
173
+ object="chat.completion",
174
+ created=get_entity([getattr(chunk, "created", 0) for chunk in chunk_list], default=0),
175
+ model=get_entity([getattr(chunk, "model", "") for chunk in chunk_list], default=""),
176
+ choices=[choice],
177
+ system_fingerprint=get_entity(
178
+ [getattr(chunk, "system_fingerprint", None) for chunk in chunk_list], default=None
179
+ ),
180
+ usage=usages,
181
+ )
182
+
183
+ return response
184
+
185
+ @property
186
+ def max_tokens(self) -> int:
187
+ cprint("max_tokensは非推奨です。context_windowを使用してください。")
188
+ return self.context_window
neollm/llm/claude/abstract_claude.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from abc import abstractmethod
3
+ from typing import Any, Literal, cast
4
+
5
+ from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, Stream
6
+ from anthropic.types import MessageParam as AnthropicMessageParam
7
+ from anthropic.types import MessageStreamEvent as AnthropicMessageStreamEvent
8
+ from anthropic.types.message import Message as AnthropicMessage
9
+
10
+ from neollm.llm.abstract_llm import AbstractLLM
11
+ from neollm.types import (
12
+ ChatCompletion,
13
+ LLMSettings,
14
+ Message,
15
+ Messages,
16
+ Response,
17
+ StreamResponse,
18
+ )
19
+ from neollm.types.openai.chat_completion import (
20
+ ChatCompletionMessage,
21
+ Choice,
22
+ CompletionUsage,
23
+ FinishReason,
24
+ )
25
+ from neollm.types.openai.chat_completion_chunk import (
26
+ ChatCompletionChunk,
27
+ ChoiceDelta,
28
+ ChunkChoice,
29
+ )
30
+ from neollm.utils.utils import cprint
31
+
32
+ DEFAULT_MAX_TOKENS = 4_096
33
+
34
+
35
+ class AbstractClaude(AbstractLLM):
36
+ @property
37
+ @abstractmethod
38
+ def client(self) -> Anthropic | AnthropicVertex | AnthropicBedrock: ...
39
+
40
+ @property
41
+ def _client_for_token(self) -> Anthropic:
42
+ """トークンカウント用のAnthropicクライアント取得
43
+ (AnthropicBedrock, AnthropicVertexがmethodを持っていないため)
44
+
45
+ Returns:
46
+ Anthropic: Anthropicクライアント
47
+ """
48
+ return Anthropic()
49
+
50
+ def encode(self, text: str) -> list[int]:
51
+ tokenizer = self._client_for_token.get_tokenizer()
52
+ encoded = cast(list[int], tokenizer.encode(text).ids)
53
+ return encoded
54
+
55
+ def decode(self, decoded: list[int]) -> str:
56
+ tokenizer = self._client_for_token.get_tokenizer()
57
+ text = cast(str, tokenizer.decode(decoded))
58
+ return text
59
+
60
+ def count_tokens(self, messages: list[Message] | None = None, only_response: bool = False) -> int:
61
+ """
62
+ トークン数の計測
63
+
64
+ Args:
65
+ messages (Messages): messages
66
+
67
+ Returns:
68
+ int: トークン数
69
+ """
70
+ if messages is None:
71
+ return 0
72
+ tokens = 0
73
+ for message in messages:
74
+ content = message["content"]
75
+ if isinstance(content, str):
76
+ tokens += self._client_for_token.count_tokens(content)
77
+ continue
78
+ if isinstance(content, list):
79
+ for content_i in content:
80
+ if content_i["type"] == "text":
81
+ tokens += self._client_for_token.count_tokens(content_i["text"])
82
+ continue
83
+ return tokens
84
+
85
+ def _convert_finish_reason(
86
+ self, stop_reason: Literal["end_turn", "max_tokens", "stop_sequence"] | None
87
+ ) -> FinishReason | None:
88
+ if stop_reason == "max_tokens":
89
+ return "length"
90
+ if stop_reason == "stop_sequence":
91
+ return "stop"
92
+ return None
93
+
94
+ def _convert_to_response(self, platform_response: AnthropicMessage) -> Response:
95
+ return ChatCompletion(
96
+ id=platform_response.id,
97
+ choices=[
98
+ Choice(
99
+ index=0,
100
+ message=ChatCompletionMessage(
101
+ content=platform_response.content[0].text if len(platform_response.content) > 0 else "",
102
+ role="assistant",
103
+ ),
104
+ finish_reason=self._convert_finish_reason(platform_response.stop_reason),
105
+ )
106
+ ],
107
+ created=int(time.time()),
108
+ model=self.model,
109
+ object="messages.create",
110
+ system_fingerprint=None,
111
+ usage=CompletionUsage(
112
+ prompt_tokens=platform_response.usage.input_tokens,
113
+ completion_tokens=platform_response.usage.output_tokens,
114
+ total_tokens=platform_response.usage.input_tokens + platform_response.usage.output_tokens,
115
+ ),
116
+ )
117
+
118
+ def _convert_to_platform_messages(self, messages: Messages) -> tuple[str, list[AnthropicMessageParam]]:
119
+ _system = ""
120
+ _message: list[AnthropicMessageParam] = []
121
+ for message in messages:
122
+ if message["role"] == "system":
123
+ _system += "\n" + message["content"]
124
+ elif message["role"] == "user":
125
+ if isinstance(message["content"], str):
126
+ _message.append({"role": "user", "content": message["content"]})
127
+ else:
128
+ cprint("WARNING: 未対応です", color="yellow", background=True)
129
+ elif message["role"] == "assistant":
130
+ if isinstance(message["content"], str):
131
+ _message.append({"role": "assistant", "content": message["content"]})
132
+ else:
133
+ cprint("WARNING: 未対応です", color="yellow", background=True)
134
+ else:
135
+ cprint("WARNING: 未対応です", color="yellow", background=True)
136
+ return _system, _message
137
+
138
+ def _convert_to_streamresponse(
139
+ self, platform_streamresponse: Stream[AnthropicMessageStreamEvent]
140
+ ) -> StreamResponse:
141
+ created = int(time.time())
142
+ model = ""
143
+ id_ = ""
144
+ content: str | None = None
145
+ for chunk in platform_streamresponse:
146
+ input_tokens = 0
147
+ output_tokens = 0
148
+ if chunk.type == "message_stop" or chunk.type == "content_block_stop":
149
+ continue
150
+ if chunk.type == "message_start":
151
+ model = model or chunk.message.model
152
+ id_ = id_ or chunk.message.id
153
+ input_tokens = chunk.message.usage.input_tokens
154
+ output_tokens = chunk.message.usage.output_tokens
155
+ content = "".join([content_block.text for content_block in chunk.message.content])
156
+ finish_reason = self._convert_finish_reason(chunk.message.stop_reason)
157
+ elif chunk.type == "message_delta":
158
+ content = ""
159
+ finish_reason = self._convert_finish_reason(chunk.delta.stop_reason)
160
+ output_tokens = chunk.usage.output_tokens
161
+ elif chunk.type == "content_block_start":
162
+ content = chunk.content_block.text
163
+ finish_reason = None
164
+ elif chunk.type == "content_block_delta":
165
+ content = chunk.delta.text
166
+ finish_reason = None
167
+ yield ChatCompletionChunk(
168
+ id=id_,
169
+ choices=[
170
+ ChunkChoice(
171
+ delta=ChoiceDelta(
172
+ content=content,
173
+ role="assistant",
174
+ ),
175
+ finish_reason=finish_reason,
176
+ index=0, # 0-indexedじゃないかもしれないので0に塗り替え
177
+ )
178
+ ],
179
+ created=created,
180
+ model=model,
181
+ object="chat.completion.chunk",
182
+ tokens={"input_tokens": input_tokens, "output_tokens": output_tokens}, # type: ignore
183
+ )
184
+
185
+ def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response:
186
+ _system, _message = self._convert_to_platform_messages(messages)
187
+ llm_settings = self._set_max_tokens(llm_settings)
188
+ response = self.client.messages.create(
189
+ model=self.model,
190
+ system=_system,
191
+ messages=_message,
192
+ stream=False,
193
+ **llm_settings,
194
+ )
195
+ return self._convert_to_response(platform_response=response)
196
+
197
+ def generate_stream(self, messages: Any, llm_settings: LLMSettings) -> StreamResponse:
198
+ _system, _message = self._convert_to_platform_messages(messages)
199
+ llm_settings = self._set_max_tokens(llm_settings)
200
+ response = self.client.messages.create(
201
+ model=self.model,
202
+ system=_system,
203
+ messages=_message,
204
+ stream=True,
205
+ **llm_settings,
206
+ )
207
+ return self._convert_to_streamresponse(platform_streamresponse=response)
208
+
209
+ def _set_max_tokens(self, llm_settings: LLMSettings) -> LLMSettings:
210
+ # claudeはmax_tokensが必須
211
+ if not llm_settings.get("max_tokens"):
212
+ cprint(f"max_tokens is not set. Set to {DEFAULT_MAX_TOKENS}.", color="yellow")
213
+ llm_settings["max_tokens"] = DEFAULT_MAX_TOKENS
214
+ return llm_settings
neollm/llm/claude/anthropic_llm.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, cast, get_args
2
+
3
+ from anthropic import Anthropic
4
+
5
+ from neollm.llm.abstract_llm import AbstractLLM
6
+ from neollm.llm.claude.abstract_claude import AbstractClaude
7
+ from neollm.types import APIPricing, ClientSettings
8
+ from neollm.utils.utils import cprint
9
+
10
+ # price: https://www.anthropic.com/api
11
+ # models: https://docs.anthropic.com/claude/docs/models-overview
12
+
13
+ SUPPORTED_MODELS = Literal[
14
+ "claude-3-opus-20240229",
15
+ "claude-3-sonnet-20240229",
16
+ "claude-3-haiku-20240307",
17
+ ]
18
+
19
+
20
+ def get_anthoropic_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM:
21
+ # Add 日付
22
+ replace_map_for_nodate: dict[str, SUPPORTED_MODELS] = {
23
+ "claude-3-opus": "claude-3-opus-20240229",
24
+ "claude-3-sonnet": "claude-3-sonnet-20240229",
25
+ "claude-3-haiku": "claude-3-haiku-20240307",
26
+ }
27
+ if model_name in replace_map_for_nodate:
28
+ cprint("WARNING: model_nameに日付を指定してください", color="yellow", background=True)
29
+ print(f"model_name: {model_name} -> {replace_map_for_nodate[model_name]}")
30
+ model_name = replace_map_for_nodate[model_name]
31
+
32
+ # map to LLM
33
+ supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = {
34
+ "claude-3-opus-20240229": AnthropicClaude3Opus20240229(client_settings),
35
+ "claude-3-sonnet-20240229": AnthropicClaude3Sonnet20240229(client_settings),
36
+ "claude-3-haiku-20240307": AnthropicClaude3Haiku20240229(client_settings),
37
+ }
38
+ if model_name in supported_model_map:
39
+ model_name = cast(SUPPORTED_MODELS, model_name)
40
+ return supported_model_map[model_name]
41
+ raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.")
42
+
43
+
44
+ class AnthoropicLLM(AbstractClaude):
45
+ @property
46
+ def client(self) -> Anthropic:
47
+ client = Anthropic(**self.client_settings)
48
+ return client
49
+
50
+
51
+ class AnthropicClaude3Opus20240229(AnthoropicLLM):
52
+ dollar_per_ktoken = APIPricing(input=15 / 1000, output=75 / 1000)
53
+ model: str = "claude-3-opus-20240229"
54
+ context_window: int = 200_000
55
+
56
+
57
+ class AnthropicClaude3Sonnet20240229(AnthoropicLLM):
58
+ dollar_per_ktoken = APIPricing(input=3 / 1000, output=15 / 1000)
59
+ model: str = "claude-3-sonnet-20240229"
60
+ context_window: int = 200_000
61
+
62
+
63
+ class AnthropicClaude3Haiku20240229(AnthoropicLLM):
64
+ dollar_per_ktoken = APIPricing(input=0.25 / 1000, output=1.25 / 1000)
65
+ model: str = "claude-3-haiku-20240307"
66
+ context_window: int = 200_000
neollm/llm/claude/gcp_llm.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, cast, get_args
2
+
3
+ from anthropic import AnthropicVertex
4
+
5
+ from neollm.llm.abstract_llm import AbstractLLM
6
+ from neollm.llm.claude.abstract_claude import AbstractClaude
7
+ from neollm.types import APIPricing, ClientSettings
8
+ from neollm.utils.utils import cprint
9
+
10
+ # price: https://www.anthropic.com/api
11
+ # models: https://docs.anthropic.com/claude/docs/models-overview
12
+
13
+ SUPPORTED_MODELS = Literal[
14
+ "claude-3-opus@20240229",
15
+ "claude-3-sonnet@20240229",
16
+ "claude-3-haiku@20240307",
17
+ ]
18
+
19
+
20
+ # TODO! google 動かしたいね
21
+ def get_gcp_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM:
22
+ # Add 日付
23
+ replace_map_for_nodate: dict[str, SUPPORTED_MODELS] = {
24
+ "claude-3-opus": "claude-3-opus@20240229",
25
+ "claude-3-sonnet": "claude-3-sonnet@20240229",
26
+ "claude-3-haiku": "claude-3-haiku@20240307",
27
+ }
28
+ if model_name in replace_map_for_nodate:
29
+ cprint("WARNING: model_nameに日付を指定してください", color="yellow", background=True)
30
+ print(f"model_name: {model_name} -> {replace_map_for_nodate[model_name]}")
31
+ model_name = replace_map_for_nodate[model_name]
32
+
33
+ # map to LLM
34
+ supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = {
35
+ "claude-3-opus@20240229": GCPClaude3Opus20240229(client_settings),
36
+ "claude-3-sonnet@20240229": GCPClaude3Sonnet20240229(client_settings),
37
+ "claude-3-haiku@20240307": GCPClaude3Haiku20240229(client_settings),
38
+ }
39
+ if model_name in supported_model_map:
40
+ model_name = cast(SUPPORTED_MODELS, model_name)
41
+ return supported_model_map[model_name]
42
+ raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.")
43
+
44
+
45
+ class GoogleLLM(AbstractClaude):
46
+ @property
47
+ def client(self) -> AnthropicVertex:
48
+ client = AnthropicVertex(**self.client_settings)
49
+ return client
50
+
51
+
52
+ class GCPClaude3Opus20240229(GoogleLLM):
53
+ dollar_per_ktoken = APIPricing(input=15 / 1000, output=75 / 1000)
54
+ model: str = "claude-3-opus@20240229"
55
+ context_window: int = 200_000
56
+
57
+
58
+ class GCPClaude3Sonnet20240229(GoogleLLM):
59
+ dollar_per_ktoken = APIPricing(input=3 / 1000, output=15 / 1000)
60
+ model: str = "claude-3-sonnet@20240229"
61
+ context_window: int = 200_000
62
+
63
+
64
+ class GCPClaude3Haiku20240229(GoogleLLM):
65
+ dollar_per_ktoken = APIPricing(input=0.25 / 1000, output=1.25 / 1000)
66
+ model: str = "claude-3-haiku@20240307"
67
+ context_window: int = 200_000
neollm/llm/gemini/abstract_gemini.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from abc import abstractmethod
3
+ from typing import Iterable, cast
4
+
5
+ from google.cloud.aiplatform_v1beta1.types import CountTokensResponse
6
+ from google.cloud.aiplatform_v1beta1.types.content import Candidate
7
+ from vertexai.generative_models import (
8
+ Content,
9
+ GenerationConfig,
10
+ GenerationResponse,
11
+ GenerativeModel,
12
+ Part,
13
+ )
14
+ from vertexai.generative_models._generative_models import ContentsType
15
+
16
+ from neollm.llm.abstract_llm import AbstractLLM
17
+ from neollm.types import (
18
+ ChatCompletion,
19
+ CompletionUsageForCustomPriceCalculation,
20
+ LLMSettings,
21
+ Message,
22
+ Messages,
23
+ Response,
24
+ StreamResponse,
25
+ )
26
+ from neollm.types.openai.chat_completion import (
27
+ ChatCompletionMessage,
28
+ Choice,
29
+ CompletionUsage,
30
+ )
31
+ from neollm.types.openai.chat_completion import FinishReason as FinishReasonVertex
32
+ from neollm.types.openai.chat_completion_chunk import (
33
+ ChatCompletionChunk,
34
+ ChoiceDelta,
35
+ ChunkChoice,
36
+ )
37
+ from neollm.utils.utils import cprint
38
+
39
+
40
+ class AbstractGemini(AbstractLLM):
41
+
42
+ @abstractmethod
43
+ def generate_config(self, llm_settings: LLMSettings) -> GenerationConfig: ...
44
+
45
+ # 使っていない
46
+ def encode(self, text: str) -> list[int]:
47
+ return [ord(char) for char in text]
48
+
49
+ # 使っていない
50
+ def decode(self, decoded: list[int]) -> str:
51
+ return "".join([chr(number) for number in decoded])
52
+
53
+ def _count_tokens_vertex(self, contents: ContentsType) -> CountTokensResponse:
54
+ model = GenerativeModel(model_name=self.model)
55
+ return cast(CountTokensResponse, model.count_tokens(contents))
56
+
57
+ def count_tokens(self, messages: list[Message] | None = None, only_response: bool = False) -> int:
58
+ """
59
+ トークン数の計測
60
+
61
+ Args:
62
+ messages (Messages): messages
63
+
64
+ Returns:
65
+ int: トークン数
66
+ """
67
+ if messages is None:
68
+ return 0
69
+ _system, _message = self._convert_to_platform_messages(messages)
70
+ total_tokens = 0
71
+ if _system:
72
+ total_tokens += int(self._count_tokens_vertex(_system).total_tokens)
73
+ if _message:
74
+ total_tokens = int(self._count_tokens_vertex(_message).total_tokens)
75
+ return total_tokens
76
+
77
+ def _convert_to_platform_messages(self, messages: Messages) -> tuple[str | None, list[Content]]:
78
+ _system = None
79
+ _message: list[Content] = []
80
+
81
+ for message in messages:
82
+ if message["role"] == "system":
83
+ _system = "\n" + message["content"]
84
+ elif message["role"] == "user":
85
+ if isinstance(message["content"], str):
86
+ _message.append(Content(role="user", parts=[Part.from_text(message["content"])]))
87
+ else:
88
+ try:
89
+ if isinstance(message["content"], list) and message["content"][1]["type"] == "image_url":
90
+ encoded_image = message["content"][1]["image_url"]["url"].split(",")[-1]
91
+ _message.append(
92
+ Content(
93
+ role="user",
94
+ parts=[
95
+ Part.from_text(message["content"][0]["text"]),
96
+ Part.from_data(data=encoded_image, mime_type="image/jpeg"),
97
+ ],
98
+ )
99
+ )
100
+ except KeyError:
101
+ cprint("WARNING: 未対応です", color="yellow", background=True)
102
+ except IndexError:
103
+ cprint("WARNING: 未対応です", color="yellow", background=True)
104
+ except Exception as e:
105
+ cprint(e, color="red", background=True)
106
+ elif message["role"] == "assistant":
107
+ if isinstance(message["content"], str):
108
+ _message.append(Content(role="model", parts=[Part.from_text(message["content"])]))
109
+ else:
110
+ cprint("WARNING: 未対応です", color="yellow", background=True)
111
+ return _system, _message
112
+
113
+ def _convert_finish_reason(self, stop_reason: Candidate.FinishReason) -> FinishReasonVertex | None:
114
+ """
115
+ 参考記事 : https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
116
+
117
+ 0: FINISH_REASON_UNSPECIFIED
118
+ Default value. This value is unused.
119
+ 1: STOP
120
+ Natural stop point of the model or provided stop sequence.
121
+ 2: MAX_TOKENS
122
+ The maximum number of tokens as specified in the request was reached.
123
+ 3: SAFETY
124
+ The candidate content was flagged for safety reasons.
125
+ 4: RECITATION
126
+ The candidate content was flagged for recitation reasons.
127
+ 5: OTHER
128
+ Unknown reason.
129
+ """
130
+
131
+ if stop_reason.value in [0, 3, 4, 5]:
132
+ return "stop"
133
+
134
+ if stop_reason.value in [2]:
135
+ return "length"
136
+
137
+ return None
138
+
139
+ def _convert_to_response(
140
+ self, platform_response: GenerationResponse, system: str | None, message: list[Content]
141
+ ) -> Response:
142
+ # input 請求用文字数
143
+ input_billable_characters = 0
144
+ if system:
145
+ input_billable_characters += self._count_tokens_vertex(system).total_billable_characters
146
+ if message:
147
+ input_billable_characters += self._count_tokens_vertex(message).total_billable_characters
148
+ # output 請求用文字数
149
+ output_billable_characters = 0
150
+ if platform_response.text:
151
+ output_billable_characters += self._count_tokens_vertex(platform_response.text).total_billable_characters
152
+ return ChatCompletion( # type: ignore [call-arg]
153
+ id="",
154
+ choices=[
155
+ Choice(
156
+ index=0,
157
+ message=ChatCompletionMessage(
158
+ content=platform_response.text,
159
+ role="assistant",
160
+ ),
161
+ finish_reason=self._convert_finish_reason(platform_response.candidates[0].finish_reason),
162
+ )
163
+ ],
164
+ created=int(time.time()),
165
+ model=self.model,
166
+ object="messages.create",
167
+ system_fingerprint=None,
168
+ usage=CompletionUsage(
169
+ prompt_tokens=platform_response.usage_metadata.prompt_token_count,
170
+ completion_tokens=platform_response.usage_metadata.candidates_token_count,
171
+ total_tokens=platform_response.usage_metadata.prompt_token_count
172
+ + platform_response.usage_metadata.candidates_token_count,
173
+ ),
174
+ usage_for_price=CompletionUsageForCustomPriceCalculation(
175
+ prompt_tokens=input_billable_characters,
176
+ completion_tokens=output_billable_characters,
177
+ total_tokens=input_billable_characters + output_billable_characters,
178
+ ),
179
+ )
180
+
181
+ def _convert_to_streamresponse(self, platform_streamresponse: Iterable[GenerationResponse]) -> StreamResponse:
182
+ created = int(time.time())
183
+ content: str | None = None
184
+ for chunk in platform_streamresponse:
185
+ content = chunk.text
186
+ yield ChatCompletionChunk(
187
+ id="",
188
+ choices=[
189
+ ChunkChoice(
190
+ delta=ChoiceDelta(
191
+ content=content,
192
+ role="assistant",
193
+ ),
194
+ finish_reason=self._convert_finish_reason(chunk.candidates[0].finish_reason),
195
+ index=0, # 0-indexedじゃないかもしれないので0に塗り替え
196
+ )
197
+ ],
198
+ created=created,
199
+ model=self.model,
200
+ object="chat.completion.chunk",
201
+ )
202
+
203
+ def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response:
204
+ _system, _message = self._convert_to_platform_messages(messages)
205
+ model = GenerativeModel(
206
+ model_name=self.model,
207
+ system_instruction=_system,
208
+ )
209
+
210
+ response = model.generate_content(
211
+ contents=_message,
212
+ stream=False,
213
+ generation_config=self.generate_config(llm_settings),
214
+ )
215
+
216
+ return self._convert_to_response(platform_response=response, system=_system, message=_message)
217
+
218
+ def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse:
219
+ _system, _message = self._convert_to_platform_messages(messages)
220
+ model = GenerativeModel(
221
+ model_name=self.model,
222
+ system_instruction=_system,
223
+ )
224
+ response = model.generate_content(
225
+ contents=_message,
226
+ stream=True,
227
+ generation_config=self.generate_config(llm_settings),
228
+ )
229
+ return self._convert_to_streamresponse(platform_streamresponse=response)
neollm/llm/gemini/gcp_llm.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import Literal, cast, get_args
3
+
4
+ import vertexai
5
+ from vertexai.generative_models import GenerationConfig
6
+
7
+ from neollm.llm.abstract_llm import AbstractLLM
8
+ from neollm.llm.gemini.abstract_gemini import AbstractGemini
9
+ from neollm.types import APIPricing, ClientSettings, LLMSettings, StreamResponse
10
+ from neollm.types.mytypes import Messages, Response
11
+ from neollm.utils.utils import cprint
12
+
13
+ # price: https://ai.google.dev/pricing?hl=ja
14
+ # models: https://ai.google.dev/gemini-api/docs/models/gemini?hl=ja
15
+
16
+ SUPPORTED_MODELS = Literal["gemini-1.0-pro", "gemini-1.0-pro-vision", "gemini-1.5-pro-preview-0409"]
17
+ AVAILABLE_CONFIG_VARIABLES = [
18
+ "candidate_count",
19
+ "stop_sequences",
20
+ "temperature",
21
+ "max_tokens", # "max_output_tokensが設定されていない場合、max_tokensを使う
22
+ "max_output_tokens",
23
+ "top_p",
24
+ "top_k",
25
+ ]
26
+
27
+
28
+ def get_gcp_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM:
29
+
30
+ vertexai.init(**client_settings)
31
+
32
+ # map to LLM
33
+ supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = {
34
+ "gemini-1.0-pro": GCPGemini10Pro(client_settings),
35
+ "gemini-1.0-pro-vision": GCPGemini10ProVision(client_settings),
36
+ "gemini-1.5-pro-preview-0409": GCPGemini15Pro0409(client_settings),
37
+ }
38
+ if model_name in supported_model_map:
39
+ model_name = cast(SUPPORTED_MODELS, model_name)
40
+ return supported_model_map[model_name]
41
+ raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.")
42
+
43
+
44
+ class GoogleLLM(AbstractGemini):
45
+
46
+ def generate_config(self, llm_settings: LLMSettings) -> GenerationConfig:
47
+ """
48
+ 参考記事 : https://ai.google.dev/api/rest/v1/GenerationConfig?hl=ja
49
+ """
50
+ # gemini
51
+ candidate_count = llm_settings.pop("candidate_count", None)
52
+ stop_sequences = llm_settings.pop("stop_sequences", None)
53
+ temperature = llm_settings.pop("temperature", None)
54
+ max_output_tokens = llm_settings.pop("max_output_tokens", None)
55
+ top_p = llm_settings.pop("top_p", None)
56
+ top_k = llm_settings.pop("top_k", None)
57
+
58
+ # neollmの引数でも動くようにする
59
+ if max_output_tokens is None:
60
+ max_output_tokens = llm_settings.pop("max_tokens", None)
61
+
62
+ if len(llm_settings) > 0 and "max_tokens" not in llm_settings:
63
+ raise ValueError(f"llm_settings has unknown keys: {llm_settings}")
64
+
65
+ return GenerationConfig(
66
+ candidate_count=candidate_count,
67
+ stop_sequences=stop_sequences,
68
+ temperature=temperature,
69
+ max_output_tokens=max_output_tokens,
70
+ top_p=top_p,
71
+ top_k=top_k,
72
+ )
73
+
74
+
75
+ class GCPGemini10Pro(GoogleLLM):
76
+ dollar_per_ktoken = APIPricing(input=0.125 / 1000, output=0.375 / 1000)
77
+ model: str = "gemini-1.0-pro"
78
+ context_window: int = 32_000
79
+
80
+
81
+ class GCPGemini10ProVision(GoogleLLM):
82
+ dollar_per_ktoken = APIPricing(input=0.125 / 1000, output=0.375 / 1000)
83
+ model: str = "gemini-1.0-pro-vision"
84
+ context_window: int = 32_000
85
+
86
+ def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response:
87
+ messages = self._preprocess_message_to_use_system(messages)
88
+ return super().generate(messages, llm_settings)
89
+
90
+ def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse:
91
+ messages = self._preprocess_message_to_use_system(messages)
92
+ return super().generate_stream(messages, llm_settings)
93
+
94
+ def _preprocess_message_to_use_system(self, message: Messages) -> Messages:
95
+ if message[0]["role"] != "system":
96
+ return message
97
+ preprocessed_message = deepcopy(message)
98
+ system = preprocessed_message[0]["content"]
99
+ del preprocessed_message[0]
100
+ if (
101
+ isinstance(system, str)
102
+ and isinstance(preprocessed_message[0]["content"], list)
103
+ and isinstance(preprocessed_message[0]["content"][0]["text"], str)
104
+ ):
105
+ preprocessed_message[0]["content"][0]["text"] = system + preprocessed_message[0]["content"][0]["text"]
106
+ else:
107
+ cprint("WARNING: 入力形式が不正です", color="yellow", background=True)
108
+ return preprocessed_message
109
+
110
+
111
+ class GCPGemini15Pro0409(GoogleLLM):
112
+ dollar_per_ktoken = APIPricing(input=2.5 / 1000, output=7.5 / 1000)
113
+ model: str = "gemini-1.5-pro-preview-0409"
114
+ context_window: int = 1_000_000
neollm/llm/get_llm.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from neollm.llm.abstract_llm import AbstractLLM
2
+ from neollm.types import ClientSettings
3
+
4
+ from .platform import Platform
5
+
6
+ SUPPORTED_CLAUDE_MODELS = [
7
+ "claude-3-opus",
8
+ "claude-3-sonnet",
9
+ "claude-3-haiku",
10
+ "claude-3-opus@20240229",
11
+ "claude-3-sonnet@20240229",
12
+ "claude-3-haiku@20240307",
13
+ ]
14
+
15
+ SUPPORTED_GEMINI_MODELS = [
16
+ "gemini-1.5-pro-preview-0409",
17
+ "gemini-1.0-pro",
18
+ "gemini-1.0-pro-vision",
19
+ ]
20
+
21
+
22
+ def get_llm(model_name: str, platform: str, client_settings: ClientSettings) -> AbstractLLM:
23
+ platform = Platform(platform)
24
+ # llmの取得
25
+ if platform == Platform.AZURE:
26
+ from neollm.llm.gpt.azure_llm import get_azure_llm
27
+
28
+ return get_azure_llm(model_name, client_settings)
29
+ if platform == Platform.OPENAI:
30
+ from neollm.llm.gpt.openai_llm import get_openai_llm
31
+
32
+ return get_openai_llm(model_name, client_settings)
33
+ if platform == Platform.ANTHROPIC:
34
+ from neollm.llm.claude.anthropic_llm import get_anthoropic_llm
35
+
36
+ return get_anthoropic_llm(model_name, client_settings)
37
+ if platform == Platform.GCP:
38
+ if model_name in SUPPORTED_CLAUDE_MODELS:
39
+ from neollm.llm.claude.gcp_llm import get_gcp_llm as get_gcp_llm_for_claude
40
+
41
+ return get_gcp_llm_for_claude(model_name, client_settings)
42
+ elif model_name in SUPPORTED_GEMINI_MODELS:
43
+ from neollm.llm.gemini.gcp_llm import get_gcp_llm as get_gcp_llm_for_gemini
44
+
45
+ return get_gcp_llm_for_gemini(model_name, client_settings)
46
+ else:
47
+ raise ValueError(f"{model_name} is not supported in GCP.")
neollm/llm/gpt/abstract_gpt.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+ from neollm.llm.abstract_llm import AbstractLLM
4
+ from neollm.types import (
5
+ ChatCompletion,
6
+ ChatCompletionChunk,
7
+ Message,
8
+ Messages,
9
+ OpenAIMessages,
10
+ OpenAIResponse,
11
+ OpenAIStreamResponse,
12
+ Response,
13
+ StreamResponse,
14
+ )
15
+
16
+
17
+ class AbstractGPT(AbstractLLM):
18
+ def encode(self, text: str) -> list[int]:
19
+ tokenizer = tiktoken.encoding_for_model(self.model or "gpt-3.5-turbo")
20
+ return tokenizer.encode(text)
21
+
22
+ def decode(self, encoded: list[int]) -> str:
23
+ tokenizer = tiktoken.encoding_for_model(self.model or "gpt-3.5-turbo")
24
+ return tokenizer.decode(encoded)
25
+
26
+ def count_tokens(self, messages: list[Message] | None = None, only_response: bool = False) -> int:
27
+ """
28
+ トークン数の計測
29
+
30
+ Args:
31
+ messages (Messages): messages
32
+
33
+ Returns:
34
+ int: トークン数
35
+ """
36
+ if messages is None:
37
+ return 0
38
+
39
+ # count tokens
40
+ num_tokens: int = 0
41
+ # messages ---------------------------------------------------------------------------v
42
+ for message in messages:
43
+ # per message -------------------------------------------
44
+ num_tokens += 4
45
+ # content -----------------------------------------------
46
+ content = message.get("content", None)
47
+ if content is None:
48
+ num_tokens += 0
49
+ elif isinstance(content, str):
50
+ num_tokens += len(self.encode(content))
51
+ continue
52
+ elif isinstance(content, list):
53
+ for content_params in content:
54
+ if content_params["type"] == "text":
55
+ num_tokens += len(self.encode(content_params["text"]))
56
+ # TODO: ChatCompletionFunctionMessageParam.name
57
+ # tokens_per_name = 1
58
+ # tool calls ------------------------------------------------
59
+ # TODO: ChatCompletionAssistantMessageParam.function_call
60
+ # TODO: ChatCompletionAssistantMessageParam.tool_calls
61
+
62
+ if only_response:
63
+ if len(messages) != 1:
64
+ raise ValueError("only_response=Trueの場合、messagesは1つのみにしてください。")
65
+ num_tokens -= 4 # per message分を消す
66
+ else:
67
+ num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
68
+
69
+ return num_tokens
70
+
71
+ def _convert_to_response(self, platform_response: OpenAIResponse) -> Response:
72
+ return ChatCompletion(**platform_response.model_dump())
73
+
74
+ def _convert_to_platform_messages(self, messages: Messages) -> OpenAIMessages:
75
+ # OpenAIのMessagesをデフォルトに置いているため、変換は不要
76
+ platform_messages: OpenAIMessages = messages
77
+ return platform_messages
78
+
79
+ def _convert_to_streamresponse(self, platform_streamresponse: OpenAIStreamResponse) -> StreamResponse:
80
+ for chunk in platform_streamresponse:
81
+ yield ChatCompletionChunk(**chunk.model_dump())
neollm/llm/gpt/azure_llm.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, cast
2
+
3
+ from openai import AzureOpenAI
4
+
5
+ from neollm.llm.abstract_llm import AbstractLLM
6
+ from neollm.llm.gpt.abstract_gpt import AbstractGPT
7
+ from neollm.types import (
8
+ APIPricing,
9
+ ClientSettings,
10
+ LLMSettings,
11
+ Messages,
12
+ Response,
13
+ StreamResponse,
14
+ )
15
+ from neollm.utils.utils import cprint, ensure_env_var, suport_unrecomended_env_var
16
+
17
+ suport_unrecomended_env_var(old_key="AZURE_API_BASE", new_key="AZURE_OPENAI_ENDPOINT")
18
+ suport_unrecomended_env_var(old_key="AZURE_API_VERSION", new_key="OPENAI_API_VERSION")
19
+ # 0613なし
20
+ suport_unrecomended_env_var(old_key="AZURE_ENGINE_GPT35", new_key="AZURE_ENGINE_GPT35T_0613")
21
+ suport_unrecomended_env_var(old_key="AZURE_ENGINE_GPT35_16k", new_key="AZURE_ENGINE_GPT35T_16K_0613")
22
+ suport_unrecomended_env_var(old_key="AZURE_ENGINE_GPT4", new_key="AZURE_ENGINE_GPT4_0613")
23
+ suport_unrecomended_env_var(old_key="AZURE_ENGINE_GPT4_32k", new_key="AZURE_ENGINE_GPT4_32K_0613")
24
+ # turbo抜け
25
+ suport_unrecomended_env_var(old_key="AZURE_ENGINE_GPT35_0613", new_key="AZURE_ENGINE_GPT35T_0613")
26
+ suport_unrecomended_env_var(old_key="AZURE_ENGINE_GPT35_16K_0613", new_key="AZURE_ENGINE_GPT35T_16K_0613")
27
+
28
+ # Pricing: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
29
+
30
+ SUPPORTED_MODELS = Literal[
31
+ "gpt-4o-2024-05-13",
32
+ "gpt-4-turbo-2024-04-09",
33
+ "gpt-3.5-turbo-0125",
34
+ "gpt-4-turbo-0125",
35
+ "gpt-3.5-turbo-1106",
36
+ "gpt-4-turbo-1106",
37
+ "gpt-4v-turbo-1106",
38
+ "gpt-3.5-turbo-0613",
39
+ "gpt-3.5-turbo-16k-0613",
40
+ "gpt-4-0613",
41
+ "gpt-4-32k-0613",
42
+ ]
43
+
44
+
45
+ def get_azure_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM:
46
+ # 表記変更
47
+ model_name = model_name.replace("gpt-35-turbo", "gpt-3.5-turbo")
48
+ # Add 日付
49
+ replace_map_for_nodate: dict[str, SUPPORTED_MODELS] = {
50
+ "gpt-4o": "gpt-4o-2024-05-13",
51
+ "gpt-3.5-turbo": "gpt-3.5-turbo-0613",
52
+ "gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
53
+ "gpt-4": "gpt-4-0613",
54
+ "gpt-4-32k": "gpt-4-32k-0613",
55
+ "gpt-4-turbo": "gpt-4-turbo-1106",
56
+ "gpt-4v-turbo": "gpt-4v-turbo-1106",
57
+ }
58
+ if model_name in replace_map_for_nodate:
59
+ cprint("WARNING: model_nameに日付を指定してください", color="yellow", background=True)
60
+ print(f"model_name: {model_name} -> {replace_map_for_nodate[model_name]}")
61
+ model_name = replace_map_for_nodate[model_name]
62
+
63
+ # map to LLM
64
+ supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = {
65
+ "gpt-4o-2024-05-13": AzureGPT4O_20240513(client_settings),
66
+ "gpt-4-turbo-2024-04-09": AzureGPT4T_20240409(client_settings),
67
+ "gpt-3.5-turbo-0125": AzureGPT35T_0125(client_settings),
68
+ "gpt-4-turbo-0125": AzureGPT4T_0125(client_settings),
69
+ "gpt-3.5-turbo-1106": AzureGPT35T_1106(client_settings),
70
+ "gpt-4-turbo-1106": AzureGPT4T_1106(client_settings),
71
+ "gpt-4v-turbo-1106": AzureGPT4VT_1106(client_settings),
72
+ "gpt-3.5-turbo-0613": AzureGPT35T_0613(client_settings),
73
+ "gpt-3.5-turbo-16k-0613": AzureGPT35T16k_0613(client_settings),
74
+ "gpt-4-0613": AzureGPT4_0613(client_settings),
75
+ "gpt-4-32k-0613": AzureGPT432k_0613(client_settings),
76
+ }
77
+ # 通常モデル
78
+ if model_name in supported_model_map:
79
+ model_name = cast(SUPPORTED_MODELS, model_name)
80
+ return supported_model_map[model_name]
81
+ # FTモデル
82
+ return AzureGPT35FT(model_name, client_settings)
83
+
84
+
85
+ class AzureLLM(AbstractGPT):
86
+ _engine_name_env_key: str | None = None
87
+
88
+ @property
89
+ def client(self) -> AzureOpenAI:
90
+ client: AzureOpenAI = AzureOpenAI(**self.client_settings)
91
+ # api_key: str | None = (None,)
92
+ # timeout: httpx.Timeout(timeout=600.0, connect=5.0)
93
+ # max_retries: int = 2
94
+ return client
95
+
96
+ @property
97
+ def engine(self) -> str:
98
+ return ensure_env_var(self._engine_name_env_key)
99
+
100
+ def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response:
101
+ openai_response = self.client.chat.completions.create(
102
+ model=self.engine,
103
+ messages=self._convert_to_platform_messages(messages),
104
+ stream=False,
105
+ **llm_settings,
106
+ )
107
+ response = self._convert_to_response(openai_response)
108
+ return response
109
+
110
+ def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse:
111
+ platform_stream_response = self.client.chat.completions.create(
112
+ model=self.engine,
113
+ messages=self._convert_to_platform_messages(messages),
114
+ stream=True,
115
+ **llm_settings,
116
+ )
117
+ stream_response = self._convert_to_streamresponse(platform_stream_response)
118
+ return stream_response
119
+
120
+
121
+ # omni 2024-05-13 --------------------------------------------------------------------------------------------
122
+ class AzureGPT4O_20240513(AzureLLM):
123
+ dollar_per_ktoken = APIPricing(input=0.005, output=0.015) # 30倍/45倍
124
+ model: str = "gpt-4o-2024-05-13"
125
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT4O_20240513"
126
+ context_window: int = 128_000
127
+
128
+
129
+ # 2024-04-09 --------------------------------------------------------------------------------------------
130
+ class AzureGPT4T_20240409(AzureLLM):
131
+ dollar_per_ktoken = APIPricing(input=0.01, output=0.03)
132
+ model: str = "gpt-4-turbo-2024-04-09"
133
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT4T_20240409"
134
+ context_window: int = 128_000
135
+
136
+
137
+ # 0125 --------------------------------------------------------------------------------------------
138
+ class AzureGPT35T_0125(AzureLLM):
139
+ dollar_per_ktoken = APIPricing(input=0.0005, output=0.0015)
140
+ model: str = "gpt-3.5-turbo-0125"
141
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT35T_0125"
142
+ context_window: int = 16_385
143
+
144
+
145
+ class AzureGPT4T_0125(AzureLLM):
146
+ dollar_per_ktoken = APIPricing(input=0.01, output=0.03)
147
+ model: str = "gpt-4-turbo-0125"
148
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT4T_0125"
149
+ context_window: int = 128_000
150
+
151
+
152
+ # 1106 --------------------------------------------------------------------------------------------
153
+ class AzureGPT35T_1106(AzureLLM):
154
+ dollar_per_ktoken = APIPricing(input=0.001, output=0.002)
155
+ model: str = "gpt-3.5-turbo-1106"
156
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT35T_1106"
157
+ context_window: int = 16_385
158
+
159
+
160
+ class AzureGPT4VT_1106(AzureLLM):
161
+ dollar_per_ktoken = APIPricing(input=0.01, output=0.03) # 10倍/15倍
162
+ model: str = "gpt-4-1106-vision-preview"
163
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT4VT_1106"
164
+ context_window: int = 128_000
165
+
166
+
167
+ class AzureGPT4T_1106(AzureLLM):
168
+ dollar_per_ktoken = APIPricing(input=0.01, output=0.03)
169
+ model: str = "gpt-4-turbo-1106"
170
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT4T_1106"
171
+ context_window: int = 128_000
172
+
173
+
174
+ # FT --------------------------------------------------------------------------------------------
175
+ class AzureGPT35FT(AzureLLM):
176
+ dollar_per_ktoken = APIPricing(input=0.0005, output=0.0015) # 1倍 + セッション稼働時間
177
+ model: str = "gpt-3.5-turbo-ft"
178
+ context_window: int = 4_096
179
+
180
+ def __init__(self, model_name: str, client_setting: ClientSettings) -> None:
181
+ super().__init__(client_setting)
182
+ self._engine = model_name
183
+
184
+ @property
185
+ def engine(self) -> str:
186
+ return self._engine
187
+
188
+
189
+ # 0613 --------------------------------------------------------------------------------------------
190
+ class AzureGPT35T_0613(AzureLLM):
191
+ dollar_per_ktoken = APIPricing(input=0.0015, output=0.002)
192
+ model: str = "gpt-3.5-turbo-0613"
193
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT35T_0613"
194
+ context_window: int = 4_096
195
+
196
+
197
+ class AzureGPT35T16k_0613(AzureLLM):
198
+ dollar_per_ktoken = APIPricing(input=0.003, output=0.004) # 2倍
199
+ model: str = "gpt-3.5-turbo-16k-0613"
200
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT35T_16K_0613"
201
+ context_window: int = 16_385
202
+
203
+
204
+ class AzureGPT4_0613(AzureLLM):
205
+ dollar_per_ktoken = APIPricing(input=0.03, output=0.06) # 20倍/30倍
206
+ model: str = "gpt-4-0613"
207
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT4_0613"
208
+ context_window: int = 8_192
209
+
210
+
211
+ class AzureGPT432k_0613(AzureLLM):
212
+ dollar_per_ktoken = APIPricing(input=0.06, output=0.12) # 40倍/60倍
213
+ model: str = "gpt-4-32k-0613"
214
+ _engine_name_env_key: str = "AZURE_ENGINE_GPT4_32K_0613"
215
+ context_window: int = 32_768
neollm/llm/gpt/openai_llm.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, cast
2
+
3
+ from openai import OpenAI
4
+
5
+ from neollm.llm.abstract_llm import AbstractLLM
6
+ from neollm.llm.gpt.abstract_gpt import AbstractGPT
7
+ from neollm.types import (
8
+ APIPricing,
9
+ ClientSettings,
10
+ LLMSettings,
11
+ Messages,
12
+ Response,
13
+ StreamResponse,
14
+ )
15
+ from neollm.utils.utils import cprint
16
+
17
+ # Models: https://platform.openai.com/docs/models/continuous-model-upgrades
18
+ # Pricing: https://openai.com/pricing
19
+
20
+ SUPPORTED_MODELS = Literal[
21
+ "gpt-4o-2024-05-13",
22
+ "gpt-4-turbo-2024-04-09",
23
+ "gpt-3.5-turbo-0125",
24
+ "gpt-4-turbo-0125",
25
+ "gpt-3.5-turbo-1106",
26
+ "gpt-4-turbo-1106",
27
+ "gpt-4v-turbo-1106",
28
+ "gpt-3.5-turbo-0613",
29
+ "gpt-3.5-turbo-16k-0613",
30
+ "gpt-4-0613",
31
+ "gpt-4-32k-0613",
32
+ ]
33
+
34
+
35
+ def get_openai_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM:
36
+ # Add 日付
37
+ replace_map_for_nodate: dict[str, SUPPORTED_MODELS] = {
38
+ "gpt-4o": "gpt-4o-2024-05-13",
39
+ "gpt-3.5-turbo": "gpt-3.5-turbo-0613",
40
+ "gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
41
+ "gpt-4": "gpt-4-0613",
42
+ "gpt-4-32k": "gpt-4-32k-0613",
43
+ "gpt-4-turbo": "gpt-4-turbo-1106",
44
+ "gpt-4v-turbo": "gpt-4v-turbo-1106",
45
+ }
46
+ if model_name in replace_map_for_nodate:
47
+ cprint("WARNING: model_nameに日付を指定してください", color="yellow", background=True)
48
+ print(f"model_name: {model_name} -> {replace_map_for_nodate[model_name]}")
49
+ model_name = replace_map_for_nodate[model_name]
50
+
51
+ # map to LLM
52
+ supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = {
53
+ "gpt-4o-2024-05-13": OpenAIGPT4O_20240513(client_settings),
54
+ "gpt-4-turbo-2024-04-09": OpenAIGPT4T_20240409(client_settings),
55
+ "gpt-3.5-turbo-0125": OpenAIGPT35T_0125(client_settings),
56
+ "gpt-4-turbo-0125": OpenAIGPT4T_0125(client_settings),
57
+ "gpt-3.5-turbo-1106": OpenAIGPT35T_1106(client_settings),
58
+ "gpt-4-turbo-1106": OpenAIGPT4T_1106(client_settings),
59
+ "gpt-4v-turbo-1106": OpenAIGPT4VT_1106(client_settings),
60
+ "gpt-3.5-turbo-0613": OpenAIGPT35T_0613(client_settings),
61
+ "gpt-3.5-turbo-16k-0613": OpenAIGPT35T16k_0613(client_settings),
62
+ "gpt-4-0613": OpenAIGPT4_0613(client_settings),
63
+ "gpt-4-32k-0613": OpenAIGPT432k_0613(client_settings),
64
+ }
65
+ # 通常モデル
66
+ if model_name in supported_model_map:
67
+ model_name = cast(SUPPORTED_MODELS, model_name)
68
+ return supported_model_map[model_name]
69
+ # FTモデル
70
+ if "gpt-3.5-turbo-1106" in model_name:
71
+ return OpenAIGPT35TFT_1106(model_name, client_settings)
72
+ if "gpt-3.5-turbo-0613" in model_name:
73
+ return OpenAIGPT35TFT_0613(model_name, client_settings)
74
+ if "gpt-3.5-turbo-0125" in model_name:
75
+ return OpenAIGPT35TFT_0125(model_name, client_settings)
76
+ if "gpt4" in model_name.replace("-", ""): # TODO! もっといい条件に修正
77
+ return OpenAIGPT4FT_0613(model_name, client_settings)
78
+
79
+ cprint(
80
+ f"WARNING: このFTモデルは何?: {model_name} -> OpenAIGPT35TFT_1106として設定", color="yellow", background=True
81
+ )
82
+ return OpenAIGPT35TFT_1106(model_name, client_settings)
83
+
84
+
85
+ class OpenAILLM(AbstractGPT):
86
+ model: str
87
+
88
+ @property
89
+ def client(self) -> OpenAI:
90
+ client: OpenAI = OpenAI(**self.client_settings)
91
+ # api_key: str | None = (None,)
92
+ # timeout: httpx.Timeout(timeout=600.0, connect=5.0)
93
+ # max_retries: int = 2
94
+ return client
95
+
96
+ def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response:
97
+ openai_response = self.client.chat.completions.create(
98
+ model=self.model,
99
+ messages=self._convert_to_platform_messages(messages),
100
+ stream=False,
101
+ **llm_settings,
102
+ )
103
+ response = self._convert_to_response(openai_response)
104
+ return response
105
+
106
+ def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse:
107
+ platform_stream_response = self.client.chat.completions.create(
108
+ model=self.model,
109
+ messages=self._convert_to_platform_messages(messages),
110
+ stream=True,
111
+ **llm_settings,
112
+ )
113
+ stream_response = self._convert_to_streamresponse(platform_stream_response)
114
+ return stream_response
115
+
116
+
117
+ # omni 2024-05-13 --------------------------------------------------------------------------------------------
118
+ class OpenAIGPT4O_20240513(OpenAILLM):
119
+ dollar_per_ktoken = APIPricing(input=0.005, output=0.015)
120
+ model: str = "gpt-4o-2024-05-13"
121
+ context_window: int = 128_000
122
+
123
+
124
+ # 2024-04-09 --------------------------------------------------------------------------------------------
125
+ class OpenAIGPT4T_20240409(OpenAILLM):
126
+ dollar_per_ktoken = APIPricing(input=0.01, output=0.03) # 10倍/15倍
127
+ model: str = "gpt-4-turbo-2024-04-09"
128
+ # model: str = "gpt-4-turbo-2024-04-09"
129
+ context_window: int = 128_000
130
+
131
+
132
+ # 0125 --------------------------------------------------------------------------------------------
133
+ class OpenAIGPT35T_0125(OpenAILLM):
134
+ dollar_per_ktoken = APIPricing(input=0.0005, output=0.0015)
135
+ model: str = "gpt-3.5-turbo-0125"
136
+ context_window: int = 16_385
137
+
138
+
139
+ class OpenAIGPT4T_0125(OpenAILLM):
140
+ dollar_per_ktoken = APIPricing(input=0.01, output=0.03)
141
+ model: str = "gpt-4-0125-preview"
142
+ context_window: int = 128_000
143
+
144
+
145
+ class OpenAIGPT35TFT_0125(OpenAILLM):
146
+ dollar_per_ktoken = APIPricing(input=0.003, output=0.006)
147
+ context_window: int = 16_385
148
+
149
+ def __init__(self, model_name: str, client_setting: ClientSettings) -> None:
150
+ super().__init__(client_setting)
151
+ self.model = model_name
152
+
153
+
154
+ # 1106 --------------------------------------------------------------------------------------------
155
+ class OpenAIGPT35T_1106(OpenAILLM):
156
+ dollar_per_ktoken = APIPricing(input=0.0010, output=0.0020)
157
+ model: str = "gpt-3.5-turbo-1106"
158
+ context_window: int = 16_385
159
+
160
+
161
+ class OpenAIGPT4T_1106(OpenAILLM):
162
+ dollar_per_ktoken = APIPricing(input=0.01, output=0.03)
163
+ model: str = "gpt-4-1106-preview"
164
+ context_window: int = 128_000
165
+
166
+
167
+ class OpenAIGPT4VT_1106(OpenAILLM):
168
+ dollar_per_ktoken = APIPricing(input=0.01, output=0.03)
169
+ model: str = "gpt-4-1106-vision-preview"
170
+ context_window: int = 128_000
171
+
172
+
173
+ class OpenAIGPT35TFT_1106(OpenAILLM):
174
+ dollar_per_ktoken = APIPricing(input=0.003, output=0.006)
175
+ context_window: int = 4_096
176
+
177
+ def __init__(self, model_name: str, client_setting: ClientSettings) -> None:
178
+ super().__init__(client_setting)
179
+ self.model = model_name
180
+
181
+
182
+ # 0613 --------------------------------------------------------------------------------------------
183
+ class OpenAIGPT35T_0613(OpenAILLM):
184
+ dollar_per_ktoken = APIPricing(input=0.0015, output=0.002)
185
+ model: str = "gpt-3.5-turbo-0613"
186
+ context_window: int = 4_096
187
+
188
+
189
+ class OpenAIGPT35T16k_0613(OpenAILLM):
190
+ dollar_per_ktoken = APIPricing(input=0.003, output=0.004)
191
+ model: str = "gpt-3.5-turbo-16k-0613"
192
+ context_window: int = 16_385
193
+
194
+
195
+ class OpenAIGPT4_0613(OpenAILLM):
196
+ dollar_per_ktoken = APIPricing(input=0.03, output=0.06)
197
+ model: str = "gpt-4-0613"
198
+ context_window: int = 8_192
199
+
200
+
201
+ class OpenAIGPT432k_0613(OpenAILLM):
202
+ dollar_per_ktoken = APIPricing(input=0.06, output=0.12)
203
+ model: str = "gpt-4-32k-0613"
204
+ context_window: int = 32_768
205
+
206
+
207
+ class OpenAIGPT35TFT_0613(OpenAILLM):
208
+ dollar_per_ktoken = APIPricing(input=0.003, output=0.006)
209
+ context_window: int = 4_096
210
+
211
+ def __init__(self, model_name: str, client_setting: ClientSettings) -> None:
212
+ super().__init__(client_setting)
213
+ self.model = model_name
214
+
215
+
216
+ class OpenAIGPT4FT_0613(OpenAILLM):
217
+ dollar_per_ktoken = APIPricing(input=0.045, output=0.090)
218
+ context_window: int = 8_192
219
+
220
+ def __init__(self, model_name: str, client_setting: ClientSettings) -> None:
221
+ super().__init__(client_setting)
222
+ self.model = model_name
neollm/llm/gpt/token.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import textwrap
3
+ from typing import Any, Iterator, overload
4
+
5
+ import tiktoken
6
+
7
+ from neollm.types import Function
8
+ from neollm.utils.utils import cprint # , Functions, Messages
9
+
10
+ DEFAULT_MODEL_NAME = "gpt-3.5-turbo"
11
+
12
+
13
+ def get_tokenizer(model_name: str) -> tiktoken.Encoding:
14
+ # 参考: https://platform.openai.com/docs/models/gpt-3-5
15
+ MODEL_NAME_MAP = [
16
+ ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613"),
17
+ ("gpt-3.5-turbo", "gpt-3.5-turbo-0613"),
18
+ ("gpt-4-32k", "gpt-4-32k-0613"),
19
+ ("gpt-4", "gpt-4-0613"),
20
+ ]
21
+ ALL_VERSION_MODELS = [
22
+ # gpt-3.5-turbo
23
+ "gpt-3.5-turbo-0125",
24
+ "gpt-3.5-turbo-1106",
25
+ "gpt-3.5-turbo-0613",
26
+ "gpt-3.5-turbo-16k-0613",
27
+ "gpt-3.5-turbo-0301", # Legacy
28
+ # gpt-4
29
+ "gpt-4o-2024-05-13",
30
+ "gpt-4-turbo-0125",
31
+ "gpt-4-turbo-1106",
32
+ "gpt-4-0613",
33
+ "gpt-4-32k-0613",
34
+ "gpt-4-0314", # Legacy
35
+ "gpt-4-32k-0314", # Legacy
36
+ ]
37
+ # Azure表記 → OpenAI表記に統一
38
+ model_name = model_name.replace("gpt-35", "gpt-3.5")
39
+ # 最新モデルを正式名称に & 新モデル, FTモデルをキャッチ
40
+ if model_name not in ALL_VERSION_MODELS:
41
+ for key, model_name_version in MODEL_NAME_MAP:
42
+ if key in model_name:
43
+ model_name = model_name_version
44
+ break
45
+ try:
46
+ return tiktoken.encoding_for_model(model_name)
47
+ except Exception as e:
48
+ cprint(f"WARNING: Tokenizerの取得に失敗。{model_name}: {e}", color="yellow", background=True)
49
+ return tiktoken.encoding_for_model("gpt-3.5-turbo")
50
+
51
+
52
+ @overload
53
+ def count_tokens(messages: str, model_name: str | None = None) -> int: ...
54
+
55
+
56
+ @overload
57
+ def count_tokens(
58
+ messages: Iterator[dict[str, str]], model_name: str | None = None, functions: Any | None = None
59
+ ) -> int: ...
60
+
61
+
62
+ def count_tokens(
63
+ messages: Iterator[dict[str, str]] | str,
64
+ model_name: str | None = None,
65
+ functions: Any | None = None,
66
+ ) -> int:
67
+ if isinstance(messages, str):
68
+ tokenizer = get_tokenizer(model_name or DEFAULT_MODEL_NAME)
69
+ encoded = tokenizer.encode(messages)
70
+ return len(encoded)
71
+ return _count_messages_and_function_tokens(messages, model_name, functions)
72
+
73
+
74
+ def _count_messages_and_function_tokens(
75
+ messages: Iterator[dict[str, str]], model_name: str | None = None, functions: Any | None = None
76
+ ) -> int:
77
+ """トークン数計測
78
+
79
+ Args:
80
+ messages (Messages): GPTAPIの入力のmessages
81
+ model_name (str | None, optional): モデル名. Defaults to None.
82
+ functions (Functions | None, optional): GPTAPIの入力のfunctions. Defaults to None.
83
+
84
+ Returns:
85
+ int: トークン数
86
+ """
87
+ num_tokens = _count_messages_tokens(messages, model_name or DEFAULT_MODEL_NAME)
88
+ if functions is not None:
89
+ num_tokens += _count_functions_tokens(functions, model_name)
90
+ return num_tokens
91
+
92
+
93
+ # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
94
+ def _count_messages_tokens(messages: Iterator[dict[str, str]] | None, model_name: str) -> int:
95
+ """メッセージのトークン数を計算
96
+
97
+ Args:
98
+ messages (Messages): ChatGPT等APIに入力するmessages
99
+ model_name (str, optional): 使用するモデルの名前
100
+ "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314", "gpt-4-32k-0314"
101
+ "gpt-4-0613", "gpt-4-32k-0613", "gpt-3.5-turbo", "gpt-4"
102
+
103
+ Returns:
104
+ int: トークン数の合計
105
+ """
106
+ if messages is None:
107
+ return 0
108
+ # setting model
109
+ encoding_model = get_tokenizer(model_name)
110
+
111
+ # config
112
+ if "gpt-3.5-turbo-0301" in model_name:
113
+ tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
114
+ tokens_per_name = -1 # if there's a name, the role is omitted
115
+ else:
116
+ tokens_per_message = 3
117
+ tokens_per_name = 1
118
+
119
+ # count tokens
120
+ num_tokens = 3 # every reply is primed with <|start|>assistant<|message|>
121
+ for message in messages:
122
+ num_tokens += tokens_per_message
123
+ for key, value in message.items():
124
+ if isinstance(value, str):
125
+ num_tokens += len(encoding_model.encode(value))
126
+ if key == "name":
127
+ num_tokens += tokens_per_name
128
+ return num_tokens
129
+
130
+
131
+ # https://gist.github.com/CGamesPlay/dd4f108f27e2eec145eedf5c717318f5
132
+ def _count_functions_tokens(functions: Any, model_name: str | None = None) -> int:
133
+ """
134
+ functionsのトークン数計測
135
+
136
+ Args:
137
+ functions (Functions): GPTAPIの入力のfunctions
138
+ model_name (str | None, optional): モデル名. Defaults to None.
139
+
140
+ Returns:
141
+ _type_: トークン数
142
+ """
143
+ encoding_model = encoding_model = get_tokenizer(model_name or DEFAULT_MODEL_NAME)
144
+ num_tokens = 3 + len(encoding_model.encode(__functions2string(functions)))
145
+ return num_tokens
146
+
147
+
148
+ # functionsのstring化、補助関数 ---------------------------------------------------------------------------
149
+ def __functions2string(functions: Any) -> str:
150
+ """functionsの文字列化
151
+
152
+ Args:
153
+ functions (Functions): GPTAPIの入力のfunctions
154
+
155
+ Returns:
156
+ str: functionsの文字列
157
+ """
158
+ prefix = "# Tools\n\n## functions\n\nnamespace functions {\n\n} // namespace functions\n"
159
+ functions_string = prefix + "".join(__function2string(function) for function in functions)
160
+ return functions_string
161
+
162
+
163
+ def __function2string(function: Function) -> str:
164
+ """functionの文字列化
165
+
166
+ Args:
167
+ function (Function): GPTAPIのfunctionの要素
168
+
169
+ Returns:
170
+ str: functionの文字列
171
+ """
172
+ object_string = __format_object(function["parameters"])
173
+ if object_string is not None:
174
+ object_string = "_: " + object_string
175
+ else:
176
+ object_string = ""
177
+
178
+ functions_string: str = (
179
+ f"// {function['description']}\ntype {function['name']} = (" + object_string + ") => any;\n\n"
180
+ )
181
+ return functions_string
182
+
183
+
184
+ def __format_object(schema: dict[str, Any], indent: int = 0) -> str | None:
185
+ if "properties" not in schema or len(schema["properties"]) == 0:
186
+ if schema.get("additionalProperties", False):
187
+ return "object"
188
+ return None
189
+
190
+ result = "{\n"
191
+ for key, value in dict(schema["properties"]).items():
192
+ # value <- resolve_ref(value)
193
+ value_rendered = __format_schema(value, indent + 1)
194
+ if value_rendered is None:
195
+ continue
196
+ # description
197
+ if "description" in value:
198
+ description = "".join(
199
+ " " * indent + f"// {description_i}\n"
200
+ for description_i in textwrap.dedent(value["description"]).strip().split("\n")
201
+ )
202
+ # optional
203
+ optional = "" if key in schema.get("required", {}) else "?"
204
+ # default
205
+ default_comment = "" if "default" not in value else f" // default: {__format_default(value)}"
206
+ # add string
207
+ result += description + " " * indent + f"{key}{optional}: {value_rendered},{default_comment}\n"
208
+ result += (" " * (indent - 1)) + "}"
209
+ return result
210
+
211
+
212
+ # よくわからん
213
+ # def resolve_ref(schema):
214
+ # if schema.get("$ref") is not None:
215
+ # ref = schema["$ref"][14:]
216
+ # schema = json_schema["definitions"][ref]
217
+ # return schema
218
+
219
+
220
+ def __format_schema(schema: dict[str, Any], indent: int) -> str | None:
221
+ # schema <- resolve_ref(schema)
222
+ if "enum" in schema:
223
+ return __format_enum(schema)
224
+ elif schema["type"] == "object":
225
+ return __format_object(schema, indent)
226
+ elif schema["type"] in {"integer", "number"}:
227
+ return "number"
228
+ elif schema["type"] in {"string"}:
229
+ return "string"
230
+ elif schema["type"] == "array":
231
+ return str(__format_schema(schema["items"], indent)) + "[]"
232
+ else:
233
+ raise ValueError("unknown schema type " + schema["type"])
234
+
235
+
236
+ def __format_enum(schema: dict[str, Any]) -> str:
237
+ # "A" | "B" | "C"
238
+ return " | ".join(json.dumps(element, ensure_ascii=False) for element in schema["enum"])
239
+
240
+
241
+ def __format_default(schema: dict[str, Any]) -> str:
242
+ default = schema["default"]
243
+ if schema["type"] == "number" and float(default).is_integer():
244
+ # numberの時、0 → 0.0
245
+ return f"{default:.1f}"
246
+ else:
247
+ return str(default)
neollm/llm/platform.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class Platform(str, Enum):
5
+ AZURE = "azure"
6
+ OPENAI = "openai"
7
+ ANTHROPIC = "anthropic"
8
+ GCP = "gcp"
9
+
10
+ @classmethod
11
+ def from_string(cls, platform: str) -> "Platform":
12
+ platform = platform.lower().strip()
13
+ try:
14
+ return cls(platform)
15
+ except Exception:
16
+ raise ValueError(f"platform must be {cls.__members__}, but got {platform}.")
neollm/llm/utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, TypeVar
2
+
3
+ from neollm.utils.utils import cprint
4
+
5
+ Immutable = tuple[Any, ...] | str | int | float | bool
6
+ _T = TypeVar("_T")
7
+ _TD = TypeVar("_TD")
8
+
9
+
10
+ def _to_immutable(x: Any) -> Immutable:
11
+ """list, dictをtupleに変換して, setに格納できるようにする
12
+
13
+ Args:
14
+ x (Any): 要素
15
+
16
+ Returns:
17
+ Immutable: Immutableな要素(dict, listはtupleに変換)
18
+ """
19
+ if isinstance(x, list):
20
+ return tuple(map(_to_immutable, x))
21
+ if isinstance(x, dict):
22
+ return tuple((key, _to_immutable(value)) for key, value in sorted(x.items()))
23
+ if isinstance(x, (set, frozenset)):
24
+ return tuple(sorted(map(_to_immutable, x)))
25
+ if isinstance(x, (str, int, float, bool)):
26
+ return x
27
+ cprint("_to_immutable: not supported: 無理やりstr(*)", color="yellow", background=True)
28
+ return str(x)
29
+
30
+
31
+ def _remove_duplicate(arr: list[_T | None]) -> list[_T]:
32
+ """listの重複と初期値を削除する
33
+
34
+ Args:
35
+ arr (list[Any]): リスト
36
+
37
+ Returns:
38
+ list[Any]: 重複削除済みのlist
39
+ """
40
+ seen_set: set[Immutable] = set()
41
+ unique_list: list[_T] = []
42
+ for x in arr:
43
+ if x is None or bool(x) is False:
44
+ continue
45
+ x_immutable = _to_immutable(x)
46
+ if x_immutable not in seen_set:
47
+ unique_list.append(x)
48
+ seen_set.add(x_immutable)
49
+ return unique_list
50
+
51
+
52
+ def get_entity(arr: list[_T | None], default: _TD, index: int | None = None) -> _T | _TD:
53
+ """listから必要な1要素を取得する
54
+
55
+ Args:
56
+ arr (list[Any]): list
57
+ default (Any): 初期値
58
+ index (int | None, optional): 複数ある場合、指定のindex. Defaults to None.
59
+
60
+ Returns:
61
+ Any: 要素
62
+ """
63
+ arr_cleaned = _remove_duplicate(arr)
64
+ if len(arr_cleaned) == 0:
65
+ return default
66
+ if len(arr_cleaned) == 1:
67
+ return arr_cleaned[0]
68
+ if index is not None:
69
+ return arr_cleaned[index]
70
+ cprint("get_entity: not unique", color="yellow", background=True)
71
+ cprint(arr_cleaned, color="yellow", background=True)
72
+ return arr_cleaned[0]
neollm/myllm/abstract_myllm.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import TYPE_CHECKING, Generator, Generic, Optional, TypeAlias, cast
3
+
4
+ from neollm.myllm.print_utils import print_inputs, print_metadata, print_outputs
5
+ from neollm.types import (
6
+ InputType,
7
+ OutputType,
8
+ PriceInfo,
9
+ StreamOutputType,
10
+ TimeInfo,
11
+ TokenInfo,
12
+ )
13
+ from neollm.utils.utils import cprint
14
+
15
+ if TYPE_CHECKING:
16
+ from typing import Any
17
+
18
+ from neollm.myllm.myl3m2 import MyL3M2
19
+
20
+ _MyL3M2: TypeAlias = MyL3M2[Any, Any]
21
+
22
+
23
+ class AbstractMyLLM(ABC, Generic[InputType, OutputType]):
24
+ """MyLLM, MyL3M2の抽象クラス"""
25
+
26
+ inputs: InputType | None
27
+ outputs: OutputType | None
28
+ silent_set: set[str]
29
+ verbose: bool
30
+ time: float = 0.0
31
+ time_detail: TimeInfo = TimeInfo()
32
+ parent: Optional["_MyL3M2"] = None
33
+ do_stream: bool
34
+
35
+ @property
36
+ @abstractmethod
37
+ def token(self) -> TokenInfo:
38
+ """LLMの利用トークン数
39
+
40
+ Returns:
41
+ TokenInfo: トークン数 (入力, 出力, 合計)
42
+ >>> TokenInfo(input=1588, output=128, total=1716)
43
+ """
44
+
45
+ @property
46
+ def custom_token(self) -> TokenInfo | None:
47
+ """料金計算用トークン(Gemini用)"""
48
+ return None
49
+
50
+ @property
51
+ @abstractmethod
52
+ def price(self) -> PriceInfo:
53
+ """LLMの利用料金 (USD)
54
+
55
+ Returns:
56
+ PriceInfo: 利用料金 (USD) (入力, 出力, 合計)
57
+ >>> PriceInfo(input=0.002382, output=0.000256, total=0.002638)
58
+ """
59
+
60
+ @abstractmethod
61
+ def _call(self, inputs: InputType, stream: bool = False) -> Generator[StreamOutputType, None, OutputType]:
62
+ """MyLLMの子クラスのメインロジック
63
+
64
+ streamとnon-streamの両方のコードを書く必要がある
65
+
66
+ Args:
67
+ inputs (InputType): LLMへの入力
68
+ stream (bool, optional): streamの有無. Defaults to False.
69
+
70
+ Yields:
71
+ Generator[StreamOutputType, None, OutputType]: LLMのstream出力
72
+
73
+ Returns:
74
+ OutputType: LLMの出力
75
+ """
76
+
77
+ def __call__(self, inputs: InputType) -> OutputType:
78
+ """MyLLMのメインロジック
79
+
80
+ Args:
81
+ inputs (InputType): LLMへの入力
82
+
83
+ Returns:
84
+ OutputType: LLMの出力
85
+ """
86
+ it: Generator[StreamOutputType, None, OutputType] = self._call(inputs, stream=self.do_stream)
87
+ while True:
88
+ try:
89
+ next(it)
90
+ except StopIteration as e:
91
+ outputs = cast(OutputType, e.value)
92
+ return outputs
93
+ except Exception as e:
94
+ raise e
95
+
96
+ def call_stream(self, inputs: InputType) -> Generator[StreamOutputType, None, OutputType]:
97
+ """MyLLMのメインロジック(stream処理)
98
+
99
+ Args:
100
+ inputs (InputType): LLMへの入力
101
+
102
+ Yields:
103
+ Generator[StreamOutputType, None, OutputType]: LLMのstream出力
104
+
105
+ Returns:
106
+ LLMの出力
107
+ """
108
+ it: Generator[StreamOutputType, None, OutputType] = self._call(inputs, stream=True)
109
+ while True:
110
+ try:
111
+ delta_content = next(it)
112
+ yield delta_content
113
+ except StopIteration as e:
114
+ outputs = cast(OutputType, e.value)
115
+ return outputs
116
+ except Exception as e:
117
+ raise e
118
+
119
+ def _print_inputs(self) -> None:
120
+ if self.inputs is None:
121
+ return
122
+ if not ("inputs" not in self.silent_set and self.verbose):
123
+ return
124
+ print_inputs(self.inputs)
125
+
126
+ def _print_outputs(self) -> None:
127
+ if self.outputs is None:
128
+ return
129
+ if not ("outputs" not in self.silent_set and self.verbose):
130
+ return
131
+ print_outputs(self.outputs)
132
+
133
+ def _print_metadata(self) -> None:
134
+ if not ("metadata" not in self.silent_set and self.verbose):
135
+ return
136
+ print_metadata(self.time, self.token, self.price)
137
+
138
+ def _print_start(self, sep: str = "-") -> None:
139
+ if not self.verbose:
140
+ return
141
+ if self.parent is None:
142
+ cprint("PARENT", color="red", background=True)
143
+ print(self, sep * (99 - len(str(self))))
144
+
145
+ def _print_end(self, sep: str = "-") -> None:
146
+ if not self.verbose:
147
+ return
148
+ print(sep * 100)
neollm/myllm/myl3m2.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from typing import Any, Generator, Literal, Optional, cast
5
+
6
+ from neollm.myllm.abstract_myllm import AbstractMyLLM
7
+ from neollm.myllm.myllm import MyLLM
8
+ from neollm.myllm.print_utils import TITLE_COLOR
9
+ from neollm.types import (
10
+ InputType,
11
+ OutputType,
12
+ PriceInfo,
13
+ StreamOutputType,
14
+ TimeInfo,
15
+ TokenInfo,
16
+ )
17
+ from neollm.utils.utils import cprint
18
+
19
+
20
+ class MyL3M2(AbstractMyLLM[InputType, OutputType]):
21
+ """LLMの複数リクエストをまとめるクラス"""
22
+
23
+ do_stream: bool = False # stream_verboseがないため、__call__ではstreamを使わない
24
+
25
+ def __init__(
26
+ self,
27
+ parent: Optional["MyL3M2[Any, Any]"] = None,
28
+ verbose: bool = False,
29
+ silent_list: list[Literal["inputs", "outputs", "metadata", "all_myllm"]] | None = None,
30
+ ) -> None:
31
+ """
32
+ MyL3M2の初期化
33
+
34
+ Args:
35
+ parent (MyL3M2, optional):
36
+ 親のMyL3M2のインスタンス(self or None)
37
+ verbose (bool, optional):
38
+ 出力をするかどうかのフラグ. Defaults to False.
39
+ sileznt_list (list[Literal["inputs", "outputs", "metadata", "all_myllm"]], optional):
40
+ サイレントモードのリスト。出力を抑制する要素を指定する。. Defaults to None(=[]).
41
+ """
42
+ self.parent = parent
43
+ self.verbose = verbose
44
+ self.silent_set = set(silent_list or [])
45
+ self.myllm_list: list["MyL3M2[Any, Any]" | MyLLM[Any, Any]] = []
46
+ self.inputs: InputType | None = None
47
+ self.outputs: OutputType | None = None
48
+ self.called: bool = False
49
+
50
+ def _link(self, inputs: InputType) -> OutputType:
51
+ """複数のLLMの処理を行う
52
+
53
+ Args:
54
+ inputs (InputType): 入力データを保持する辞書
55
+
56
+ Returns:
57
+ OutputType: 処理結果の出力データ
58
+ """
59
+ raise NotImplementedError("_link(self, inputs: InputType) -> OutputType:を実装してください")
60
+
61
+ def _stream_link(self, inputs: InputType) -> Generator[StreamOutputType, None, OutputType]:
62
+ """複数のLLMの処理を行う(stream処理)
63
+
64
+ Args:
65
+ inputs (InputType): 入力データを保持する辞書
66
+
67
+ Yields:
68
+ Generator[StreamOutputType, None, OutputType]: 処理結果の出力データ(stream)
69
+
70
+ Returns:
71
+ self.outputsに入れたいもの
72
+ """
73
+ raise NotImplementedError(
74
+ "_stream_link(self, inputs: InputType) -> Generator[StreamOutputType, None, None]を実装してください"
75
+ )
76
+
77
+ def _call(self, inputs: InputType, stream: bool = False) -> Generator[StreamOutputType, None, OutputType]:
78
+ if self.called:
79
+ raise RuntimeError("MyLLMは1回しか呼び出せない")
80
+
81
+ self._print_start(sep="=")
82
+
83
+ # main -----------------------------------------------------------
84
+ t_start = time.time()
85
+ self.inputs = inputs
86
+ # [stream]
87
+ if stream:
88
+ it = self._stream_link(inputs)
89
+ while True:
90
+ try:
91
+ yield next(it)
92
+ except StopIteration as e:
93
+ self.outputs = cast(OutputType, e.value)
94
+ break
95
+ except Exception as e:
96
+ raise e
97
+ # [non-stream]
98
+ else:
99
+ self.outputs = self._link(inputs)
100
+ self._print_inputs()
101
+ self._print_outputs()
102
+ self._print_all_myllm()
103
+ self.time = time.time() - t_start
104
+ self.time_detail = TimeInfo(total=self.time, main=self.time)
105
+
106
+ # metadata -----------------------------------------------------------
107
+ self._print_metadata()
108
+ self._print_end(sep="=")
109
+
110
+ # 親MyL3M2にAppend -----------------------------------------------------------
111
+ if self.parent is not None:
112
+ self.parent.myllm_list.append(self)
113
+ self.called = True
114
+
115
+ return self.outputs
116
+
117
+ @property
118
+ def token(self) -> TokenInfo:
119
+ token = TokenInfo(input=0, output=0, total=0)
120
+ for myllm in self.myllm_list:
121
+ # TODO: token += myllm.token
122
+ token.input += myllm.token.input
123
+ token.output += myllm.token.output
124
+ token.total += myllm.token.total
125
+ return token
126
+
127
+ @property
128
+ def price(self) -> PriceInfo:
129
+ price = PriceInfo(input=0.0, output=0.0, total=0.0)
130
+ for myllm in self.myllm_list:
131
+ # TODO: price += myllm.price
132
+ price.input += myllm.price.input
133
+ price.output += myllm.price.output
134
+ price.total += myllm.price.total
135
+ return price
136
+
137
+ @property
138
+ def logs(self) -> list[Any]:
139
+ logs: list[Any] = []
140
+ for myllm in self.myllm_list:
141
+ if isinstance(myllm, MyLLM):
142
+ logs.append(myllm.log)
143
+ elif isinstance(myllm, MyL3M2):
144
+ logs.extend(myllm.logs)
145
+ return logs
146
+
147
+ def _print_all_myllm(self, prefix: str = "", title: bool = True) -> None:
148
+ if not ("all_myllm" not in self.silent_set and self.verbose):
149
+ return
150
+ try:
151
+ if title:
152
+ cprint("[all_myllm]", color=TITLE_COLOR)
153
+ print(" ", end="")
154
+ cprint(f"{self}", color="magenta", bold=True, underline=True)
155
+ for myllm in self.myllm_list:
156
+ if isinstance(myllm, MyLLM):
157
+ cprint(f" {prefix}- {myllm}", color="cyan")
158
+ elif isinstance(myllm, MyL3M2):
159
+ cprint(f" {prefix}- {myllm}", color="magenta")
160
+ myllm._print_all_myllm(prefix=prefix + " ", title=False)
161
+ except Exception as e:
162
+ cprint(e, color="red", background=True)
163
+
164
+ def __repr__(self) -> str:
165
+ return f"MyL3M2({self.__class__.__name__})"
neollm/myllm/myllm.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from abc import abstractmethod
4
+ from typing import TYPE_CHECKING, Any, Final, Generator, Literal, Optional
5
+
6
+ from neollm.exceptions import ContentFilterError
7
+ from neollm.llm import AbstractLLM, get_llm
8
+ from neollm.llm.gpt.azure_llm import AzureLLM
9
+ from neollm.myllm.abstract_myllm import AbstractMyLLM
10
+ from neollm.myllm.print_utils import (
11
+ print_client_settings,
12
+ print_delta,
13
+ print_llm_settings,
14
+ print_messages,
15
+ )
16
+ from neollm.types import (
17
+ Chunk,
18
+ ClientSettings,
19
+ Functions,
20
+ InputType,
21
+ LLMSettings,
22
+ Message,
23
+ Messages,
24
+ OutputType,
25
+ PriceInfo,
26
+ Response,
27
+ StreamOutputType,
28
+ TimeInfo,
29
+ TokenInfo,
30
+ Tools,
31
+ )
32
+ from neollm.types.openai.chat_completion import CompletionUsageForCustomPriceCalculation
33
+ from neollm.utils.preprocess import dict2json
34
+ from neollm.utils.utils import cprint
35
+
36
+ if TYPE_CHECKING:
37
+ from neollm.myllm.myl3m2 import MyL3M2
38
+
39
+ _MyL3M2 = MyL3M2[Any, Any]
40
+ _State = dict[Any, Any]
41
+
42
+ DEFAULT_LLM_SETTINGS: LLMSettings = {"temperature": 0}
43
+ DEFAULT_PLATFORM: Final[str] = "azure"
44
+
45
+
46
+ class MyLLM(AbstractMyLLM[InputType, OutputType]):
47
+ """LLMの単一リクエストをまとめるクラス"""
48
+
49
+ def __init__(
50
+ self,
51
+ model: str,
52
+ parent: Optional["_MyL3M2"] = None,
53
+ llm_settings: LLMSettings | None = None,
54
+ client_settings: ClientSettings | None = None,
55
+ platform: str | None = None,
56
+ verbose: bool = False,
57
+ stream_verbose: bool = False,
58
+ silent_list: list[Literal["llm_settings", "inputs", "outputs", "messages", "metadata"]] | None = None,
59
+ log_dir: str | None = None,
60
+ ) -> None:
61
+ """
62
+ MyLLMクラスの初期化
63
+
64
+ Args:
65
+ model (Optional[str]): LLMモデル名
66
+ parent (Optional[MyL3M2]): 親のMyL3M2のインスタンス (self or None)
67
+ llm_settings (LLMSettings): LLMの設定パラメータ
68
+ client_settings (ClientSettings): llmのclientの設定パラメータ
69
+ platform (Optional[str]): LLMのプラットフォーム名 (デフォルト: os.environ["PLATFORM"] or "azure")
70
+ (enum: openai, azure)
71
+ verbose (bool): 出力をするかどうかのフラグ
72
+ stream_verbose (bool): assitantをstreamで出力するか(verbose=False, message in "messages"の時、無効)
73
+ silent_list (list[Literal["llm_settings", "inputs", "outputs", "messages", "metadata"]]):
74
+ verbose=True時, 出力を抑制する要素のリスト
75
+ log_dir (Optional[str]): ログを保存するディレクトリのパス Noneの時、保存しない
76
+ """
77
+ self.parent: _MyL3M2 | None = parent
78
+ self.llm_settings = llm_settings or DEFAULT_LLM_SETTINGS
79
+ self.client_settings = client_settings or {}
80
+ self.model: str = model
81
+ self.platform: str = platform or os.environ.get("LLM_PLATFORM", DEFAULT_PLATFORM) or DEFAULT_PLATFORM
82
+ self.verbose: bool = verbose & (True if self.parent is None else self.parent.verbose) # 親に合わせる
83
+ self.silent_set = set(silent_list or [])
84
+ self.stream_verbose: bool = stream_verbose if verbose and ("messages" not in self.silent_set) else False
85
+ self.log_dir: str | None = log_dir
86
+
87
+ self.inputs: InputType | None = None
88
+ self.outputs: OutputType | None = None
89
+ self.messages: Messages | None = None
90
+ self.functions: Functions | None = None
91
+ self.tools: Tools | None = None
92
+ self.response: Response | None = None
93
+ self.called: bool = False
94
+ self.do_stream: bool = self.stream_verbose
95
+
96
+ self.llm: AbstractLLM = get_llm(
97
+ model_name=self.model, platform=self.platform, client_settings=self.client_settings
98
+ )
99
+
100
+ @abstractmethod
101
+ def _preprocess(self, inputs: InputType) -> Messages:
102
+ """
103
+ inputs を API入力 の messages に前処理する
104
+
105
+ Args:
106
+ inputs (InputType): 入力
107
+
108
+ Returns:
109
+ Messages: API入力 の messages
110
+ >>> [{"role": "system", "content": "system_prompt"}, {"role": "user", "content": "user_prompt"}]
111
+ """
112
+
113
+ @abstractmethod
114
+ def _postprocess(self, response: Response) -> OutputType:
115
+ """
116
+ API の response を outputs に後処理する
117
+
118
+ Args:
119
+ response (Response): API の response
120
+ >>> {"choices": [{"message": {"role": "assistant",
121
+ >>> "content": "This is a test!"}}]}
122
+ >>> {"choices": [{"message": {"role": "assistant",
123
+ >>> "function_call": {"name": "func", "arguments": "{a: 1}"}}]}
124
+
125
+ Returns:
126
+ OutputType: 出力
127
+ """
128
+
129
+ def _ruleprocess(self, inputs: InputType) -> OutputType | None:
130
+ """
131
+ ルールベース処理 or APIリクエスト の判断
132
+
133
+ Args:
134
+ inputs (InputType): MyLLMの入力
135
+
136
+ Returns:
137
+ RuleOutputs:
138
+ ルールベース処理の時、MyLLMの出力を返す
139
+ APIリクエストの時、Noneを返す
140
+ """
141
+ return None
142
+
143
+ def _update_settings(self) -> None:
144
+ """
145
+ APIの設定の更新
146
+ Note:
147
+ messageのトークン数
148
+ >>> self.llm.count_tokens(self.messsage)
149
+
150
+ モデル変更
151
+ >>> self.model = "gpt-3.5-turbo-16k"
152
+
153
+ パラメータ変更
154
+ >>> self.llm_settings = {"temperature": 0.2}
155
+ """
156
+ return None
157
+
158
+ def _add_tools(self, inputs: InputType) -> Tools | None:
159
+ return None
160
+
161
+ def _add_functions(self, inputs: InputType) -> Functions | None:
162
+ """
163
+ functions の追加
164
+
165
+ Args:
166
+ inputs (InputType): 入力
167
+
168
+ Returns:
169
+ Functions | None: functions。追加しない場合None
170
+ https://json-schema.org/understanding-json-schema/reference/index.html
171
+ >>> {
172
+ >>> "name": "関数名",
173
+ >>> "description": "関数の動作の説明。GPTは説明を見て利用するか選ぶ",
174
+ >>> "parameters": {
175
+ >>> "type": "object", "properties": {"city_name": {"type": "string", "description": "都市名"}},
176
+ >>> json-schema[https://json-schema.org/understanding-json-schema/reference/index.html]
177
+ >>> }
178
+ >>> }
179
+ """
180
+ return None
181
+
182
+ def _stream_postprocess(
183
+ self,
184
+ new_chunk: Chunk,
185
+ state: "_State",
186
+ ) -> StreamOutputType:
187
+ """call_streamのGeneratorのpostprocess
188
+
189
+ Args:
190
+ new_chunk (OpenAIChunkResponse): 新しいchunk
191
+ state (dict[Any, Any]): 状態を持てるdict. 初めは、default {}. 状態が消えてしまうのでoverwriteしない。
192
+
193
+ Returns:
194
+ StreamOutputType: 一時的なoutput
195
+ """
196
+ if len(new_chunk.choices) == 0:
197
+ return ""
198
+ return new_chunk.choices[0].delta.content
199
+
200
+ def _generate(self, stream: bool) -> Generator[StreamOutputType, None, None]:
201
+ """
202
+ LLMの出力を得て、`self.response`に格納する
203
+
204
+ Args:
205
+ messages (list[dict[str, str]]): LLMの入力メッセージ
206
+ """
207
+ # 例外処理 -----------------------------------------------------------
208
+ if self.messages is None:
209
+ raise ValueError("MessagesがNoneです。")
210
+
211
+ # kwargs -----------------------------------------------------------
212
+ generate_kwargs = dict(**self.llm_settings)
213
+ if self.functions is not None:
214
+ generate_kwargs["functions"] = self.functions
215
+ if self.functions is not None:
216
+ generate_kwargs["tools"] = self.tools
217
+
218
+ # generate ----------------------------------------------------------
219
+ self._print_messages() # verbose
220
+ self.llm = get_llm(model_name=self.model, platform=self.platform, client_settings=self.client_settings)
221
+ # [stream]
222
+ if stream or self.stream_verbose:
223
+ it = self.llm.generate_stream(messages=self.messages, llm_settings=generate_kwargs)
224
+ chunk_list: list[Chunk] = []
225
+ state: "_State" = {}
226
+ for chunk in it:
227
+ chunk_list.append(chunk)
228
+ self._print_delta(chunk=chunk) # verbose: stop→改行、conent, TODO: fc→出力
229
+ yield self._stream_postprocess(new_chunk=chunk, state=state)
230
+ self.response = self.llm.convert_nonstream_response(chunk_list, self.messages, self.functions)
231
+ # [non-stream]
232
+ else:
233
+ try:
234
+ self.response = self.llm.generate(messages=self.messages, llm_settings=generate_kwargs)
235
+ self._print_message_assistant()
236
+ except Exception as e:
237
+ raise e
238
+
239
+ # ContentFilterError -------------------------------------------------
240
+ if len(self.response.choices) == 0:
241
+ cprint(self.response, color="red", background=True)
242
+ raise ContentFilterError("入力のコンテンツフィルターに引っかかりました。")
243
+ if self.response.choices[0].finish_reason == "content_filter":
244
+ cprint(self.response, color="red", background=True)
245
+ raise ContentFilterError("出力のコンテンツフィルターに引っかかりました。")
246
+
247
+ def _call(self, inputs: InputType, stream: bool = False) -> Generator[StreamOutputType, None, OutputType]:
248
+ """
249
+ LLMの処理を行う (preprocess, check_input, generate, postprocess)
250
+
251
+ Args:
252
+ inputs (InputType): 入力データを保持する辞書
253
+
254
+ Returns:
255
+ OutputType: 処理結果の出力データ
256
+
257
+ Raises:
258
+ RuntimeError: 既に呼び出されている場合に発生
259
+ """
260
+ if self.called:
261
+ raise RuntimeError("MyLLMは1回しか呼び出せない")
262
+
263
+ self._print_start(sep="-")
264
+
265
+ # main -----------------------------------------------------------
266
+ t_start = time.time()
267
+ self.inputs = inputs
268
+ self._print_inputs()
269
+ rulebase_output = self._ruleprocess(inputs)
270
+ if rulebase_output is None: # API リクエストを送る場合
271
+ self._update_settings()
272
+ self.messages = self._preprocess(inputs)
273
+ self.functions = self._add_functions(inputs)
274
+ self.tools = self._add_tools(inputs)
275
+ t_preprocessed = time.time()
276
+ # [generate]
277
+ it = self._generate(stream=stream)
278
+ for delta_content in it: # stream=Falseの時、空のGenerator
279
+ yield delta_content
280
+ if self.response is None:
281
+ raise ValueError("responseがNoneです。")
282
+ t_generated = time.time()
283
+ # [postprocess]
284
+ self.outputs = self._postprocess(self.response)
285
+ t_postprocessed = time.time()
286
+ else: # ルールベースの場合
287
+ self.outputs = rulebase_output
288
+ t_preprocessed = t_generated = t_postprocessed = time.time()
289
+ self.time_detail = TimeInfo(
290
+ total=t_postprocessed - t_start,
291
+ preprocess=t_preprocessed - t_start,
292
+ main=t_generated - t_preprocessed,
293
+ postprocess=t_postprocessed - t_generated,
294
+ )
295
+ self.time = t_postprocessed - t_start
296
+
297
+ # print -----------------------------------------------------------
298
+ self._print_outputs()
299
+ self._print_client_settings()
300
+ self._print_llm_settings()
301
+ self._print_metadata()
302
+ self._print_end(sep="-")
303
+
304
+ # 親MyL3M2にAppend -----------------------------------------------------------
305
+ if self.parent is not None:
306
+ self.parent.myllm_list.append(self)
307
+ self.called = True
308
+
309
+ # log -----------------------------------------------------------
310
+ self._save_log()
311
+
312
+ return self.outputs
313
+
314
+ @property
315
+ def log(self) -> dict[str, Any]:
316
+ return {
317
+ "inputs": self.inputs,
318
+ "outputs": self.outputs,
319
+ "resposnse": self.response.model_dump() if self.response is not None else None,
320
+ "input_token": self.token.input,
321
+ "output_token": self.token.output,
322
+ "total_token": self.token.total,
323
+ "input_price": self.price.input,
324
+ "output_price": self.price.output,
325
+ "total_price": self.price.total,
326
+ "time": self.time,
327
+ "time_stamp": time.time(),
328
+ "llm_settings": self.llm_settings,
329
+ "client_settings": self.client_settings,
330
+ "model": self.model,
331
+ "platform": self.platform,
332
+ "verbose": self.verbose,
333
+ "messages": self.messages,
334
+ "assistant_message": self.assistant_message,
335
+ "functions": self.functions,
336
+ "tools": self.tools,
337
+ }
338
+
339
+ def _save_log(self) -> None:
340
+ if self.log_dir is None:
341
+ return
342
+ try:
343
+ log = self.log
344
+ json_string = dict2json(log)
345
+
346
+ save_log_path = os.path.join(self.log_dir, f"{log['time_stamp']}.json")
347
+ os.makedirs(self.log_dir, exist_ok=True)
348
+ with open(save_log_path, mode="w") as f:
349
+ f.write(json_string)
350
+ except Exception as e:
351
+ cprint(e, color="red", background=True)
352
+
353
+ @property
354
+ def token(self) -> TokenInfo:
355
+ if self.response is None or self.response.usage is None:
356
+ return TokenInfo(input=0, output=0, total=0)
357
+ return TokenInfo(
358
+ input=self.response.usage.prompt_tokens,
359
+ output=self.response.usage.completion_tokens,
360
+ total=self.response.usage.total_tokens,
361
+ )
362
+
363
+ @property
364
+ def custom_token(self) -> TokenInfo | None:
365
+ if not self.llm._custom_price_calculation:
366
+ return None
367
+ if self.response is None:
368
+ return TokenInfo(input=0, output=0, total=0)
369
+ usage_for_price = getattr(self.response, "usage_for_price", None)
370
+ if not isinstance(usage_for_price, CompletionUsageForCustomPriceCalculation):
371
+ cprint("usage_for_priceがNoneです。正しくトークン計算できません", color="red", background=True)
372
+ return TokenInfo(input=0, output=0, total=0)
373
+ return TokenInfo(
374
+ input=usage_for_price.prompt_tokens,
375
+ output=usage_for_price.completion_tokens,
376
+ total=usage_for_price.total_tokens,
377
+ )
378
+
379
+ @property
380
+ def price(self) -> PriceInfo:
381
+ if self.response is None:
382
+ return PriceInfo(input=0.0, output=0.0, total=0.0)
383
+ if self.llm._custom_price_calculation:
384
+ # Geniniの時は必ずcustom_tokenがある想定
385
+ if self.custom_token is None:
386
+ cprint("custom_tokenがNoneです。正しくトークン計算できません", color="red", background=True)
387
+ else:
388
+ return PriceInfo(
389
+ input=self.llm.calculate_price(num_input_tokens=self.custom_token.input),
390
+ output=self.llm.calculate_price(num_output_tokens=self.custom_token.output),
391
+ total=self.llm.calculate_price(
392
+ num_input_tokens=self.custom_token.input, num_output_tokens=self.custom_token.output
393
+ ),
394
+ )
395
+ return PriceInfo(
396
+ input=self.llm.calculate_price(num_input_tokens=self.token.input),
397
+ output=self.llm.calculate_price(num_output_tokens=self.token.output),
398
+ total=self.llm.calculate_price(num_input_tokens=self.token.input, num_output_tokens=self.token.output),
399
+ )
400
+
401
+ @property
402
+ def assistant_message(self) -> Message | None:
403
+ if self.response is None or len(self.response.choices) == 0:
404
+ return None
405
+ return self.response.choices[0].message.to_typeddict_message()
406
+
407
+ @property
408
+ def chat_history(self) -> Messages:
409
+ chat_history: Messages = []
410
+ if self.messages:
411
+ chat_history += self.messages
412
+ if self.assistant_message is not None:
413
+ chat_history.append(self.assistant_message)
414
+ return chat_history
415
+
416
+ def _print_llm_settings(self) -> None:
417
+ if not ("llm_settings" not in self.silent_set and self.verbose):
418
+ return
419
+ print_llm_settings(
420
+ llm_settings=self.llm_settings,
421
+ model=self.model,
422
+ platform=self.platform,
423
+ engine=self.llm.engine if isinstance(self.llm, AzureLLM) else None,
424
+ )
425
+
426
+ def _print_messages(self) -> None:
427
+ if not ("messages" not in self.silent_set and self.verbose):
428
+ return
429
+ print_messages(self.messages, title=True)
430
+
431
+ def _print_message_assistant(self) -> None:
432
+ if self.response is None or len(self.response.choices) == 0:
433
+ return
434
+ if not ("messages" not in self.silent_set and self.verbose):
435
+ return
436
+ print_messages(messages=[self.response.choices[0].message], title=False)
437
+
438
+ def _print_delta(self, chunk: Chunk) -> None:
439
+ if not ("messages" not in self.silent_set and self.verbose):
440
+ return
441
+ print_delta(chunk)
442
+
443
+ def _print_client_settings(self) -> None:
444
+ if not ("client_settings" not in self.silent_set and self.verbose):
445
+ return
446
+ print_client_settings(self.llm.client_settings)
447
+
448
+ def __repr__(self) -> str:
449
+ return f"MyLLM({self.__class__.__name__})"
neollm/myllm/print_utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any
3
+
4
+ from openai.types.chat import ChatCompletionAssistantMessageParam
5
+ from openai.types.chat.chat_completion_assistant_message_param import FunctionCall
6
+ from openai.types.chat.chat_completion_message_tool_call_param import (
7
+ ChatCompletionMessageToolCallParam,
8
+ Function,
9
+ )
10
+
11
+ from neollm.types import (
12
+ ChatCompletionMessage,
13
+ Chunk,
14
+ ClientSettings,
15
+ InputType,
16
+ LLMSettings,
17
+ Message,
18
+ Messages,
19
+ OutputType,
20
+ PriceInfo,
21
+ PrintColor,
22
+ Role,
23
+ TokenInfo,
24
+ )
25
+ from neollm.utils.postprocess import json2dict
26
+ from neollm.utils.utils import CPrintParam, cprint
27
+
28
+ TITLE_COLOR: PrintColor = "blue"
29
+ YEN_PAR_DOLLAR: float = 140.0 # 150円になってしまったぴえん(231027)
30
+
31
+
32
+ def _ChatCompletionMessage2dict(message: ChatCompletionMessage) -> Message:
33
+ message_dict = ChatCompletionAssistantMessageParam(content=message.content, role=message.role)
34
+ if message.function_call is not None:
35
+ message_dict["function_call"] = FunctionCall(
36
+ arguments=message.function_call.arguments, name=message.function_call.name
37
+ )
38
+ if message.tool_calls is not None:
39
+ message_dict["tool_calls"] = [
40
+ ChatCompletionMessageToolCallParam(
41
+ id=tool_call.id,
42
+ function=Function(arguments=tool_call.function.arguments, name=tool_call.function.name),
43
+ type=tool_call.type,
44
+ )
45
+ for tool_call in message.tool_calls
46
+ ]
47
+ return message_dict
48
+
49
+
50
+ def _get_tool_calls(message_dict: Message) -> list[ChatCompletionMessageToolCallParam]:
51
+ tool_calls: list[ChatCompletionMessageToolCallParam] = []
52
+ if "tool_calls" in message_dict:
53
+ _tool_calls = message_dict.get("tool_calls", None)
54
+ if _tool_calls is not None and isinstance(_tool_calls, list): # isinstance(_tool_calls, list)ないと通らん,,,
55
+ for _tool_call in _tool_calls:
56
+ tool_call = ChatCompletionMessageToolCallParam(
57
+ id=_tool_call["id"],
58
+ function=Function(
59
+ arguments=_tool_call["function"]["arguments"],
60
+ name=_tool_call["function"]["name"],
61
+ ),
62
+ type=_tool_call["type"],
63
+ )
64
+ tool_calls.append(tool_call)
65
+ if "function_call" in message_dict:
66
+ function_call = message_dict.get("function_call", None)
67
+ if function_call is not None and isinstance(
68
+ function_call, dict
69
+ ): # isinstance(function_call, dict)ないと通らん,,,
70
+ tool_calls.append(
71
+ ChatCompletionMessageToolCallParam(
72
+ id="",
73
+ function=Function(
74
+ arguments=function_call["arguments"],
75
+ name=function_call["name"],
76
+ ),
77
+ type="function",
78
+ )
79
+ )
80
+ return tool_calls
81
+
82
+
83
+ def print_metadata(time: float, token: TokenInfo, price: PriceInfo) -> None:
84
+ try:
85
+ cprint("[metadata]", color=TITLE_COLOR, kwargs={"end": " "})
86
+ print(
87
+ f"{time:.1f}s; "
88
+ f"{token.total:,}({token.input:,}+{token.output:,})tokens; "
89
+ f"${price.total:.2g}; ¥{price.total*YEN_PAR_DOLLAR:.2g}"
90
+ )
91
+ except Exception as e:
92
+ cprint(e, color="red", background=True)
93
+
94
+
95
+ def print_inputs(inputs: InputType) -> None:
96
+ try:
97
+ cprint("[inputs]", color=TITLE_COLOR)
98
+ print(json.dumps(_arange_dumpable_object(inputs), indent=2, ensure_ascii=False))
99
+ except Exception as e:
100
+ cprint(e, color="red", background=True)
101
+
102
+
103
+ def print_outputs(outputs: OutputType) -> None:
104
+ try:
105
+ cprint("[outputs]", color=TITLE_COLOR)
106
+ print(json.dumps(_arange_dumpable_object(outputs), indent=2, ensure_ascii=False))
107
+ except Exception as e:
108
+ cprint(e, color="red", background=True)
109
+
110
+
111
+ def print_messages(messages: list[ChatCompletionMessage] | Messages | None, title: bool = True) -> None:
112
+ if messages is None:
113
+ cprint("Not yet running _preprocess", color="red")
114
+ return
115
+ # try:
116
+ if title:
117
+ cprint("[messages]", color=TITLE_COLOR)
118
+ role2prarams: dict[Role, CPrintParam] = {
119
+ "system": {"color": "green"},
120
+ "user": {"color": "green"},
121
+ "assistant": {"color": "green"},
122
+ "function": {"color": "green", "background": True},
123
+ "tool": {"color": "green", "background": True},
124
+ }
125
+ for message in messages:
126
+ message_dict: Message
127
+ if isinstance(message, ChatCompletionMessage):
128
+ message_dict = _ChatCompletionMessage2dict(message)
129
+ else:
130
+ message_dict = message
131
+
132
+ # roleの出力 ----------------------------------------
133
+ print(" ", end="")
134
+ role = message_dict["role"]
135
+ cprint(role, **role2prarams[role])
136
+
137
+ # contentの出力 ----------------------------------------
138
+ content = message_dict.get("content", None)
139
+ if isinstance(content, str):
140
+ print(" " + content.replace("\n", "\n "))
141
+ elif isinstance(content, list):
142
+ for content_part in content:
143
+ if content_part["type"] == "text":
144
+ print(" " + content_part["text"].replace("\n", "\n "))
145
+ elif content_part["type"] == "image_url":
146
+ cprint(" <image_url>", color="green", kwargs={"end": " "})
147
+ print(content_part["image_url"])
148
+ # TODO: 画像出力
149
+ # TODO: Preview用、content_part["image"]: str, dict両方いけてしまう
150
+ else:
151
+ # TODO: 未対応のcontentの出力
152
+ pass
153
+
154
+ # tool_callの出力 ----------------------------------------
155
+ for tool_call in _get_tool_calls(message_dict):
156
+ print(" ", end="")
157
+ cprint(tool_call["function"]["name"], color="green", background=True)
158
+ print(" " + str(json2dict(tool_call["function"]["arguments"], error_key=None)).replace("\n", "\n "))
159
+
160
+ # except Exception as e:
161
+ # cprint(e, color="red", background=True)
162
+
163
+
164
+ def print_delta(chunk: Chunk) -> None:
165
+ if len(chunk.choices) == 0:
166
+ return
167
+ choice = chunk.choices[0] # TODO: n>2の対応
168
+ if choice.delta.role is not None:
169
+ print(" ", end="")
170
+ cprint(choice.delta.role, color="green")
171
+ print(" ", end="")
172
+ if choice.delta.content is not None:
173
+ print(choice.delta.content.replace("\n", "\n "), end="")
174
+ if choice.delta.function_call is not None:
175
+ if choice.delta.function_call.name is not None:
176
+ cprint(choice.delta.function_call.name, color="green", background=True)
177
+ print(" ", end="")
178
+ if choice.delta.function_call.arguments is not None:
179
+ print(choice.delta.function_call.arguments.replace("\n", "\n "), end="")
180
+ if choice.delta.tool_calls is not None:
181
+ for tool_call in choice.delta.tool_calls:
182
+ if tool_call.function is not None:
183
+ if tool_call.function.name is not None:
184
+ if tool_call.index != 0:
185
+ print("\n ", end="")
186
+ cprint(tool_call.function.name, color="green", background=True)
187
+ print(" ", end="")
188
+ if tool_call.function.arguments is not None:
189
+ print(tool_call.function.arguments.replace("\n", "\n "), end="")
190
+ if choice.finish_reason is not None:
191
+ print()
192
+
193
+
194
+ def print_llm_settings(llm_settings: LLMSettings, model: str, engine: str | None, platform: str) -> None:
195
+ try:
196
+ cprint("[llm_settings]", color=TITLE_COLOR, kwargs={"end": " "})
197
+ llm_settings_copy = dict(platform=platform, **llm_settings)
198
+ llm_settings_copy["model"] = model
199
+ # Azureの場合
200
+ if platform == "azure":
201
+ llm_settings_copy["engine"] = engine # engineを追加
202
+ print(llm_settings_copy or "-")
203
+ except Exception as e:
204
+ cprint(e, color="red", background=True)
205
+
206
+
207
+ def print_client_settings(client_settings: ClientSettings) -> None:
208
+ try:
209
+ cprint("[client_settings]", color=TITLE_COLOR, kwargs={"end": " "})
210
+ print(client_settings or "-")
211
+ except Exception as e:
212
+ cprint(e, color="red", background=True)
213
+
214
+
215
+ # -------
216
+
217
+ _DumplableEntity = int | float | str | bool | None | list[Any] | dict[Any, Any]
218
+ DumplableType = _DumplableEntity | list["DumplableType"] | dict["DumplableType", "DumplableType"]
219
+
220
+
221
+ def _arange_dumpable_object(obj: Any) -> DumplableType:
222
+ # 基本データ型の場合、そのまま返す
223
+ if isinstance(obj, (int, float, str, bool, type(None))):
224
+ return obj
225
+
226
+ # リストの場合、再帰的に各要素を変換
227
+ if isinstance(obj, list):
228
+ return [_arange_dumpable_object(item) for item in obj]
229
+
230
+ # 辞書の場合、再帰的に各キーと値を変換
231
+ if isinstance(obj, dict):
232
+ return {_arange_dumpable_object(key): _arange_dumpable_object(value) for key, value in obj.items()}
233
+
234
+ # それ以外の型の場合、型情報を含めて文字列に変換
235
+ return f"<{type(obj).__name__}>{str(obj)}"
neollm/types/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from neollm.types.info import * # NOQA
2
+ from neollm.types.mytypes import * # NOQA
3
+ from neollm.types.openai.chat_completion import * # NOQA
4
+ from neollm.types.openai.chat_completion_chunk import * # NOQA
neollm/types/_model.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from openai._models import BaseModel
4
+
5
+
6
+ class DictableBaseModel(BaseModel): # openaiのBaseModelをDictAccessできるようにした
7
+ def __getitem__(self, item: str) -> Any:
8
+ return getattr(self, item)
neollm/types/info.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ from neollm.types._model import DictableBaseModel
4
+
5
+ PrintColor = Literal["black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"]
6
+
7
+
8
+ class TimeInfo(DictableBaseModel):
9
+ total: float = 0.0
10
+ """処理時間合計 preprocess + main + postprocess"""
11
+ preprocess: float = 0.0
12
+ """前処理時間"""
13
+ main: float = 0.0
14
+ """メイン処理時間"""
15
+ postprocess: float = 0.0
16
+ """後処理時間"""
17
+
18
+ def __repr__(self) -> str:
19
+ return (
20
+ f"TimeInfo(total={self.total:.3f}, preprocess={self.preprocess:.3f}, main={self.main:.3f}, "
21
+ f"postprocess={self.postprocess:.3f})"
22
+ )
23
+
24
+
25
+ class TokenInfo(DictableBaseModel):
26
+ input: int
27
+ """入力部分のトークン数"""
28
+ output: int
29
+ """出力部分のトークン数"""
30
+ total: int
31
+ """合計トークン数"""
32
+
33
+ def __add__(self, other: "TokenInfo") -> "TokenInfo":
34
+ if not isinstance(other, TokenInfo):
35
+ raise TypeError(f"{other} is not TokenInfo")
36
+ return TokenInfo(
37
+ input=self.input + other.input, output=self.output + other.output, total=self.total + other.total
38
+ )
39
+
40
+ def __iadd__(self, other: "TokenInfo") -> "TokenInfo":
41
+ if not isinstance(other, TokenInfo):
42
+ raise TypeError(f"{other} is not TokenInfo")
43
+ self.input += other.input
44
+ self.output += other.output
45
+ self.total += other.total
46
+ return self
47
+
48
+
49
+ class PriceInfo(DictableBaseModel):
50
+ input: float
51
+ """入力部分の費用 (USD)"""
52
+ output: float
53
+ """出力部分の費用 (USD)"""
54
+ total: float
55
+ """合計費用 (USD)"""
56
+
57
+ def __add__(self, other: "PriceInfo") -> "PriceInfo":
58
+ if not isinstance(other, PriceInfo):
59
+ raise TypeError(f"{other} is not PriceInfo")
60
+ return PriceInfo(
61
+ input=self.input + other.input, output=self.output + other.output, total=self.total + other.total
62
+ )
63
+
64
+ def __iadd__(self, other: "PriceInfo") -> "PriceInfo":
65
+ if not isinstance(other, PriceInfo):
66
+ raise TypeError(f"{other} is not PriceInfo")
67
+ self.input += other.input
68
+ self.output += other.output
69
+ self.total += other.total
70
+ return self
71
+
72
+ def __repr__(self) -> str:
73
+ return f"PriceInfo(input={self.input:.3f}, output={self.output:.3f}, total={self.total:.3f})"
74
+
75
+
76
+ class APIPricing(DictableBaseModel):
77
+ """APIの価格設定に関する情報を表すクラス。"""
78
+
79
+ input: float
80
+ """入力 1k tokens 当たりのAPI利用料 (USD)"""
81
+ output: float
82
+ """出力 1k tokens 当たりのAPI利用料 (USD)"""
neollm/types/mytypes.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Iterator, Literal, TypeVar
2
+
3
+ import openai.types.chat as openai_types
4
+ from openai._streaming import Stream
5
+
6
+ from neollm.types.openai.chat_completion import ChatCompletion
7
+ from neollm.types.openai.chat_completion_chunk import ChatCompletionChunk
8
+
9
+ Role = Literal["system", "user", "assistant", "tool", "function"]
10
+ # Settings
11
+ LLMSettings = dict[str, Any]
12
+ ClientSettings = dict[str, Any]
13
+ # Message
14
+ Message = openai_types.ChatCompletionMessageParam
15
+ Messages = list[Message]
16
+ Tools = Any
17
+ Functions = Any
18
+ # Response
19
+ Response = ChatCompletion
20
+ Chunk = ChatCompletionChunk
21
+ StreamResponse = Iterator[Chunk]
22
+ # IO
23
+ InputType = TypeVar("InputType")
24
+ OutputType = TypeVar("OutputType")
25
+ StreamOutputType = Any
26
+
27
+ # OpenAI --------------------------------------------
28
+ OpenAIResponse = openai_types.ChatCompletion
29
+ OpenAIChunk = openai_types.ChatCompletionChunk
30
+ OpenAIStreamResponse = Stream[OpenAIChunk] # OpneAI StreamResponse
31
+ OpenAIMessages = list[openai_types.ChatCompletionMessageParam]
neollm/types/openai/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from neollm.types.openai.chat_completion import * # NOQA
2
+ from neollm.types.openai.chat_completion_chunk import * # NOQA
neollm/types/openai/chat_completion.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, Optional
2
+
3
+ from openai.types.chat import ChatCompletionAssistantMessageParam
4
+ from openai.types.chat.chat_completion_assistant_message_param import (
5
+ FunctionCall as FunctionCallParams,
6
+ )
7
+ from openai.types.chat.chat_completion_message_tool_call_param import (
8
+ ChatCompletionMessageToolCallParam,
9
+ )
10
+ from openai.types.chat.chat_completion_message_tool_call_param import (
11
+ Function as FunctionParams,
12
+ )
13
+ from pydantic import field_validator
14
+
15
+ from neollm.types._model import DictableBaseModel
16
+
17
+
18
+ class CompletionUsage(DictableBaseModel):
19
+ completion_tokens: int
20
+ """Number of tokens in the generated completion."""
21
+
22
+ prompt_tokens: int
23
+ """Number of tokens in the prompt."""
24
+
25
+ total_tokens: int
26
+ """Total number of tokens used in the request (prompt + completion)."""
27
+
28
+ # ADDED: gpt4v preview用(Noneを許容するため)
29
+ @field_validator("completion_tokens", "prompt_tokens", "total_tokens", mode="before")
30
+ def validate_name(cls, v: int | None) -> int:
31
+ return v or 0
32
+
33
+
34
+ class CompletionUsageForCustomPriceCalculation(DictableBaseModel):
35
+ completion_tokens: int
36
+ """Number of tokens in the generated completion."""
37
+
38
+ prompt_tokens: int
39
+ """Number of tokens in the prompt."""
40
+
41
+ total_tokens: int
42
+ """Total number of tokens used in the request (prompt + completion)."""
43
+
44
+ # ADDED: gpt4v preview用(Noneを許容するため)
45
+ @field_validator("completion_tokens", "prompt_tokens", "total_tokens", mode="before")
46
+ def validate_name(cls, v: int | None) -> int:
47
+ return v or 0
48
+
49
+
50
+ class Function(DictableBaseModel):
51
+ arguments: str
52
+ """
53
+ The arguments to call the function with, as generated by the model in JSON
54
+ format. Note that the model does not always generate valid JSON, and may
55
+ hallucinate parameters not defined by your function schema. Validate the
56
+ arguments in your code before calling your function.
57
+ """
58
+
59
+ name: str
60
+ """The name of the function to call."""
61
+
62
+
63
+ class ChatCompletionMessageToolCall(DictableBaseModel):
64
+ id: str
65
+ """The ID of the tool call."""
66
+
67
+ function: Function
68
+ """The function that the model called."""
69
+
70
+ type: Literal["function"]
71
+ """The type of the tool. Currently, only `function` is supported."""
72
+
73
+
74
+ class FunctionCall(DictableBaseModel):
75
+ arguments: str
76
+ """
77
+ The arguments to call the function with, as generated by the model in JSON
78
+ format. Note that the model does not always generate valid JSON, and may
79
+ hallucinate parameters not defined by your function schema. Validate the
80
+ arguments in your code before calling your function.
81
+ """
82
+
83
+ name: str
84
+ """The name of the function to call."""
85
+
86
+
87
+ class ChatCompletionMessage(DictableBaseModel):
88
+ content: Optional[str]
89
+ """The contents of the message."""
90
+
91
+ role: Literal["assistant"]
92
+ """The role of the author of this message."""
93
+
94
+ function_call: Optional[FunctionCall] = None
95
+ """Deprecated and replaced by `tool_calls`.
96
+
97
+ The name and arguments of a function that should be called, as generated by the
98
+ model.
99
+ """
100
+
101
+ tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
102
+ """The tool calls generated by the model, such as function calls."""
103
+
104
+ def to_typeddict_message(self) -> ChatCompletionAssistantMessageParam:
105
+ message_dict = ChatCompletionAssistantMessageParam(role=self.role, content=self.content)
106
+ if self.function_call is not None:
107
+ message_dict["function_call"] = FunctionCallParams(
108
+ arguments=self.function_call.arguments, name=self.function_call.name
109
+ )
110
+ if self.tool_calls is not None:
111
+ message_dict["tool_calls"] = [
112
+ ChatCompletionMessageToolCallParam(
113
+ id=tool_call.id,
114
+ function=FunctionParams(arguments=tool_call.function.arguments, name=tool_call.function.name),
115
+ type=tool_call.type,
116
+ )
117
+ for tool_call in self.tool_calls
118
+ ]
119
+ return message_dict
120
+
121
+
122
+ FinishReason = Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
123
+
124
+
125
+ class Choice(DictableBaseModel):
126
+ finish_reason: FinishReason | None = None # ADDED: gpt4v preview用
127
+ """The reason the model stopped generating tokens.
128
+
129
+ This will be `stop` if the model hit a natural stop point or a provided stop
130
+ sequence, `length` if the maximum number of tokens specified in the request was
131
+ reached, `content_filter` if content was omitted due to a flag from our content
132
+ filters, `tool_calls` if the model called a tool, or `function_call`
133
+ (deprecated) if the model called a function.
134
+ """
135
+
136
+ index: int
137
+ """The index of the choice in the list of choices."""
138
+
139
+ message: ChatCompletionMessage
140
+ """A chat completion message generated by the model."""
141
+
142
+
143
+ class ChatCompletion(DictableBaseModel):
144
+ id: str
145
+ """A unique identifier for the chat completion."""
146
+
147
+ choices: List[Choice]
148
+ """A list of chat completion choices.
149
+
150
+ Can be more than one if `n` is greater than 1.
151
+ """
152
+
153
+ created: int
154
+ """The Unix timestamp (in seconds) of when the chat completion was created."""
155
+
156
+ model: str
157
+ """The model used for the chat completion."""
158
+
159
+ object: Literal["chat.completion"] | str
160
+ """The object type, which is always `chat.completion`."""
161
+
162
+ system_fingerprint: Optional[str] = None
163
+ """This fingerprint represents the backend configuration that the model runs with.
164
+
165
+ Can be used in conjunction with the `seed` request parameter to understand when
166
+ backend changes have been made that might impact determinism.
167
+ """
168
+
169
+ usage: Optional[CompletionUsage] = None
170
+ """Usage statistics for the completion request."""
neollm/types/openai/chat_completion_chunk.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, Optional
2
+
3
+ from pydantic import field_validator
4
+
5
+ from neollm.types._model import DictableBaseModel
6
+ from neollm.utils.utils import cprint
7
+
8
+
9
+ class ChoiceDeltaFunctionCall(DictableBaseModel):
10
+ arguments: Optional[str] = None
11
+ """
12
+ The arguments to call the function with, as generated by the model in JSON
13
+ format. Note that the model does not always generate valid JSON, and may
14
+ hallucinate parameters not defined by your function schema. Validate the
15
+ arguments in your code before calling your function.
16
+ """
17
+
18
+ name: Optional[str] = None
19
+ """The name of the function to call."""
20
+
21
+
22
+ class ChoiceDeltaToolCallFunction(DictableBaseModel):
23
+ arguments: Optional[str] = None
24
+ """
25
+ The arguments to call the function with, as generated by the model in JSON
26
+ format. Note that the model does not always generate valid JSON, and may
27
+ hallucinate parameters not defined by your function schema. Validate the
28
+ arguments in your code before calling your function.
29
+ """
30
+
31
+ name: Optional[str] = None
32
+ """The name of the function to call."""
33
+
34
+
35
+ class ChoiceDeltaToolCall(DictableBaseModel):
36
+ index: int
37
+
38
+ id: Optional[str] = None
39
+ """The ID of the tool call."""
40
+
41
+ function: Optional[ChoiceDeltaToolCallFunction] = None
42
+
43
+ type: Optional[Literal["function"]] = None
44
+ """The type of the tool. Currently, only `function` is supported."""
45
+
46
+
47
+ class ChoiceDelta(DictableBaseModel):
48
+ content: Optional[str] = None
49
+ """The contents of the chunk message."""
50
+
51
+ function_call: Optional[ChoiceDeltaFunctionCall] = None
52
+ """Deprecated and replaced by `tool_calls`.
53
+
54
+ The name and arguments of a function that should be called, as generated by the
55
+ model.
56
+ """
57
+
58
+ role: Optional[Literal["system", "user", "assistant", "tool"]] = None
59
+ """The role of the author of this message."""
60
+
61
+ tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
62
+
63
+
64
+ class ChunkChoice(DictableBaseModel): # chat_completionと同名なため、改名(Choice->ChunkChoice)
65
+ delta: ChoiceDelta
66
+ """A chat completion delta generated by streamed model responses."""
67
+
68
+ finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
69
+ """The reason the model stopped generating tokens.
70
+
71
+ This will be `stop` if the model hit a natural stop point or a provided stop
72
+ sequence, `length` if the maximum number of tokens specified in the request was
73
+ reached, `content_filter` if content was omitted due to a flag from our content
74
+ filters, `tool_calls` if the model called a tool, or `function_call`
75
+ (deprecated) if the model called a function.
76
+ """
77
+
78
+ index: int
79
+ """The index of the choice in the list of choices."""
80
+
81
+
82
+ class ChatCompletionChunk(DictableBaseModel):
83
+ id: str
84
+ """A unique identifier for the chat completion. Each chunk has the same ID."""
85
+
86
+ choices: List[ChunkChoice]
87
+ """A list of chat completion choices.
88
+
89
+ Can be more than one if `n` is greater than 1.
90
+ """
91
+
92
+ created: int
93
+ """The Unix timestamp (in seconds) of when the chat completion was created.
94
+
95
+ Each chunk has the same timestamp.
96
+ """
97
+
98
+ model: str
99
+ """The model to generate the completion."""
100
+
101
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk" # for azure
102
+ """The object type, which is always `chat.completion.chunk`."""
103
+
104
+ # ADDED: azure用 (""を許容するため)
105
+ @field_validator("object", mode="before")
106
+ def validate_name(cls, v: str) -> Literal["chat.completion.chunk"]:
107
+ if v != "" and v != "chat.completion.chunk":
108
+ cprint(f"ChatCompletionChunk.object is not 'chat.completion.chunk': {v}", "yellow")
109
+ return "chat.completion.chunk"
neollm/utils/inference.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import glob
3
+ import json
4
+ from concurrent.futures import Future, ThreadPoolExecutor
5
+ from typing import Any, Callable, TypeVar
6
+
7
+ _T = TypeVar("_T")
8
+
9
+
10
+ def execute_parallel(func: Callable[..., _T], kwargs_list: list[dict[str, Any]], max_workers: int) -> list[_T]:
11
+ """並行処理を行う
12
+
13
+ Args:
14
+ func (Callable): 並行処理したい関数
15
+ kwargs_list (list[dict[str, Any]]): 関数の引数(dict型)のリスト
16
+ max_workers (int): 並行処理数
17
+
18
+ Returns:
19
+ list[Any]: 関数の戻り値のリスト
20
+ """
21
+ response_list: list[Future[_T]] = []
22
+ with ThreadPoolExecutor(max_workers=max_workers) as e:
23
+ for kwargs in kwargs_list:
24
+ response: Future[_T] = e.submit(func, **kwargs)
25
+ response_list.append(response)
26
+ return [r.result() for r in response_list]
27
+
28
+
29
+ def _load_json_file(file_path: str) -> Any:
30
+ # TODO: Docstring追加
31
+ with open(file_path, "r", encoding="utf-8") as json_file:
32
+ data = json.load(json_file)
33
+ return data
34
+
35
+
36
+ def make_log_csv(log_dir: str, csv_file_name: str = "log.csv") -> None:
37
+ """ログデータのcsvを保存
38
+
39
+ Args:
40
+ log_dir (str): ログデータが保存されているディレクトリ
41
+ csv_file_name (str, optional): 保存するcsvファイル名. Defaults to "log.csv".
42
+ """
43
+ # ディレクトリ内のJSONファイルのリストを取得
44
+ # TODO: エラーキャッチ
45
+ json_files = sorted([f for f in glob.glob(f"{log_dir}/*.json")], key=lambda x: int(x.split("/")[-1].split(".")[0]))
46
+
47
+ # すべてのJSONファイルからユニークなキーを取得
48
+ columns = []
49
+ data_list: list[dict[Any, Any]] = []
50
+ keys_set = set()
51
+ for json_file in json_files:
52
+ data = _load_json_file(json_file)
53
+ if isinstance(data, dict):
54
+ for key in data.keys():
55
+ if key not in keys_set:
56
+ keys_set.add(key)
57
+ columns.append(key)
58
+ data_list.append(data)
59
+
60
+ # CSVファイルを作成し、ヘッダーを書き込む
61
+ with open(csv_file_name, "w", encoding="utf-8", newline="") as csv_file:
62
+ writer = csv.writer(csv_file)
63
+ writer.writerow(columns)
64
+
65
+ # JSONファイルからデータを読み取り、CSVファイルに書き込む
66
+ for data in data_list:
67
+ row = [data.get(key, "") for key in columns]
68
+ writer.writerow(row)
69
+
70
+ print(f"saved csv file: {csv_file_name}")
neollm/utils/postprocess.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any, overload
3
+
4
+
5
+ # string ---------------------------------------
6
+ def _extract_string(text: str, start_string: str | None = None, end_string: str | None = None) -> str:
7
+ """
8
+ テキストから必要な文字列を抽出する
9
+
10
+ Args:
11
+ text (str): 抽出するテキスト
12
+
13
+ Returns:
14
+ str: 抽出された必要な文字列
15
+ """
16
+ # 最初の文字
17
+ if start_string is not None and start_string in text:
18
+ idx_head = text.index(start_string)
19
+ text = text[idx_head:]
20
+ # 最後の文字
21
+ if end_string is not None and end_string in text:
22
+ idx_tail = len(text) - text[::-1].index(end_string[::-1])
23
+ text = text[:idx_tail]
24
+ return text
25
+
26
+
27
+ def _delete_first_chapter_tag(text: str, first_character_tag: str | list[str]) -> str:
28
+ """_summary_
29
+
30
+ Args:
31
+ text (str): テキスト
32
+ first_character_tag (str | list[str]): 最初にある余分な文字列
33
+
34
+ Returns:
35
+ str: 除去済みのテキスト
36
+ """
37
+ # first_character_tagのlist化
38
+ if isinstance(first_character_tag, str):
39
+ first_character_tag = [first_character_tag]
40
+ # 最初のチャプタータグの消去
41
+ for first_character_i in first_character_tag:
42
+ if text.startswith(first_character_i):
43
+ text = text[len(first_character_i) :]
44
+ break
45
+ return text.strip()
46
+
47
+
48
+ def strip_string(
49
+ text: str,
50
+ first_character: str | list[str] = ["<output>", "<outputs>"],
51
+ start_string: str | None = None,
52
+ end_string: str | None = None,
53
+ strip_quotes: str | list[str] = ["'", '"'],
54
+ ) -> str:
55
+ """stringの前後の余分な文字を削除する
56
+
57
+ Args:
58
+ text (str): ChatGPTの出力文字列
59
+ first_character (str, optional): 出力の先頭につく文字 Defaults to ["<output>", "<outputs>"].
60
+ start_string (str, optional): 出力の先頭につく文字 Defaults to None.
61
+ end_string (str, optional): 出力の先頭につく文字 Defaults to None.
62
+ strip_quotes (str, optional): 前後の余分な'"を消す. Defaults to ["'", '"'].
63
+
64
+ Returns:
65
+ str: 余分な文字列を消去した文字列
66
+
67
+ Examples:
68
+ >>> strip_string("<output>'''ChatGPT is smart!'''", "<output>")
69
+ ChatGPT is smart!
70
+ >>> strip_string('{"a": 1}', start_string="{", end_string="}")
71
+ {"a": 1}
72
+ >>> strip_string("<outputs> `neoAI`", strip_quotes="`")
73
+ neoAI
74
+ """
75
+ # 余分な文字列消去
76
+ text = _delete_first_chapter_tag(text, first_character)
77
+ # 前後の'" を消す
78
+ if isinstance(strip_quotes, str):
79
+ strip_quotes = [strip_quotes]
80
+ for quote in strip_quotes:
81
+ text = text.strip(quote).strip()
82
+ text = _extract_string(text, start_string, end_string)
83
+ return text.strip()
84
+
85
+
86
+ # dict ---------------------------------------
87
+
88
+
89
+ @overload
90
+ def json2dict(json_string: str, error_key: None) -> dict[Any, Any] | str: ...
91
+
92
+
93
+ @overload
94
+ def json2dict(json_string: str, error_key: str) -> dict[Any, Any]: ...
95
+
96
+
97
+ def json2dict(json_string: str, error_key: str | None = "error") -> dict[Any, Any] | str:
98
+ """
99
+ JSON文字列をPython dictに変換する
100
+
101
+ Args:
102
+ json_string (str): 変換するJSON文字列
103
+ error_key (str, optional): エラーキーの値として代入する文字列. Defaults to "error".
104
+
105
+ Returns:
106
+ dict: 変換されたPython dict
107
+ """
108
+ try:
109
+ python_dict = json.loads(_extract_string(json_string, start_string="{", end_string="}"), strict=False)
110
+ except ValueError:
111
+ if error_key is None:
112
+ return json_string
113
+ python_dict = {error_key: json_string}
114
+ if isinstance(python_dict, dict):
115
+ return python_dict
116
+ return {error_key: python_dict}
117
+
118
+
119
+ # calender
120
+ # YYYY年MM月YY日 -> YYYY-MM-DD
neollm/utils/preprocess.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import Any, Callable
4
+
5
+
6
+ # dict2json --------------------------------
7
+ def dict2json(python_dict: dict[str, Any]) -> str:
8
+ """
9
+ Python dictをJSON文字列に変換する
10
+
11
+ Args:
12
+ python_dict (dict): 変換するPython dict
13
+
14
+ Returns:
15
+ str: 変換されたJSON文字列
16
+ """
17
+ # ensure_ascii: 日本語とかを出力するため
18
+ json_string = json.dumps(python_dict, indent=2, ensure_ascii=False)
19
+ return json_string
20
+
21
+
22
+ # optimize token --------------------------------
23
+ def optimize_token(text: str, funcs: list[Callable[[str], str]] | None = None) -> str:
24
+ """
25
+ テキストのトークンを最適化をする
26
+
27
+ Args:
28
+ text (str): 最適化するテキスト
29
+
30
+ Returns:
31
+ str: 最適化されたテキスト
32
+ """
33
+ funcs = funcs or [minimize_newline, zenkaku_to_hankaku, remove_trailing_spaces]
34
+ for func in funcs:
35
+ text = func(text)
36
+ return text.strip()
37
+
38
+
39
+ def _replace_consecutive(text: str, pattern: str, replacing_text: str) -> str:
40
+ """
41
+ テキスト内の連続するパターンに対して、指定された置換テキストで置換する
42
+
43
+ Args:
44
+ text (str): テキスト
45
+ pattern (str): 置換するパターン
46
+ replacing_text (str): 置換テキスト
47
+
48
+ Returns:
49
+ str: 置換されたテキスト
50
+ """
51
+ p = re.compile(pattern)
52
+ matches = [(m.start(), m.end()) for m in p.finditer(text)][::-1]
53
+
54
+ text_replaced = list(text)
55
+
56
+ for i_start, i_end in matches:
57
+ text_replaced[i_start:i_end] = [replacing_text]
58
+ return "".join(text_replaced)
59
+
60
+
61
+ def minimize_newline(text: str) -> str:
62
+ """
63
+ テキスト内の連続する改行を2以下にする
64
+
65
+ Args:
66
+ text (str): テキスト
67
+
68
+ Returns:
69
+ str: 改行を最小限にしたテキスト
70
+ """
71
+ return _replace_consecutive(text, pattern="\n{2,}", replacing_text="\n\n")
72
+
73
+
74
+ def zenkaku_to_hankaku(text: str) -> str:
75
+ """
76
+ テキスト内の全角文字を半角文字に変換する
77
+
78
+ Args:
79
+ text (str): テキスト
80
+
81
+ Returns:
82
+ str: 半角文字に変換されたテキスト
83
+ """
84
+ mapping_dict = {" ": " ", ":": ": ", "‎": " ", ".": "。", ",": "、", "¥": "¥"}
85
+ hankaku_text = ""
86
+ for char in text:
87
+ # A-Za-z0-9!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
88
+ if char in mapping_dict:
89
+ hankaku_text += mapping_dict[char]
90
+ elif 65281 <= ord(char) <= 65374:
91
+ hankaku_text += chr(ord(char) - 65248)
92
+ else:
93
+ hankaku_text += char
94
+ return hankaku_text
95
+
96
+
97
+ def remove_trailing_spaces(text: str) -> str:
98
+ """
99
+ テキスト内の各行の末尾のスペースを削除する
100
+
101
+ Args:
102
+ text (str): テキスト
103
+
104
+ Returns:
105
+ str: スペースを削除したテキスト
106
+ """
107
+ return "\n".join([line.rstrip() for line in text.split("\n")])
neollm/utils/prompt_checker.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from typing_extensions import TypedDict
6
+
7
+ from neollm import MyL3M2, MyLLM
8
+ from neollm.types import LLMSettings, Messages, Response
9
+
10
+ _MyLLM = MyLLM[Any, Any]
11
+ _MyL3M2 = MyL3M2[Any, Any]
12
+
13
+
14
+ class PromptCheckerInput(TypedDict):
15
+ myllm: _MyLLM | _MyL3M2
16
+ model: str
17
+ platform: str
18
+ llm_settings: LLMSettings | None
19
+
20
+
21
+ class APromptCheckerInput(TypedDict):
22
+ myllm: _MyLLM
23
+
24
+
25
+ class APromptChecker(MyLLM[APromptCheckerInput, str]):
26
+ def _preprocess(self, inputs: APromptCheckerInput) -> Messages:
27
+ system_prompt = (
28
+ "あなたは、AIへの指示(プロンプト)をより良くすることが仕事です。\n"
29
+ "あなたは言語能力が非常に高く、仕事も丁寧なので小さなミスも気づくことができる天才です。"
30
+ "誤字脱字・論理的でない点・指示が不明確な点を箇条書きで指摘し、より良いプロンプトを提案してください。\n"
31
+ "# 出力例: \n"
32
+ "[指示の誤字脱字/文法ミス]\n"
33
+ "- ...\n"
34
+ "- ...\n"
35
+ "[指示が論理的でない点]\n"
36
+ "- ...\n"
37
+ "- ...\n"
38
+ "[指示が不明確な点]\n"
39
+ "- ...\n"
40
+ "- ...\n"
41
+ "[その他気になる点]\n"
42
+ "- ...\n"
43
+ "- ...\n"
44
+ "[提案]\n"
45
+ "- ...\n"
46
+ "- ...\n"
47
+ )
48
+ if inputs["myllm"].messages is None:
49
+ return []
50
+ user_prompt = "# プロンプト\n" + "\n".join(
51
+ # [f"<{message['role']}>\n{message['content']}\n" for message in inputs.messages]
52
+ [str(message) for message in inputs["myllm"].messages]
53
+ )
54
+ messages: Messages = [
55
+ {"role": "system", "content": system_prompt},
56
+ {"role": "user", "content": user_prompt},
57
+ ]
58
+ return messages
59
+
60
+ def _postprocess(self, response: Response) -> str:
61
+ if response.choices[0].message.content is None:
62
+ return "contentがないンゴ"
63
+ return response.choices[0].message.content
64
+
65
+ def _ruleprocess(self, inputs: APromptCheckerInput) -> str | None:
66
+ if inputs["myllm"].messages is None:
67
+ return "ruleprocessが走って、リクエストしてないよ!"
68
+ return None
69
+
70
+ def __call__(self, inputs: APromptCheckerInput) -> str:
71
+ outputs: str = super().__call__(inputs)
72
+ return outputs
73
+
74
+
75
+ class PromptsChecker(MyL3M2[PromptCheckerInput, None]):
76
+ def _link(self, inputs: PromptCheckerInput) -> None:
77
+ if isinstance(inputs["myllm"], MyL3M2):
78
+ for myllm in inputs["myllm"].myllm_list:
79
+ prompts_checker = PromptsChecker(parent=self, verbose=True)
80
+ prompts_checker(
81
+ inputs={
82
+ "myllm": myllm,
83
+ "model": inputs["model"],
84
+ "platform": inputs["platform"],
85
+ "llm_settings": inputs["llm_settings"],
86
+ }
87
+ )
88
+ elif isinstance(inputs["myllm"], MyLLM):
89
+ a_prompt_checker = APromptChecker(
90
+ parent=self,
91
+ llm_settings=inputs["llm_settings"],
92
+ verbose=True,
93
+ platform=inputs["platform"],
94
+ model=inputs["model"],
95
+ )
96
+ a_prompt_checker(inputs={"myllm": inputs["myllm"]})
97
+
98
+ def __call__(self, inputs: PromptCheckerInput) -> None:
99
+ super().__call__(inputs)
100
+
101
+
102
+ def check_prompt(
103
+ myllm: _MyLLM | _MyL3M2,
104
+ llm_settings: LLMSettings | None = None,
105
+ model: str = "gpt-3.5-turbo",
106
+ platform: str = "openai",
107
+ ) -> MyL3M2[Any, Any]:
108
+ prompt_checker_2 = PromptsChecker(verbose=True)
109
+ prompt_checker_2(inputs={"myllm": myllm, "llm_settings": llm_settings, "model": model, "platform": platform})
110
+ return prompt_checker_2
neollm/utils/tokens.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import textwrap
3
+ from typing import Any
4
+
5
+ import tiktoken
6
+
7
+ from neollm.types import Function # , Functions, Messages
8
+
9
+
10
+ def normalize_model_name(model_name: str) -> str:
11
+ """model_nameのトークン数計測のための標準化
12
+
13
+ Args:
14
+ model_name (str): model_name
15
+ OpenAI: gpt-3.5-turbo-0613, gpt-3.5-turbo-16k-0613, gpt-4-0613, gpt-4-32k-0613
16
+ OpenAIFT: ft:gpt-3.5-turbo:org_id
17
+ Azure: gpt-35-turbo-0613, gpt-35-turbo-16k-0613, gpt-4-0613, gpt-4-32k-0613
18
+
19
+ Returns:
20
+ str: 標準化されたmodel_name
21
+
22
+ Raises:
23
+ ValueError: model_nameが不適切
24
+ """
25
+ # 参考: https://platform.openai.com/docs/models/gpt-3-5
26
+ NEWEST_MAP = [
27
+ ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613"),
28
+ ("gpt-3.5-turbo", "gpt-3.5-turbo-0613"),
29
+ ("gpt-4-32k", "gpt-4-32k-0613"),
30
+ ("gpt-4", "gpt-4-0613"),
31
+ ]
32
+ ALL_VERSION_MODELS = [
33
+ # gpt-3.5-turbo
34
+ "gpt-3.5-turbo-0613",
35
+ "gpt-3.5-turbo-16k-0613",
36
+ "gpt-3.5-turbo-0301", # Legacy
37
+ # gpt-4
38
+ "gpt-4-0613",
39
+ "gpt-4-32k-0613",
40
+ "gpt-4-0314", # Legacy
41
+ "gpt-4-32k-0314", # Legacy
42
+ ]
43
+ # Azure表記 → OpenAI表記に統一
44
+ model_name = model_name.replace("gpt-35", "gpt-3.5")
45
+ # 最新モデルを正式名称に & 新モデル, FTモデルをキャッチ
46
+ if model_name not in ALL_VERSION_MODELS:
47
+ for key, model_name_version in NEWEST_MAP:
48
+ if key in model_name:
49
+ model_name = model_name_version
50
+ break
51
+ # Return
52
+ if model_name in ALL_VERSION_MODELS:
53
+ return model_name
54
+ raise ValueError("model_name は以下から選んで.\n" + ",".join(ALL_VERSION_MODELS))
55
+
56
+
57
+ def count_tokens(messages: Any | None = None, model_name: str | None = None, functions: Any | None = None) -> int:
58
+ """トークン数計測
59
+
60
+ Args:
61
+ messages (Messages): GPTAPIの入力のmessages
62
+ model_name (str | None, optional): モデル名. Defaults to None.
63
+ functions (Functions | None, optional): GPTAPIの入力のfunctions. Defaults to None.
64
+
65
+ Returns:
66
+ int: トークン数
67
+ """
68
+ model_name = normalize_model_name(model_name or "cl100k_base")
69
+ num_tokens = _count_messages_tokens(messages, model_name)
70
+ if functions is not None:
71
+ num_tokens += _count_functions_tokens(functions, model_name)
72
+ return num_tokens
73
+
74
+
75
+ # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
76
+ def _count_messages_tokens(messages: Any | None, model_name: str) -> int:
77
+ """メッセージのトークン数を計算
78
+
79
+ Args:
80
+ messages (Messages): ChatGPT等APIに入力するmessages
81
+ model_name (str, optional): 使用するモデルの名前
82
+ "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314", "gpt-4-32k-0314"
83
+ "gpt-4-0613", "gpt-4-32k-0613", "gpt-3.5-turbo", "gpt-4"
84
+
85
+ Returns:
86
+ int: トークン数の合計
87
+ """
88
+ if messages is None:
89
+ return 0
90
+ # setting model
91
+ encoding_model = tiktoken.encoding_for_model(model_name) # "cl100k_base"
92
+
93
+ # config
94
+ if model_name == "gpt-3.5-turbo-0301":
95
+ tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
96
+ tokens_per_name = -1 # if there's a name, the role is omitted
97
+ else:
98
+ tokens_per_message = 3
99
+ tokens_per_name = 1
100
+
101
+ # count tokens
102
+ num_tokens = 3 # every reply is primed with <|start|>assistant<|message|>
103
+ for message in messages:
104
+ num_tokens += tokens_per_message
105
+ for key, value in message.items():
106
+ if isinstance(value, str):
107
+ num_tokens += len(encoding_model.encode(value))
108
+ if key == "name":
109
+ num_tokens += tokens_per_name
110
+ return num_tokens
111
+
112
+
113
+ # https://gist.github.com/CGamesPlay/dd4f108f27e2eec145eedf5c717318f5
114
+ def _count_functions_tokens(functions: Any, model_name: str | None = None) -> int:
115
+ """
116
+ functionsのトークン数計測
117
+
118
+ Args:
119
+ functions (Functions): GPTAPIの入力のfunctions
120
+ model_name (str | None, optional): モデル名. Defaults to None.
121
+
122
+ Returns:
123
+ _type_: トークン数
124
+ """
125
+ encoding_model = tiktoken.encoding_for_model(model_name or "cl100k_base") # "cl100k_base"
126
+ num_tokens = 3 + len(encoding_model.encode(__functions2string(functions)))
127
+ return num_tokens
128
+
129
+
130
+ # functionsのstring化、補助関数 ---------------------------------------------------------------------------
131
+ def __functions2string(functions: Any) -> str:
132
+ """functionsの文字列化
133
+
134
+ Args:
135
+ functions (Functions): GPTAPIの入力のfunctions
136
+
137
+ Returns:
138
+ str: functionsの文字列
139
+ """
140
+ prefix = "# Tools\n\n## functions\n\nnamespace functions {\n\n} // namespace functions\n"
141
+ functions_string = prefix + "".join(__function2string(function) for function in functions)
142
+ return functions_string
143
+
144
+
145
+ def __function2string(function: Function) -> str:
146
+ """functionの文字列化
147
+
148
+ Args:
149
+ function (Function): GPTAPIのfunctionの要素
150
+
151
+ Returns:
152
+ str: functionの文字列
153
+ """
154
+ object_string = __format_object(function["parameters"])
155
+ if object_string is not None:
156
+ object_string = "_: " + object_string
157
+ else:
158
+ object_string = ""
159
+
160
+ functions_string: str = (
161
+ f"// {function['description']}\ntype {function['name']} = (" + object_string + ") => any;\n\n"
162
+ )
163
+ return functions_string
164
+
165
+
166
+ def __format_object(schema: dict[str, Any], indent: int = 0) -> str | None:
167
+ if "properties" not in schema or len(schema["properties"]) == 0:
168
+ if schema.get("additionalProperties", False):
169
+ return "object"
170
+ return None
171
+
172
+ result = "{\n"
173
+ for key, value in dict(schema["properties"]).items():
174
+ # value <- resolve_ref(value)
175
+ value_rendered = __format_schema(value, indent + 1)
176
+ if value_rendered is None:
177
+ continue
178
+ # description
179
+ if "description" in value:
180
+ description = "".join(
181
+ " " * indent + f"// {description_i}\n"
182
+ for description_i in textwrap.dedent(value["description"]).strip().split("\n")
183
+ )
184
+ # optional
185
+ optional = "" if key in schema.get("required", {}) else "?"
186
+ # default
187
+ default_comment = "" if "default" not in value else f" // default: {__format_default(value)}"
188
+ # add string
189
+ result += description + " " * indent + f"{key}{optional}: {value_rendered},{default_comment}\n"
190
+ result += (" " * (indent - 1)) + "}"
191
+ return result
192
+
193
+
194
+ # よくわからん
195
+ # def resolve_ref(schema):
196
+ # if schema.get("$ref") is not None:
197
+ # ref = schema["$ref"][14:]
198
+ # schema = json_schema["definitions"][ref]
199
+ # return schema
200
+
201
+
202
+ def __format_schema(schema: dict[str, Any], indent: int) -> str | None:
203
+ # schema <- resolve_ref(schema)
204
+ if "enum" in schema:
205
+ return __format_enum(schema)
206
+ elif schema["type"] == "object":
207
+ return __format_object(schema, indent)
208
+ elif schema["type"] in {"integer", "number"}:
209
+ return "number"
210
+ elif schema["type"] in {"string"}:
211
+ return "string"
212
+ elif schema["type"] == "array":
213
+ return str(__format_schema(schema["items"], indent)) + "[]"
214
+ else:
215
+ raise ValueError("unknown schema type " + schema["type"])
216
+
217
+
218
+ def __format_enum(schema: dict[str, Any]) -> str:
219
+ # "A" | "B" | "C"
220
+ return " | ".join(json.dumps(element, ensure_ascii=False) for element in schema["enum"])
221
+
222
+
223
+ def __format_default(schema: dict[str, Any]) -> str:
224
+ default = schema["default"]
225
+ if schema["type"] == "number" and float(default).is_integer():
226
+ # numberの時、0 → 0.0
227
+ return f"{default:.1f}"
228
+ else:
229
+ return str(default)
neollm/utils/utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+ from typing_extensions import TypedDict
7
+
8
+ from neollm.types import PrintColor
9
+
10
+
11
+ class CPrintParam(TypedDict, total=False):
12
+ text: Any
13
+ color: PrintColor | None
14
+ background: bool
15
+ light: bool
16
+ bold: bool
17
+ italic: bool
18
+ underline: bool
19
+ kwargs: dict[str, Any]
20
+
21
+
22
+ def cprint(
23
+ *text: Any,
24
+ color: PrintColor | None = None,
25
+ background: bool = False,
26
+ light: bool = False,
27
+ bold: bool = False,
28
+ italic: bool = False,
29
+ underline: bool = False,
30
+ kwargs: dict[str, Any] = {},
31
+ ) -> None:
32
+ """
33
+ 色付けなどリッチにprint
34
+
35
+ Args:
36
+ *text: 表示するテキスト。
37
+ color (PrintColor): テキストの色: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'。
38
+ background (bool): 背景色
39
+ light (bool): 淡い色にするか
40
+ bold (bool): 太字
41
+ italic (bool): 斜体
42
+ underline (bool): 下線
43
+ **kwargs: printの引数
44
+ """
45
+ # ANSIエスケープシーケンスを使用して、テキストを書式設定して表示する
46
+ format_string = ""
47
+
48
+ # 色の設定
49
+ color2code: dict[PrintColor, int] = {
50
+ "black": 30,
51
+ "red": 31,
52
+ "green": 32,
53
+ "yellow": 33,
54
+ "blue": 34,
55
+ "magenta": 35,
56
+ "cyan": 36,
57
+ "white": 37,
58
+ }
59
+ if color is not None and color in color2code:
60
+ code = color2code[color]
61
+ if background:
62
+ code += 10
63
+ elif light:
64
+ code += 60
65
+ format_string += f"\033[{code}m"
66
+ if bold:
67
+ format_string += "\033[1m"
68
+ if italic:
69
+ format_string += "\033[3m"
70
+ if underline:
71
+ format_string += "\033[4m"
72
+
73
+ # テキストの表示
74
+ for text_i in text:
75
+ print(format_string + str(text_i) + "\033[0m", **kwargs)
76
+
77
+
78
+ def ensure_env_var(var_name: str | None = None, default: str | None = None) -> str:
79
+ if var_name is None:
80
+ return ""
81
+ if os.environ.get(var_name, "") == "":
82
+ if default is None:
83
+ raise ValueError(f"{var_name}をenvで設定しよう")
84
+ cprint(f"WARNING: {var_name}が設定されていません。{default}を使用します。", color="yellow", background=True)
85
+ os.environ[var_name] = default
86
+ return os.environ[var_name]
87
+
88
+
89
+ def suport_unrecomended_env_var(old_key: str, new_key: str) -> None:
90
+ """非推奨の環境変数をサポートする
91
+
92
+ Args:
93
+ old_key (str): 非推奨の環境変数名
94
+ new_key (str): 推奨の環境変数名
95
+ """
96
+ if os.getenv(old_key) is not None and os.getenv(new_key) is None:
97
+ cprint(f"WARNING: {old_key}ではなく、{new_key}にしてね", color="yellow", background=True)
98
+ os.environ[new_key] = os.environ[old_key]
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
project/.env.template ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LLM_PLATFORM=azure
2
+
3
+ # OpenAIkey
4
+ OPENAI_API_KEY=sk-XXX
5
+
6
+ # Azure OpenAIkey
7
+ AZURE_OPENAI_API_KEY=XXX # AZURE_OPENAI_AD_TOKEN=YYY
8
+ AZURE_OPENAI_ENDPOINT=https://neoai-pjname.openai.azure.com/ # (not-recomended): AZURE_API_BASE
9
+ OPENAI_API_VERSION=2024-02-01 # (not-recomended): AZURE_API_VERSION
10
+
11
+ # ENGINE
12
+ # 1106 ----------------------------------------------------------
13
+ AZURE_ENGINE_GPT35T_1106=xxx
14
+ AZURE_ENGINE_GPT4T_1106=xxx
15
+ # 0613 ----------------------------------------------------------
16
+ AZURE_ENGINE_GPT35T_0613=neoai-pjname-gpt-35 # (not-recomended): AZURE_ENGINE_GPT35, AZURE_ENGINE_GPT35_0613
17
+ AZURE_ENGINE_GPT35T_16K_0613=neoai-pjname-gpt-35-16k # (not-recomended): AZURE_ENGINE_GPT35_16k, AZURE_ENGINE_GPT35_16K_0613
18
+ AZURE_ENGINE_GPT4_0613=neoai-pjname-gpt4 # (not-recomended): AZURE_ENGINE_GPT4
19
+ AZURE_ENGINE_GPT4_32K_0613=neoai-pjname-gpt4-32k # (not recomended): AZURE_ENGINE_GPT4_32k
20
+
21
+ # Anthropic
22
+ ANTHROPIC_API_KEY=xxx
23
+
24
+ # GCP
project/ex_module/ex_profile_extractor.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, TypedDict
2
+
3
+ from neollm import MyLLM
4
+ from neollm.types import Functions
5
+ from neollm.utils.postprocess import json2dict
6
+ from neollm.utils.preprocess import optimize_token
7
+
8
+
9
+ class ProfileExtractorInputType(TypedDict):
10
+ text: str
11
+
12
+
13
+ class ProfileExtractorOuputType(TypedDict):
14
+ name: str
15
+ birth_year: int
16
+ domain: str
17
+ lang: Literal["ENG", "JPN"]
18
+
19
+
20
+ class ProfileExtractor(MyLLM):
21
+ """情報を抽出するMyLLM
22
+
23
+ Notes:
24
+ inputs:
25
+ >>> {"text": str}
26
+ outpus:
27
+ >>> {"text_translated": str | None(うまくいかなかった場合)}
28
+ """
29
+
30
+ def _preprocess(self, inputs: ProfileExtractorInputType):
31
+ system_prompt = "<input>より情報を抽出する。存在しない場合nullとする"
32
+ user_prompt = "<input>\n" f"'''{inputs['text'].strip()}'''"
33
+ messages = [
34
+ {"role": "system", "content": optimize_token(system_prompt)},
35
+ {"role": "user", "content": optimize_token(user_prompt)},
36
+ ]
37
+ return messages
38
+
39
+ def _check_input(
40
+ self, inputs: ProfileExtractorInputType, messages
41
+ ) -> tuple[bool, ProfileExtractorOuputType | None]:
42
+ # 入力がない場合の処理
43
+ if inputs["text"].strip() == "":
44
+ # requestしない, ルールベースのoutput
45
+ return False, {"name": "", "birth_year": -1, "domain": "", "lang": "JPN"}
46
+ # 入力が多い時に16kを使う
47
+ if self.llm.count_tokens(messages) >= 1600:
48
+ self.model = "gpt-3.5-turbo-16k"
49
+ else:
50
+ self.model = "gpt-3.5-turbo"
51
+ # requestする, _
52
+ return True, None
53
+
54
+ def _postprocess(self, response) -> ProfileExtractorOuputType:
55
+ if dict(response["choices"][0]["message"]).get("function_call"):
56
+ try:
57
+ extracted_data = json2dict(response["choices"][0]["message"]["function_call"]["arguments"])
58
+ except Exception:
59
+ extracted_data = {}
60
+ else:
61
+ extracted_data = {}
62
+
63
+ lang_ = extracted_data.get("lang")
64
+ if lang_ in {"ENG", "JPN"}:
65
+ lang = lang_
66
+ else:
67
+ lang = "JPN"
68
+
69
+ outputs: ProfileExtractorOuputType = {
70
+ "name": str(extracted_data.get("name") or ""),
71
+ "birth_year": int(extracted_data.get("birth_year") or -1),
72
+ "domain": str(extracted_data.get("domain") or ""),
73
+ "lang": lang,
74
+ }
75
+ return outputs
76
+
77
+ # Function Callingを使う場合必要
78
+ def _add_functions(self, inputs: Any) -> Functions | None:
79
+ functions: Functions = [
80
+ {
81
+ "name": "extract_profile",
82
+ "description": "extract profile of a person",
83
+ "parameters": {
84
+ "type": "object",
85
+ "properties": {
86
+ "name": {
87
+ "type": "string",
88
+ "description": "名前",
89
+ },
90
+ "domain": {
91
+ "type": "string",
92
+ "description": "研究ドメイン カンマ区切り",
93
+ },
94
+ "birth_year": {
95
+ "type": "integer",
96
+ "description": "the year of the birth YYYY",
97
+ },
98
+ "lang": {
99
+ "type": "string",
100
+ "description": "the language of the text",
101
+ "enum": ["ENG", "JPN"],
102
+ },
103
+ },
104
+ "required": ["name", "birth_year", "domain", "lang"],
105
+ },
106
+ }
107
+ ]
108
+ return functions
109
+
110
+ # 型定義のために必要
111
+ def __call__(self, inputs: ProfileExtractorInputType) -> ProfileExtractorOuputType:
112
+ outputs: ProfileExtractorOuputType = super().__call__(inputs)
113
+ return outputs
project/ex_module/ex_translated_profile_extractor.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict
2
+
3
+ from ex_profile_extractor import ProfileExtractor, ProfileExtractorInputType
4
+ from ex_translator import Translator
5
+
6
+ from neollm import MyL3M2
7
+
8
+
9
+ class TranslatedProfileExtractorOutputType(TypedDict):
10
+ name_ENG: str
11
+ name_JPN: str
12
+ domain_ENG: str
13
+ domain_JPN: str
14
+ birth_year: int
15
+
16
+
17
+ class TranslatedProfileExtractor(MyL3M2):
18
+ def _link(self, inputs: ProfileExtractorInputType) -> TranslatedProfileExtractorOutputType:
19
+ # Profile Extract
20
+ profile_extractor = ProfileExtractor(parent=self, silent_list=["llm_settings", "inputs", "messages"])
21
+ profile = profile_extractor(inputs)
22
+ # Translator name
23
+ translator_name = Translator(parent=self, silent_list=["llm_settings", "inputs", "messages"])
24
+ translated_name = translator_name(inputs={"text": profile["name"]})["text_translated"]
25
+ # Translate domain
26
+ translator_domain = Translator(parent=self, silent_list=["llm_settings", "inputs", "messages"])
27
+ translated_domain = translator_domain(inputs={"text": profile["domain"]})["text_translated"]
28
+
29
+ outputs: TranslatedProfileExtractorOutputType = {
30
+ "name_ENG": profile["name"],
31
+ "name_JPN": profile["name"],
32
+ "domain_ENG": profile["domain"],
33
+ "domain_JPN": profile["domain"],
34
+ "birth_year": profile["birth_year"],
35
+ }
36
+
37
+ if profile["lang"] == "ENG":
38
+ outputs["name_JPN"] = translated_name
39
+ outputs["domain_JPN"] = translated_domain
40
+ else:
41
+ outputs["name_ENG"] = translated_name
42
+ outputs["domain_ENG"] = translated_domain
43
+
44
+ return outputs
45
+
46
+ # 型定義のために必要
47
+ def __call__(self, inputs: ProfileExtractorInputType) -> TranslatedProfileExtractorOutputType:
48
+ outputs: TranslatedProfileExtractorOutputType = super().__call__(inputs)
49
+ return outputs
project/ex_module/ex_translator.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict
2
+
3
+ from neollm import MyLLM
4
+ from neollm.types import Messages, OpenAIResponse
5
+ from neollm.utils.postprocess import strip_string
6
+ from neollm.utils.preprocess import optimize_token
7
+
8
+
9
+ class TranslatorInputType(TypedDict):
10
+ text: str
11
+
12
+
13
+ class TranslatorOuputType(TypedDict):
14
+ text_translated: str
15
+
16
+
17
+ class Translator(MyLLM):
18
+ """情報を抽出するMyLLM
19
+
20
+ Notes:
21
+ inputs:
22
+ >>> {"text": str}
23
+ outpus:
24
+ >>> {"text_translated": str | None(うまくいかなかった場合)}
25
+ """
26
+
27
+ def _preprocess(self, inputs: TranslatorInputType) -> Messages:
28
+ system_prompt = (
29
+ "You are a good translator. Translate Japanese into English or English into Japanese.\n"
30
+ "# output_format:\n<output>\n{translated text in English or Japanese}"
31
+ )
32
+ user_prompt = "<input>\n" f"'''{inputs['text'].strip()}'''"
33
+ messages: Messages = [
34
+ {"role": "system", "content": optimize_token(system_prompt)},
35
+ {"role": "user", "content": optimize_token(user_prompt)},
36
+ ]
37
+ return messages
38
+
39
+ def _ruleprocess(self, inputs: TranslatorInputType) -> None | TranslatorOuputType:
40
+ # 入力がない場合の処理
41
+ if inputs["text"].strip() == "":
42
+ return {"text_translated": ""}
43
+ return None
44
+
45
+ def _update_settings(self) -> None:
46
+ # 入力が多い時に16kを使う
47
+ if self.messages is not None:
48
+ if self.llm.count_tokens(self.messages) >= 1600:
49
+ self.model = "gpt-3.5-turbo-16k"
50
+ else:
51
+ self.model = "gpt-3.5-turbo"
52
+
53
+ def _postprocess(self, response: OpenAIResponse) -> TranslatorOuputType:
54
+ text_translated: str = str(response.choices[0].message["content"])
55
+ text_translated = strip_string(text=text_translated, first_character=["<output>", "<outputs>"])
56
+ outputs: TranslatorOuputType = {"text_translated": text_translated}
57
+ return outputs
58
+
59
+ # 型定義のために必要
60
+ def __call__(self, inputs: TranslatorInputType) -> TranslatorOuputType:
61
+ outputs: TranslatorOuputType = super().__call__(inputs)
62
+ return outputs
project/neollm-tutorial.ipynb ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# <font color=orange> settings\n"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "### 1. install neollm\n",
15
+ "\n",
16
+ "[Document インストール方法](https://www.notion.so/c760d96f1b4240e6880a32bee96bba35)\n"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "# githubのssh接続してね\n",
26
+ "# versionは適宜変更してね\n",
27
+ "%pip install git+https://github.com/neoAI-inc/neo-llm-module.git@v1.2.6\n"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "metadata": {},
33
+ "source": [
34
+ "### 2 環境変数の設定方法\n",
35
+ "\n",
36
+ "[Document env ファイルの作り方](https://www.notion.so/env-32ebb04105684a77bbc730c39865df34)\n"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 1,
42
+ "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "name": "stdout",
46
+ "output_type": "stream",
47
+ "text": [
48
+ "環境変数読み込み成功\n"
49
+ ]
50
+ }
51
+ ],
52
+ "source": [
53
+ "from dotenv import load_dotenv\n",
54
+ "\n",
55
+ "env_path = \".env\" # .envのpath 適宜変更\n",
56
+ "if load_dotenv(env_path):\n",
57
+ " print(\"環境変数読み込み成功\")\n",
58
+ "else:\n",
59
+ " print(\"path違うよ〜\")"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "metadata": {},
65
+ "source": [
66
+ "# <font color=orange> neoLLM  使い方\n"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "markdown",
71
+ "metadata": {},
72
+ "source": [
73
+ "neollm は、前処理・LLM のリクエスト・後処理を 1 つのクラスにした、Pytorch 的な記法で書ける neoAI の LLM 統一ライブラリ。\n",
74
+ "\n",
75
+ "大きく 2 種類のクラスがあり、MyLLM は 1 つのリクエスト、MyL3M2 は複数のリクエストを受け持つことができる。\n",
76
+ "\n",
77
+ "![概観図](../asset/external_view.png)\n"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "markdown",
82
+ "metadata": {},
83
+ "source": [
84
+ "##### モデルの定義\n"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 2,
90
+ "metadata": {},
91
+ "outputs": [
92
+ {
93
+ "name": "stdout",
94
+ "output_type": "stream",
95
+ "text": [
96
+ "\u001b[43mWARNING: AZURE_API_BASEではなく、AZURE_OPENAI_ENDPOINTにしてね\u001b[0m\n",
97
+ "\u001b[43mWARNING: AZURE_API_VERSIONではなく、OPENAI_API_VERSIONにしてね\u001b[0m\n"
98
+ ]
99
+ }
100
+ ],
101
+ "source": [
102
+ "from neollm import MyLLM\n",
103
+ "\n",
104
+ "# 例: 翻訳をするclass\n",
105
+ "# _preprocess, _postprocessを必ず書く\n",
106
+ "\n",
107
+ "\n",
108
+ "class Translator(MyLLM):\n",
109
+ " # _preprocessは、前処理をしてMessageを作る関数\n",
110
+ " def _preprocess(self, inputs: str):\n",
111
+ " messages = [\n",
112
+ " {\"role\": \"system\", \"content\": \"英語を日本語に翻訳するAIです。\"},\n",
113
+ " {\"role\": \"user\", \"content\": inputs},\n",
114
+ " ]\n",
115
+ " return messages\n",
116
+ "\n",
117
+ " # _postprocessは、APIのResponseを後処理をして、欲しいものを返す関数\n",
118
+ " def _postprocess(self, response):\n",
119
+ " text_translated: str = str(response.choices[0].message.content)\n",
120
+ " return text_translated"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "markdown",
125
+ "metadata": {},
126
+ "source": [
127
+ "##### モデルの呼び出し\n"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 16,
133
+ "metadata": {},
134
+ "outputs": [
135
+ {
136
+ "name": "stdout",
137
+ "output_type": "stream",
138
+ "text": [
139
+ "\u001b[41mPARENT\u001b[0m\n",
140
+ "MyLLM(Translator) ----------------------------------------------------------------------------------\n",
141
+ "\u001b[34m[inputs]\u001b[0m\n",
142
+ "\"Hello, We are neoAI.\"\n",
143
+ "\u001b[34m[messages]\u001b[0m\n",
144
+ " \u001b[32msystem\u001b[0m\n",
145
+ " 英語を日本語に翻訳するAIです。\n",
146
+ " \u001b[32muser\u001b[0m\n",
147
+ " Hello, We are neoAI.\n",
148
+ " \u001b[32massistant\u001b[0m\n",
149
+ " こんにちは、私たちはneoAIです。\n",
150
+ "\u001b[34m[outputs]\u001b[0m\n",
151
+ "\"こんにちは、私たちはneoAIです。\"\n",
152
+ "\u001b[34m[client_settings]\u001b[0m -\n",
153
+ "\u001b[34m[llm_settings]\u001b[0m {'platform': 'azure', 'temperature': 1, 'model': 'gpt-3.5-turbo-0613', 'engine': 'neoai-free-swd-gpt-35-0613'}\n",
154
+ "\u001b[34m[metadata]\u001b[0m 1.6s; 45(36+9)tokens; $6.8e-05; ¥0.0095\n",
155
+ "----------------------------------------------------------------------------------------------------\n",
156
+ "こんにちは、私たちはneoAIです。\n"
157
+ ]
158
+ }
159
+ ],
160
+ "source": [
161
+ "# 初期化 (platformやmodelなど設定をしておく)\n",
162
+ "# 詳細: https://www.notion.so/neollm-MyLLM-581cd7562df9473b91c981d88469c452?pvs=4#ac5361a5e3fa46a48441fdd538858fee\n",
163
+ "translator = Translator(\n",
164
+ " platform=\"azure\", # azure or openai\n",
165
+ " model=\"gpt-3.5-turbo-0613\", # gpt-3.5-turbo-1106, gpt-4-turbo-1106\n",
166
+ " llm_settings={\"temperature\": 1}, # llmの設定 dictで渡す\n",
167
+ ")\n",
168
+ "\n",
169
+ "# 呼び出し\n",
170
+ "# preprocessでinputsとしたものを入力として、postprocessで処理したものを出力とする。\n",
171
+ "translated_text = translator(inputs=\"Hello, We are neoAI.\")\n",
172
+ "print(translated_text)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 17,
178
+ "metadata": {},
179
+ "outputs": [
180
+ {
181
+ "name": "stdout",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "時間 1.5658628940582275\n",
185
+ "token数 TokenInfo(input=36, output=9, total=45)\n",
186
+ "token数合計 45\n",
187
+ "値段(USD) PriceInfo(input=5.4e-05, output=1.8e-05, total=6.75e-05)\n",
188
+ "値段数合計(USD) 6.75e-05\n"
189
+ ]
190
+ }
191
+ ],
192
+ "source": [
193
+ "# 処理時間\n",
194
+ "print(\"時間\", translator.time)\n",
195
+ "# トークン数\n",
196
+ "print(\"token数\", translator.token)\n",
197
+ "print(\"token数合計\", translator.token.total)\n",
198
+ "# 値段の取得\n",
199
+ "print(\"値段(USD)\", translator.price)\n",
200
+ "print(\"値段数合計(USD)\", translator.price.total)"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 20,
206
+ "metadata": {},
207
+ "outputs": [
208
+ {
209
+ "name": "stdout",
210
+ "output_type": "stream",
211
+ "text": [
212
+ "inputs Hello, We are neoAI.\n",
213
+ "messages [{'role': 'system', 'content': '英語を日本語に翻訳するAIです。'}, {'role': 'user', 'content': 'Hello, We are neoAI.'}]\n",
214
+ "response ChatCompletion(id='chatcmpl-8T5MkidV9bhqewdzcUwO1PioHOSHi', choices=[Choice(finish_reason='stop', index=0, message=ChatCompletionMessage(content='こんにちは、私たちはneoAIです。', role='assistant', function_call=None, tool_calls=None), content_filter_results={'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}})], created=1701942830, model='gpt-35-turbo', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=9, prompt_tokens=36, total_tokens=45), prompt_filter_results=[{'prompt_index': 0, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}}}])\n",
215
+ "outputs こんにちは、私たちはneoAIです。\n",
216
+ "chat_history [{'role': 'system', 'content': '英語を日本語に翻訳するAIです。'}, {'role': 'user', 'content': 'Hello, We are neoAI.'}, {'content': 'こんにちは、私たちはneoAIです。', 'role': 'assistant'}]\n"
217
+ ]
218
+ }
219
+ ],
220
+ "source": [
221
+ "# その他property\n",
222
+ "print(\"inputs\", translator.inputs)\n",
223
+ "print(\"messages\", translator.messages)\n",
224
+ "print(\"response\", translator.response)\n",
225
+ "print(\"outputs\", translator.outputs)\n",
226
+ "\n",
227
+ "print(\"chat_history\", translator.chat_history)"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "markdown",
232
+ "metadata": {},
233
+ "source": [
234
+ "# <font color=orange> neoLLM  例\n"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "markdown",
239
+ "metadata": {},
240
+ "source": [
241
+ "### 1-1 MyLLM (ex. 翻訳)\n"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 21,
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": [
250
+ "from neollm import MyLLM\n",
251
+ "from neollm.utils.preprocess import optimize_token\n",
252
+ "from neollm.utils.postprocess import strip_string\n",
253
+ "\n",
254
+ "\n",
255
+ "class Translator(MyLLM):\n",
256
+ " def _preprocess(self, inputs):\n",
257
+ " system_prompt = (\n",
258
+ " \"You are a good translator. Translate Japanese into English or English into Japanese.\\n\"\n",
259
+ " \"# output_format:\\n<output>\\n{translated text in English or Japanese}\"\n",
260
+ " )\n",
261
+ " user_prompt = \"<input>\\n\" f\"'''{inputs['text'].strip()}'''\"\n",
262
+ " messages = [\n",
263
+ " {\"role\": \"system\", \"content\": optimize_token(system_prompt)},\n",
264
+ " {\"role\": \"user\", \"content\": optimize_token(user_prompt)},\n",
265
+ " ]\n",
266
+ " return messages\n",
267
+ "\n",
268
+ " def _ruleprocess(self, inputs):\n",
269
+ " # 例外処理\n",
270
+ " if inputs[\"text\"].strip() == \"\":\n",
271
+ " return {\"text_translated\": \"\"}\n",
272
+ " # APIリクエストを送る場合はNone\n",
273
+ " return None\n",
274
+ "\n",
275
+ " def _update_settings(self):\n",
276
+ " # 入力によってAPIの設定を変更する\n",
277
+ "\n",
278
+ " # トークン数: self.llm.count_tokens(self.messsage)\n",
279
+ "\n",
280
+ " # モデル変更: self.model = \"gpt-3.5-turbo-16k\"\n",
281
+ "\n",
282
+ " # パラメータ変更: self.llm_settings = {\"temperature\": 0.2}\n",
283
+ "\n",
284
+ " # 入力が多い時に16kを使う(1106の場合はやらなくていい)\n",
285
+ " if self.messages is not None:\n",
286
+ " if self.llm.count_tokens(self.messages) >= 1600:\n",
287
+ " self.model = \"gpt-3.5-turbo-16k-0613\"\n",
288
+ " else:\n",
289
+ " self.model = \"gpt-3.5-turbo-0613\"\n",
290
+ "\n",
291
+ " def _postprocess(self, response):\n",
292
+ " text_translated: str = str(response.choices[0].message.content)\n",
293
+ " text_translated = strip_string(text=text_translated, first_character=[\"<output>\", \"<outputs>\"])\n",
294
+ " outputs = {\"text_translated\": text_translated}\n",
295
+ " return outputs"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 23,
301
+ "metadata": {},
302
+ "outputs": [
303
+ {
304
+ "name": "stdout",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "\u001b[41mPARENT\u001b[0m\n",
308
+ "MyLLM(Translator) ----------------------------------------------------------------------------------\n",
309
+ "\u001b[34m[inputs]\u001b[0m\n",
310
+ "{\n",
311
+ " \"text\": \"大規模LLMモデル\"\n",
312
+ "}\n",
313
+ "\u001b[34m[messages]\u001b[0m\n",
314
+ " \u001b[32msystem\u001b[0m\n",
315
+ " You are a good translator. Translate Japanese into English or English into Japanese.\n",
316
+ " # output_format:\n",
317
+ " <output>\n",
318
+ " {translated text in English or Japanese}\n",
319
+ " \u001b[32muser\u001b[0m\n",
320
+ " <input>\n",
321
+ " '''大規模LLMモデル'''\n",
322
+ " \u001b[32massistant\u001b[0m\n",
323
+ " <output>\n",
324
+ " \"Large-Scale LLM Model\"\n",
325
+ "\u001b[34m[outputs]\u001b[0m\n",
326
+ "{\n",
327
+ " \"text_translated\": \"Large-Scale LLM Model\"\n",
328
+ "}\n",
329
+ "\u001b[34m[client_settings]\u001b[0m -\n",
330
+ "\u001b[34m[llm_settings]\u001b[0m {'platform': 'azure', 'temperature': 1, 'model': 'gpt-3.5-turbo-0613', 'engine': 'neoai-free-swd-gpt-35-0613'}\n",
331
+ "\u001b[34m[metadata]\u001b[0m 1.5s; 66(55+11)tokens; $9.9e-05; ¥0.014\n",
332
+ "----------------------------------------------------------------------------------------------------\n",
333
+ "{'text_translated': 'Large-Scale LLM Model'}\n"
334
+ ]
335
+ }
336
+ ],
337
+ "source": [
338
+ "translator = Translator(\n",
339
+ " llm_settings={\"temperature\": 1}, # defaultは、{\"temperature\": 0}\n",
340
+ " model=\"gpt-3.5-turbo-0613\", # defaultは、DEFAULT_MODEL_NAME\n",
341
+ " platform=\"azure\", # defaultは、LLM_PLATFORM\n",
342
+ " verbose=True,\n",
343
+ " silent_list=[], # 表示しないもの\n",
344
+ ")\n",
345
+ "output_1 = translator(inputs={\"text\": \"大規模LLMモデル\"})\n",
346
+ "print(output_1)"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": 25,
352
+ "metadata": {},
353
+ "outputs": [
354
+ {
355
+ "name": "stdout",
356
+ "output_type": "stream",
357
+ "text": [
358
+ "\u001b[43mWARNING: model_nameに日付を指定してください\u001b[0m\n",
359
+ "model_name: gpt-3.5-turbo -> gpt-3.5-turbo-0613\n",
360
+ "\u001b[41mPARENT\u001b[0m\n",
361
+ "MyLLM(Translator) ----------------------------------------------------------------------------------\n",
362
+ "\u001b[34m[inputs]\u001b[0m\n",
363
+ "{\n",
364
+ " \"text\": \"Large LLM Model\"\n",
365
+ "}\n",
366
+ "\u001b[34m[messages]\u001b[0m\n",
367
+ " \u001b[32msystem\u001b[0m\n",
368
+ " You are a good translator. Translate Japanese into English or English into Japanese.\n",
369
+ " # output_format:\n",
370
+ " <output>\n",
371
+ " {translated text in English or Japanese}\n",
372
+ " \u001b[32muser\u001b[0m\n",
373
+ " <input>\n",
374
+ " '''Large LLM Model'''\n",
375
+ "\u001b[43mWARNING: model_nameに日付を指定してください\u001b[0m\n",
376
+ "model_name: gpt-3.5-turbo -> gpt-3.5-turbo-0613\n",
377
+ " \u001b[32massistant\u001b[0m\n",
378
+ " <output>\n",
379
+ " 大きなLLMモデル\n",
380
+ "\u001b[34m[outputs]\u001b[0m\n",
381
+ "{\n",
382
+ " \"text_translated\": \"大きなLLMモデル\"\n",
383
+ "}\n",
384
+ "\u001b[34m[client_settings]\u001b[0m -\n",
385
+ "\u001b[34m[llm_settings]\u001b[0m {'platform': 'openai', 'temperature': 0, 'model': 'gpt-3.5-turbo-0613'}\n",
386
+ "\u001b[34m[metadata]\u001b[0m 0.9s; 61(49+12)tokens; $9.2e-05; ¥0.013\n",
387
+ "----------------------------------------------------------------------------------------------------\n",
388
+ "{'text_translated': '大きなLLMモデル'}\n"
389
+ ]
390
+ }
391
+ ],
392
+ "source": [
393
+ "translator = Translator(\n",
394
+ " platform=\"openai\", # <- 変えてみる\n",
395
+ " verbose=True,\n",
396
+ ")\n",
397
+ "output_1 = translator(inputs={\"text\": \"Large LLM Model\"})\n",
398
+ "print(output_1)"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": 26,
404
+ "metadata": {},
405
+ "outputs": [
406
+ {
407
+ "name": "stdout",
408
+ "output_type": "stream",
409
+ "text": [
410
+ "\u001b[43mWARNING: model_nameに日��を指定してください\u001b[0m\n",
411
+ "model_name: gpt-3.5-turbo -> gpt-3.5-turbo-0613\n",
412
+ "\u001b[41mPARENT\u001b[0m\n",
413
+ "MyLLM(Translator) ----------------------------------------------------------------------------------\n",
414
+ "\u001b[34m[inputs]\u001b[0m\n",
415
+ "{\n",
416
+ " \"text\": \"\"\n",
417
+ "}\n",
418
+ "\u001b[34m[outputs]\u001b[0m\n",
419
+ "{\n",
420
+ " \"text_translated\": \"\"\n",
421
+ "}\n",
422
+ "\u001b[34m[client_settings]\u001b[0m -\n",
423
+ "\u001b[34m[llm_settings]\u001b[0m {'platform': 'azure', 'temperature': 0, 'model': 'gpt-3.5-turbo-0613', 'engine': 'neoai-free-swd-gpt-35-0613'}\n",
424
+ "\u001b[34m[metadata]\u001b[0m 0.0s; 0(0+0)tokens; $0; ¥0\n",
425
+ "----------------------------------------------------------------------------------------------------\n"
426
+ ]
427
+ },
428
+ {
429
+ "data": {
430
+ "text/plain": [
431
+ "{'text_translated': ''}"
432
+ ]
433
+ },
434
+ "execution_count": 26,
435
+ "metadata": {},
436
+ "output_type": "execute_result"
437
+ }
438
+ ],
439
+ "source": [
440
+ "# ルールベースが起動\n",
441
+ "data = {\"text\": \"\"}\n",
442
+ "translator = Translator(verbose=True)\n",
443
+ "translator(data)"
444
+ ]
445
+ },
446
+ {
447
+ "cell_type": "code",
448
+ "execution_count": 27,
449
+ "metadata": {},
450
+ "outputs": [
451
+ {
452
+ "name": "stdout",
453
+ "output_type": "stream",
454
+ "text": [
455
+ "\u001b[43mWARNING: model_nameに日付を指定してください\u001b[0m\n",
456
+ "model_name: gpt-3.5-turbo -> gpt-3.5-turbo-0613\n",
457
+ "\u001b[41mPARENT\u001b[0m\n",
458
+ "MyLLM(Translator) ----------------------------------------------------------------------------------\n",
459
+ "\u001b[34m[inputs]\u001b[0m\n",
460
+ "{\n",
461
+ " \"text\": \"こんにちは!!\\nこんにちは?こんにちは?\"\n",
462
+ "}\n",
463
+ "\u001b[34m[messages]\u001b[0m\n",
464
+ " \u001b[32msystem\u001b[0m\n",
465
+ " You are a good translator. Translate Japanese into English or English into Japanese.\n",
466
+ " # output_format:\n",
467
+ " <output>\n",
468
+ " {translated text in English or Japanese}\n",
469
+ " \u001b[32muser\u001b[0m\n",
470
+ " <input>\n",
471
+ " '''こんにちは!!\n",
472
+ " こんにちは?こんにちは?'''\n",
473
+ "\u001b[43mWARNING: model_nameに日付を指定してください\u001b[0m\n",
474
+ "model_name: gpt-3.5-turbo -> gpt-3.5-turbo-0613\n",
475
+ " \u001b[32massistant\u001b[0m\n",
476
+ " <output>\n",
477
+ " Hello!!\n",
478
+ " Hello? Hello?\n",
479
+ "\u001b[34m[outputs]\u001b[0m\n",
480
+ "{\n",
481
+ " \"text_translated\": \"Hello!!\\nHello? Hello?\"\n",
482
+ "}\n",
483
+ "\u001b[34m[client_settings]\u001b[0m -\n",
484
+ "\u001b[34m[llm_settings]\u001b[0m {'platform': 'azure', 'temperature': 0, 'model': 'gpt-3.5-turbo-0613', 'engine': 'neoai-free-swd-gpt-35-0613'}\n",
485
+ "\u001b[34m[metadata]\u001b[0m 1.4s; 60(51+9)tokens; $9e-05; ¥0.013\n",
486
+ "----------------------------------------------------------------------------------------------------\n"
487
+ ]
488
+ },
489
+ {
490
+ "data": {
491
+ "text/plain": [
492
+ "{'text_translated': 'Hello!!\\nHello? Hello?'}"
493
+ ]
494
+ },
495
+ "execution_count": 27,
496
+ "metadata": {},
497
+ "output_type": "execute_result"
498
+ }
499
+ ],
500
+ "source": [
501
+ "data = {\"text\": \"こんにちは!!\\nこんにちは?こんにちは?\"}\n",
502
+ "translator = Translator(verbose=True)\n",
503
+ "translator(data)"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "markdown",
508
+ "metadata": {},
509
+ "source": [
510
+ "## 情報抽出\n"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": 50,
516
+ "metadata": {},
517
+ "outputs": [],
518
+ "source": [
519
+ "from neollm import MyLLM\n",
520
+ "from neollm.utils.preprocess import optimize_token, dict2json\n",
521
+ "from neollm.utils.postprocess import json2dict\n",
522
+ "\n",
523
+ "\n",
524
+ "class Extractor(MyLLM):\n",
525
+ " def _preprocess(self, inputs):\n",
526
+ " system_prompt = \"<INFO>から、<OUTPUT_FORMAT>にしたがって、情報を抽出しなさい。\"\n",
527
+ " output_format = {\"date\": \"yy-mm-dd形式 日付\", \"event\": \"起きたことを簡潔に。\"}\n",
528
+ " user_prompt = (\n",
529
+ " \"<INFO>\\n\"\n",
530
+ " \"```\\n\"\n",
531
+ " f\"{inputs['info'].strip()}\\n\"\n",
532
+ " \"```\\n\"\n",
533
+ " \"\\n\"\n",
534
+ " \"<OUTPUT_FORMAT>\\n\"\n",
535
+ " \"```json\\n\"\n",
536
+ " f\"{dict2json(output_format)}\\n\"\n",
537
+ " \"```\"\n",
538
+ " )\n",
539
+ "\n",
540
+ " messages = [\n",
541
+ " {\"role\": \"system\", \"content\": optimize_token(system_prompt)},\n",
542
+ " {\"role\": \"user\", \"content\": optimize_token(user_prompt)},\n",
543
+ " ]\n",
544
+ " return messages\n",
545
+ "\n",
546
+ " def _ruleprocess(self, inputs):\n",
547
+ " # 例外処理\n",
548
+ " if inputs[\"info\"].strip() == \"\":\n",
549
+ " return {\"date\": \"\", \"event\": \"\"}\n",
550
+ " # APIリクエストを送る場合はNone\n",
551
+ " return None\n",
552
+ "\n",
553
+ " def _postprocess(self, response):\n",
554
+ " return json2dict(response.choices[0].message.content)"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "code",
559
+ "execution_count": 51,
560
+ "metadata": {},
561
+ "outputs": [
562
+ {
563
+ "name": "stdout",
564
+ "output_type": "stream",
565
+ "text": [
566
+ "\u001b[41mPARENT\u001b[0m\n",
567
+ "MyLLM(Extractor) -----------------------------------------------------------------------------------\n",
568
+ "\u001b[34m[inputs]\u001b[0m\n",
569
+ "{\n",
570
+ " \"info\": \"2021年6月13日に、neoAIのサービスが始まりました。\"\n",
571
+ "}\n",
572
+ "\u001b[34m[messages]\u001b[0m\n",
573
+ " \u001b[32msystem\u001b[0m\n",
574
+ " <INFO>から、<OUTPUT_FORMAT>にしたがって、情報を抽出しなさい。\n",
575
+ " \u001b[32muser\u001b[0m\n",
576
+ " <INFO>\n",
577
+ " ```\n",
578
+ " 2021年6月13日に、neoAIのサービスが始まりました。\n",
579
+ " ```\n",
580
+ " \n",
581
+ " <OUTPUT_FORMAT>\n",
582
+ " ```json\n",
583
+ " {\n",
584
+ " \"date\": \"yy-mm-dd形式 日付\",\n",
585
+ " \"event\": \"起きたことを簡潔に。\"\n",
586
+ " }\n",
587
+ " ```\n",
588
+ " \u001b[32massistant\u001b[0m\n",
589
+ " ```json\n",
590
+ " {\n",
591
+ " \"date\": \"2021-06-13\",\n",
592
+ " \"event\": \"neoAIのサービスが始まりました。\"\n",
593
+ " }\n",
594
+ " ```\n",
595
+ "\u001b[34m[outputs]\u001b[0m\n",
596
+ "{\n",
597
+ " \"date\": \"2021-06-13\",\n",
598
+ " \"event\": \"neoAIのサービスが始まりました。\"\n",
599
+ "}\n",
600
+ "\u001b[34m[client_settings]\u001b[0m -\n",
601
+ "\u001b[34m[llm_settings]\u001b[0m {'platform': 'azure', 'temperature': 0, 'model': 'gpt-3.5-turbo-0613', 'engine': 'neoai-free-swd-gpt-35-0613'}\n",
602
+ "\u001b[34m[metadata]\u001b[0m 1.6s; 143(106+37)tokens; $0.00021; ¥0.03\n",
603
+ "----------------------------------------------------------------------------------------------------\n"
604
+ ]
605
+ },
606
+ {
607
+ "data": {
608
+ "text/plain": [
609
+ "{'date': '2021-06-13', 'event': 'neoAIのサービスが始まりました。'}"
610
+ ]
611
+ },
612
+ "execution_count": 51,
613
+ "metadata": {},
614
+ "output_type": "execute_result"
615
+ }
616
+ ],
617
+ "source": [
618
+ "extractor = Extractor(model=\"gpt-3.5-turbo-0613\")\n",
619
+ "\n",
620
+ "extractor(inputs={\"info\": \"2021年6月13日に、neoAIのサービスが始まりました。\"})"
621
+ ]
622
+ },
623
+ {
624
+ "cell_type": "code",
625
+ "execution_count": 52,
626
+ "metadata": {},
627
+ "outputs": [
628
+ {
629
+ "name": "stdout",
630
+ "output_type": "stream",
631
+ "text": [
632
+ "\u001b[41mPARENT\u001b[0m\n",
633
+ "MyLLM(Extractor) -----------------------------------------------------------------------------------\n",
634
+ "\u001b[34m[inputs]\u001b[0m\n",
635
+ "{\n",
636
+ " \"info\": \"1998年4月1日に、neoAI大学が設立されました。\"\n",
637
+ "}\n",
638
+ "\u001b[34m[messages]\u001b[0m\n",
639
+ " \u001b[32msystem\u001b[0m\n",
640
+ " <INFO>から、<OUTPUT_FORMAT>にしたがって、情報を抽出しなさい。\n",
641
+ " \u001b[32muser\u001b[0m\n",
642
+ " <INFO>\n",
643
+ " ```\n",
644
+ " 1998年4月1日に、neoAI大学が設立されました。\n",
645
+ " ```\n",
646
+ " \n",
647
+ " <OUTPUT_FORMAT>\n",
648
+ " ```json\n",
649
+ " {\n",
650
+ " \"date\": \"yy-mm-dd形式 日付\",\n",
651
+ " \"event\": \"起きたことを簡潔に。\"\n",
652
+ " }\n",
653
+ " ```\n",
654
+ " \u001b[32massistant\u001b[0m\n",
655
+ " <OUTPUT>\n",
656
+ " ```json\n",
657
+ " {\n",
658
+ " \"date\": \"1998-04-01\",\n",
659
+ " \"event\": \"neoAI大学の設立\"\n",
660
+ " }\n",
661
+ " ```\n",
662
+ "\u001b[34m[outputs]\u001b[0m\n",
663
+ "{\n",
664
+ " \"date\": \"1998-04-01\",\n",
665
+ " \"event\": \"neoAI大学の設立\"\n",
666
+ "}\n",
667
+ "\u001b[34m[client_settings]\u001b[0m -\n",
668
+ "\u001b[34m[llm_settings]\u001b[0m {'platform': 'azure', 'temperature': 0, 'model': 'gpt-3.5-turbo-0613', 'engine': 'neoai-free-swd-gpt-35-0613'}\n",
669
+ "\u001b[34m[metadata]\u001b[0m 1.6s; 139(104+35)tokens; $0.00021; ¥0.029\n",
670
+ "----------------------------------------------------------------------------------------------------\n"
671
+ ]
672
+ },
673
+ {
674
+ "data": {
675
+ "text/plain": [
676
+ "{'date': '1998-04-01', 'event': 'neoAI大学の設立'}"
677
+ ]
678
+ },
679
+ "execution_count": 52,
680
+ "metadata": {},
681
+ "output_type": "execute_result"
682
+ }
683
+ ],
684
+ "source": [
685
+ "extractor = Extractor(model=\"gpt-3.5-turbo-0613\")\n",
686
+ "\n",
687
+ "extractor(inputs={\"info\": \"1998年4月1日に、neoAI大学が設立されました。\"})"
688
+ ]
689
+ }
690
+ ],
691
+ "metadata": {
692
+ "kernelspec": {
693
+ "display_name": "Python 3",
694
+ "language": "python",
695
+ "name": "python3"
696
+ },
697
+ "language_info": {
698
+ "codemirror_mode": {
699
+ "name": "ipython",
700
+ "version": 3
701
+ },
702
+ "file_extension": ".py",
703
+ "mimetype": "text/x-python",
704
+ "name": "python",
705
+ "nbconvert_exporter": "python",
706
+ "pygments_lexer": "ipython3",
707
+ "version": "3.10.11"
708
+ },
709
+ "orig_nbformat": 4
710
+ },
711
+ "nbformat": 4,
712
+ "nbformat_minor": 2
713
+ }
pyproject.toml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "neollm"
3
+ version = "1.3.3"
4
+ description = "neo LLM Module for Python 3.10"
5
+ authors = ["KoshiroTerasawa <k.terasawa@neoai.jp>"]
6
+ readme = "README.md"
7
+ packages = [{ include = "neollm" }]
8
+
9
+ [tool.poetry.dependencies]
10
+ python = "^3.10"
11
+ python-dotenv = "^1.0.0"
12
+ pydantic = "^2.4.2"
13
+ openai = "^1.1.1"
14
+ google-cloud-aiplatform = "^1.48.0"
15
+ anthropic = { version = "^0.18.1", extras = ["vertex"] }
16
+ typing-extensions = "^4.8.0"
17
+ google-generativeai = "0.5.2"
18
+ tiktoken = "0.7.0"
19
+
20
+
21
+ [tool.poetry.group.dev.dependencies]
22
+ isort = "^5.12.0"
23
+ black = "24.3.0"
24
+ mypy = "^1.8.0"
25
+ pyproject-flake8 = "^6.1.0"
26
+ ipykernel = "^6.26.0"
27
+ jupyter = "^1.0.0"
28
+ jupyter-client = "^8.6.0"
29
+ pytest = "^8.1.1"
30
+
31
+ [build-system]
32
+ requires = ["poetry-core"]
33
+ build-backend = "poetry.core.masonry.api"
34
+
35
+ [tool.black]
36
+ line-length = 119
37
+ exclude = '''
38
+ /(
39
+ \venv
40
+ | \.git
41
+ | \.hg
42
+ | __pycache__
43
+ | \.mypy_cache
44
+ )/
45
+ '''
46
+
47
+ [tool.isort]
48
+ profile = "black"
49
+ multi_line_output = 3
50
+
51
+ [tool.flake8]
52
+ max-line-length = 119
53
+ extend-ignore = ["E203", "W503", "E501", "E704"]
54
+ exclude = [".venv", ".git", "__pycache__", ".mypy_cache", ".hg"]
55
+ max-complexity = 15
56
+
57
+ [tool.mypy]
58
+ ignore_missing_imports = true
59
+ # follow_imports = normal
60
+ disallow_any_unimported = false
61
+ disallow_any_expr = false # 式でのAny禁止
62
+ disallow_any_decorated = false
63
+ disallow_any_explicit = false # 変数でAny禁止
64
+ disallow_any_generics = true # ジェネリックで書かないの禁止
65
+ disallow_subclassing_any = true # Anyのサブクラス禁止
66
+
67
+ disallow_untyped_calls = true # 型なし関数呼び出し禁止 `a: int = f()`
68
+ disallow_untyped_defs = true # 型なし関数定義禁止 `def f(a: int) -> int`
69
+ disallow_incomplete_defs = true # 一部の型定義を禁止 `def f(a: int, b)`
70
+ check_untyped_defs = true
71
+ disallow_untyped_decorators = true
72
+ no_implicit_optional = true
73
+
74
+ warn_redundant_casts = true
75
+ warn_unused_ignores = true
76
+ warn_return_any = true
77
+ warn_unreachable = true # 辿りつかないコードの検出
78
+ allow_redefinition = false # 変数の再定義を禁止
79
+
80
+ show_error_context = true
81
+ show_column_numbers = true
test/llm/claude/test_claude_llm.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from neollm.llm.gpt.azure_llm import (
2
+ # AzureGPT4_0613,
3
+ # AzureGPT4T_0125,
4
+ # AzureGPT4T_1106,
5
+ # AzureGPT4T_20240409,
6
+ # AzureGPT4VT_1106,
7
+ # AzureGPT35FT,
8
+ # AzureGPT35T16k_0613,
9
+ # AzureGPT35T_0125,
10
+ # AzureGPT35T_0613,
11
+ # AzureGPT35T_1106,
12
+ # AzureGPT432k_0613,
13
+ # )
14
+ # from neollm.types.info import APIPricing
15
+
16
+
17
+ # def test_check_price() -> None:
18
+ # # https://azure.microsoft.com/ja-jp/pricing/details/cognitive-services/openai-service/
19
+
20
+ # # これからのモデル --------------------------------------------------------
21
+ # assert AzureGPT4T_20240409.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
22
+ # # Updated --------------------------------------------------------
23
+ # # GPT3.5T
24
+ # assert AzureGPT35T_0125.dollar_per_ktoken == APIPricing(input=0.0005, output=0.0015)
25
+ # # GPT4
26
+ # assert AzureGPT4T_0125.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
27
+ # assert AzureGPT4VT_1106.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
28
+ # assert AzureGPT4T_1106.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
29
+ # assert AzureGPT4_0613.dollar_per_ktoken == APIPricing(input=0.03, output=0.06)
30
+ # assert AzureGPT432k_0613.dollar_per_ktoken == APIPricing(input=0.06, output=0.12)
31
+ # # FT
32
+ # assert AzureGPT35FT.dollar_per_ktoken == APIPricing(input=0.0005, output=0.0015)
33
+ # # Legacy ---------------------------------------------------------
34
+ # # AzureGPT35T_0301 なし
35
+ # assert AzureGPT35T_0613.dollar_per_ktoken == APIPricing(input=0.0015, output=0.002)
36
+ # assert AzureGPT35T16k_0613.dollar_per_ktoken == APIPricing(input=0.003, output=0.004)
37
+ # assert AzureGPT35T_1106.dollar_per_ktoken == APIPricing(input=0.001, output=0.002)
test/llm/gpt/test_azure_llm.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from neollm.llm.gpt.azure_llm import (
2
+ AzureGPT4_0613,
3
+ AzureGPT4O_20240513,
4
+ AzureGPT4T_0125,
5
+ AzureGPT4T_1106,
6
+ AzureGPT4T_20240409,
7
+ AzureGPT4VT_1106,
8
+ AzureGPT35FT,
9
+ AzureGPT35T16k_0613,
10
+ AzureGPT35T_0125,
11
+ AzureGPT35T_0613,
12
+ AzureGPT35T_1106,
13
+ AzureGPT432k_0613,
14
+ get_azure_llm,
15
+ )
16
+ from neollm.types.info import APIPricing
17
+
18
+
19
+ def test_get_azure_llm() -> None:
20
+
21
+ # no date
22
+ assert get_azure_llm("gpt-3.5-turbo", {}).__class__ == AzureGPT35T_0613
23
+ assert get_azure_llm("gpt-35-turbo", {}).__class__ == AzureGPT35T_0613
24
+ assert get_azure_llm("gpt-3.5-turbo-16k", {}).__class__ == AzureGPT35T16k_0613
25
+ assert get_azure_llm("gpt-35-turbo-16k", {}).__class__ == AzureGPT35T16k_0613
26
+ assert get_azure_llm("gpt-4", {}).__class__ == AzureGPT4_0613
27
+ assert get_azure_llm("gpt-4-32k", {}).__class__ == AzureGPT432k_0613
28
+ assert get_azure_llm("gpt-4-turbo", {}).__class__ == AzureGPT4T_1106
29
+ assert get_azure_llm("gpt-4v-turbo", {}).__class__ == AzureGPT4VT_1106
30
+ assert get_azure_llm("gpt-4o", {}).__class__ == AzureGPT4O_20240513
31
+ # with date
32
+ assert get_azure_llm("gpt-4o-2024-05-13", {}).__class__ == AzureGPT4O_20240513
33
+ assert get_azure_llm("gpt-4-turbo-2024-04-09", {}).__class__ == AzureGPT4T_20240409
34
+ assert get_azure_llm("gpt-3.5-turbo-0125", {}).__class__ == AzureGPT35T_0125
35
+ assert get_azure_llm("gpt-35-turbo-0125", {}).__class__ == AzureGPT35T_0125
36
+ assert get_azure_llm("gpt-4-turbo-0125", {}).__class__ == AzureGPT4T_0125
37
+ assert get_azure_llm("gpt-3.5-turbo-1106", {}).__class__ == AzureGPT35T_1106
38
+ assert get_azure_llm("gpt-35-turbo-1106", {}).__class__ == AzureGPT35T_1106
39
+ assert get_azure_llm("gpt-4-turbo-1106", {}).__class__ == AzureGPT4T_1106
40
+ assert get_azure_llm("gpt-4v-turbo-1106", {}).__class__ == AzureGPT4VT_1106
41
+ assert get_azure_llm("gpt-3.5-turbo-0613", {}).__class__ == AzureGPT35T_0613
42
+ assert get_azure_llm("gpt-35-turbo-0613", {}).__class__ == AzureGPT35T_0613
43
+ assert get_azure_llm("gpt-3.5-turbo-16k-0613", {}).__class__ == AzureGPT35T16k_0613
44
+ assert get_azure_llm("gpt-35-turbo-16k-0613", {}).__class__ == AzureGPT35T16k_0613
45
+ assert get_azure_llm("gpt-4-0613", {}).__class__ == AzureGPT4_0613
46
+ assert get_azure_llm("gpt-4-32k-0613", {}).__class__ == AzureGPT432k_0613
47
+ # ft
48
+ assert get_azure_llm("ft:gpt-3.5-turbo-1106-XXXX", {}).__class__ == AzureGPT35FT
49
+
50
+
51
+ def test_check_price() -> None:
52
+ # https://azure.microsoft.com/ja-jp/pricing/details/cognitive-services/openai-service/
53
+
54
+ # これからのモデル --------------------------------------------------------
55
+ assert AzureGPT4T_20240409.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
56
+ # Updated --------------------------------------------------------
57
+ # GPT3.5T
58
+ assert AzureGPT35T_0125.dollar_per_ktoken == APIPricing(input=0.0005, output=0.0015)
59
+ # GPT4
60
+ assert AzureGPT4O_20240513.dollar_per_ktoken == APIPricing(input=0.005, output=0.015)
61
+ assert AzureGPT4T_0125.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
62
+ assert AzureGPT4VT_1106.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
63
+ assert AzureGPT4T_1106.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
64
+ assert AzureGPT4_0613.dollar_per_ktoken == APIPricing(input=0.03, output=0.06)
65
+ assert AzureGPT432k_0613.dollar_per_ktoken == APIPricing(input=0.06, output=0.12)
66
+ # FT
67
+ assert AzureGPT35FT.dollar_per_ktoken == APIPricing(input=0.0005, output=0.0015)
68
+ # Legacy ---------------------------------------------------------
69
+ # AzureGPT35T_0301 なし
70
+ assert AzureGPT35T_0613.dollar_per_ktoken == APIPricing(input=0.0015, output=0.002)
71
+ assert AzureGPT35T16k_0613.dollar_per_ktoken == APIPricing(input=0.003, output=0.004)
72
+ assert AzureGPT35T_1106.dollar_per_ktoken == APIPricing(input=0.001, output=0.002)
73
+
74
+
75
+ def test_check_context_window() -> None:
76
+ # https://learn.microsoft.com/ja-jp/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-preview-models
77
+ assert AzureGPT4T_20240409.context_window == 128_000
78
+
79
+ assert AzureGPT4T_0125.context_window == 128_000
80
+ assert AzureGPT35T_0125.context_window == 16_385
81
+
82
+ assert AzureGPT4O_20240513.context_window == 128_000
83
+ assert AzureGPT4T_1106.context_window == 128_000
84
+ assert AzureGPT4VT_1106.context_window == 128_000
85
+ assert AzureGPT35T_1106.context_window == 16_385
86
+
87
+ assert AzureGPT35T_0613.context_window == 4_096
88
+ assert AzureGPT4_0613.context_window == 8_192
89
+ assert AzureGPT35T16k_0613.context_window == 16_385
90
+ assert AzureGPT432k_0613.context_window == 32_768
91
+
92
+ assert AzureGPT35FT.context_window == 4_096
test/llm/gpt/test_openai_llm.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from neollm.llm.gpt.azure_llm import (
2
+ # AzureGPT4_0613,
3
+ # AzureGPT4T_0125,
4
+ # AzureGPT4T_1106,
5
+ # AzureGPT4T_20240409,
6
+ # AzureGPT4VT_1106,
7
+ # AzureGPT35FT,
8
+ # AzureGPT35T16k_0613,
9
+ # AzureGPT35T_0125,
10
+ # AzureGPT35T_0613,
11
+ # AzureGPT35T_1106,
12
+ # AzureGPT432k_0613,
13
+ # )
14
+ # from neollm.types.info import APIPricing
15
+
16
+
17
+ # def test_check_price() -> None:
18
+ # # https://azure.microsoft.com/ja-jp/pricing/details/cognitive-services/openai-service/
19
+
20
+ # # これからのモデル --------------------------------------------------------
21
+ # assert AzureGPT4T_20240409.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
22
+ # # Updated --------------------------------------------------------
23
+ # # GPT3.5T
24
+ # assert AzureGPT35T_0125.dollar_per_ktoken == APIPricing(input=0.0005, output=0.0015)
25
+ # # GPT4
26
+ # assert AzureGPT4T_0125.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
27
+ # assert AzureGPT4VT_1106.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
28
+ # assert AzureGPT4T_1106.dollar_per_ktoken == APIPricing(input=0.01, output=0.03)
29
+ # assert AzureGPT4_0613.dollar_per_ktoken == APIPricing(input=0.03, output=0.06)
30
+ # assert AzureGPT432k_0613.dollar_per_ktoken == APIPricing(input=0.06, output=0.12)
31
+ # # FT
32
+ # assert AzureGPT35FT.dollar_per_ktoken == APIPricing(input=0.0005, output=0.0015)
33
+ # # Legacy ---------------------------------------------------------
34
+ # # AzureGPT35T_0301 なし
35
+ # assert AzureGPT35T_0613.dollar_per_ktoken == APIPricing(input=0.0015, output=0.002)
36
+ # assert AzureGPT35T16k_0613.dollar_per_ktoken == APIPricing(input=0.003, output=0.004)
37
+ # assert AzureGPT35T_1106.dollar_per_ktoken == APIPricing(input=0.001, output=0.002)
test/llm/platform.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from neollm.llm.platform import Platform
4
+
5
+
6
+ class TestPlatform:
7
+ def test_str(self) -> None:
8
+ assert Platform.AZURE == "azure" # type: ignore
9
+ assert Platform.OPENAI == "openai" # type: ignore
10
+ assert Platform.ANTHROPIC == "anthropic" # type: ignore
11
+ assert Platform.GCP == "gcp" # type: ignore
12
+
13
+ def test_init(self) -> None:
14
+ assert Platform("azure") == Platform.AZURE
15
+ assert Platform("openai") == Platform.OPENAI
16
+ assert Platform("anthropic") == Platform.ANTHROPIC
17
+ assert Platform("gcp") == Platform.GCP
18
+
19
+ assert Platform("Azure ") == Platform.AZURE
20
+ assert Platform(" OpenAI") == Platform.OPENAI
21
+ assert Platform("Anthropic ") == Platform.ANTHROPIC
22
+ assert Platform("GcP") == Platform.GCP
23
+
24
+ def test_from_string(self) -> None:
25
+ assert Platform.from_string("azure") == Platform.AZURE
26
+ assert Platform.from_string("openai") == Platform.OPENAI
27
+ assert Platform.from_string("anthropic") == Platform.ANTHROPIC
28
+ assert Platform.from_string("gcp") == Platform.GCP
29
+
30
+ def test_from_string_error(self) -> None:
31
+ with pytest.raises(ValueError):
32
+ Platform.from_string("error")