Spaces:
Sleeping
Sleeping
BREAKING: Merge 'expansive': 代码重构,支持本地model (#572)
Browse files重大新功能:
- 支持更多参数
- 支持ChatGLM
- 支持本地embedding
- 支持LLaMA本地模型
可能的问题:
- 移除了许多错误处理的代码,错误将更多地会在终端中体现
- 本地embedding对中文的支持不是很好
- .gitignore +3 -0
- ChuanhuChatbot.py +129 -89
- configs/ds_config_chatbot.json +17 -0
- modules/__init__.py +0 -0
- modules/base_model.py +519 -0
- modules/chat_func.py +0 -497
- modules/config.py +2 -0
- modules/llama_func.py +74 -47
- modules/models.py +568 -0
- modules/openai_func.py +0 -65
- modules/presets.py +46 -44
- modules/utils.py +22 -103
- requirements_advanced.txt +7 -0
.gitignore
CHANGED
@@ -133,7 +133,10 @@ dmypy.json
|
|
133 |
# Mac system file
|
134 |
**/.DS_Store
|
135 |
|
|
|
136 |
api_key.txt
|
137 |
config.json
|
138 |
auth.json
|
|
|
|
|
139 |
.idea
|
|
|
133 |
# Mac system file
|
134 |
**/.DS_Store
|
135 |
|
136 |
+
# 配置文件/模型文件
|
137 |
api_key.txt
|
138 |
config.json
|
139 |
auth.json
|
140 |
+
models/
|
141 |
+
lora/
|
142 |
.idea
|
ChuanhuChatbot.py
CHANGED
@@ -10,8 +10,7 @@ from modules.config import *
|
|
10 |
from modules.utils import *
|
11 |
from modules.presets import *
|
12 |
from modules.overwrites import *
|
13 |
-
from modules.
|
14 |
-
from modules.openai_func import get_usage
|
15 |
|
16 |
gr.Chatbot.postprocess = postprocess
|
17 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
@@ -21,16 +20,14 @@ with open("assets/custom.css", "r", encoding="utf-8") as f:
|
|
21 |
|
22 |
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
23 |
user_name = gr.State("")
|
24 |
-
history = gr.State([])
|
25 |
-
token_count = gr.State([])
|
26 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
27 |
-
user_api_key = gr.State(my_api_key)
|
28 |
user_question = gr.State("")
|
29 |
-
|
|
|
30 |
topic = gr.State("未命名对话历史记录")
|
31 |
|
32 |
with gr.Row():
|
33 |
-
gr.HTML(
|
34 |
status_display = gr.Markdown(get_geoip(), elem_id="status_display")
|
35 |
with gr.Row(elem_id="float_display"):
|
36 |
user_info = gr.Markdown(value="getting user info...", elem_id="user_info")
|
@@ -64,11 +61,10 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
64 |
retryBtn = gr.Button("🔄 重新生成")
|
65 |
delFirstBtn = gr.Button("🗑️ 删除最旧对话")
|
66 |
delLastBtn = gr.Button("🗑️ 删除最新对话")
|
67 |
-
reduceTokenBtn = gr.Button("♻️ 总结对话")
|
68 |
|
69 |
with gr.Column():
|
70 |
with gr.Column(min_width=50, scale=1):
|
71 |
-
with gr.Tab(label="
|
72 |
keyTxt = gr.Textbox(
|
73 |
show_label=True,
|
74 |
placeholder=f"OpenAI API-key...",
|
@@ -82,10 +78,13 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
82 |
else:
|
83 |
usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display", elem_classes="insert_block")
|
84 |
model_select_dropdown = gr.Dropdown(
|
85 |
-
label="选择模型", choices=MODELS, multiselect=False, value=MODELS[
|
|
|
|
|
|
|
86 |
)
|
87 |
use_streaming_checkbox = gr.Checkbox(
|
88 |
-
label="实时传输回答", value=True, visible=
|
89 |
)
|
90 |
use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
|
91 |
language_select_dropdown = gr.Dropdown(
|
@@ -94,7 +93,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
94 |
multiselect=False,
|
95 |
value=REPLY_LANGUAGES[0],
|
96 |
)
|
97 |
-
index_files = gr.Files(label="上传索引文件", type="file"
|
98 |
two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
|
99 |
# TODO: 公式ocr
|
100 |
# formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
|
@@ -104,7 +103,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
104 |
show_label=True,
|
105 |
placeholder=f"在这里输入System Prompt...",
|
106 |
label="System prompt",
|
107 |
-
value=
|
108 |
lines=10,
|
109 |
).style(container=False)
|
110 |
with gr.Accordion(label="加载Prompt模板", open=True):
|
@@ -160,24 +159,84 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
160 |
|
161 |
with gr.Tab(label="高级"):
|
162 |
gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
|
163 |
-
|
164 |
-
gr.HTML(appearance_switcher, elem_classes="insert_block")
|
165 |
with gr.Accordion("参数", open=False):
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
minimum=-0,
|
168 |
maximum=1.0,
|
169 |
value=1.0,
|
170 |
step=0.05,
|
171 |
interactive=True,
|
172 |
-
label="
|
173 |
)
|
174 |
-
|
175 |
-
minimum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
maximum=2.0,
|
177 |
-
value=
|
178 |
-
step=0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
interactive=True,
|
180 |
-
label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
)
|
182 |
|
183 |
with gr.Accordion("网络设置", open=False):
|
@@ -198,27 +257,21 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
198 |
lines=2,
|
199 |
)
|
200 |
changeProxyBtn = gr.Button("🔄 设置代理地址")
|
|
|
201 |
|
202 |
-
gr.Markdown(
|
203 |
-
gr.HTML(
|
204 |
chatgpt_predict_args = dict(
|
205 |
-
fn=predict,
|
206 |
inputs=[
|
207 |
-
user_api_key,
|
208 |
-
systemPromptTxt,
|
209 |
-
history,
|
210 |
user_question,
|
211 |
chatbot,
|
212 |
-
token_count,
|
213 |
-
top_p,
|
214 |
-
temperature,
|
215 |
use_streaming_checkbox,
|
216 |
-
model_select_dropdown,
|
217 |
use_websearch_checkbox,
|
218 |
index_files,
|
219 |
language_select_dropdown,
|
220 |
],
|
221 |
-
outputs=[chatbot,
|
222 |
show_progress=True,
|
223 |
)
|
224 |
|
@@ -242,12 +295,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
242 |
)
|
243 |
|
244 |
get_usage_args = dict(
|
245 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
)
|
247 |
|
248 |
|
249 |
# Chatbot
|
250 |
-
cancelBtn.click(
|
251 |
|
252 |
user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
253 |
user_input.submit(**get_usage_args)
|
@@ -256,70 +315,49 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
256 |
submitBtn.click(**get_usage_args)
|
257 |
|
258 |
emptyBtn.click(
|
259 |
-
|
260 |
-
outputs=[chatbot,
|
261 |
show_progress=True,
|
262 |
)
|
263 |
emptyBtn.click(**reset_textbox_args)
|
264 |
|
265 |
retryBtn.click(**start_outputing_args).then(
|
266 |
-
retry,
|
267 |
[
|
268 |
-
user_api_key,
|
269 |
-
systemPromptTxt,
|
270 |
-
history,
|
271 |
chatbot,
|
272 |
-
token_count,
|
273 |
-
top_p,
|
274 |
-
temperature,
|
275 |
use_streaming_checkbox,
|
276 |
-
|
|
|
277 |
language_select_dropdown,
|
278 |
],
|
279 |
-
[chatbot,
|
280 |
show_progress=True,
|
281 |
).then(**end_outputing_args)
|
282 |
retryBtn.click(**get_usage_args)
|
283 |
|
284 |
delFirstBtn.click(
|
285 |
-
delete_first_conversation,
|
286 |
-
|
287 |
-
[
|
288 |
)
|
289 |
|
290 |
delLastBtn.click(
|
291 |
-
delete_last_conversation,
|
292 |
-
[chatbot
|
293 |
-
[chatbot,
|
294 |
-
show_progress=
|
295 |
-
)
|
296 |
-
|
297 |
-
reduceTokenBtn.click(
|
298 |
-
reduce_token_size,
|
299 |
-
[
|
300 |
-
user_api_key,
|
301 |
-
systemPromptTxt,
|
302 |
-
history,
|
303 |
-
chatbot,
|
304 |
-
token_count,
|
305 |
-
top_p,
|
306 |
-
temperature,
|
307 |
-
gr.State(sum(token_count.value[-4:])),
|
308 |
-
model_select_dropdown,
|
309 |
-
language_select_dropdown,
|
310 |
-
],
|
311 |
-
[chatbot, history, status_display, token_count],
|
312 |
-
show_progress=True,
|
313 |
)
|
314 |
-
reduceTokenBtn.click(**get_usage_args)
|
315 |
|
316 |
two_column.change(update_doc_config, [two_column], None)
|
317 |
|
318 |
-
#
|
319 |
-
keyTxt.change(
|
320 |
keyTxt.submit(**get_usage_args)
|
|
|
|
|
321 |
|
322 |
# Template
|
|
|
323 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
324 |
templateFileSelectDropdown.change(
|
325 |
load_template,
|
@@ -336,32 +374,34 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
336 |
|
337 |
# S&L
|
338 |
saveHistoryBtn.click(
|
339 |
-
save_chat_history,
|
340 |
-
[saveFileName,
|
341 |
downloadFile,
|
342 |
show_progress=True,
|
343 |
)
|
344 |
saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
345 |
exportMarkdownBtn.click(
|
346 |
-
export_markdown,
|
347 |
-
[saveFileName,
|
348 |
downloadFile,
|
349 |
show_progress=True,
|
350 |
)
|
351 |
historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
352 |
-
historyFileSelectDropdown.change(
|
353 |
-
|
354 |
-
[historyFileSelectDropdown, systemPromptTxt, history, chatbot, user_name],
|
355 |
-
[saveFileName, systemPromptTxt, history, chatbot],
|
356 |
-
show_progress=True,
|
357 |
-
)
|
358 |
-
downloadFile.change(
|
359 |
-
load_chat_history,
|
360 |
-
[downloadFile, systemPromptTxt, history, chatbot, user_name],
|
361 |
-
[saveFileName, systemPromptTxt, history, chatbot],
|
362 |
-
)
|
363 |
|
364 |
# Advanced
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
default_btn.click(
|
366 |
reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
|
367 |
)
|
|
|
10 |
from modules.utils import *
|
11 |
from modules.presets import *
|
12 |
from modules.overwrites import *
|
13 |
+
from modules.models import ModelManager
|
|
|
14 |
|
15 |
gr.Chatbot.postprocess = postprocess
|
16 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
|
|
20 |
|
21 |
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
22 |
user_name = gr.State("")
|
|
|
|
|
23 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
|
|
24 |
user_question = gr.State("")
|
25 |
+
current_model = gr.State(ModelManager(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key))
|
26 |
+
|
27 |
topic = gr.State("未命名对话历史记录")
|
28 |
|
29 |
with gr.Row():
|
30 |
+
gr.HTML(CHUANHU_TITLE, elem_id="app_title")
|
31 |
status_display = gr.Markdown(get_geoip(), elem_id="status_display")
|
32 |
with gr.Row(elem_id="float_display"):
|
33 |
user_info = gr.Markdown(value="getting user info...", elem_id="user_info")
|
|
|
61 |
retryBtn = gr.Button("🔄 重新生成")
|
62 |
delFirstBtn = gr.Button("🗑️ 删除最旧对话")
|
63 |
delLastBtn = gr.Button("🗑️ 删除最新对话")
|
|
|
64 |
|
65 |
with gr.Column():
|
66 |
with gr.Column(min_width=50, scale=1):
|
67 |
+
with gr.Tab(label="模型"):
|
68 |
keyTxt = gr.Textbox(
|
69 |
show_label=True,
|
70 |
placeholder=f"OpenAI API-key...",
|
|
|
78 |
else:
|
79 |
usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display", elem_classes="insert_block")
|
80 |
model_select_dropdown = gr.Dropdown(
|
81 |
+
label="选择模型", choices=MODELS, multiselect=False, value=MODELS[DEFAULT_MODEL], interactive=True
|
82 |
+
)
|
83 |
+
lora_select_dropdown = gr.Dropdown(
|
84 |
+
label="选择LoRA模型", choices=[], multiselect=False, interactive=True, visible=False
|
85 |
)
|
86 |
use_streaming_checkbox = gr.Checkbox(
|
87 |
+
label="实时传输回答", value=True, visible=ENABLE_STREAMING_OPTION
|
88 |
)
|
89 |
use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
|
90 |
language_select_dropdown = gr.Dropdown(
|
|
|
93 |
multiselect=False,
|
94 |
value=REPLY_LANGUAGES[0],
|
95 |
)
|
96 |
+
index_files = gr.Files(label="上传索引文件", type="file")
|
97 |
two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
|
98 |
# TODO: 公式ocr
|
99 |
# formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
|
|
|
103 |
show_label=True,
|
104 |
placeholder=f"在这里输入System Prompt...",
|
105 |
label="System prompt",
|
106 |
+
value=INITIAL_SYSTEM_PROMPT,
|
107 |
lines=10,
|
108 |
).style(container=False)
|
109 |
with gr.Accordion(label="加载Prompt模板", open=True):
|
|
|
159 |
|
160 |
with gr.Tab(label="高级"):
|
161 |
gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
|
162 |
+
gr.HTML(APPEARANCE_SWITCHER, elem_classes="insert_block")
|
|
|
163 |
with gr.Accordion("参数", open=False):
|
164 |
+
temperature_slider = gr.Slider(
|
165 |
+
minimum=-0,
|
166 |
+
maximum=2.0,
|
167 |
+
value=1.0,
|
168 |
+
step=0.1,
|
169 |
+
interactive=True,
|
170 |
+
label="temperature",
|
171 |
+
)
|
172 |
+
top_p_slider = gr.Slider(
|
173 |
minimum=-0,
|
174 |
maximum=1.0,
|
175 |
value=1.0,
|
176 |
step=0.05,
|
177 |
interactive=True,
|
178 |
+
label="top-p",
|
179 |
)
|
180 |
+
n_choices_slider = gr.Slider(
|
181 |
+
minimum=1,
|
182 |
+
maximum=10,
|
183 |
+
value=1,
|
184 |
+
step=1,
|
185 |
+
interactive=True,
|
186 |
+
label="n choices",
|
187 |
+
)
|
188 |
+
stop_sequence_txt = gr.Textbox(
|
189 |
+
show_label=True,
|
190 |
+
placeholder=f"在这里输入停止符,用英文逗号隔开...",
|
191 |
+
label="stop",
|
192 |
+
value="",
|
193 |
+
lines=1,
|
194 |
+
)
|
195 |
+
max_context_length_slider = gr.Slider(
|
196 |
+
minimum=1,
|
197 |
+
maximum=32768,
|
198 |
+
value=2000,
|
199 |
+
step=1,
|
200 |
+
interactive=True,
|
201 |
+
label="max context",
|
202 |
+
)
|
203 |
+
max_generation_slider = gr.Slider(
|
204 |
+
minimum=1,
|
205 |
+
maximum=32768,
|
206 |
+
value=1000,
|
207 |
+
step=1,
|
208 |
+
interactive=True,
|
209 |
+
label="max generations",
|
210 |
+
)
|
211 |
+
presence_penalty_slider = gr.Slider(
|
212 |
+
minimum=-2.0,
|
213 |
maximum=2.0,
|
214 |
+
value=0.0,
|
215 |
+
step=0.01,
|
216 |
+
interactive=True,
|
217 |
+
label="presence penalty",
|
218 |
+
)
|
219 |
+
frequency_penalty_slider = gr.Slider(
|
220 |
+
minimum=-2.0,
|
221 |
+
maximum=2.0,
|
222 |
+
value=0.0,
|
223 |
+
step=0.01,
|
224 |
interactive=True,
|
225 |
+
label="frequency penalty",
|
226 |
+
)
|
227 |
+
logit_bias_txt = gr.Textbox(
|
228 |
+
show_label=True,
|
229 |
+
placeholder=f"word:likelihood",
|
230 |
+
label="logit bias",
|
231 |
+
value="",
|
232 |
+
lines=1,
|
233 |
+
)
|
234 |
+
user_identifier_txt = gr.Textbox(
|
235 |
+
show_label=True,
|
236 |
+
placeholder=f"用于定位滥用行为",
|
237 |
+
label="用户名",
|
238 |
+
value=user_name.value,
|
239 |
+
lines=1,
|
240 |
)
|
241 |
|
242 |
with gr.Accordion("网络设置", open=False):
|
|
|
257 |
lines=2,
|
258 |
)
|
259 |
changeProxyBtn = gr.Button("🔄 设置代理地址")
|
260 |
+
default_btn = gr.Button("🔙 恢复默认设置")
|
261 |
|
262 |
+
gr.Markdown(CHUANHU_DESCRIPTION)
|
263 |
+
gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
|
264 |
chatgpt_predict_args = dict(
|
265 |
+
fn=current_model.value.predict,
|
266 |
inputs=[
|
|
|
|
|
|
|
267 |
user_question,
|
268 |
chatbot,
|
|
|
|
|
|
|
269 |
use_streaming_checkbox,
|
|
|
270 |
use_websearch_checkbox,
|
271 |
index_files,
|
272 |
language_select_dropdown,
|
273 |
],
|
274 |
+
outputs=[chatbot, status_display],
|
275 |
show_progress=True,
|
276 |
)
|
277 |
|
|
|
295 |
)
|
296 |
|
297 |
get_usage_args = dict(
|
298 |
+
fn=current_model.value.billing_info, inputs=None, outputs=[usageTxt], show_progress=False
|
299 |
+
)
|
300 |
+
|
301 |
+
load_history_from_file_args = dict(
|
302 |
+
fn=current_model.value.load_chat_history,
|
303 |
+
inputs=[historyFileSelectDropdown, chatbot, user_name],
|
304 |
+
outputs=[saveFileName, systemPromptTxt, chatbot]
|
305 |
)
|
306 |
|
307 |
|
308 |
# Chatbot
|
309 |
+
cancelBtn.click(current_model.value.interrupt, [], [])
|
310 |
|
311 |
user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
312 |
user_input.submit(**get_usage_args)
|
|
|
315 |
submitBtn.click(**get_usage_args)
|
316 |
|
317 |
emptyBtn.click(
|
318 |
+
current_model.value.reset,
|
319 |
+
outputs=[chatbot, status_display],
|
320 |
show_progress=True,
|
321 |
)
|
322 |
emptyBtn.click(**reset_textbox_args)
|
323 |
|
324 |
retryBtn.click(**start_outputing_args).then(
|
325 |
+
current_model.value.retry,
|
326 |
[
|
|
|
|
|
|
|
327 |
chatbot,
|
|
|
|
|
|
|
328 |
use_streaming_checkbox,
|
329 |
+
use_websearch_checkbox,
|
330 |
+
index_files,
|
331 |
language_select_dropdown,
|
332 |
],
|
333 |
+
[chatbot, status_display],
|
334 |
show_progress=True,
|
335 |
).then(**end_outputing_args)
|
336 |
retryBtn.click(**get_usage_args)
|
337 |
|
338 |
delFirstBtn.click(
|
339 |
+
current_model.value.delete_first_conversation,
|
340 |
+
None,
|
341 |
+
[status_display],
|
342 |
)
|
343 |
|
344 |
delLastBtn.click(
|
345 |
+
current_model.value.delete_last_conversation,
|
346 |
+
[chatbot],
|
347 |
+
[chatbot, status_display],
|
348 |
+
show_progress=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
)
|
|
|
350 |
|
351 |
two_column.change(update_doc_config, [two_column], None)
|
352 |
|
353 |
+
# LLM Models
|
354 |
+
keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
|
355 |
keyTxt.submit(**get_usage_args)
|
356 |
+
model_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display, lora_select_dropdown], show_progress=True)
|
357 |
+
lora_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
|
358 |
|
359 |
# Template
|
360 |
+
systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
|
361 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
362 |
templateFileSelectDropdown.change(
|
363 |
load_template,
|
|
|
374 |
|
375 |
# S&L
|
376 |
saveHistoryBtn.click(
|
377 |
+
current_model.value.save_chat_history,
|
378 |
+
[saveFileName, chatbot, user_name],
|
379 |
downloadFile,
|
380 |
show_progress=True,
|
381 |
)
|
382 |
saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
383 |
exportMarkdownBtn.click(
|
384 |
+
current_model.value.export_markdown,
|
385 |
+
[saveFileName, chatbot, user_name],
|
386 |
downloadFile,
|
387 |
show_progress=True,
|
388 |
)
|
389 |
historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
390 |
+
historyFileSelectDropdown.change(**load_history_from_file_args)
|
391 |
+
downloadFile.change(**load_history_from_file_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
|
393 |
# Advanced
|
394 |
+
max_context_length_slider.change(current_model.value.set_token_upper_limit, [max_context_length_slider], None)
|
395 |
+
temperature_slider.change(current_model.value.set_temperature, [temperature_slider], None)
|
396 |
+
top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
|
397 |
+
n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
|
398 |
+
stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
|
399 |
+
max_generation_slider.change(current_model.value.set_max_tokens, [max_generation_slider], None)
|
400 |
+
presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
|
401 |
+
frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
|
402 |
+
logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
|
403 |
+
user_identifier_txt.change(current_model.value.set_user_identifier, [user_identifier_txt], None)
|
404 |
+
|
405 |
default_btn.click(
|
406 |
reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
|
407 |
)
|
configs/ds_config_chatbot.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": false
|
4 |
+
},
|
5 |
+
"bf16": {
|
6 |
+
"enabled": true
|
7 |
+
},
|
8 |
+
"comms_logger": {
|
9 |
+
"enabled": false,
|
10 |
+
"verbose": false,
|
11 |
+
"prof_all": false,
|
12 |
+
"debug": false
|
13 |
+
},
|
14 |
+
"steps_per_print": 20000000000000000,
|
15 |
+
"train_micro_batch_size_per_gpu": 1,
|
16 |
+
"wall_clock_breakdown": false
|
17 |
+
}
|
modules/__init__.py
ADDED
File without changes
|
modules/base_model.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from typing import TYPE_CHECKING, List
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
import commentjson as cjson
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import requests
|
10 |
+
import urllib3
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
import colorama
|
14 |
+
from duckduckgo_search import ddg
|
15 |
+
import asyncio
|
16 |
+
import aiohttp
|
17 |
+
from enum import Enum
|
18 |
+
|
19 |
+
from .presets import *
|
20 |
+
from .llama_func import *
|
21 |
+
from .utils import *
|
22 |
+
from . import shared
|
23 |
+
from .config import retrieve_proxy
|
24 |
+
|
25 |
+
|
26 |
+
class ModelType(Enum):
|
27 |
+
Unknown = -1
|
28 |
+
OpenAI = 0
|
29 |
+
ChatGLM = 1
|
30 |
+
LLaMA = 2
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def get_type(cls, model_name: str):
|
34 |
+
model_type = None
|
35 |
+
model_name_lower = model_name.lower()
|
36 |
+
if "gpt" in model_name_lower:
|
37 |
+
model_type = ModelType.OpenAI
|
38 |
+
elif "chatglm" in model_name_lower:
|
39 |
+
model_type = ModelType.ChatGLM
|
40 |
+
elif "llama" in model_name_lower:
|
41 |
+
model_type = ModelType.LLaMA
|
42 |
+
else:
|
43 |
+
model_type = ModelType.Unknown
|
44 |
+
return model_type
|
45 |
+
|
46 |
+
|
47 |
+
class BaseLLMModel:
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
model_name,
|
51 |
+
system_prompt="",
|
52 |
+
temperature=1.0,
|
53 |
+
top_p=1.0,
|
54 |
+
n_choices=1,
|
55 |
+
stop=None,
|
56 |
+
max_generation_token=None,
|
57 |
+
presence_penalty=0,
|
58 |
+
frequency_penalty=0,
|
59 |
+
logit_bias=None,
|
60 |
+
user="",
|
61 |
+
) -> None:
|
62 |
+
self.history = []
|
63 |
+
self.all_token_counts = []
|
64 |
+
self.model_name = model_name
|
65 |
+
self.model_type = ModelType.get_type(model_name)
|
66 |
+
try:
|
67 |
+
self.token_upper_limit = MODEL_TOKEN_LIMIT[model_name]
|
68 |
+
except KeyError:
|
69 |
+
self.token_upper_limit = DEFAULT_TOKEN_LIMIT
|
70 |
+
self.interrupted = False
|
71 |
+
self.system_prompt = system_prompt
|
72 |
+
self.api_key = None
|
73 |
+
self.need_api_key = False
|
74 |
+
|
75 |
+
self.temperature = temperature
|
76 |
+
self.top_p = top_p
|
77 |
+
self.n_choices = n_choices
|
78 |
+
self.stop_sequence = stop
|
79 |
+
self.max_generation_token = None
|
80 |
+
self.presence_penalty = presence_penalty
|
81 |
+
self.frequency_penalty = frequency_penalty
|
82 |
+
self.logit_bias = logit_bias
|
83 |
+
self.user_identifier = user
|
84 |
+
|
85 |
+
def get_answer_stream_iter(self):
|
86 |
+
"""stream predict, need to be implemented
|
87 |
+
conversations are stored in self.history, with the most recent question, in OpenAI format
|
88 |
+
should return a generator, each time give the next word (str) in the answer
|
89 |
+
"""
|
90 |
+
logging.warning("stream predict not implemented, using at once predict instead")
|
91 |
+
response, _ = self.get_answer_at_once()
|
92 |
+
yield response
|
93 |
+
|
94 |
+
def get_answer_at_once(self):
|
95 |
+
"""predict at once, need to be implemented
|
96 |
+
conversations are stored in self.history, with the most recent question, in OpenAI format
|
97 |
+
Should return:
|
98 |
+
the answer (str)
|
99 |
+
total token count (int)
|
100 |
+
"""
|
101 |
+
logging.warning("at once predict not implemented, using stream predict instead")
|
102 |
+
response_iter = self.get_answer_stream_iter()
|
103 |
+
count = 0
|
104 |
+
for response in response_iter:
|
105 |
+
count += 1
|
106 |
+
return response, sum(self.all_token_counts) + count
|
107 |
+
|
108 |
+
def billing_info(self):
|
109 |
+
"""get billing infomation, inplement if needed"""
|
110 |
+
logging.warning("billing info not implemented, using default")
|
111 |
+
return BILLING_NOT_APPLICABLE_MSG
|
112 |
+
|
113 |
+
def count_token(self, user_input):
|
114 |
+
"""get token count from input, implement if needed"""
|
115 |
+
logging.warning("token count not implemented, using default")
|
116 |
+
return len(user_input)
|
117 |
+
|
118 |
+
def stream_next_chatbot(self, inputs, chatbot, fake_input=None, display_append=""):
|
119 |
+
def get_return_value():
|
120 |
+
return chatbot, status_text
|
121 |
+
|
122 |
+
status_text = "开始实时传输回答……"
|
123 |
+
if fake_input:
|
124 |
+
chatbot.append((fake_input, ""))
|
125 |
+
else:
|
126 |
+
chatbot.append((inputs, ""))
|
127 |
+
|
128 |
+
user_token_count = self.count_token(inputs)
|
129 |
+
self.all_token_counts.append(user_token_count)
|
130 |
+
logging.debug(f"输入token计数: {user_token_count}")
|
131 |
+
|
132 |
+
stream_iter = self.get_answer_stream_iter()
|
133 |
+
|
134 |
+
for partial_text in stream_iter:
|
135 |
+
chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
|
136 |
+
self.all_token_counts[-1] += 1
|
137 |
+
status_text = self.token_message()
|
138 |
+
yield get_return_value()
|
139 |
+
if self.interrupted:
|
140 |
+
self.recover()
|
141 |
+
break
|
142 |
+
self.history.append(construct_assistant(partial_text))
|
143 |
+
|
144 |
+
def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
|
145 |
+
if fake_input:
|
146 |
+
chatbot.append((fake_input, ""))
|
147 |
+
else:
|
148 |
+
chatbot.append((inputs, ""))
|
149 |
+
if fake_input is not None:
|
150 |
+
user_token_count = self.count_token(fake_input)
|
151 |
+
else:
|
152 |
+
user_token_count = self.count_token(inputs)
|
153 |
+
self.all_token_counts.append(user_token_count)
|
154 |
+
ai_reply, total_token_count = self.get_answer_at_once()
|
155 |
+
self.history.append(construct_assistant(ai_reply))
|
156 |
+
if fake_input is not None:
|
157 |
+
self.history[-2] = construct_user(fake_input)
|
158 |
+
chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
|
159 |
+
if fake_input is not None:
|
160 |
+
self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
|
161 |
+
else:
|
162 |
+
self.all_token_counts[-1] = total_token_count - sum(self.all_token_counts)
|
163 |
+
status_text = self.token_message()
|
164 |
+
return chatbot, status_text
|
165 |
+
|
166 |
+
def predict(
|
167 |
+
self,
|
168 |
+
inputs,
|
169 |
+
chatbot,
|
170 |
+
stream=False,
|
171 |
+
use_websearch=False,
|
172 |
+
files=None,
|
173 |
+
reply_language="中文",
|
174 |
+
should_check_token_count=True,
|
175 |
+
): # repetition_penalty, top_k
|
176 |
+
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
|
177 |
+
from llama_index.indices.query.schema import QueryBundle
|
178 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
179 |
+
from langchain.chat_models import ChatOpenAI
|
180 |
+
from llama_index import (
|
181 |
+
GPTSimpleVectorIndex,
|
182 |
+
ServiceContext,
|
183 |
+
LangchainEmbedding,
|
184 |
+
OpenAIEmbedding,
|
185 |
+
)
|
186 |
+
|
187 |
+
logging.info(
|
188 |
+
"输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
|
189 |
+
)
|
190 |
+
if should_check_token_count:
|
191 |
+
yield chatbot + [(inputs, "")], "开始生成回答……"
|
192 |
+
if reply_language == "跟随问题语言(不稳定)":
|
193 |
+
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
194 |
+
old_inputs = None
|
195 |
+
display_reference = []
|
196 |
+
limited_context = False
|
197 |
+
if files:
|
198 |
+
limited_context = True
|
199 |
+
old_inputs = inputs
|
200 |
+
msg = "加载索引中……(这可能需要几分钟)"
|
201 |
+
logging.info(msg)
|
202 |
+
yield chatbot + [(inputs, "")], msg
|
203 |
+
index = construct_index(self.api_key, file_src=files)
|
204 |
+
assert index is not None, "索引构建失败"
|
205 |
+
msg = "索引构建完成,获取回答中……"
|
206 |
+
if local_embedding:
|
207 |
+
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
|
208 |
+
else:
|
209 |
+
embed_model = OpenAIEmbedding()
|
210 |
+
logging.info(msg)
|
211 |
+
yield chatbot + [(inputs, "")], msg
|
212 |
+
with retrieve_proxy():
|
213 |
+
prompt_helper = PromptHelper(
|
214 |
+
max_input_size=4096,
|
215 |
+
num_output=5,
|
216 |
+
max_chunk_overlap=20,
|
217 |
+
chunk_size_limit=600,
|
218 |
+
)
|
219 |
+
from llama_index import ServiceContext
|
220 |
+
|
221 |
+
service_context = ServiceContext.from_defaults(
|
222 |
+
prompt_helper=prompt_helper, embed_model=embed_model
|
223 |
+
)
|
224 |
+
query_object = GPTVectorStoreIndexQuery(
|
225 |
+
index.index_struct,
|
226 |
+
service_context=service_context,
|
227 |
+
similarity_top_k=5,
|
228 |
+
vector_store=index._vector_store,
|
229 |
+
docstore=index._docstore,
|
230 |
+
)
|
231 |
+
query_bundle = QueryBundle(inputs)
|
232 |
+
nodes = query_object.retrieve(query_bundle)
|
233 |
+
reference_results = [n.node.text for n in nodes]
|
234 |
+
reference_results = add_source_numbers(reference_results, use_source=False)
|
235 |
+
display_reference = add_details(reference_results)
|
236 |
+
display_reference = "\n\n" + "".join(display_reference)
|
237 |
+
inputs = (
|
238 |
+
replace_today(PROMPT_TEMPLATE)
|
239 |
+
.replace("{query_str}", inputs)
|
240 |
+
.replace("{context_str}", "\n\n".join(reference_results))
|
241 |
+
.replace("{reply_language}", reply_language)
|
242 |
+
)
|
243 |
+
elif use_websearch:
|
244 |
+
limited_context = True
|
245 |
+
search_results = ddg(inputs, max_results=5)
|
246 |
+
old_inputs = inputs
|
247 |
+
reference_results = []
|
248 |
+
for idx, result in enumerate(search_results):
|
249 |
+
logging.debug(f"搜索结果{idx + 1}:{result}")
|
250 |
+
domain_name = urllib3.util.parse_url(result["href"]).host
|
251 |
+
reference_results.append([result["body"], result["href"]])
|
252 |
+
display_reference.append(
|
253 |
+
f"{idx+1}. [{domain_name}]({result['href']})\n"
|
254 |
+
)
|
255 |
+
reference_results = add_source_numbers(reference_results)
|
256 |
+
display_reference = "\n\n" + "".join(display_reference)
|
257 |
+
inputs = (
|
258 |
+
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
|
259 |
+
.replace("{query}", inputs)
|
260 |
+
.replace("{web_results}", "\n\n".join(reference_results))
|
261 |
+
.replace("{reply_language}", reply_language)
|
262 |
+
)
|
263 |
+
else:
|
264 |
+
display_reference = ""
|
265 |
+
|
266 |
+
if (
|
267 |
+
self.need_api_key and
|
268 |
+
self.api_key is None
|
269 |
+
and not shared.state.multi_api_key
|
270 |
+
):
|
271 |
+
status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
|
272 |
+
logging.info(status_text)
|
273 |
+
chatbot.append((inputs, ""))
|
274 |
+
if len(self.history) == 0:
|
275 |
+
self.history.append(construct_user(inputs))
|
276 |
+
self.history.append("")
|
277 |
+
self.all_token_counts.append(0)
|
278 |
+
else:
|
279 |
+
self.history[-2] = construct_user(inputs)
|
280 |
+
yield chatbot + [(inputs, "")], status_text
|
281 |
+
return
|
282 |
+
elif len(inputs.strip()) == 0:
|
283 |
+
status_text = STANDARD_ERROR_MSG + NO_INPUT_MSG
|
284 |
+
logging.info(status_text)
|
285 |
+
yield chatbot + [(inputs, "")], status_text
|
286 |
+
return
|
287 |
+
|
288 |
+
self.history.append(construct_user(inputs))
|
289 |
+
|
290 |
+
try:
|
291 |
+
if stream:
|
292 |
+
logging.debug("使用流式传输")
|
293 |
+
iter = self.stream_next_chatbot(
|
294 |
+
inputs,
|
295 |
+
chatbot,
|
296 |
+
fake_input=old_inputs,
|
297 |
+
display_append=display_reference,
|
298 |
+
)
|
299 |
+
for chatbot, status_text in iter:
|
300 |
+
yield chatbot, status_text
|
301 |
+
else:
|
302 |
+
logging.debug("不使用流式传输")
|
303 |
+
chatbot, status_text = self.next_chatbot_at_once(
|
304 |
+
inputs,
|
305 |
+
chatbot,
|
306 |
+
fake_input=old_inputs,
|
307 |
+
display_append=display_reference,
|
308 |
+
)
|
309 |
+
yield chatbot, status_text
|
310 |
+
except Exception as e:
|
311 |
+
status_text = STANDARD_ERROR_MSG + str(e)
|
312 |
+
yield chatbot, status_text
|
313 |
+
|
314 |
+
if len(self.history) > 1 and self.history[-1]["content"] != inputs:
|
315 |
+
logging.info(
|
316 |
+
"回答为:"
|
317 |
+
+ colorama.Fore.BLUE
|
318 |
+
+ f"{self.history[-1]['content']}"
|
319 |
+
+ colorama.Style.RESET_ALL
|
320 |
+
)
|
321 |
+
|
322 |
+
if limited_context:
|
323 |
+
self.history = self.history[-4:]
|
324 |
+
self.all_token_counts = self.all_token_counts[-2:]
|
325 |
+
|
326 |
+
max_token = self.token_upper_limit - TOKEN_OFFSET
|
327 |
+
|
328 |
+
if sum(self.all_token_counts) > max_token and should_check_token_count:
|
329 |
+
count = 0
|
330 |
+
while (
|
331 |
+
sum(self.all_token_counts)
|
332 |
+
> self.token_upper_limit * REDUCE_TOKEN_FACTOR
|
333 |
+
and sum(self.all_token_counts) > 0
|
334 |
+
):
|
335 |
+
count += 1
|
336 |
+
del self.all_token_counts[0]
|
337 |
+
del self.history[:2]
|
338 |
+
logging.info(status_text)
|
339 |
+
status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
|
340 |
+
yield chatbot, status_text
|
341 |
+
|
342 |
+
def retry(
|
343 |
+
self,
|
344 |
+
chatbot,
|
345 |
+
stream=False,
|
346 |
+
use_websearch=False,
|
347 |
+
files=None,
|
348 |
+
reply_language="中文",
|
349 |
+
):
|
350 |
+
logging.debug("重试中……")
|
351 |
+
if len(self.history) == 0:
|
352 |
+
yield chatbot, f"{STANDARD_ERROR_MSG}上下文是空的"
|
353 |
+
return
|
354 |
+
|
355 |
+
inputs = self.history[-2]["content"]
|
356 |
+
del self.history[-2:]
|
357 |
+
self.all_token_counts.pop()
|
358 |
+
iter = self.predict(
|
359 |
+
inputs,
|
360 |
+
chatbot,
|
361 |
+
stream=stream,
|
362 |
+
use_websearch=use_websearch,
|
363 |
+
files=files,
|
364 |
+
reply_language=reply_language,
|
365 |
+
)
|
366 |
+
for x in iter:
|
367 |
+
yield x
|
368 |
+
logging.debug("重试完毕")
|
369 |
+
|
370 |
+
# def reduce_token_size(self, chatbot):
|
371 |
+
# logging.info("开始减少token数量……")
|
372 |
+
# chatbot, status_text = self.next_chatbot_at_once(
|
373 |
+
# summarize_prompt,
|
374 |
+
# chatbot
|
375 |
+
# )
|
376 |
+
# max_token_count = self.token_upper_limit * REDUCE_TOKEN_FACTOR
|
377 |
+
# num_chat = find_n(self.all_token_counts, max_token_count)
|
378 |
+
# logging.info(f"previous_token_count: {self.all_token_counts}, keeping {num_chat} chats")
|
379 |
+
# chatbot = chatbot[:-1]
|
380 |
+
# self.history = self.history[-2*num_chat:] if num_chat > 0 else []
|
381 |
+
# self.all_token_counts = self.all_token_counts[-num_chat:] if num_chat > 0 else []
|
382 |
+
# msg = f"保留了最近{num_chat}轮对话"
|
383 |
+
# logging.info(msg)
|
384 |
+
# logging.info("减少token数量完毕")
|
385 |
+
# return chatbot, msg + "," + self.token_message(self.all_token_counts if len(self.all_token_counts) > 0 else [0])
|
386 |
+
|
387 |
+
def interrupt(self):
|
388 |
+
self.interrupted = True
|
389 |
+
|
390 |
+
def recover(self):
|
391 |
+
self.interrupted = False
|
392 |
+
|
393 |
+
def set_token_upper_limit(self, new_upper_limit):
|
394 |
+
self.token_upper_limit = new_upper_limit
|
395 |
+
print(f"token上限设置为{new_upper_limit}")
|
396 |
+
|
397 |
+
def set_temperature(self, new_temperature):
|
398 |
+
self.temperature = new_temperature
|
399 |
+
|
400 |
+
def set_top_p(self, new_top_p):
|
401 |
+
self.top_p = new_top_p
|
402 |
+
|
403 |
+
def set_n_choices(self, new_n_choices):
|
404 |
+
self.n_choices = new_n_choices
|
405 |
+
|
406 |
+
def set_stop_sequence(self, new_stop_sequence: str):
|
407 |
+
new_stop_sequence = new_stop_sequence.split(",")
|
408 |
+
self.stop_sequence = new_stop_sequence
|
409 |
+
|
410 |
+
def set_max_tokens(self, new_max_tokens):
|
411 |
+
self.max_generation_token = new_max_tokens
|
412 |
+
|
413 |
+
def set_presence_penalty(self, new_presence_penalty):
|
414 |
+
self.presence_penalty = new_presence_penalty
|
415 |
+
|
416 |
+
def set_frequency_penalty(self, new_frequency_penalty):
|
417 |
+
self.frequency_penalty = new_frequency_penalty
|
418 |
+
|
419 |
+
def set_logit_bias(self, logit_bias):
|
420 |
+
logit_bias = logit_bias.split()
|
421 |
+
bias_map = {}
|
422 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
423 |
+
for line in logit_bias:
|
424 |
+
word, bias_amount = line.split(":")
|
425 |
+
if word:
|
426 |
+
for token in encoding.encode(word):
|
427 |
+
bias_map[token] = float(bias_amount)
|
428 |
+
self.logit_bias = bias_map
|
429 |
+
|
430 |
+
def set_user_identifier(self, new_user_identifier):
|
431 |
+
self.user_identifier = new_user_identifier
|
432 |
+
|
433 |
+
def set_system_prompt(self, new_system_prompt):
|
434 |
+
self.system_prompt = new_system_prompt
|
435 |
+
|
436 |
+
def set_key(self, new_access_key):
|
437 |
+
self.api_key = new_access_key.strip()
|
438 |
+
msg = f"API密钥更改为了{hide_middle_chars(self.api_key)}"
|
439 |
+
logging.info(msg)
|
440 |
+
return msg
|
441 |
+
|
442 |
+
def reset(self):
|
443 |
+
self.history = []
|
444 |
+
self.all_token_counts = []
|
445 |
+
self.interrupted = False
|
446 |
+
return [], self.token_message([0])
|
447 |
+
|
448 |
+
def delete_first_conversation(self):
|
449 |
+
if self.history:
|
450 |
+
del self.history[:2]
|
451 |
+
del self.all_token_counts[0]
|
452 |
+
return self.token_message()
|
453 |
+
|
454 |
+
def delete_last_conversation(self, chatbot):
|
455 |
+
if len(chatbot) > 0 and STANDARD_ERROR_MSG in chatbot[-1][1]:
|
456 |
+
msg = "由于包含报错信息,只删除chatbot记录"
|
457 |
+
chatbot.pop()
|
458 |
+
return chatbot, self.history
|
459 |
+
if len(self.history) > 0:
|
460 |
+
self.history.pop()
|
461 |
+
self.history.pop()
|
462 |
+
if len(chatbot) > 0:
|
463 |
+
msg = "删除了一组chatbot对话"
|
464 |
+
chatbot.pop()
|
465 |
+
if len(self.all_token_counts) > 0:
|
466 |
+
msg = "删除了一组对话的token计数记录"
|
467 |
+
self.all_token_counts.pop()
|
468 |
+
msg = "删除了一组对话"
|
469 |
+
return chatbot, msg
|
470 |
+
|
471 |
+
def token_message(self, token_lst=None):
|
472 |
+
if token_lst is None:
|
473 |
+
token_lst = self.all_token_counts
|
474 |
+
token_sum = 0
|
475 |
+
for i in range(len(token_lst)):
|
476 |
+
token_sum += sum(token_lst[: i + 1])
|
477 |
+
return f"Token 计数: {sum(token_lst)},本次对话累计消耗了 {token_sum} tokens"
|
478 |
+
|
479 |
+
def save_chat_history(self, filename, chatbot, user_name):
|
480 |
+
if filename == "":
|
481 |
+
return
|
482 |
+
if not filename.endswith(".json"):
|
483 |
+
filename += ".json"
|
484 |
+
return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
|
485 |
+
|
486 |
+
def export_markdown(self, filename, chatbot, user_name):
|
487 |
+
if filename == "":
|
488 |
+
return
|
489 |
+
if not filename.endswith(".md"):
|
490 |
+
filename += ".md"
|
491 |
+
return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
|
492 |
+
|
493 |
+
def load_chat_history(self, filename, chatbot, user_name):
|
494 |
+
logging.debug(f"{user_name} 加载对话历史中……")
|
495 |
+
if type(filename) != str:
|
496 |
+
filename = filename.name
|
497 |
+
try:
|
498 |
+
with open(os.path.join(HISTORY_DIR, user_name, filename), "r") as f:
|
499 |
+
json_s = json.load(f)
|
500 |
+
try:
|
501 |
+
if type(json_s["history"][0]) == str:
|
502 |
+
logging.info("历史记录格式为旧版,正在转换……")
|
503 |
+
new_history = []
|
504 |
+
for index, item in enumerate(json_s["history"]):
|
505 |
+
if index % 2 == 0:
|
506 |
+
new_history.append(construct_user(item))
|
507 |
+
else:
|
508 |
+
new_history.append(construct_assistant(item))
|
509 |
+
json_s["history"] = new_history
|
510 |
+
logging.info(new_history)
|
511 |
+
except:
|
512 |
+
# 没有对话历史
|
513 |
+
pass
|
514 |
+
logging.debug(f"{user_name} 加载对话历史完毕")
|
515 |
+
self.history = json_s["history"]
|
516 |
+
return filename, json_s["system"], json_s["chatbot"]
|
517 |
+
except FileNotFoundError:
|
518 |
+
logging.warning(f"{user_name} 没有找到对话历史文件,不执行任何操作")
|
519 |
+
return filename, self.system_prompt, chatbot
|
modules/chat_func.py
DELETED
@@ -1,497 +0,0 @@
|
|
1 |
-
# -*- coding:utf-8 -*-
|
2 |
-
from __future__ import annotations
|
3 |
-
from typing import TYPE_CHECKING, List
|
4 |
-
|
5 |
-
import logging
|
6 |
-
import json
|
7 |
-
import os
|
8 |
-
import requests
|
9 |
-
import urllib3
|
10 |
-
|
11 |
-
from tqdm import tqdm
|
12 |
-
import colorama
|
13 |
-
from duckduckgo_search import ddg
|
14 |
-
import asyncio
|
15 |
-
import aiohttp
|
16 |
-
|
17 |
-
|
18 |
-
from modules.presets import *
|
19 |
-
from modules.llama_func import *
|
20 |
-
from modules.utils import *
|
21 |
-
from . import shared
|
22 |
-
from modules.config import retrieve_proxy
|
23 |
-
|
24 |
-
# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
|
25 |
-
|
26 |
-
if TYPE_CHECKING:
|
27 |
-
from typing import TypedDict
|
28 |
-
|
29 |
-
class DataframeData(TypedDict):
|
30 |
-
headers: List[str]
|
31 |
-
data: List[List[str | int | bool]]
|
32 |
-
|
33 |
-
|
34 |
-
initial_prompt = "You are a helpful assistant."
|
35 |
-
HISTORY_DIR = "history"
|
36 |
-
TEMPLATES_DIR = "templates"
|
37 |
-
|
38 |
-
@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
|
39 |
-
def get_response(
|
40 |
-
openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
|
41 |
-
):
|
42 |
-
headers = {
|
43 |
-
"Content-Type": "application/json",
|
44 |
-
"Authorization": f"Bearer {openai_api_key}",
|
45 |
-
}
|
46 |
-
|
47 |
-
history = [construct_system(system_prompt), *history]
|
48 |
-
|
49 |
-
payload = {
|
50 |
-
"model": selected_model,
|
51 |
-
"messages": history, # [{"role": "user", "content": f"{inputs}"}],
|
52 |
-
"temperature": temperature, # 1.0,
|
53 |
-
"top_p": top_p, # 1.0,
|
54 |
-
"n": 1,
|
55 |
-
"stream": stream,
|
56 |
-
"presence_penalty": 0,
|
57 |
-
"frequency_penalty": 0,
|
58 |
-
}
|
59 |
-
if stream:
|
60 |
-
timeout = timeout_streaming
|
61 |
-
else:
|
62 |
-
timeout = timeout_all
|
63 |
-
|
64 |
-
|
65 |
-
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
|
66 |
-
if shared.state.completion_url != COMPLETION_URL:
|
67 |
-
logging.info(f"使用自定义API URL: {shared.state.completion_url}")
|
68 |
-
|
69 |
-
with retrieve_proxy():
|
70 |
-
response = requests.post(
|
71 |
-
shared.state.completion_url,
|
72 |
-
headers=headers,
|
73 |
-
json=payload,
|
74 |
-
stream=True,
|
75 |
-
timeout=timeout,
|
76 |
-
)
|
77 |
-
|
78 |
-
return response
|
79 |
-
|
80 |
-
|
81 |
-
def stream_predict(
|
82 |
-
openai_api_key,
|
83 |
-
system_prompt,
|
84 |
-
history,
|
85 |
-
inputs,
|
86 |
-
chatbot,
|
87 |
-
all_token_counts,
|
88 |
-
top_p,
|
89 |
-
temperature,
|
90 |
-
selected_model,
|
91 |
-
fake_input=None,
|
92 |
-
display_append=""
|
93 |
-
):
|
94 |
-
def get_return_value():
|
95 |
-
return chatbot, history, status_text, all_token_counts
|
96 |
-
|
97 |
-
logging.info("实时回答模式")
|
98 |
-
partial_words = ""
|
99 |
-
counter = 0
|
100 |
-
status_text = "开始实时传输回答……"
|
101 |
-
history.append(construct_user(inputs))
|
102 |
-
history.append(construct_assistant(""))
|
103 |
-
if fake_input:
|
104 |
-
chatbot.append((fake_input, ""))
|
105 |
-
else:
|
106 |
-
chatbot.append((inputs, ""))
|
107 |
-
user_token_count = 0
|
108 |
-
if fake_input is not None:
|
109 |
-
input_token_count = count_token(construct_user(fake_input))
|
110 |
-
else:
|
111 |
-
input_token_count = count_token(construct_user(inputs))
|
112 |
-
if len(all_token_counts) == 0:
|
113 |
-
system_prompt_token_count = count_token(construct_system(system_prompt))
|
114 |
-
user_token_count = (
|
115 |
-
input_token_count + system_prompt_token_count
|
116 |
-
)
|
117 |
-
else:
|
118 |
-
user_token_count = input_token_count
|
119 |
-
all_token_counts.append(user_token_count)
|
120 |
-
logging.info(f"输入token计数: {user_token_count}")
|
121 |
-
yield get_return_value()
|
122 |
-
try:
|
123 |
-
response = get_response(
|
124 |
-
openai_api_key,
|
125 |
-
system_prompt,
|
126 |
-
history,
|
127 |
-
temperature,
|
128 |
-
top_p,
|
129 |
-
True,
|
130 |
-
selected_model,
|
131 |
-
)
|
132 |
-
except requests.exceptions.ConnectTimeout:
|
133 |
-
status_text = (
|
134 |
-
standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
135 |
-
)
|
136 |
-
yield get_return_value()
|
137 |
-
return
|
138 |
-
except requests.exceptions.ReadTimeout:
|
139 |
-
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
|
140 |
-
yield get_return_value()
|
141 |
-
return
|
142 |
-
|
143 |
-
yield get_return_value()
|
144 |
-
error_json_str = ""
|
145 |
-
|
146 |
-
if fake_input is not None:
|
147 |
-
history[-2] = construct_user(fake_input)
|
148 |
-
for chunk in tqdm(response.iter_lines()):
|
149 |
-
if counter == 0:
|
150 |
-
counter += 1
|
151 |
-
continue
|
152 |
-
counter += 1
|
153 |
-
# check whether each line is non-empty
|
154 |
-
if chunk:
|
155 |
-
chunk = chunk.decode()
|
156 |
-
chunklength = len(chunk)
|
157 |
-
try:
|
158 |
-
chunk = json.loads(chunk[6:])
|
159 |
-
except json.JSONDecodeError:
|
160 |
-
logging.info(chunk)
|
161 |
-
error_json_str += chunk
|
162 |
-
status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
|
163 |
-
yield get_return_value()
|
164 |
-
continue
|
165 |
-
# decode each line as response data is in bytes
|
166 |
-
if chunklength > 6 and "delta" in chunk["choices"][0]:
|
167 |
-
finish_reason = chunk["choices"][0]["finish_reason"]
|
168 |
-
status_text = construct_token_message(all_token_counts)
|
169 |
-
if finish_reason == "stop":
|
170 |
-
yield get_return_value()
|
171 |
-
break
|
172 |
-
try:
|
173 |
-
partial_words = (
|
174 |
-
partial_words + chunk["choices"][0]["delta"]["content"]
|
175 |
-
)
|
176 |
-
except KeyError:
|
177 |
-
status_text = (
|
178 |
-
standard_error_msg
|
179 |
-
+ "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
|
180 |
-
+ str(sum(all_token_counts))
|
181 |
-
)
|
182 |
-
yield get_return_value()
|
183 |
-
break
|
184 |
-
history[-1] = construct_assistant(partial_words)
|
185 |
-
chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
|
186 |
-
all_token_counts[-1] += 1
|
187 |
-
yield get_return_value()
|
188 |
-
|
189 |
-
|
190 |
-
def predict_all(
|
191 |
-
openai_api_key,
|
192 |
-
system_prompt,
|
193 |
-
history,
|
194 |
-
inputs,
|
195 |
-
chatbot,
|
196 |
-
all_token_counts,
|
197 |
-
top_p,
|
198 |
-
temperature,
|
199 |
-
selected_model,
|
200 |
-
fake_input=None,
|
201 |
-
display_append=""
|
202 |
-
):
|
203 |
-
logging.info("一次性回答模式")
|
204 |
-
history.append(construct_user(inputs))
|
205 |
-
history.append(construct_assistant(""))
|
206 |
-
if fake_input:
|
207 |
-
chatbot.append((fake_input, ""))
|
208 |
-
else:
|
209 |
-
chatbot.append((inputs, ""))
|
210 |
-
if fake_input is not None:
|
211 |
-
all_token_counts.append(count_token(construct_user(fake_input)))
|
212 |
-
else:
|
213 |
-
all_token_counts.append(count_token(construct_user(inputs)))
|
214 |
-
try:
|
215 |
-
response = get_response(
|
216 |
-
openai_api_key,
|
217 |
-
system_prompt,
|
218 |
-
history,
|
219 |
-
temperature,
|
220 |
-
top_p,
|
221 |
-
False,
|
222 |
-
selected_model,
|
223 |
-
)
|
224 |
-
except requests.exceptions.ConnectTimeout:
|
225 |
-
status_text = (
|
226 |
-
standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
227 |
-
)
|
228 |
-
return chatbot, history, status_text, all_token_counts
|
229 |
-
except requests.exceptions.ProxyError:
|
230 |
-
status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
|
231 |
-
return chatbot, history, status_text, all_token_counts
|
232 |
-
except requests.exceptions.SSLError:
|
233 |
-
status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
|
234 |
-
return chatbot, history, status_text, all_token_counts
|
235 |
-
response = json.loads(response.text)
|
236 |
-
if fake_input is not None:
|
237 |
-
history[-2] = construct_user(fake_input)
|
238 |
-
try:
|
239 |
-
content = response["choices"][0]["message"]["content"]
|
240 |
-
history[-1] = construct_assistant(content)
|
241 |
-
chatbot[-1] = (chatbot[-1][0], content+display_append)
|
242 |
-
total_token_count = response["usage"]["total_tokens"]
|
243 |
-
if fake_input is not None:
|
244 |
-
all_token_counts[-1] += count_token(construct_assistant(content))
|
245 |
-
else:
|
246 |
-
all_token_counts[-1] = total_token_count - sum(all_token_counts)
|
247 |
-
status_text = construct_token_message([total_token_count])
|
248 |
-
return chatbot, history, status_text, all_token_counts
|
249 |
-
except KeyError:
|
250 |
-
status_text = standard_error_msg + str(response)
|
251 |
-
return chatbot, history, status_text, all_token_counts
|
252 |
-
|
253 |
-
|
254 |
-
def predict(
|
255 |
-
openai_api_key,
|
256 |
-
system_prompt,
|
257 |
-
history,
|
258 |
-
inputs,
|
259 |
-
chatbot,
|
260 |
-
all_token_counts,
|
261 |
-
top_p,
|
262 |
-
temperature,
|
263 |
-
stream=False,
|
264 |
-
selected_model=MODELS[0],
|
265 |
-
use_websearch=False,
|
266 |
-
files = None,
|
267 |
-
reply_language="中文",
|
268 |
-
should_check_token_count=True,
|
269 |
-
): # repetition_penalty, top_k
|
270 |
-
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
|
271 |
-
from llama_index.indices.query.schema import QueryBundle
|
272 |
-
from langchain.llms import OpenAIChat
|
273 |
-
|
274 |
-
|
275 |
-
logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
|
276 |
-
if should_check_token_count:
|
277 |
-
yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
|
278 |
-
if reply_language == "跟随问题语言(不稳定)":
|
279 |
-
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
280 |
-
old_inputs = None
|
281 |
-
display_reference = []
|
282 |
-
limited_context = False
|
283 |
-
if files:
|
284 |
-
limited_context = True
|
285 |
-
old_inputs = inputs
|
286 |
-
msg = "加载索引中……(这可能需要几分钟)"
|
287 |
-
logging.info(msg)
|
288 |
-
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
289 |
-
index = construct_index(openai_api_key, file_src=files)
|
290 |
-
msg = "索引构建完成,获取回答中……"
|
291 |
-
logging.info(msg)
|
292 |
-
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
293 |
-
with retrieve_proxy():
|
294 |
-
llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
|
295 |
-
prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
|
296 |
-
from llama_index import ServiceContext
|
297 |
-
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
|
298 |
-
query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
|
299 |
-
query_bundle = QueryBundle(inputs)
|
300 |
-
nodes = query_object.retrieve(query_bundle)
|
301 |
-
reference_results = [n.node.text for n in nodes]
|
302 |
-
reference_results = add_source_numbers(reference_results, use_source=False)
|
303 |
-
display_reference = add_details(reference_results)
|
304 |
-
display_reference = "\n\n" + "".join(display_reference)
|
305 |
-
inputs = (
|
306 |
-
replace_today(PROMPT_TEMPLATE)
|
307 |
-
.replace("{query_str}", inputs)
|
308 |
-
.replace("{context_str}", "\n\n".join(reference_results))
|
309 |
-
.replace("{reply_language}", reply_language )
|
310 |
-
)
|
311 |
-
elif use_websearch:
|
312 |
-
limited_context = True
|
313 |
-
search_results = ddg(inputs, max_results=5)
|
314 |
-
old_inputs = inputs
|
315 |
-
reference_results = []
|
316 |
-
for idx, result in enumerate(search_results):
|
317 |
-
logging.info(f"搜索结果{idx + 1}:{result}")
|
318 |
-
domain_name = urllib3.util.parse_url(result["href"]).host
|
319 |
-
reference_results.append([result["body"], result["href"]])
|
320 |
-
display_reference.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
|
321 |
-
reference_results = add_source_numbers(reference_results)
|
322 |
-
display_reference = "\n\n" + "".join(display_reference)
|
323 |
-
inputs = (
|
324 |
-
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
|
325 |
-
.replace("{query}", inputs)
|
326 |
-
.replace("{web_results}", "\n\n".join(reference_results))
|
327 |
-
.replace("{reply_language}", reply_language )
|
328 |
-
)
|
329 |
-
else:
|
330 |
-
display_reference = ""
|
331 |
-
|
332 |
-
if len(openai_api_key) == 0 and not shared.state.multi_api_key:
|
333 |
-
status_text = standard_error_msg + no_apikey_msg
|
334 |
-
logging.info(status_text)
|
335 |
-
chatbot.append((inputs, ""))
|
336 |
-
if len(history) == 0:
|
337 |
-
history.append(construct_user(inputs))
|
338 |
-
history.append("")
|
339 |
-
all_token_counts.append(0)
|
340 |
-
else:
|
341 |
-
history[-2] = construct_user(inputs)
|
342 |
-
yield chatbot+[(inputs, "")], history, status_text, all_token_counts
|
343 |
-
return
|
344 |
-
elif len(inputs.strip()) == 0:
|
345 |
-
status_text = standard_error_msg + no_input_msg
|
346 |
-
logging.info(status_text)
|
347 |
-
yield chatbot+[(inputs, "")], history, status_text, all_token_counts
|
348 |
-
return
|
349 |
-
|
350 |
-
if stream:
|
351 |
-
logging.info("使用流式传输")
|
352 |
-
iter = stream_predict(
|
353 |
-
openai_api_key,
|
354 |
-
system_prompt,
|
355 |
-
history,
|
356 |
-
inputs,
|
357 |
-
chatbot,
|
358 |
-
all_token_counts,
|
359 |
-
top_p,
|
360 |
-
temperature,
|
361 |
-
selected_model,
|
362 |
-
fake_input=old_inputs,
|
363 |
-
display_append=display_reference
|
364 |
-
)
|
365 |
-
for chatbot, history, status_text, all_token_counts in iter:
|
366 |
-
if shared.state.interrupted:
|
367 |
-
shared.state.recover()
|
368 |
-
return
|
369 |
-
yield chatbot, history, status_text, all_token_counts
|
370 |
-
else:
|
371 |
-
logging.info("不使用流式传输")
|
372 |
-
chatbot, history, status_text, all_token_counts = predict_all(
|
373 |
-
openai_api_key,
|
374 |
-
system_prompt,
|
375 |
-
history,
|
376 |
-
inputs,
|
377 |
-
chatbot,
|
378 |
-
all_token_counts,
|
379 |
-
top_p,
|
380 |
-
temperature,
|
381 |
-
selected_model,
|
382 |
-
fake_input=old_inputs,
|
383 |
-
display_append=display_reference
|
384 |
-
)
|
385 |
-
yield chatbot, history, status_text, all_token_counts
|
386 |
-
|
387 |
-
logging.info(f"传输完毕。当前token计数为{all_token_counts}")
|
388 |
-
if len(history) > 1 and history[-1]["content"] != inputs:
|
389 |
-
logging.info(
|
390 |
-
"回答为:"
|
391 |
-
+ colorama.Fore.BLUE
|
392 |
-
+ f"{history[-1]['content']}"
|
393 |
-
+ colorama.Style.RESET_ALL
|
394 |
-
)
|
395 |
-
|
396 |
-
if limited_context:
|
397 |
-
history = history[-4:]
|
398 |
-
all_token_counts = all_token_counts[-2:]
|
399 |
-
yield chatbot, history, status_text, all_token_counts
|
400 |
-
|
401 |
-
if stream:
|
402 |
-
max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["streaming"]
|
403 |
-
else:
|
404 |
-
max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
|
405 |
-
|
406 |
-
if sum(all_token_counts) > max_token and should_check_token_count:
|
407 |
-
print(all_token_counts)
|
408 |
-
count = 0
|
409 |
-
while sum(all_token_counts) > max_token - 500 and sum(all_token_counts) > 0:
|
410 |
-
count += 1
|
411 |
-
del all_token_counts[0]
|
412 |
-
del history[:2]
|
413 |
-
logging.info(status_text)
|
414 |
-
status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
|
415 |
-
yield chatbot, history, status_text, all_token_counts
|
416 |
-
|
417 |
-
|
418 |
-
def retry(
|
419 |
-
openai_api_key,
|
420 |
-
system_prompt,
|
421 |
-
history,
|
422 |
-
chatbot,
|
423 |
-
token_count,
|
424 |
-
top_p,
|
425 |
-
temperature,
|
426 |
-
stream=False,
|
427 |
-
selected_model=MODELS[0],
|
428 |
-
reply_language="中文",
|
429 |
-
):
|
430 |
-
logging.info("重试中……")
|
431 |
-
if len(history) == 0:
|
432 |
-
yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
|
433 |
-
return
|
434 |
-
history.pop()
|
435 |
-
inputs = history.pop()["content"]
|
436 |
-
token_count.pop()
|
437 |
-
iter = predict(
|
438 |
-
openai_api_key,
|
439 |
-
system_prompt,
|
440 |
-
history,
|
441 |
-
inputs,
|
442 |
-
chatbot,
|
443 |
-
token_count,
|
444 |
-
top_p,
|
445 |
-
temperature,
|
446 |
-
stream=stream,
|
447 |
-
selected_model=selected_model,
|
448 |
-
reply_language=reply_language,
|
449 |
-
)
|
450 |
-
logging.info("重试中……")
|
451 |
-
for x in iter:
|
452 |
-
yield x
|
453 |
-
logging.info("重试完毕")
|
454 |
-
|
455 |
-
|
456 |
-
def reduce_token_size(
|
457 |
-
openai_api_key,
|
458 |
-
system_prompt,
|
459 |
-
history,
|
460 |
-
chatbot,
|
461 |
-
token_count,
|
462 |
-
top_p,
|
463 |
-
temperature,
|
464 |
-
max_token_count,
|
465 |
-
selected_model=MODELS[0],
|
466 |
-
reply_language="中文",
|
467 |
-
):
|
468 |
-
logging.info("开始减少token数量……")
|
469 |
-
iter = predict(
|
470 |
-
openai_api_key,
|
471 |
-
system_prompt,
|
472 |
-
history,
|
473 |
-
summarize_prompt,
|
474 |
-
chatbot,
|
475 |
-
token_count,
|
476 |
-
top_p,
|
477 |
-
temperature,
|
478 |
-
selected_model=selected_model,
|
479 |
-
should_check_token_count=False,
|
480 |
-
reply_language=reply_language,
|
481 |
-
)
|
482 |
-
logging.info(f"chatbot: {chatbot}")
|
483 |
-
flag = False
|
484 |
-
for chatbot, history, status_text, previous_token_count in iter:
|
485 |
-
num_chat = find_n(previous_token_count, max_token_count)
|
486 |
-
logging.info(f"previous_token_count: {previous_token_count}, keeping {num_chat} chats")
|
487 |
-
if flag:
|
488 |
-
chatbot = chatbot[:-1]
|
489 |
-
flag = True
|
490 |
-
history = history[-2*num_chat:] if num_chat > 0 else []
|
491 |
-
token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
|
492 |
-
msg = f"保留了最近{num_chat}轮对话"
|
493 |
-
yield chatbot, history, msg + "," + construct_token_message(
|
494 |
-
token_count if len(token_count) > 0 else [0],
|
495 |
-
), token_count
|
496 |
-
logging.info(msg)
|
497 |
-
logging.info("减少token数量完毕")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/config.py
CHANGED
@@ -117,6 +117,8 @@ https_proxy = os.environ.get("HTTPS_PROXY", https_proxy)
|
|
117 |
os.environ["HTTP_PROXY"] = ""
|
118 |
os.environ["HTTPS_PROXY"] = ""
|
119 |
|
|
|
|
|
120 |
@contextmanager
|
121 |
def retrieve_proxy(proxy=None):
|
122 |
"""
|
|
|
117 |
os.environ["HTTP_PROXY"] = ""
|
118 |
os.environ["HTTPS_PROXY"] = ""
|
119 |
|
120 |
+
local_embedding = config.get("local_embedding", False) # 是否使用本地embedding
|
121 |
+
|
122 |
@contextmanager
|
123 |
def retrieve_proxy(proxy=None):
|
124 |
"""
|
modules/llama_func.py
CHANGED
@@ -15,6 +15,8 @@ from tqdm import tqdm
|
|
15 |
|
16 |
from modules.presets import *
|
17 |
from modules.utils import *
|
|
|
|
|
18 |
|
19 |
def get_index_name(file_src):
|
20 |
file_paths = [x.name for x in file_src]
|
@@ -28,6 +30,7 @@ def get_index_name(file_src):
|
|
28 |
|
29 |
return md5_hash.hexdigest()
|
30 |
|
|
|
31 |
def block_split(text):
|
32 |
blocks = []
|
33 |
while len(text) > 0:
|
@@ -35,6 +38,7 @@ def block_split(text):
|
|
35 |
text = text[1000:]
|
36 |
return blocks
|
37 |
|
|
|
38 |
def get_documents(file_src):
|
39 |
documents = []
|
40 |
logging.debug("Loading documents...")
|
@@ -44,40 +48,45 @@ def get_documents(file_src):
|
|
44 |
filename = os.path.basename(filepath)
|
45 |
file_type = os.path.splitext(filepath)[1]
|
46 |
logging.info(f"loading file: {filename}")
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
81 |
text = add_space(text_raw)
|
82 |
# text = block_split(text)
|
83 |
# documents += text
|
@@ -87,19 +96,21 @@ def get_documents(file_src):
|
|
87 |
|
88 |
|
89 |
def construct_index(
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
):
|
99 |
from langchain.chat_models import ChatOpenAI
|
100 |
-
from
|
|
|
101 |
|
102 |
-
|
|
|
103 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
104 |
embedding_limit = None if embedding_limit == 0 else embedding_limit
|
105 |
separator = " " if separator == "" else separator
|
@@ -107,7 +118,14 @@ def construct_index(
|
|
107 |
llm_predictor = LLMPredictor(
|
108 |
llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
|
109 |
)
|
110 |
-
prompt_helper = PromptHelper(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
index_name = get_index_name(file_src)
|
112 |
if os.path.exists(f"./index/{index_name}.json"):
|
113 |
logging.info("找到了缓存的索引文件,加载中……")
|
@@ -115,11 +133,20 @@ def construct_index(
|
|
115 |
else:
|
116 |
try:
|
117 |
documents = get_documents(file_src)
|
|
|
|
|
|
|
|
|
118 |
logging.info("构建索引中……")
|
119 |
with retrieve_proxy():
|
120 |
-
service_context = ServiceContext.from_defaults(
|
|
|
|
|
|
|
|
|
|
|
121 |
index = GPTSimpleVectorIndex.from_documents(
|
122 |
-
documents,
|
123 |
)
|
124 |
logging.debug("索引构建完成!")
|
125 |
os.makedirs("./index", exist_ok=True)
|
|
|
15 |
|
16 |
from modules.presets import *
|
17 |
from modules.utils import *
|
18 |
+
from modules.config import local_embedding
|
19 |
+
|
20 |
|
21 |
def get_index_name(file_src):
|
22 |
file_paths = [x.name for x in file_src]
|
|
|
30 |
|
31 |
return md5_hash.hexdigest()
|
32 |
|
33 |
+
|
34 |
def block_split(text):
|
35 |
blocks = []
|
36 |
while len(text) > 0:
|
|
|
38 |
text = text[1000:]
|
39 |
return blocks
|
40 |
|
41 |
+
|
42 |
def get_documents(file_src):
|
43 |
documents = []
|
44 |
logging.debug("Loading documents...")
|
|
|
48 |
filename = os.path.basename(filepath)
|
49 |
file_type = os.path.splitext(filepath)[1]
|
50 |
logging.info(f"loading file: {filename}")
|
51 |
+
try:
|
52 |
+
if file_type == ".pdf":
|
53 |
+
logging.debug("Loading PDF...")
|
54 |
+
try:
|
55 |
+
from modules.pdf_func import parse_pdf
|
56 |
+
from modules.config import advance_docs
|
57 |
+
|
58 |
+
two_column = advance_docs["pdf"].get("two_column", False)
|
59 |
+
pdftext = parse_pdf(filepath, two_column).text
|
60 |
+
except:
|
61 |
+
pdftext = ""
|
62 |
+
with open(filepath, "rb") as pdfFileObj:
|
63 |
+
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
64 |
+
for page in tqdm(pdfReader.pages):
|
65 |
+
pdftext += page.extract_text()
|
66 |
+
text_raw = pdftext
|
67 |
+
elif file_type == ".docx":
|
68 |
+
logging.debug("Loading Word...")
|
69 |
+
DocxReader = download_loader("DocxReader")
|
70 |
+
loader = DocxReader()
|
71 |
+
text_raw = loader.load_data(file=filepath)[0].text
|
72 |
+
elif file_type == ".epub":
|
73 |
+
logging.debug("Loading EPUB...")
|
74 |
+
EpubReader = download_loader("EpubReader")
|
75 |
+
loader = EpubReader()
|
76 |
+
text_raw = loader.load_data(file=filepath)[0].text
|
77 |
+
elif file_type == ".xlsx":
|
78 |
+
logging.debug("Loading Excel...")
|
79 |
+
text_list = excel_to_string(filepath)
|
80 |
+
for elem in text_list:
|
81 |
+
documents.append(Document(elem))
|
82 |
+
continue
|
83 |
+
else:
|
84 |
+
logging.debug("Loading text file...")
|
85 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
86 |
+
text_raw = f.read()
|
87 |
+
except Exception as e:
|
88 |
+
logging.error(f"Error loading file: {filename}")
|
89 |
+
pass
|
90 |
text = add_space(text_raw)
|
91 |
# text = block_split(text)
|
92 |
# documents += text
|
|
|
96 |
|
97 |
|
98 |
def construct_index(
|
99 |
+
api_key,
|
100 |
+
file_src,
|
101 |
+
max_input_size=4096,
|
102 |
+
num_outputs=5,
|
103 |
+
max_chunk_overlap=20,
|
104 |
+
chunk_size_limit=600,
|
105 |
+
embedding_limit=None,
|
106 |
+
separator=" ",
|
107 |
):
|
108 |
from langchain.chat_models import ChatOpenAI
|
109 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
110 |
+
from llama_index import GPTSimpleVectorIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding
|
111 |
|
112 |
+
if api_key:
|
113 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
114 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
115 |
embedding_limit = None if embedding_limit == 0 else embedding_limit
|
116 |
separator = " " if separator == "" else separator
|
|
|
118 |
llm_predictor = LLMPredictor(
|
119 |
llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
|
120 |
)
|
121 |
+
prompt_helper = PromptHelper(
|
122 |
+
max_input_size=max_input_size,
|
123 |
+
num_output=num_outputs,
|
124 |
+
max_chunk_overlap=max_chunk_overlap,
|
125 |
+
embedding_limit=embedding_limit,
|
126 |
+
chunk_size_limit=600,
|
127 |
+
separator=separator,
|
128 |
+
)
|
129 |
index_name = get_index_name(file_src)
|
130 |
if os.path.exists(f"./index/{index_name}.json"):
|
131 |
logging.info("找到了缓存的索引文件,加载中……")
|
|
|
133 |
else:
|
134 |
try:
|
135 |
documents = get_documents(file_src)
|
136 |
+
if local_embedding:
|
137 |
+
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
|
138 |
+
else:
|
139 |
+
embed_model = OpenAIEmbedding()
|
140 |
logging.info("构建索引中……")
|
141 |
with retrieve_proxy():
|
142 |
+
service_context = ServiceContext.from_defaults(
|
143 |
+
llm_predictor=llm_predictor,
|
144 |
+
prompt_helper=prompt_helper,
|
145 |
+
chunk_size_limit=chunk_size_limit,
|
146 |
+
embed_model=embed_model,
|
147 |
+
)
|
148 |
index = GPTSimpleVectorIndex.from_documents(
|
149 |
+
documents, service_context=service_context
|
150 |
)
|
151 |
logging.debug("索引构建完成!")
|
152 |
os.makedirs("./index", exist_ok=True)
|
modules/models.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from typing import TYPE_CHECKING, List
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
import commentjson as cjson
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import requests
|
10 |
+
import urllib3
|
11 |
+
import platform
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
import colorama
|
15 |
+
from duckduckgo_search import ddg
|
16 |
+
import asyncio
|
17 |
+
import aiohttp
|
18 |
+
from enum import Enum
|
19 |
+
|
20 |
+
from .presets import *
|
21 |
+
from .llama_func import *
|
22 |
+
from .utils import *
|
23 |
+
from . import shared
|
24 |
+
from .config import retrieve_proxy
|
25 |
+
from modules import config
|
26 |
+
from .base_model import BaseLLMModel, ModelType
|
27 |
+
|
28 |
+
|
29 |
+
class OpenAIClient(BaseLLMModel):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
model_name,
|
33 |
+
api_key,
|
34 |
+
system_prompt=INITIAL_SYSTEM_PROMPT,
|
35 |
+
temperature=1.0,
|
36 |
+
top_p=1.0,
|
37 |
+
) -> None:
|
38 |
+
super().__init__(
|
39 |
+
model_name=model_name,
|
40 |
+
temperature=temperature,
|
41 |
+
top_p=top_p,
|
42 |
+
system_prompt=system_prompt,
|
43 |
+
)
|
44 |
+
self.api_key = api_key
|
45 |
+
self.need_api_key = True
|
46 |
+
self._refresh_header()
|
47 |
+
|
48 |
+
def get_answer_stream_iter(self):
|
49 |
+
response = self._get_response(stream=True)
|
50 |
+
if response is not None:
|
51 |
+
iter = self._decode_chat_response(response)
|
52 |
+
partial_text = ""
|
53 |
+
for i in iter:
|
54 |
+
partial_text += i
|
55 |
+
yield partial_text
|
56 |
+
else:
|
57 |
+
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
|
58 |
+
|
59 |
+
def get_answer_at_once(self):
|
60 |
+
response = self._get_response()
|
61 |
+
response = json.loads(response.text)
|
62 |
+
content = response["choices"][0]["message"]["content"]
|
63 |
+
total_token_count = response["usage"]["total_tokens"]
|
64 |
+
return content, total_token_count
|
65 |
+
|
66 |
+
def count_token(self, user_input):
|
67 |
+
input_token_count = count_token(construct_user(user_input))
|
68 |
+
if self.system_prompt is not None and len(self.all_token_counts) == 0:
|
69 |
+
system_prompt_token_count = count_token(
|
70 |
+
construct_system(self.system_prompt)
|
71 |
+
)
|
72 |
+
return input_token_count + system_prompt_token_count
|
73 |
+
return input_token_count
|
74 |
+
|
75 |
+
def billing_info(self):
|
76 |
+
try:
|
77 |
+
curr_time = datetime.datetime.now()
|
78 |
+
last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
|
79 |
+
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
|
80 |
+
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
|
81 |
+
try:
|
82 |
+
usage_data = self._get_billing_data(usage_url)
|
83 |
+
except Exception as e:
|
84 |
+
logging.error(f"获取API使用情况失败:" + str(e))
|
85 |
+
return f"**获取API使用情况失败**"
|
86 |
+
rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
|
87 |
+
return f"**本月使用金额** \u3000 ${rounded_usage}"
|
88 |
+
except requests.exceptions.ConnectTimeout:
|
89 |
+
status_text = (
|
90 |
+
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
|
91 |
+
)
|
92 |
+
return status_text
|
93 |
+
except requests.exceptions.ReadTimeout:
|
94 |
+
status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
|
95 |
+
return status_text
|
96 |
+
except Exception as e:
|
97 |
+
logging.error(f"获取API使用情况失败:" + str(e))
|
98 |
+
return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
|
99 |
+
|
100 |
+
def set_token_upper_limit(self, new_upper_limit):
|
101 |
+
pass
|
102 |
+
|
103 |
+
def set_key(self, new_access_key):
|
104 |
+
self.api_key = new_access_key.strip()
|
105 |
+
self._refresh_header()
|
106 |
+
msg = f"API密钥更改为了{hide_middle_chars(self.api_key)}"
|
107 |
+
logging.info(msg)
|
108 |
+
return msg
|
109 |
+
|
110 |
+
@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
|
111 |
+
def _get_response(self, stream=False):
|
112 |
+
openai_api_key = self.api_key
|
113 |
+
system_prompt = self.system_prompt
|
114 |
+
history = self.history
|
115 |
+
logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
|
116 |
+
headers = {
|
117 |
+
"Content-Type": "application/json",
|
118 |
+
"Authorization": f"Bearer {openai_api_key}",
|
119 |
+
}
|
120 |
+
|
121 |
+
if system_prompt is not None:
|
122 |
+
history = [construct_system(system_prompt), *history]
|
123 |
+
|
124 |
+
payload = {
|
125 |
+
"model": self.model_name,
|
126 |
+
"messages": history,
|
127 |
+
"temperature": self.temperature,
|
128 |
+
"top_p": self.top_p,
|
129 |
+
"n": self.n_choices,
|
130 |
+
"stream": stream,
|
131 |
+
"presence_penalty": self.presence_penalty,
|
132 |
+
"frequency_penalty": self.frequency_penalty,
|
133 |
+
}
|
134 |
+
|
135 |
+
if self.max_generation_token is not None:
|
136 |
+
payload["max_tokens"] = self.max_generation_token
|
137 |
+
if self.stop_sequence is not None:
|
138 |
+
payload["stop"] = self.stop_sequence
|
139 |
+
if self.logit_bias is not None:
|
140 |
+
payload["logit_bias"] = self.logit_bias
|
141 |
+
if self.user_identifier is not None:
|
142 |
+
payload["user"] = self.user_identifier
|
143 |
+
|
144 |
+
if stream:
|
145 |
+
timeout = TIMEOUT_STREAMING
|
146 |
+
else:
|
147 |
+
timeout = TIMEOUT_ALL
|
148 |
+
|
149 |
+
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
|
150 |
+
if shared.state.completion_url != COMPLETION_URL:
|
151 |
+
logging.info(f"使用自定义API URL: {shared.state.completion_url}")
|
152 |
+
|
153 |
+
with retrieve_proxy():
|
154 |
+
try:
|
155 |
+
response = requests.post(
|
156 |
+
shared.state.completion_url,
|
157 |
+
headers=headers,
|
158 |
+
json=payload,
|
159 |
+
stream=stream,
|
160 |
+
timeout=timeout,
|
161 |
+
)
|
162 |
+
except:
|
163 |
+
return None
|
164 |
+
return response
|
165 |
+
|
166 |
+
def _refresh_header(self):
|
167 |
+
self.headers = {
|
168 |
+
"Content-Type": "application/json",
|
169 |
+
"Authorization": f"Bearer {self.api_key}",
|
170 |
+
}
|
171 |
+
|
172 |
+
def _get_billing_data(self, billing_url):
|
173 |
+
with retrieve_proxy():
|
174 |
+
response = requests.get(
|
175 |
+
billing_url,
|
176 |
+
headers=self.headers,
|
177 |
+
timeout=TIMEOUT_ALL,
|
178 |
+
)
|
179 |
+
|
180 |
+
if response.status_code == 200:
|
181 |
+
data = response.json()
|
182 |
+
return data
|
183 |
+
else:
|
184 |
+
raise Exception(
|
185 |
+
f"API request failed with status code {response.status_code}: {response.text}"
|
186 |
+
)
|
187 |
+
|
188 |
+
def _decode_chat_response(self, response):
|
189 |
+
error_msg = ""
|
190 |
+
for chunk in response.iter_lines():
|
191 |
+
if chunk:
|
192 |
+
chunk = chunk.decode()
|
193 |
+
chunk_length = len(chunk)
|
194 |
+
try:
|
195 |
+
chunk = json.loads(chunk[6:])
|
196 |
+
except json.JSONDecodeError:
|
197 |
+
print(f"JSON解析错误,收到的内容: {chunk}")
|
198 |
+
error_msg+=chunk
|
199 |
+
continue
|
200 |
+
if chunk_length > 6 and "delta" in chunk["choices"][0]:
|
201 |
+
if chunk["choices"][0]["finish_reason"] == "stop":
|
202 |
+
break
|
203 |
+
try:
|
204 |
+
yield chunk["choices"][0]["delta"]["content"]
|
205 |
+
except Exception as e:
|
206 |
+
# logging.error(f"Error: {e}")
|
207 |
+
continue
|
208 |
+
if error_msg:
|
209 |
+
raise Exception(error_msg)
|
210 |
+
|
211 |
+
|
212 |
+
class ChatGLM_Client(BaseLLMModel):
|
213 |
+
def __init__(self, model_name) -> None:
|
214 |
+
super().__init__(model_name=model_name)
|
215 |
+
from transformers import AutoTokenizer, AutoModel
|
216 |
+
import torch
|
217 |
+
|
218 |
+
system_name = platform.system()
|
219 |
+
model_path=None
|
220 |
+
if os.path.exists("models"):
|
221 |
+
model_dirs = os.listdir("models")
|
222 |
+
if model_name in model_dirs:
|
223 |
+
model_path = f"models/{model_name}"
|
224 |
+
if model_path is not None:
|
225 |
+
model_source = model_path
|
226 |
+
else:
|
227 |
+
model_source = f"THUDM/{model_name}"
|
228 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
229 |
+
model_source, trust_remote_code=True
|
230 |
+
)
|
231 |
+
quantified = False
|
232 |
+
if "int4" in model_name:
|
233 |
+
quantified = True
|
234 |
+
if quantified:
|
235 |
+
model = AutoModel.from_pretrained(
|
236 |
+
model_source, trust_remote_code=True
|
237 |
+
).float()
|
238 |
+
else:
|
239 |
+
model = AutoModel.from_pretrained(
|
240 |
+
model_source, trust_remote_code=True
|
241 |
+
).half()
|
242 |
+
if torch.cuda.is_available():
|
243 |
+
# run on CUDA
|
244 |
+
logging.info("CUDA is available, using CUDA")
|
245 |
+
model = model.cuda()
|
246 |
+
# mps加速还存在一些问题,暂时不使用
|
247 |
+
elif system_name == "Darwin" and model_path is not None and not quantified:
|
248 |
+
logging.info("Running on macOS, using MPS")
|
249 |
+
# running on macOS and model already downloaded
|
250 |
+
model = model.to("mps")
|
251 |
+
else:
|
252 |
+
logging.info("GPU is not available, using CPU")
|
253 |
+
model = model.eval()
|
254 |
+
self.model = model
|
255 |
+
|
256 |
+
def _get_glm_style_input(self):
|
257 |
+
history = [x["content"] for x in self.history]
|
258 |
+
query = history.pop()
|
259 |
+
logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
|
260 |
+
assert (
|
261 |
+
len(history) % 2 == 0
|
262 |
+
), f"History should be even length. current history is: {history}"
|
263 |
+
history = [[history[i], history[i + 1]] for i in range(0, len(history), 2)]
|
264 |
+
return history, query
|
265 |
+
|
266 |
+
def get_answer_at_once(self):
|
267 |
+
history, query = self._get_glm_style_input()
|
268 |
+
response, _ = self.model.chat(self.tokenizer, query, history=history)
|
269 |
+
return response, len(response)
|
270 |
+
|
271 |
+
def get_answer_stream_iter(self):
|
272 |
+
history, query = self._get_glm_style_input()
|
273 |
+
for response, history in self.model.stream_chat(
|
274 |
+
self.tokenizer,
|
275 |
+
query,
|
276 |
+
history,
|
277 |
+
max_length=self.token_upper_limit,
|
278 |
+
top_p=self.top_p,
|
279 |
+
temperature=self.temperature,
|
280 |
+
):
|
281 |
+
yield response
|
282 |
+
|
283 |
+
|
284 |
+
class LLaMA_Client(BaseLLMModel):
|
285 |
+
def __init__(
|
286 |
+
self,
|
287 |
+
model_name,
|
288 |
+
lora_path=None,
|
289 |
+
) -> None:
|
290 |
+
super().__init__(model_name=model_name)
|
291 |
+
from lmflow.datasets.dataset import Dataset
|
292 |
+
from lmflow.pipeline.auto_pipeline import AutoPipeline
|
293 |
+
from lmflow.models.auto_model import AutoModel
|
294 |
+
from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
|
295 |
+
model_path = None
|
296 |
+
if os.path.exists("models"):
|
297 |
+
model_dirs = os.listdir("models")
|
298 |
+
if model_name in model_dirs:
|
299 |
+
model_path = f"models/{model_name}"
|
300 |
+
if model_path is not None:
|
301 |
+
model_source = model_path
|
302 |
+
else:
|
303 |
+
raise Exception(f"models目录下没有这个模型: {model_name}")
|
304 |
+
if lora_path is not None:
|
305 |
+
lora_path = f"lora/{lora_path}"
|
306 |
+
self.max_generation_token = 1000
|
307 |
+
pipeline_name = "inferencer"
|
308 |
+
model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
|
309 |
+
pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
|
310 |
+
|
311 |
+
with open(pipeline_args.deepspeed, "r") as f:
|
312 |
+
ds_config = json.load(f)
|
313 |
+
|
314 |
+
self.model = AutoModel.get_model(
|
315 |
+
model_args,
|
316 |
+
tune_strategy="none",
|
317 |
+
ds_config=ds_config,
|
318 |
+
)
|
319 |
+
|
320 |
+
# We don't need input data
|
321 |
+
data_args = DatasetArguments(dataset_path=None)
|
322 |
+
self.dataset = Dataset(data_args)
|
323 |
+
|
324 |
+
self.inferencer = AutoPipeline.get_pipeline(
|
325 |
+
pipeline_name=pipeline_name,
|
326 |
+
model_args=model_args,
|
327 |
+
data_args=data_args,
|
328 |
+
pipeline_args=pipeline_args,
|
329 |
+
)
|
330 |
+
|
331 |
+
# Chats
|
332 |
+
model_name = model_args.model_name_or_path
|
333 |
+
if model_args.lora_model_path is not None:
|
334 |
+
model_name += f" + {model_args.lora_model_path}"
|
335 |
+
|
336 |
+
# context = (
|
337 |
+
# "You are a helpful assistant who follows the given instructions"
|
338 |
+
# " unconditionally."
|
339 |
+
# )
|
340 |
+
self.end_string = "\n\n"
|
341 |
+
|
342 |
+
def _get_llama_style_input(self):
|
343 |
+
history = [x["content"] for x in self.history]
|
344 |
+
context = "\n".join(history)
|
345 |
+
return context
|
346 |
+
|
347 |
+
def get_answer_at_once(self):
|
348 |
+
context = self._get_llama_style_input()
|
349 |
+
|
350 |
+
input_dataset = self.dataset.from_dict(
|
351 |
+
{"type": "text_only", "instances": [{"text": context}]}
|
352 |
+
)
|
353 |
+
|
354 |
+
output_dataset = self.inferencer.inference(
|
355 |
+
model=self.model,
|
356 |
+
dataset=input_dataset,
|
357 |
+
max_new_tokens=self.max_generation_token,
|
358 |
+
temperature=self.temperature,
|
359 |
+
)
|
360 |
+
|
361 |
+
response = output_dataset.to_dict()["instances"][0]["text"]
|
362 |
+
|
363 |
+
try:
|
364 |
+
index = response.index(self.end_string)
|
365 |
+
except ValueError:
|
366 |
+
response += self.end_string
|
367 |
+
index = response.index(self.end_string)
|
368 |
+
|
369 |
+
response = response[: index + 1]
|
370 |
+
return response, len(response)
|
371 |
+
|
372 |
+
def get_answer_stream_iter(self):
|
373 |
+
context = self._get_llama_style_input()
|
374 |
+
|
375 |
+
input_dataset = self.dataset.from_dict(
|
376 |
+
{"type": "text_only", "instances": [{"text": context}]}
|
377 |
+
)
|
378 |
+
|
379 |
+
output_dataset = self.inferencer.inference(
|
380 |
+
model=self.model,
|
381 |
+
dataset=input_dataset,
|
382 |
+
max_new_tokens=self.max_generation_token,
|
383 |
+
temperature=self.temperature,
|
384 |
+
)
|
385 |
+
|
386 |
+
response = output_dataset.to_dict()["instances"][0]["text"]
|
387 |
+
|
388 |
+
try:
|
389 |
+
index = response.index(self.end_string)
|
390 |
+
except ValueError:
|
391 |
+
response += self.end_string
|
392 |
+
index = response.index(self.end_string)
|
393 |
+
|
394 |
+
response = response[: index + 1]
|
395 |
+
yield response
|
396 |
+
|
397 |
+
|
398 |
+
class ModelManager:
|
399 |
+
def __init__(self, **kwargs) -> None:
|
400 |
+
self.get_model(**kwargs)
|
401 |
+
|
402 |
+
def get_model(
|
403 |
+
self,
|
404 |
+
model_name,
|
405 |
+
lora_model_path=None,
|
406 |
+
access_key=None,
|
407 |
+
temperature=None,
|
408 |
+
top_p=None,
|
409 |
+
system_prompt=None,
|
410 |
+
) -> BaseLLMModel:
|
411 |
+
msg = f"模型设置为了: {model_name}"
|
412 |
+
model_type = ModelType.get_type(model_name)
|
413 |
+
lora_selector_visibility = False
|
414 |
+
lora_choices = []
|
415 |
+
dont_change_lora_selector = False
|
416 |
+
if model_type != ModelType.OpenAI:
|
417 |
+
config.local_embedding = True
|
418 |
+
model = None
|
419 |
+
try:
|
420 |
+
if model_type == ModelType.OpenAI:
|
421 |
+
model = OpenAIClient(
|
422 |
+
model_name=model_name,
|
423 |
+
api_key=access_key,
|
424 |
+
system_prompt=system_prompt,
|
425 |
+
temperature=temperature,
|
426 |
+
top_p=top_p,
|
427 |
+
)
|
428 |
+
elif model_type == ModelType.ChatGLM:
|
429 |
+
model = ChatGLM_Client(model_name)
|
430 |
+
elif model_type == ModelType.LLaMA and lora_model_path == "":
|
431 |
+
msg = "现在请选择LoRA模型"
|
432 |
+
logging.info(msg)
|
433 |
+
lora_selector_visibility = True
|
434 |
+
if os.path.isdir("lora"):
|
435 |
+
lora_choices = get_file_names("lora", plain=True, filetypes=[""])
|
436 |
+
lora_choices = ["No LoRA"] + lora_choices
|
437 |
+
elif model_type == ModelType.LLaMA and lora_model_path != "":
|
438 |
+
dont_change_lora_selector = True
|
439 |
+
if lora_model_path == "No LoRA":
|
440 |
+
lora_model_path = None
|
441 |
+
msg += " + No LoRA"
|
442 |
+
else:
|
443 |
+
msg += f" + {lora_model_path}"
|
444 |
+
model = LLaMA_Client(model_name, lora_model_path)
|
445 |
+
pass
|
446 |
+
elif model_type == ModelType.Unknown:
|
447 |
+
raise ValueError(f"未知模型: {model_name}")
|
448 |
+
logging.info(msg)
|
449 |
+
except Exception as e:
|
450 |
+
logging.error(e)
|
451 |
+
msg = f"{STANDARD_ERROR_MSG}: {e}"
|
452 |
+
if model is not None:
|
453 |
+
self.model = model
|
454 |
+
if dont_change_lora_selector:
|
455 |
+
return msg
|
456 |
+
else:
|
457 |
+
return msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility)
|
458 |
+
|
459 |
+
def predict(self, *args):
|
460 |
+
iter = self.model.predict(*args)
|
461 |
+
for i in iter:
|
462 |
+
yield i
|
463 |
+
|
464 |
+
def billing_info(self):
|
465 |
+
return self.model.billing_info()
|
466 |
+
|
467 |
+
def set_key(self, *args):
|
468 |
+
return self.model.set_key(*args)
|
469 |
+
|
470 |
+
def load_chat_history(self, *args):
|
471 |
+
return self.model.load_chat_history(*args)
|
472 |
+
|
473 |
+
def interrupt(self, *args):
|
474 |
+
return self.model.interrupt(*args)
|
475 |
+
|
476 |
+
def reset(self, *args):
|
477 |
+
return self.model.reset(*args)
|
478 |
+
|
479 |
+
def retry(self, *args):
|
480 |
+
iter = self.model.retry(*args)
|
481 |
+
for i in iter:
|
482 |
+
yield i
|
483 |
+
|
484 |
+
def delete_first_conversation(self, *args):
|
485 |
+
return self.model.delete_first_conversation(*args)
|
486 |
+
|
487 |
+
def delete_last_conversation(self, *args):
|
488 |
+
return self.model.delete_last_conversation(*args)
|
489 |
+
|
490 |
+
def set_system_prompt(self, *args):
|
491 |
+
return self.model.set_system_prompt(*args)
|
492 |
+
|
493 |
+
def save_chat_history(self, *args):
|
494 |
+
return self.model.save_chat_history(*args)
|
495 |
+
|
496 |
+
def export_markdown(self, *args):
|
497 |
+
return self.model.export_markdown(*args)
|
498 |
+
|
499 |
+
def load_chat_history(self, *args):
|
500 |
+
return self.model.load_chat_history(*args)
|
501 |
+
|
502 |
+
def set_token_upper_limit(self, *args):
|
503 |
+
return self.model.set_token_upper_limit(*args)
|
504 |
+
|
505 |
+
def set_temperature(self, *args):
|
506 |
+
self.model.set_temperature(*args)
|
507 |
+
|
508 |
+
def set_top_p(self, *args):
|
509 |
+
self.model.set_top_p(*args)
|
510 |
+
|
511 |
+
def set_n_choices(self, *args):
|
512 |
+
self.model.set_n_choices(*args)
|
513 |
+
|
514 |
+
def set_stop_sequence(self, *args):
|
515 |
+
self.model.set_stop_sequence(*args)
|
516 |
+
|
517 |
+
def set_max_tokens(self, *args):
|
518 |
+
self.model.set_max_tokens(*args)
|
519 |
+
|
520 |
+
def set_presence_penalty(self, *args):
|
521 |
+
self.model.set_presence_penalty(*args)
|
522 |
+
|
523 |
+
def set_frequency_penalty(self, *args):
|
524 |
+
self.model.set_frequency_penalty(*args)
|
525 |
+
|
526 |
+
def set_logit_bias(self, *args):
|
527 |
+
self.model.set_logit_bias(*args)
|
528 |
+
|
529 |
+
def set_user_identifier(self, *args):
|
530 |
+
self.model.set_user_identifier(*args)
|
531 |
+
|
532 |
+
|
533 |
+
|
534 |
+
|
535 |
+
if __name__ == "__main__":
|
536 |
+
with open("config.json", "r") as f:
|
537 |
+
openai_api_key = cjson.load(f)["openai_api_key"]
|
538 |
+
# set logging level to debug
|
539 |
+
logging.basicConfig(level=logging.DEBUG)
|
540 |
+
# client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
|
541 |
+
client = ModelManager(model_name="chatglm-6b-int4")
|
542 |
+
chatbot = []
|
543 |
+
stream = False
|
544 |
+
# 测试账单功能
|
545 |
+
logging.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
|
546 |
+
logging.info(client.billing_info())
|
547 |
+
# 测试问答
|
548 |
+
logging.info(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET)
|
549 |
+
question = "巴黎是中国的首都吗?"
|
550 |
+
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
|
551 |
+
logging.info(i)
|
552 |
+
logging.info(f"测试问答后history : {client.history}")
|
553 |
+
# 测试记忆力
|
554 |
+
logging.info(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET)
|
555 |
+
question = "我刚刚问了你什么问题?"
|
556 |
+
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
|
557 |
+
logging.info(i)
|
558 |
+
logging.info(f"测试记忆力后history : {client.history}")
|
559 |
+
# 测试重试功能
|
560 |
+
logging.info(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET)
|
561 |
+
for i in client.retry(chatbot=chatbot, stream=stream):
|
562 |
+
logging.info(i)
|
563 |
+
logging.info(f"重试后history : {client.history}")
|
564 |
+
# # 测试总结功能
|
565 |
+
# print(colorama.Back.GREEN + "测试总结功能" + colorama.Back.RESET)
|
566 |
+
# chatbot, msg = client.reduce_token_size(chatbot=chatbot)
|
567 |
+
# print(chatbot, msg)
|
568 |
+
# print(f"总结后history: {client.history}")
|
modules/openai_func.py
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
import requests
|
2 |
-
import logging
|
3 |
-
from modules.presets import (
|
4 |
-
timeout_all,
|
5 |
-
USAGE_API_URL,
|
6 |
-
BALANCE_API_URL,
|
7 |
-
standard_error_msg,
|
8 |
-
connection_timeout_prompt,
|
9 |
-
error_retrieve_prompt,
|
10 |
-
read_timeout_prompt
|
11 |
-
)
|
12 |
-
|
13 |
-
from . import shared
|
14 |
-
from modules.config import retrieve_proxy
|
15 |
-
import os, datetime
|
16 |
-
|
17 |
-
def get_billing_data(openai_api_key, billing_url):
|
18 |
-
headers = {
|
19 |
-
"Content-Type": "application/json",
|
20 |
-
"Authorization": f"Bearer {openai_api_key}"
|
21 |
-
}
|
22 |
-
|
23 |
-
timeout = timeout_all
|
24 |
-
with retrieve_proxy():
|
25 |
-
response = requests.get(
|
26 |
-
billing_url,
|
27 |
-
headers=headers,
|
28 |
-
timeout=timeout,
|
29 |
-
)
|
30 |
-
|
31 |
-
if response.status_code == 200:
|
32 |
-
data = response.json()
|
33 |
-
return data
|
34 |
-
else:
|
35 |
-
raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
|
36 |
-
|
37 |
-
|
38 |
-
def get_usage(openai_api_key):
|
39 |
-
try:
|
40 |
-
curr_time = datetime.datetime.now()
|
41 |
-
last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
|
42 |
-
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
|
43 |
-
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
|
44 |
-
try:
|
45 |
-
usage_data = get_billing_data(openai_api_key, usage_url)
|
46 |
-
except Exception as e:
|
47 |
-
logging.error(f"获取API使用情况失败:"+str(e))
|
48 |
-
return f"**获取API使用情况失败**"
|
49 |
-
rounded_usage = "{:.5f}".format(usage_data['total_usage']/100)
|
50 |
-
return f"**本月使用金额** \u3000 ${rounded_usage}"
|
51 |
-
except requests.exceptions.ConnectTimeout:
|
52 |
-
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
53 |
-
return status_text
|
54 |
-
except requests.exceptions.ReadTimeout:
|
55 |
-
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
|
56 |
-
return status_text
|
57 |
-
except Exception as e:
|
58 |
-
logging.error(f"获取API使用情况失败:"+str(e))
|
59 |
-
return standard_error_msg + error_retrieve_prompt
|
60 |
-
|
61 |
-
def get_last_day_of_month(any_day):
|
62 |
-
# The day 28 exists in every month. 4 days later, it's always next month
|
63 |
-
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
64 |
-
# subtracting the number of the current day brings us back one month
|
65 |
-
return next_month - datetime.timedelta(days=next_month.day)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/presets.py
CHANGED
@@ -3,48 +3,50 @@ import gradio as gr
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
# ChatGPT 设置
|
6 |
-
|
7 |
API_HOST = "api.openai.com"
|
8 |
COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
|
9 |
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
|
10 |
USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
|
11 |
HISTORY_DIR = Path("history")
|
|
|
12 |
TEMPLATES_DIR = "templates"
|
13 |
|
14 |
# 错误信息
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
27 |
HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
|
28 |
CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
|
29 |
|
30 |
SIM_K = 5
|
31 |
INDEX_QUERY_TEMPRATURE = 1.0
|
32 |
|
33 |
-
|
34 |
-
|
35 |
<div align="center" style="margin:16px 0">
|
36 |
|
37 |
由Bilibili [土川虎虎虎](https://space.bilibili.com/29125536) 和 [明昭MZhao](https://space.bilibili.com/24807452)开发
|
38 |
|
39 |
访问川虎ChatGPT的 [GitHub项目](https://github.com/GaiZhenbiao/ChuanhuChatGPT) 下载最新版脚本
|
40 |
|
41 |
-
此App使用 `gpt-3.5-turbo` 大语言模型
|
42 |
</div>
|
43 |
"""
|
44 |
|
45 |
-
|
46 |
|
47 |
-
|
48 |
<div style="display: flex; justify-content: space-between;">
|
49 |
<span style="margin-top: 4px !important;">切换亮暗色主题</span>
|
50 |
<span><label class="apSwitch" for="checkbox">
|
@@ -53,7 +55,8 @@ appearance_switcher = """
|
|
53 |
</label></span>
|
54 |
</div>
|
55 |
"""
|
56 |
-
|
|
|
57 |
|
58 |
MODELS = [
|
59 |
"gpt-3.5-turbo",
|
@@ -62,35 +65,34 @@ MODELS = [
|
|
62 |
"gpt-4-0314",
|
63 |
"gpt-4-32k",
|
64 |
"gpt-4-32k-0314",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
] # 可选的模型
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
"gpt-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
"gpt-4": {
|
77 |
-
"streaming": 7500,
|
78 |
-
"all": 7500
|
79 |
-
},
|
80 |
-
"gpt-4-0314": {
|
81 |
-
"streaming": 7500,
|
82 |
-
"all": 7500
|
83 |
-
},
|
84 |
-
"gpt-4-32k": {
|
85 |
-
"streaming": 31000,
|
86 |
-
"all": 31000
|
87 |
-
},
|
88 |
-
"gpt-4-32k-0314": {
|
89 |
-
"streaming": 31000,
|
90 |
-
"all": 31000
|
91 |
-
}
|
92 |
}
|
93 |
|
|
|
|
|
|
|
|
|
94 |
REPLY_LANGUAGES = [
|
95 |
"简体中文",
|
96 |
"繁體中文",
|
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
# ChatGPT 设置
|
6 |
+
INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
|
7 |
API_HOST = "api.openai.com"
|
8 |
COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
|
9 |
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
|
10 |
USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
|
11 |
HISTORY_DIR = Path("history")
|
12 |
+
HISTORY_DIR = "history"
|
13 |
TEMPLATES_DIR = "templates"
|
14 |
|
15 |
# 错误信息
|
16 |
+
STANDARD_ERROR_MSG = "☹️发生了错误:" # 错误信息的标准前缀
|
17 |
+
GENERAL_ERROR_MSG = "获取对话时发生错误,请查看后台日志"
|
18 |
+
ERROR_RETRIEVE_MSG = "请检查网络连接,或者API-Key是否有效。"
|
19 |
+
CONNECTION_TIMEOUT_MSG = "连接超时,无法获取对话。" # 连接超时
|
20 |
+
READ_TIMEOUT_MSG = "读取超时,无法获取对话。" # 读取超时
|
21 |
+
PROXY_ERROR_MSG = "代理错误,无法获取对话。" # 代理错误
|
22 |
+
SSL_ERROR_PROMPT = "SSL错误,无法获取对话。" # SSL 错误
|
23 |
+
NO_APIKEY_MSG = "API key为空,请检查是否输入正确。" # API key 长度不足 51 位
|
24 |
+
NO_INPUT_MSG = "请输入对话内容。" # 未输入对话内容
|
25 |
+
BILLING_NOT_APPLICABLE_MSG = "模型本地运行中" # 本地运行的模型返回的账单信息
|
26 |
+
|
27 |
+
TIMEOUT_STREAMING = 60 # 流式对话时的超时时间
|
28 |
+
TIMEOUT_ALL = 200 # 非流式对话时的超时时间
|
29 |
+
ENABLE_STREAMING_OPTION = True # 是否启用选择选择是否实时显示回答的勾选框
|
30 |
HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
|
31 |
CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
|
32 |
|
33 |
SIM_K = 5
|
34 |
INDEX_QUERY_TEMPRATURE = 1.0
|
35 |
|
36 |
+
CHUANHU_TITLE = """<h1 align="left">川虎ChatGPT 🚀</h1>"""
|
37 |
+
CHUANHU_DESCRIPTION = """\
|
38 |
<div align="center" style="margin:16px 0">
|
39 |
|
40 |
由Bilibili [土川虎虎虎](https://space.bilibili.com/29125536) 和 [明昭MZhao](https://space.bilibili.com/24807452)开发
|
41 |
|
42 |
访问川虎ChatGPT的 [GitHub项目](https://github.com/GaiZhenbiao/ChuanhuChatGPT) 下载最新版脚本
|
43 |
|
|
|
44 |
</div>
|
45 |
"""
|
46 |
|
47 |
+
FOOTER = """<div class="versions">{versions}</div>"""
|
48 |
|
49 |
+
APPEARANCE_SWITCHER = """
|
50 |
<div style="display: flex; justify-content: space-between;">
|
51 |
<span style="margin-top: 4px !important;">切换亮暗色主题</span>
|
52 |
<span><label class="apSwitch" for="checkbox">
|
|
|
55 |
</label></span>
|
56 |
</div>
|
57 |
"""
|
58 |
+
|
59 |
+
SUMMARIZE_PROMPT = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
|
60 |
|
61 |
MODELS = [
|
62 |
"gpt-3.5-turbo",
|
|
|
65 |
"gpt-4-0314",
|
66 |
"gpt-4-32k",
|
67 |
"gpt-4-32k-0314",
|
68 |
+
"chatglm-6b",
|
69 |
+
"chatglm-6b-int4",
|
70 |
+
"chatglm-6b-int4-qe",
|
71 |
+
"llama-7b-hf",
|
72 |
+
"llama-7b-hf-int4",
|
73 |
+
"llama-7b-hf-int8",
|
74 |
+
"llama-13b-hf",
|
75 |
+
"llama-13b-hf-int4",
|
76 |
+
"llama-30b-hf",
|
77 |
+
"llama-30b-hf-int4",
|
78 |
+
"llama-65b-hf",
|
79 |
] # 可选的模型
|
80 |
|
81 |
+
DEFAULT_MODEL = 0 # 默认的模型在MODELS中的序号,从0开始数
|
82 |
+
|
83 |
+
MODEL_TOKEN_LIMIT = {
|
84 |
+
"gpt-3.5-turbo": 4096,
|
85 |
+
"gpt-3.5-turbo-0301": 4096,
|
86 |
+
"gpt-4": 8192,
|
87 |
+
"gpt-4-0314": 8192,
|
88 |
+
"gpt-4-32k": 32768,
|
89 |
+
"gpt-4-32k-0314": 32768
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
}
|
91 |
|
92 |
+
TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
|
93 |
+
DEFAULT_TOKEN_LIMIT = 3000 # 默认的token上限
|
94 |
+
REDUCE_TOKEN_FACTOR = 0.5 # 与模型token上限想乘,得到目标token数。减少token占用时,将token占用减少到目标token数以下。
|
95 |
+
|
96 |
REPLY_LANGUAGES = [
|
97 |
"简体中文",
|
98 |
"繁體中文",
|
modules/utils.py
CHANGED
@@ -153,107 +153,22 @@ def construct_assistant(text):
|
|
153 |
return construct_text("assistant", text)
|
154 |
|
155 |
|
156 |
-
def construct_token_message(tokens: List[int]):
|
157 |
-
token_sum = 0
|
158 |
-
for i in range(len(tokens)):
|
159 |
-
token_sum += sum(tokens[: i + 1])
|
160 |
-
return f"Token 计数: {sum(tokens)},本次对话累计消耗了 {token_sum} tokens"
|
161 |
-
|
162 |
-
|
163 |
-
def delete_first_conversation(history, previous_token_count):
|
164 |
-
if history:
|
165 |
-
del history[:2]
|
166 |
-
del previous_token_count[0]
|
167 |
-
return (
|
168 |
-
history,
|
169 |
-
previous_token_count,
|
170 |
-
construct_token_message(previous_token_count),
|
171 |
-
)
|
172 |
-
|
173 |
-
|
174 |
-
def delete_last_conversation(chatbot, history, previous_token_count):
|
175 |
-
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
176 |
-
logging.info("由于包含报错信息,只删除chatbot记录")
|
177 |
-
chatbot.pop()
|
178 |
-
return chatbot, history
|
179 |
-
if len(history) > 0:
|
180 |
-
logging.info("删除了一组对话历史")
|
181 |
-
history.pop()
|
182 |
-
history.pop()
|
183 |
-
if len(chatbot) > 0:
|
184 |
-
logging.info("删除了一组chatbot对话")
|
185 |
-
chatbot.pop()
|
186 |
-
if len(previous_token_count) > 0:
|
187 |
-
logging.info("删除了一组对话的token计数记录")
|
188 |
-
previous_token_count.pop()
|
189 |
-
return (
|
190 |
-
chatbot,
|
191 |
-
history,
|
192 |
-
previous_token_count,
|
193 |
-
construct_token_message(previous_token_count),
|
194 |
-
)
|
195 |
-
|
196 |
-
|
197 |
def save_file(filename, system, history, chatbot, user_name):
|
198 |
-
logging.
|
199 |
-
os.makedirs(HISTORY_DIR
|
200 |
if filename.endswith(".json"):
|
201 |
json_s = {"system": system, "history": history, "chatbot": chatbot}
|
202 |
print(json_s)
|
203 |
-
with open(os.path.join(HISTORY_DIR
|
204 |
json.dump(json_s, f)
|
205 |
elif filename.endswith(".md"):
|
206 |
md_s = f"system: \n- {system} \n"
|
207 |
for data in history:
|
208 |
md_s += f"\n{data['role']}: \n- {data['content']} \n"
|
209 |
-
with open(os.path.join(HISTORY_DIR
|
210 |
f.write(md_s)
|
211 |
-
logging.
|
212 |
-
return os.path.join(HISTORY_DIR
|
213 |
-
|
214 |
-
|
215 |
-
def save_chat_history(filename, system, history, chatbot, user_name):
|
216 |
-
if filename == "":
|
217 |
-
return
|
218 |
-
if not filename.endswith(".json"):
|
219 |
-
filename += ".json"
|
220 |
-
return save_file(filename, system, history, chatbot, user_name)
|
221 |
-
|
222 |
-
|
223 |
-
def export_markdown(filename, system, history, chatbot, user_name):
|
224 |
-
if filename == "":
|
225 |
-
return
|
226 |
-
if not filename.endswith(".md"):
|
227 |
-
filename += ".md"
|
228 |
-
return save_file(filename, system, history, chatbot, user_name)
|
229 |
-
|
230 |
-
|
231 |
-
def load_chat_history(filename, system, history, chatbot, user_name):
|
232 |
-
logging.info(f"{user_name} 加载对话历史中……")
|
233 |
-
if type(filename) != str:
|
234 |
-
filename = filename.name
|
235 |
-
try:
|
236 |
-
with open(os.path.join(HISTORY_DIR / user_name, filename), "r") as f:
|
237 |
-
json_s = json.load(f)
|
238 |
-
try:
|
239 |
-
if type(json_s["history"][0]) == str:
|
240 |
-
logging.info("历史记录格式为旧版,正在转换……")
|
241 |
-
new_history = []
|
242 |
-
for index, item in enumerate(json_s["history"]):
|
243 |
-
if index % 2 == 0:
|
244 |
-
new_history.append(construct_user(item))
|
245 |
-
else:
|
246 |
-
new_history.append(construct_assistant(item))
|
247 |
-
json_s["history"] = new_history
|
248 |
-
logging.info(new_history)
|
249 |
-
except:
|
250 |
-
# 没有对话历史
|
251 |
-
pass
|
252 |
-
logging.info(f"{user_name} 加载对话历史完毕")
|
253 |
-
return filename, json_s["system"], json_s["history"], json_s["chatbot"]
|
254 |
-
except FileNotFoundError:
|
255 |
-
logging.info(f"{user_name} 没有找到对话历史文件,不执行任何操作")
|
256 |
-
return filename, system, history, chatbot
|
257 |
|
258 |
|
259 |
def sorted_by_pinyin(list):
|
@@ -261,7 +176,7 @@ def sorted_by_pinyin(list):
|
|
261 |
|
262 |
|
263 |
def get_file_names(dir, plain=False, filetypes=[".json"]):
|
264 |
-
logging.
|
265 |
files = []
|
266 |
try:
|
267 |
for type in filetypes:
|
@@ -279,14 +194,13 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
|
|
279 |
|
280 |
|
281 |
def get_history_names(plain=False, user_name=""):
|
282 |
-
logging.
|
283 |
-
return get_file_names(HISTORY_DIR
|
284 |
|
285 |
|
286 |
def load_template(filename, mode=0):
|
287 |
-
logging.
|
288 |
lines = []
|
289 |
-
logging.info("Loading template...")
|
290 |
if filename.endswith(".json"):
|
291 |
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
292 |
lines = json.load(f)
|
@@ -310,23 +224,18 @@ def load_template(filename, mode=0):
|
|
310 |
|
311 |
|
312 |
def get_template_names(plain=False):
|
313 |
-
logging.
|
314 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
315 |
|
316 |
|
317 |
def get_template_content(templates, selection, original_system_prompt):
|
318 |
-
logging.
|
319 |
try:
|
320 |
return templates[selection]
|
321 |
except:
|
322 |
return original_system_prompt
|
323 |
|
324 |
|
325 |
-
def reset_state():
|
326 |
-
logging.info("重置状态")
|
327 |
-
return [], [], [], construct_token_message([0])
|
328 |
-
|
329 |
-
|
330 |
def reset_textbox():
|
331 |
logging.debug("重置文本框")
|
332 |
return gr.update(value="")
|
@@ -530,3 +439,13 @@ def excel_to_string(file_path):
|
|
530 |
|
531 |
|
532 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
return construct_text("assistant", text)
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def save_file(filename, system, history, chatbot, user_name):
|
157 |
+
logging.debug(f"{user_name} 保存对话历史中……")
|
158 |
+
os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
|
159 |
if filename.endswith(".json"):
|
160 |
json_s = {"system": system, "history": history, "chatbot": chatbot}
|
161 |
print(json_s)
|
162 |
+
with open(os.path.join(HISTORY_DIR, user_name, filename), "w") as f:
|
163 |
json.dump(json_s, f)
|
164 |
elif filename.endswith(".md"):
|
165 |
md_s = f"system: \n- {system} \n"
|
166 |
for data in history:
|
167 |
md_s += f"\n{data['role']}: \n- {data['content']} \n"
|
168 |
+
with open(os.path.join(HISTORY_DIR, user_name, filename), "w", encoding="utf8") as f:
|
169 |
f.write(md_s)
|
170 |
+
logging.debug(f"{user_name} 保存对话历史完毕")
|
171 |
+
return os.path.join(HISTORY_DIR, user_name, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
|
174 |
def sorted_by_pinyin(list):
|
|
|
176 |
|
177 |
|
178 |
def get_file_names(dir, plain=False, filetypes=[".json"]):
|
179 |
+
logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
|
180 |
files = []
|
181 |
try:
|
182 |
for type in filetypes:
|
|
|
194 |
|
195 |
|
196 |
def get_history_names(plain=False, user_name=""):
|
197 |
+
logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
|
198 |
+
return get_file_names(os.path.join(HISTORY_DIR, user_name), plain)
|
199 |
|
200 |
|
201 |
def load_template(filename, mode=0):
|
202 |
+
logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
|
203 |
lines = []
|
|
|
204 |
if filename.endswith(".json"):
|
205 |
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
206 |
lines = json.load(f)
|
|
|
224 |
|
225 |
|
226 |
def get_template_names(plain=False):
|
227 |
+
logging.debug("获取模板文件名列表")
|
228 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
229 |
|
230 |
|
231 |
def get_template_content(templates, selection, original_system_prompt):
|
232 |
+
logging.debug(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
|
233 |
try:
|
234 |
return templates[selection]
|
235 |
except:
|
236 |
return original_system_prompt
|
237 |
|
238 |
|
|
|
|
|
|
|
|
|
|
|
239 |
def reset_textbox():
|
240 |
logging.debug("重置文本框")
|
241 |
return gr.update(value="")
|
|
|
439 |
|
440 |
|
441 |
return result
|
442 |
+
|
443 |
+
def get_last_day_of_month(any_day):
|
444 |
+
# The day 28 exists in every month. 4 days later, it's always next month
|
445 |
+
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
446 |
+
# subtracting the number of the current day brings us back one month
|
447 |
+
return next_month - datetime.timedelta(days=next_month.day)
|
448 |
+
|
449 |
+
def get_model_source(model_name, alternative_source):
|
450 |
+
if model_name == "gpt2-medium":
|
451 |
+
return "https://huggingface.co/gpt2-medium"
|
requirements_advanced.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
icetk
|
4 |
+
protobuf==3.19.0
|
5 |
+
git+https://github.com/OptimalScale/LMFlow.git#egg=lmflow
|
6 |
+
cpm-kernels
|
7 |
+
sentence_transformers
|