Keldos commited on
Commit
4282926
2 Parent(s): d3b93fb 0d61cce

BREAKING: Merge 'expansive': 代码重构,支持本地model (#572)

Browse files

重大新功能:
- 支持更多参数
- 支持ChatGLM
- 支持本地embedding
- 支持LLaMA本地模型

可能的问题:
- 移除了许多错误处理的代码,错误将更多地会在终端中体现
- 本地embedding对中文的支持不是很好

.gitignore CHANGED
@@ -133,7 +133,10 @@ dmypy.json
133
  # Mac system file
134
  **/.DS_Store
135
 
 
136
  api_key.txt
137
  config.json
138
  auth.json
 
 
139
  .idea
 
133
  # Mac system file
134
  **/.DS_Store
135
 
136
+ # 配置文件/模型文件
137
  api_key.txt
138
  config.json
139
  auth.json
140
+ models/
141
+ lora/
142
  .idea
ChuanhuChatbot.py CHANGED
@@ -10,8 +10,7 @@ from modules.config import *
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
- from modules.chat_func import *
14
- from modules.openai_func import get_usage
15
 
16
  gr.Chatbot.postprocess = postprocess
17
  PromptHelper.compact_text_chunks = compact_text_chunks
@@ -21,16 +20,14 @@ with open("assets/custom.css", "r", encoding="utf-8") as f:
21
 
22
  with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
23
  user_name = gr.State("")
24
- history = gr.State([])
25
- token_count = gr.State([])
26
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
27
- user_api_key = gr.State(my_api_key)
28
  user_question = gr.State("")
29
- outputing = gr.State(False)
 
30
  topic = gr.State("未命名对话历史记录")
31
 
32
  with gr.Row():
33
- gr.HTML(title, elem_id="app_title")
34
  status_display = gr.Markdown(get_geoip(), elem_id="status_display")
35
  with gr.Row(elem_id="float_display"):
36
  user_info = gr.Markdown(value="getting user info...", elem_id="user_info")
@@ -64,11 +61,10 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
64
  retryBtn = gr.Button("🔄 重新生成")
65
  delFirstBtn = gr.Button("🗑️ 删除最旧对话")
66
  delLastBtn = gr.Button("🗑️ 删除最新对话")
67
- reduceTokenBtn = gr.Button("♻️ 总结对话")
68
 
69
  with gr.Column():
70
  with gr.Column(min_width=50, scale=1):
71
- with gr.Tab(label="ChatGPT"):
72
  keyTxt = gr.Textbox(
73
  show_label=True,
74
  placeholder=f"OpenAI API-key...",
@@ -82,10 +78,13 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
82
  else:
83
  usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display", elem_classes="insert_block")
84
  model_select_dropdown = gr.Dropdown(
85
- label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
 
 
 
86
  )
87
  use_streaming_checkbox = gr.Checkbox(
88
- label="实时传输回答", value=True, visible=enable_streaming_option
89
  )
90
  use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
91
  language_select_dropdown = gr.Dropdown(
@@ -94,7 +93,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
94
  multiselect=False,
95
  value=REPLY_LANGUAGES[0],
96
  )
97
- index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
98
  two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
99
  # TODO: 公式ocr
100
  # formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
@@ -104,7 +103,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
104
  show_label=True,
105
  placeholder=f"在这里输入System Prompt...",
106
  label="System prompt",
107
- value=initial_prompt,
108
  lines=10,
109
  ).style(container=False)
110
  with gr.Accordion(label="加载Prompt模板", open=True):
@@ -160,24 +159,84 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
160
 
161
  with gr.Tab(label="高级"):
162
  gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
163
- default_btn = gr.Button("🔙 恢复默认设置")
164
- gr.HTML(appearance_switcher, elem_classes="insert_block")
165
  with gr.Accordion("参数", open=False):
166
- top_p = gr.Slider(
 
 
 
 
 
 
 
 
167
  minimum=-0,
168
  maximum=1.0,
169
  value=1.0,
170
  step=0.05,
171
  interactive=True,
172
- label="Top-p",
173
  )
174
- temperature = gr.Slider(
175
- minimum=-0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  maximum=2.0,
177
- value=1.0,
178
- step=0.1,
 
 
 
 
 
 
 
 
179
  interactive=True,
180
- label="Temperature",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  )
182
 
183
  with gr.Accordion("网络设置", open=False):
@@ -198,27 +257,21 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
198
  lines=2,
199
  )
200
  changeProxyBtn = gr.Button("🔄 设置代理地址")
 
201
 
202
- gr.Markdown(description)
203
- gr.HTML(footer.format(versions=versions_html()), elem_id="footer")
204
  chatgpt_predict_args = dict(
205
- fn=predict,
206
  inputs=[
207
- user_api_key,
208
- systemPromptTxt,
209
- history,
210
  user_question,
211
  chatbot,
212
- token_count,
213
- top_p,
214
- temperature,
215
  use_streaming_checkbox,
216
- model_select_dropdown,
217
  use_websearch_checkbox,
218
  index_files,
219
  language_select_dropdown,
220
  ],
221
- outputs=[chatbot, history, status_display, token_count],
222
  show_progress=True,
223
  )
224
 
@@ -242,12 +295,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
242
  )
243
 
244
  get_usage_args = dict(
245
- fn=get_usage, inputs=[user_api_key], outputs=[usageTxt], show_progress=False
 
 
 
 
 
 
246
  )
247
 
248
 
249
  # Chatbot
250
- cancelBtn.click(cancel_outputing, [], [])
251
 
252
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
253
  user_input.submit(**get_usage_args)
@@ -256,70 +315,49 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
256
  submitBtn.click(**get_usage_args)
257
 
258
  emptyBtn.click(
259
- reset_state,
260
- outputs=[chatbot, history, token_count, status_display],
261
  show_progress=True,
262
  )
263
  emptyBtn.click(**reset_textbox_args)
264
 
265
  retryBtn.click(**start_outputing_args).then(
266
- retry,
267
  [
268
- user_api_key,
269
- systemPromptTxt,
270
- history,
271
  chatbot,
272
- token_count,
273
- top_p,
274
- temperature,
275
  use_streaming_checkbox,
276
- model_select_dropdown,
 
277
  language_select_dropdown,
278
  ],
279
- [chatbot, history, status_display, token_count],
280
  show_progress=True,
281
  ).then(**end_outputing_args)
282
  retryBtn.click(**get_usage_args)
283
 
284
  delFirstBtn.click(
285
- delete_first_conversation,
286
- [history, token_count],
287
- [history, token_count, status_display],
288
  )
289
 
290
  delLastBtn.click(
291
- delete_last_conversation,
292
- [chatbot, history, token_count],
293
- [chatbot, history, token_count, status_display],
294
- show_progress=True,
295
- )
296
-
297
- reduceTokenBtn.click(
298
- reduce_token_size,
299
- [
300
- user_api_key,
301
- systemPromptTxt,
302
- history,
303
- chatbot,
304
- token_count,
305
- top_p,
306
- temperature,
307
- gr.State(sum(token_count.value[-4:])),
308
- model_select_dropdown,
309
- language_select_dropdown,
310
- ],
311
- [chatbot, history, status_display, token_count],
312
- show_progress=True,
313
  )
314
- reduceTokenBtn.click(**get_usage_args)
315
 
316
  two_column.change(update_doc_config, [two_column], None)
317
 
318
- # ChatGPT
319
- keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
320
  keyTxt.submit(**get_usage_args)
 
 
321
 
322
  # Template
 
323
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
324
  templateFileSelectDropdown.change(
325
  load_template,
@@ -336,32 +374,34 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
336
 
337
  # S&L
338
  saveHistoryBtn.click(
339
- save_chat_history,
340
- [saveFileName, systemPromptTxt, history, chatbot, user_name],
341
  downloadFile,
342
  show_progress=True,
343
  )
344
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
345
  exportMarkdownBtn.click(
346
- export_markdown,
347
- [saveFileName, systemPromptTxt, history, chatbot, user_name],
348
  downloadFile,
349
  show_progress=True,
350
  )
351
  historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
352
- historyFileSelectDropdown.change(
353
- load_chat_history,
354
- [historyFileSelectDropdown, systemPromptTxt, history, chatbot, user_name],
355
- [saveFileName, systemPromptTxt, history, chatbot],
356
- show_progress=True,
357
- )
358
- downloadFile.change(
359
- load_chat_history,
360
- [downloadFile, systemPromptTxt, history, chatbot, user_name],
361
- [saveFileName, systemPromptTxt, history, chatbot],
362
- )
363
 
364
  # Advanced
 
 
 
 
 
 
 
 
 
 
 
365
  default_btn.click(
366
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
367
  )
 
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
+ from modules.models import ModelManager
 
14
 
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
 
20
 
21
  with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
22
  user_name = gr.State("")
 
 
23
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
 
24
  user_question = gr.State("")
25
+ current_model = gr.State(ModelManager(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key))
26
+
27
  topic = gr.State("未命名对话历史记录")
28
 
29
  with gr.Row():
30
+ gr.HTML(CHUANHU_TITLE, elem_id="app_title")
31
  status_display = gr.Markdown(get_geoip(), elem_id="status_display")
32
  with gr.Row(elem_id="float_display"):
33
  user_info = gr.Markdown(value="getting user info...", elem_id="user_info")
 
61
  retryBtn = gr.Button("🔄 重新生成")
62
  delFirstBtn = gr.Button("🗑️ 删除最旧对话")
63
  delLastBtn = gr.Button("🗑️ 删除最新对话")
 
64
 
65
  with gr.Column():
66
  with gr.Column(min_width=50, scale=1):
67
+ with gr.Tab(label="模型"):
68
  keyTxt = gr.Textbox(
69
  show_label=True,
70
  placeholder=f"OpenAI API-key...",
 
78
  else:
79
  usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display", elem_classes="insert_block")
80
  model_select_dropdown = gr.Dropdown(
81
+ label="选择模型", choices=MODELS, multiselect=False, value=MODELS[DEFAULT_MODEL], interactive=True
82
+ )
83
+ lora_select_dropdown = gr.Dropdown(
84
+ label="选择LoRA模型", choices=[], multiselect=False, interactive=True, visible=False
85
  )
86
  use_streaming_checkbox = gr.Checkbox(
87
+ label="实时传输回答", value=True, visible=ENABLE_STREAMING_OPTION
88
  )
89
  use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
90
  language_select_dropdown = gr.Dropdown(
 
93
  multiselect=False,
94
  value=REPLY_LANGUAGES[0],
95
  )
96
+ index_files = gr.Files(label="上传索引文件", type="file")
97
  two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
98
  # TODO: 公式ocr
99
  # formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
 
103
  show_label=True,
104
  placeholder=f"在这里输入System Prompt...",
105
  label="System prompt",
106
+ value=INITIAL_SYSTEM_PROMPT,
107
  lines=10,
108
  ).style(container=False)
109
  with gr.Accordion(label="加载Prompt模板", open=True):
 
159
 
160
  with gr.Tab(label="高级"):
161
  gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
162
+ gr.HTML(APPEARANCE_SWITCHER, elem_classes="insert_block")
 
163
  with gr.Accordion("参数", open=False):
164
+ temperature_slider = gr.Slider(
165
+ minimum=-0,
166
+ maximum=2.0,
167
+ value=1.0,
168
+ step=0.1,
169
+ interactive=True,
170
+ label="temperature",
171
+ )
172
+ top_p_slider = gr.Slider(
173
  minimum=-0,
174
  maximum=1.0,
175
  value=1.0,
176
  step=0.05,
177
  interactive=True,
178
+ label="top-p",
179
  )
180
+ n_choices_slider = gr.Slider(
181
+ minimum=1,
182
+ maximum=10,
183
+ value=1,
184
+ step=1,
185
+ interactive=True,
186
+ label="n choices",
187
+ )
188
+ stop_sequence_txt = gr.Textbox(
189
+ show_label=True,
190
+ placeholder=f"在这里输入停止符,用英文逗号隔开...",
191
+ label="stop",
192
+ value="",
193
+ lines=1,
194
+ )
195
+ max_context_length_slider = gr.Slider(
196
+ minimum=1,
197
+ maximum=32768,
198
+ value=2000,
199
+ step=1,
200
+ interactive=True,
201
+ label="max context",
202
+ )
203
+ max_generation_slider = gr.Slider(
204
+ minimum=1,
205
+ maximum=32768,
206
+ value=1000,
207
+ step=1,
208
+ interactive=True,
209
+ label="max generations",
210
+ )
211
+ presence_penalty_slider = gr.Slider(
212
+ minimum=-2.0,
213
  maximum=2.0,
214
+ value=0.0,
215
+ step=0.01,
216
+ interactive=True,
217
+ label="presence penalty",
218
+ )
219
+ frequency_penalty_slider = gr.Slider(
220
+ minimum=-2.0,
221
+ maximum=2.0,
222
+ value=0.0,
223
+ step=0.01,
224
  interactive=True,
225
+ label="frequency penalty",
226
+ )
227
+ logit_bias_txt = gr.Textbox(
228
+ show_label=True,
229
+ placeholder=f"word:likelihood",
230
+ label="logit bias",
231
+ value="",
232
+ lines=1,
233
+ )
234
+ user_identifier_txt = gr.Textbox(
235
+ show_label=True,
236
+ placeholder=f"用于定位滥用行为",
237
+ label="用户名",
238
+ value=user_name.value,
239
+ lines=1,
240
  )
241
 
242
  with gr.Accordion("网络设置", open=False):
 
257
  lines=2,
258
  )
259
  changeProxyBtn = gr.Button("🔄 设置代理地址")
260
+ default_btn = gr.Button("🔙 恢复默认设置")
261
 
262
+ gr.Markdown(CHUANHU_DESCRIPTION)
263
+ gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
264
  chatgpt_predict_args = dict(
265
+ fn=current_model.value.predict,
266
  inputs=[
 
 
 
267
  user_question,
268
  chatbot,
 
 
 
269
  use_streaming_checkbox,
 
270
  use_websearch_checkbox,
271
  index_files,
272
  language_select_dropdown,
273
  ],
274
+ outputs=[chatbot, status_display],
275
  show_progress=True,
276
  )
277
 
 
295
  )
296
 
297
  get_usage_args = dict(
298
+ fn=current_model.value.billing_info, inputs=None, outputs=[usageTxt], show_progress=False
299
+ )
300
+
301
+ load_history_from_file_args = dict(
302
+ fn=current_model.value.load_chat_history,
303
+ inputs=[historyFileSelectDropdown, chatbot, user_name],
304
+ outputs=[saveFileName, systemPromptTxt, chatbot]
305
  )
306
 
307
 
308
  # Chatbot
309
+ cancelBtn.click(current_model.value.interrupt, [], [])
310
 
311
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
312
  user_input.submit(**get_usage_args)
 
315
  submitBtn.click(**get_usage_args)
316
 
317
  emptyBtn.click(
318
+ current_model.value.reset,
319
+ outputs=[chatbot, status_display],
320
  show_progress=True,
321
  )
322
  emptyBtn.click(**reset_textbox_args)
323
 
324
  retryBtn.click(**start_outputing_args).then(
325
+ current_model.value.retry,
326
  [
 
 
 
327
  chatbot,
 
 
 
328
  use_streaming_checkbox,
329
+ use_websearch_checkbox,
330
+ index_files,
331
  language_select_dropdown,
332
  ],
333
+ [chatbot, status_display],
334
  show_progress=True,
335
  ).then(**end_outputing_args)
336
  retryBtn.click(**get_usage_args)
337
 
338
  delFirstBtn.click(
339
+ current_model.value.delete_first_conversation,
340
+ None,
341
+ [status_display],
342
  )
343
 
344
  delLastBtn.click(
345
+ current_model.value.delete_last_conversation,
346
+ [chatbot],
347
+ [chatbot, status_display],
348
+ show_progress=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  )
 
350
 
351
  two_column.change(update_doc_config, [two_column], None)
352
 
353
+ # LLM Models
354
+ keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
355
  keyTxt.submit(**get_usage_args)
356
+ model_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display, lora_select_dropdown], show_progress=True)
357
+ lora_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
358
 
359
  # Template
360
+ systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
361
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
362
  templateFileSelectDropdown.change(
363
  load_template,
 
374
 
375
  # S&L
376
  saveHistoryBtn.click(
377
+ current_model.value.save_chat_history,
378
+ [saveFileName, chatbot, user_name],
379
  downloadFile,
380
  show_progress=True,
381
  )
382
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
383
  exportMarkdownBtn.click(
384
+ current_model.value.export_markdown,
385
+ [saveFileName, chatbot, user_name],
386
  downloadFile,
387
  show_progress=True,
388
  )
389
  historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
390
+ historyFileSelectDropdown.change(**load_history_from_file_args)
391
+ downloadFile.change(**load_history_from_file_args)
 
 
 
 
 
 
 
 
 
392
 
393
  # Advanced
394
+ max_context_length_slider.change(current_model.value.set_token_upper_limit, [max_context_length_slider], None)
395
+ temperature_slider.change(current_model.value.set_temperature, [temperature_slider], None)
396
+ top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
397
+ n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
398
+ stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
399
+ max_generation_slider.change(current_model.value.set_max_tokens, [max_generation_slider], None)
400
+ presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
401
+ frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
402
+ logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
403
+ user_identifier_txt.change(current_model.value.set_user_identifier, [user_identifier_txt], None)
404
+
405
  default_btn.click(
406
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
407
  )
configs/ds_config_chatbot.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false
4
+ },
5
+ "bf16": {
6
+ "enabled": true
7
+ },
8
+ "comms_logger": {
9
+ "enabled": false,
10
+ "verbose": false,
11
+ "prof_all": false,
12
+ "debug": false
13
+ },
14
+ "steps_per_print": 20000000000000000,
15
+ "train_micro_batch_size_per_gpu": 1,
16
+ "wall_clock_breakdown": false
17
+ }
modules/__init__.py ADDED
File without changes
modules/base_model.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, List
3
+
4
+ import logging
5
+ import json
6
+ import commentjson as cjson
7
+ import os
8
+ import sys
9
+ import requests
10
+ import urllib3
11
+
12
+ from tqdm import tqdm
13
+ import colorama
14
+ from duckduckgo_search import ddg
15
+ import asyncio
16
+ import aiohttp
17
+ from enum import Enum
18
+
19
+ from .presets import *
20
+ from .llama_func import *
21
+ from .utils import *
22
+ from . import shared
23
+ from .config import retrieve_proxy
24
+
25
+
26
+ class ModelType(Enum):
27
+ Unknown = -1
28
+ OpenAI = 0
29
+ ChatGLM = 1
30
+ LLaMA = 2
31
+
32
+ @classmethod
33
+ def get_type(cls, model_name: str):
34
+ model_type = None
35
+ model_name_lower = model_name.lower()
36
+ if "gpt" in model_name_lower:
37
+ model_type = ModelType.OpenAI
38
+ elif "chatglm" in model_name_lower:
39
+ model_type = ModelType.ChatGLM
40
+ elif "llama" in model_name_lower:
41
+ model_type = ModelType.LLaMA
42
+ else:
43
+ model_type = ModelType.Unknown
44
+ return model_type
45
+
46
+
47
+ class BaseLLMModel:
48
+ def __init__(
49
+ self,
50
+ model_name,
51
+ system_prompt="",
52
+ temperature=1.0,
53
+ top_p=1.0,
54
+ n_choices=1,
55
+ stop=None,
56
+ max_generation_token=None,
57
+ presence_penalty=0,
58
+ frequency_penalty=0,
59
+ logit_bias=None,
60
+ user="",
61
+ ) -> None:
62
+ self.history = []
63
+ self.all_token_counts = []
64
+ self.model_name = model_name
65
+ self.model_type = ModelType.get_type(model_name)
66
+ try:
67
+ self.token_upper_limit = MODEL_TOKEN_LIMIT[model_name]
68
+ except KeyError:
69
+ self.token_upper_limit = DEFAULT_TOKEN_LIMIT
70
+ self.interrupted = False
71
+ self.system_prompt = system_prompt
72
+ self.api_key = None
73
+ self.need_api_key = False
74
+
75
+ self.temperature = temperature
76
+ self.top_p = top_p
77
+ self.n_choices = n_choices
78
+ self.stop_sequence = stop
79
+ self.max_generation_token = None
80
+ self.presence_penalty = presence_penalty
81
+ self.frequency_penalty = frequency_penalty
82
+ self.logit_bias = logit_bias
83
+ self.user_identifier = user
84
+
85
+ def get_answer_stream_iter(self):
86
+ """stream predict, need to be implemented
87
+ conversations are stored in self.history, with the most recent question, in OpenAI format
88
+ should return a generator, each time give the next word (str) in the answer
89
+ """
90
+ logging.warning("stream predict not implemented, using at once predict instead")
91
+ response, _ = self.get_answer_at_once()
92
+ yield response
93
+
94
+ def get_answer_at_once(self):
95
+ """predict at once, need to be implemented
96
+ conversations are stored in self.history, with the most recent question, in OpenAI format
97
+ Should return:
98
+ the answer (str)
99
+ total token count (int)
100
+ """
101
+ logging.warning("at once predict not implemented, using stream predict instead")
102
+ response_iter = self.get_answer_stream_iter()
103
+ count = 0
104
+ for response in response_iter:
105
+ count += 1
106
+ return response, sum(self.all_token_counts) + count
107
+
108
+ def billing_info(self):
109
+ """get billing infomation, inplement if needed"""
110
+ logging.warning("billing info not implemented, using default")
111
+ return BILLING_NOT_APPLICABLE_MSG
112
+
113
+ def count_token(self, user_input):
114
+ """get token count from input, implement if needed"""
115
+ logging.warning("token count not implemented, using default")
116
+ return len(user_input)
117
+
118
+ def stream_next_chatbot(self, inputs, chatbot, fake_input=None, display_append=""):
119
+ def get_return_value():
120
+ return chatbot, status_text
121
+
122
+ status_text = "开始实时传输回答……"
123
+ if fake_input:
124
+ chatbot.append((fake_input, ""))
125
+ else:
126
+ chatbot.append((inputs, ""))
127
+
128
+ user_token_count = self.count_token(inputs)
129
+ self.all_token_counts.append(user_token_count)
130
+ logging.debug(f"输入token计数: {user_token_count}")
131
+
132
+ stream_iter = self.get_answer_stream_iter()
133
+
134
+ for partial_text in stream_iter:
135
+ chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
136
+ self.all_token_counts[-1] += 1
137
+ status_text = self.token_message()
138
+ yield get_return_value()
139
+ if self.interrupted:
140
+ self.recover()
141
+ break
142
+ self.history.append(construct_assistant(partial_text))
143
+
144
+ def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
145
+ if fake_input:
146
+ chatbot.append((fake_input, ""))
147
+ else:
148
+ chatbot.append((inputs, ""))
149
+ if fake_input is not None:
150
+ user_token_count = self.count_token(fake_input)
151
+ else:
152
+ user_token_count = self.count_token(inputs)
153
+ self.all_token_counts.append(user_token_count)
154
+ ai_reply, total_token_count = self.get_answer_at_once()
155
+ self.history.append(construct_assistant(ai_reply))
156
+ if fake_input is not None:
157
+ self.history[-2] = construct_user(fake_input)
158
+ chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
159
+ if fake_input is not None:
160
+ self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
161
+ else:
162
+ self.all_token_counts[-1] = total_token_count - sum(self.all_token_counts)
163
+ status_text = self.token_message()
164
+ return chatbot, status_text
165
+
166
+ def predict(
167
+ self,
168
+ inputs,
169
+ chatbot,
170
+ stream=False,
171
+ use_websearch=False,
172
+ files=None,
173
+ reply_language="中文",
174
+ should_check_token_count=True,
175
+ ): # repetition_penalty, top_k
176
+ from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
177
+ from llama_index.indices.query.schema import QueryBundle
178
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
179
+ from langchain.chat_models import ChatOpenAI
180
+ from llama_index import (
181
+ GPTSimpleVectorIndex,
182
+ ServiceContext,
183
+ LangchainEmbedding,
184
+ OpenAIEmbedding,
185
+ )
186
+
187
+ logging.info(
188
+ "输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
189
+ )
190
+ if should_check_token_count:
191
+ yield chatbot + [(inputs, "")], "开始生成回答……"
192
+ if reply_language == "跟随问题语言(不稳定)":
193
+ reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
194
+ old_inputs = None
195
+ display_reference = []
196
+ limited_context = False
197
+ if files:
198
+ limited_context = True
199
+ old_inputs = inputs
200
+ msg = "加载索引中……(这可能需要几分钟)"
201
+ logging.info(msg)
202
+ yield chatbot + [(inputs, "")], msg
203
+ index = construct_index(self.api_key, file_src=files)
204
+ assert index is not None, "索引构建失败"
205
+ msg = "索引构建完成,获取回答中……"
206
+ if local_embedding:
207
+ embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
208
+ else:
209
+ embed_model = OpenAIEmbedding()
210
+ logging.info(msg)
211
+ yield chatbot + [(inputs, "")], msg
212
+ with retrieve_proxy():
213
+ prompt_helper = PromptHelper(
214
+ max_input_size=4096,
215
+ num_output=5,
216
+ max_chunk_overlap=20,
217
+ chunk_size_limit=600,
218
+ )
219
+ from llama_index import ServiceContext
220
+
221
+ service_context = ServiceContext.from_defaults(
222
+ prompt_helper=prompt_helper, embed_model=embed_model
223
+ )
224
+ query_object = GPTVectorStoreIndexQuery(
225
+ index.index_struct,
226
+ service_context=service_context,
227
+ similarity_top_k=5,
228
+ vector_store=index._vector_store,
229
+ docstore=index._docstore,
230
+ )
231
+ query_bundle = QueryBundle(inputs)
232
+ nodes = query_object.retrieve(query_bundle)
233
+ reference_results = [n.node.text for n in nodes]
234
+ reference_results = add_source_numbers(reference_results, use_source=False)
235
+ display_reference = add_details(reference_results)
236
+ display_reference = "\n\n" + "".join(display_reference)
237
+ inputs = (
238
+ replace_today(PROMPT_TEMPLATE)
239
+ .replace("{query_str}", inputs)
240
+ .replace("{context_str}", "\n\n".join(reference_results))
241
+ .replace("{reply_language}", reply_language)
242
+ )
243
+ elif use_websearch:
244
+ limited_context = True
245
+ search_results = ddg(inputs, max_results=5)
246
+ old_inputs = inputs
247
+ reference_results = []
248
+ for idx, result in enumerate(search_results):
249
+ logging.debug(f"搜索结果{idx + 1}:{result}")
250
+ domain_name = urllib3.util.parse_url(result["href"]).host
251
+ reference_results.append([result["body"], result["href"]])
252
+ display_reference.append(
253
+ f"{idx+1}. [{domain_name}]({result['href']})\n"
254
+ )
255
+ reference_results = add_source_numbers(reference_results)
256
+ display_reference = "\n\n" + "".join(display_reference)
257
+ inputs = (
258
+ replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
259
+ .replace("{query}", inputs)
260
+ .replace("{web_results}", "\n\n".join(reference_results))
261
+ .replace("{reply_language}", reply_language)
262
+ )
263
+ else:
264
+ display_reference = ""
265
+
266
+ if (
267
+ self.need_api_key and
268
+ self.api_key is None
269
+ and not shared.state.multi_api_key
270
+ ):
271
+ status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
272
+ logging.info(status_text)
273
+ chatbot.append((inputs, ""))
274
+ if len(self.history) == 0:
275
+ self.history.append(construct_user(inputs))
276
+ self.history.append("")
277
+ self.all_token_counts.append(0)
278
+ else:
279
+ self.history[-2] = construct_user(inputs)
280
+ yield chatbot + [(inputs, "")], status_text
281
+ return
282
+ elif len(inputs.strip()) == 0:
283
+ status_text = STANDARD_ERROR_MSG + NO_INPUT_MSG
284
+ logging.info(status_text)
285
+ yield chatbot + [(inputs, "")], status_text
286
+ return
287
+
288
+ self.history.append(construct_user(inputs))
289
+
290
+ try:
291
+ if stream:
292
+ logging.debug("使用流式传输")
293
+ iter = self.stream_next_chatbot(
294
+ inputs,
295
+ chatbot,
296
+ fake_input=old_inputs,
297
+ display_append=display_reference,
298
+ )
299
+ for chatbot, status_text in iter:
300
+ yield chatbot, status_text
301
+ else:
302
+ logging.debug("不使用流式传输")
303
+ chatbot, status_text = self.next_chatbot_at_once(
304
+ inputs,
305
+ chatbot,
306
+ fake_input=old_inputs,
307
+ display_append=display_reference,
308
+ )
309
+ yield chatbot, status_text
310
+ except Exception as e:
311
+ status_text = STANDARD_ERROR_MSG + str(e)
312
+ yield chatbot, status_text
313
+
314
+ if len(self.history) > 1 and self.history[-1]["content"] != inputs:
315
+ logging.info(
316
+ "回答为:"
317
+ + colorama.Fore.BLUE
318
+ + f"{self.history[-1]['content']}"
319
+ + colorama.Style.RESET_ALL
320
+ )
321
+
322
+ if limited_context:
323
+ self.history = self.history[-4:]
324
+ self.all_token_counts = self.all_token_counts[-2:]
325
+
326
+ max_token = self.token_upper_limit - TOKEN_OFFSET
327
+
328
+ if sum(self.all_token_counts) > max_token and should_check_token_count:
329
+ count = 0
330
+ while (
331
+ sum(self.all_token_counts)
332
+ > self.token_upper_limit * REDUCE_TOKEN_FACTOR
333
+ and sum(self.all_token_counts) > 0
334
+ ):
335
+ count += 1
336
+ del self.all_token_counts[0]
337
+ del self.history[:2]
338
+ logging.info(status_text)
339
+ status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
340
+ yield chatbot, status_text
341
+
342
+ def retry(
343
+ self,
344
+ chatbot,
345
+ stream=False,
346
+ use_websearch=False,
347
+ files=None,
348
+ reply_language="中文",
349
+ ):
350
+ logging.debug("重试中……")
351
+ if len(self.history) == 0:
352
+ yield chatbot, f"{STANDARD_ERROR_MSG}上下文是空的"
353
+ return
354
+
355
+ inputs = self.history[-2]["content"]
356
+ del self.history[-2:]
357
+ self.all_token_counts.pop()
358
+ iter = self.predict(
359
+ inputs,
360
+ chatbot,
361
+ stream=stream,
362
+ use_websearch=use_websearch,
363
+ files=files,
364
+ reply_language=reply_language,
365
+ )
366
+ for x in iter:
367
+ yield x
368
+ logging.debug("重试完毕")
369
+
370
+ # def reduce_token_size(self, chatbot):
371
+ # logging.info("开始减少token数量……")
372
+ # chatbot, status_text = self.next_chatbot_at_once(
373
+ # summarize_prompt,
374
+ # chatbot
375
+ # )
376
+ # max_token_count = self.token_upper_limit * REDUCE_TOKEN_FACTOR
377
+ # num_chat = find_n(self.all_token_counts, max_token_count)
378
+ # logging.info(f"previous_token_count: {self.all_token_counts}, keeping {num_chat} chats")
379
+ # chatbot = chatbot[:-1]
380
+ # self.history = self.history[-2*num_chat:] if num_chat > 0 else []
381
+ # self.all_token_counts = self.all_token_counts[-num_chat:] if num_chat > 0 else []
382
+ # msg = f"保留了最近{num_chat}轮对话"
383
+ # logging.info(msg)
384
+ # logging.info("减少token数量完毕")
385
+ # return chatbot, msg + "," + self.token_message(self.all_token_counts if len(self.all_token_counts) > 0 else [0])
386
+
387
+ def interrupt(self):
388
+ self.interrupted = True
389
+
390
+ def recover(self):
391
+ self.interrupted = False
392
+
393
+ def set_token_upper_limit(self, new_upper_limit):
394
+ self.token_upper_limit = new_upper_limit
395
+ print(f"token上限设置为{new_upper_limit}")
396
+
397
+ def set_temperature(self, new_temperature):
398
+ self.temperature = new_temperature
399
+
400
+ def set_top_p(self, new_top_p):
401
+ self.top_p = new_top_p
402
+
403
+ def set_n_choices(self, new_n_choices):
404
+ self.n_choices = new_n_choices
405
+
406
+ def set_stop_sequence(self, new_stop_sequence: str):
407
+ new_stop_sequence = new_stop_sequence.split(",")
408
+ self.stop_sequence = new_stop_sequence
409
+
410
+ def set_max_tokens(self, new_max_tokens):
411
+ self.max_generation_token = new_max_tokens
412
+
413
+ def set_presence_penalty(self, new_presence_penalty):
414
+ self.presence_penalty = new_presence_penalty
415
+
416
+ def set_frequency_penalty(self, new_frequency_penalty):
417
+ self.frequency_penalty = new_frequency_penalty
418
+
419
+ def set_logit_bias(self, logit_bias):
420
+ logit_bias = logit_bias.split()
421
+ bias_map = {}
422
+ encoding = tiktoken.get_encoding("cl100k_base")
423
+ for line in logit_bias:
424
+ word, bias_amount = line.split(":")
425
+ if word:
426
+ for token in encoding.encode(word):
427
+ bias_map[token] = float(bias_amount)
428
+ self.logit_bias = bias_map
429
+
430
+ def set_user_identifier(self, new_user_identifier):
431
+ self.user_identifier = new_user_identifier
432
+
433
+ def set_system_prompt(self, new_system_prompt):
434
+ self.system_prompt = new_system_prompt
435
+
436
+ def set_key(self, new_access_key):
437
+ self.api_key = new_access_key.strip()
438
+ msg = f"API密钥更改为了{hide_middle_chars(self.api_key)}"
439
+ logging.info(msg)
440
+ return msg
441
+
442
+ def reset(self):
443
+ self.history = []
444
+ self.all_token_counts = []
445
+ self.interrupted = False
446
+ return [], self.token_message([0])
447
+
448
+ def delete_first_conversation(self):
449
+ if self.history:
450
+ del self.history[:2]
451
+ del self.all_token_counts[0]
452
+ return self.token_message()
453
+
454
+ def delete_last_conversation(self, chatbot):
455
+ if len(chatbot) > 0 and STANDARD_ERROR_MSG in chatbot[-1][1]:
456
+ msg = "由于包含报错信息,只删除chatbot记录"
457
+ chatbot.pop()
458
+ return chatbot, self.history
459
+ if len(self.history) > 0:
460
+ self.history.pop()
461
+ self.history.pop()
462
+ if len(chatbot) > 0:
463
+ msg = "删除了一组chatbot对话"
464
+ chatbot.pop()
465
+ if len(self.all_token_counts) > 0:
466
+ msg = "删除了一组对话的token计数记录"
467
+ self.all_token_counts.pop()
468
+ msg = "删除了一组对话"
469
+ return chatbot, msg
470
+
471
+ def token_message(self, token_lst=None):
472
+ if token_lst is None:
473
+ token_lst = self.all_token_counts
474
+ token_sum = 0
475
+ for i in range(len(token_lst)):
476
+ token_sum += sum(token_lst[: i + 1])
477
+ return f"Token 计数: {sum(token_lst)},本次对话累计消耗了 {token_sum} tokens"
478
+
479
+ def save_chat_history(self, filename, chatbot, user_name):
480
+ if filename == "":
481
+ return
482
+ if not filename.endswith(".json"):
483
+ filename += ".json"
484
+ return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
485
+
486
+ def export_markdown(self, filename, chatbot, user_name):
487
+ if filename == "":
488
+ return
489
+ if not filename.endswith(".md"):
490
+ filename += ".md"
491
+ return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
492
+
493
+ def load_chat_history(self, filename, chatbot, user_name):
494
+ logging.debug(f"{user_name} 加载对话历史中……")
495
+ if type(filename) != str:
496
+ filename = filename.name
497
+ try:
498
+ with open(os.path.join(HISTORY_DIR, user_name, filename), "r") as f:
499
+ json_s = json.load(f)
500
+ try:
501
+ if type(json_s["history"][0]) == str:
502
+ logging.info("历史记录格式为旧版,正在转换……")
503
+ new_history = []
504
+ for index, item in enumerate(json_s["history"]):
505
+ if index % 2 == 0:
506
+ new_history.append(construct_user(item))
507
+ else:
508
+ new_history.append(construct_assistant(item))
509
+ json_s["history"] = new_history
510
+ logging.info(new_history)
511
+ except:
512
+ # 没有对话历史
513
+ pass
514
+ logging.debug(f"{user_name} 加载对话历史完毕")
515
+ self.history = json_s["history"]
516
+ return filename, json_s["system"], json_s["chatbot"]
517
+ except FileNotFoundError:
518
+ logging.warning(f"{user_name} 没有找到对话历史文件,不执行任何操作")
519
+ return filename, self.system_prompt, chatbot
modules/chat_func.py DELETED
@@ -1,497 +0,0 @@
1
- # -*- coding:utf-8 -*-
2
- from __future__ import annotations
3
- from typing import TYPE_CHECKING, List
4
-
5
- import logging
6
- import json
7
- import os
8
- import requests
9
- import urllib3
10
-
11
- from tqdm import tqdm
12
- import colorama
13
- from duckduckgo_search import ddg
14
- import asyncio
15
- import aiohttp
16
-
17
-
18
- from modules.presets import *
19
- from modules.llama_func import *
20
- from modules.utils import *
21
- from . import shared
22
- from modules.config import retrieve_proxy
23
-
24
- # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
25
-
26
- if TYPE_CHECKING:
27
- from typing import TypedDict
28
-
29
- class DataframeData(TypedDict):
30
- headers: List[str]
31
- data: List[List[str | int | bool]]
32
-
33
-
34
- initial_prompt = "You are a helpful assistant."
35
- HISTORY_DIR = "history"
36
- TEMPLATES_DIR = "templates"
37
-
38
- @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
39
- def get_response(
40
- openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
41
- ):
42
- headers = {
43
- "Content-Type": "application/json",
44
- "Authorization": f"Bearer {openai_api_key}",
45
- }
46
-
47
- history = [construct_system(system_prompt), *history]
48
-
49
- payload = {
50
- "model": selected_model,
51
- "messages": history, # [{"role": "user", "content": f"{inputs}"}],
52
- "temperature": temperature, # 1.0,
53
- "top_p": top_p, # 1.0,
54
- "n": 1,
55
- "stream": stream,
56
- "presence_penalty": 0,
57
- "frequency_penalty": 0,
58
- }
59
- if stream:
60
- timeout = timeout_streaming
61
- else:
62
- timeout = timeout_all
63
-
64
-
65
- # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
66
- if shared.state.completion_url != COMPLETION_URL:
67
- logging.info(f"使用自定义API URL: {shared.state.completion_url}")
68
-
69
- with retrieve_proxy():
70
- response = requests.post(
71
- shared.state.completion_url,
72
- headers=headers,
73
- json=payload,
74
- stream=True,
75
- timeout=timeout,
76
- )
77
-
78
- return response
79
-
80
-
81
- def stream_predict(
82
- openai_api_key,
83
- system_prompt,
84
- history,
85
- inputs,
86
- chatbot,
87
- all_token_counts,
88
- top_p,
89
- temperature,
90
- selected_model,
91
- fake_input=None,
92
- display_append=""
93
- ):
94
- def get_return_value():
95
- return chatbot, history, status_text, all_token_counts
96
-
97
- logging.info("实时回答模式")
98
- partial_words = ""
99
- counter = 0
100
- status_text = "开始实时传输回答……"
101
- history.append(construct_user(inputs))
102
- history.append(construct_assistant(""))
103
- if fake_input:
104
- chatbot.append((fake_input, ""))
105
- else:
106
- chatbot.append((inputs, ""))
107
- user_token_count = 0
108
- if fake_input is not None:
109
- input_token_count = count_token(construct_user(fake_input))
110
- else:
111
- input_token_count = count_token(construct_user(inputs))
112
- if len(all_token_counts) == 0:
113
- system_prompt_token_count = count_token(construct_system(system_prompt))
114
- user_token_count = (
115
- input_token_count + system_prompt_token_count
116
- )
117
- else:
118
- user_token_count = input_token_count
119
- all_token_counts.append(user_token_count)
120
- logging.info(f"输入token计数: {user_token_count}")
121
- yield get_return_value()
122
- try:
123
- response = get_response(
124
- openai_api_key,
125
- system_prompt,
126
- history,
127
- temperature,
128
- top_p,
129
- True,
130
- selected_model,
131
- )
132
- except requests.exceptions.ConnectTimeout:
133
- status_text = (
134
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
135
- )
136
- yield get_return_value()
137
- return
138
- except requests.exceptions.ReadTimeout:
139
- status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
140
- yield get_return_value()
141
- return
142
-
143
- yield get_return_value()
144
- error_json_str = ""
145
-
146
- if fake_input is not None:
147
- history[-2] = construct_user(fake_input)
148
- for chunk in tqdm(response.iter_lines()):
149
- if counter == 0:
150
- counter += 1
151
- continue
152
- counter += 1
153
- # check whether each line is non-empty
154
- if chunk:
155
- chunk = chunk.decode()
156
- chunklength = len(chunk)
157
- try:
158
- chunk = json.loads(chunk[6:])
159
- except json.JSONDecodeError:
160
- logging.info(chunk)
161
- error_json_str += chunk
162
- status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
163
- yield get_return_value()
164
- continue
165
- # decode each line as response data is in bytes
166
- if chunklength > 6 and "delta" in chunk["choices"][0]:
167
- finish_reason = chunk["choices"][0]["finish_reason"]
168
- status_text = construct_token_message(all_token_counts)
169
- if finish_reason == "stop":
170
- yield get_return_value()
171
- break
172
- try:
173
- partial_words = (
174
- partial_words + chunk["choices"][0]["delta"]["content"]
175
- )
176
- except KeyError:
177
- status_text = (
178
- standard_error_msg
179
- + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
180
- + str(sum(all_token_counts))
181
- )
182
- yield get_return_value()
183
- break
184
- history[-1] = construct_assistant(partial_words)
185
- chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
186
- all_token_counts[-1] += 1
187
- yield get_return_value()
188
-
189
-
190
- def predict_all(
191
- openai_api_key,
192
- system_prompt,
193
- history,
194
- inputs,
195
- chatbot,
196
- all_token_counts,
197
- top_p,
198
- temperature,
199
- selected_model,
200
- fake_input=None,
201
- display_append=""
202
- ):
203
- logging.info("一次性回答模式")
204
- history.append(construct_user(inputs))
205
- history.append(construct_assistant(""))
206
- if fake_input:
207
- chatbot.append((fake_input, ""))
208
- else:
209
- chatbot.append((inputs, ""))
210
- if fake_input is not None:
211
- all_token_counts.append(count_token(construct_user(fake_input)))
212
- else:
213
- all_token_counts.append(count_token(construct_user(inputs)))
214
- try:
215
- response = get_response(
216
- openai_api_key,
217
- system_prompt,
218
- history,
219
- temperature,
220
- top_p,
221
- False,
222
- selected_model,
223
- )
224
- except requests.exceptions.ConnectTimeout:
225
- status_text = (
226
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
227
- )
228
- return chatbot, history, status_text, all_token_counts
229
- except requests.exceptions.ProxyError:
230
- status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
231
- return chatbot, history, status_text, all_token_counts
232
- except requests.exceptions.SSLError:
233
- status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
234
- return chatbot, history, status_text, all_token_counts
235
- response = json.loads(response.text)
236
- if fake_input is not None:
237
- history[-2] = construct_user(fake_input)
238
- try:
239
- content = response["choices"][0]["message"]["content"]
240
- history[-1] = construct_assistant(content)
241
- chatbot[-1] = (chatbot[-1][0], content+display_append)
242
- total_token_count = response["usage"]["total_tokens"]
243
- if fake_input is not None:
244
- all_token_counts[-1] += count_token(construct_assistant(content))
245
- else:
246
- all_token_counts[-1] = total_token_count - sum(all_token_counts)
247
- status_text = construct_token_message([total_token_count])
248
- return chatbot, history, status_text, all_token_counts
249
- except KeyError:
250
- status_text = standard_error_msg + str(response)
251
- return chatbot, history, status_text, all_token_counts
252
-
253
-
254
- def predict(
255
- openai_api_key,
256
- system_prompt,
257
- history,
258
- inputs,
259
- chatbot,
260
- all_token_counts,
261
- top_p,
262
- temperature,
263
- stream=False,
264
- selected_model=MODELS[0],
265
- use_websearch=False,
266
- files = None,
267
- reply_language="中文",
268
- should_check_token_count=True,
269
- ): # repetition_penalty, top_k
270
- from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
271
- from llama_index.indices.query.schema import QueryBundle
272
- from langchain.llms import OpenAIChat
273
-
274
-
275
- logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
276
- if should_check_token_count:
277
- yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
278
- if reply_language == "跟随问题语言(不稳定)":
279
- reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
280
- old_inputs = None
281
- display_reference = []
282
- limited_context = False
283
- if files:
284
- limited_context = True
285
- old_inputs = inputs
286
- msg = "加载索引中……(这可能需要几分钟)"
287
- logging.info(msg)
288
- yield chatbot+[(inputs, "")], history, msg, all_token_counts
289
- index = construct_index(openai_api_key, file_src=files)
290
- msg = "索引构建完成,获取回答中……"
291
- logging.info(msg)
292
- yield chatbot+[(inputs, "")], history, msg, all_token_counts
293
- with retrieve_proxy():
294
- llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
295
- prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
296
- from llama_index import ServiceContext
297
- service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
298
- query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
299
- query_bundle = QueryBundle(inputs)
300
- nodes = query_object.retrieve(query_bundle)
301
- reference_results = [n.node.text for n in nodes]
302
- reference_results = add_source_numbers(reference_results, use_source=False)
303
- display_reference = add_details(reference_results)
304
- display_reference = "\n\n" + "".join(display_reference)
305
- inputs = (
306
- replace_today(PROMPT_TEMPLATE)
307
- .replace("{query_str}", inputs)
308
- .replace("{context_str}", "\n\n".join(reference_results))
309
- .replace("{reply_language}", reply_language )
310
- )
311
- elif use_websearch:
312
- limited_context = True
313
- search_results = ddg(inputs, max_results=5)
314
- old_inputs = inputs
315
- reference_results = []
316
- for idx, result in enumerate(search_results):
317
- logging.info(f"搜索结果{idx + 1}:{result}")
318
- domain_name = urllib3.util.parse_url(result["href"]).host
319
- reference_results.append([result["body"], result["href"]])
320
- display_reference.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
321
- reference_results = add_source_numbers(reference_results)
322
- display_reference = "\n\n" + "".join(display_reference)
323
- inputs = (
324
- replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
325
- .replace("{query}", inputs)
326
- .replace("{web_results}", "\n\n".join(reference_results))
327
- .replace("{reply_language}", reply_language )
328
- )
329
- else:
330
- display_reference = ""
331
-
332
- if len(openai_api_key) == 0 and not shared.state.multi_api_key:
333
- status_text = standard_error_msg + no_apikey_msg
334
- logging.info(status_text)
335
- chatbot.append((inputs, ""))
336
- if len(history) == 0:
337
- history.append(construct_user(inputs))
338
- history.append("")
339
- all_token_counts.append(0)
340
- else:
341
- history[-2] = construct_user(inputs)
342
- yield chatbot+[(inputs, "")], history, status_text, all_token_counts
343
- return
344
- elif len(inputs.strip()) == 0:
345
- status_text = standard_error_msg + no_input_msg
346
- logging.info(status_text)
347
- yield chatbot+[(inputs, "")], history, status_text, all_token_counts
348
- return
349
-
350
- if stream:
351
- logging.info("使用流式传输")
352
- iter = stream_predict(
353
- openai_api_key,
354
- system_prompt,
355
- history,
356
- inputs,
357
- chatbot,
358
- all_token_counts,
359
- top_p,
360
- temperature,
361
- selected_model,
362
- fake_input=old_inputs,
363
- display_append=display_reference
364
- )
365
- for chatbot, history, status_text, all_token_counts in iter:
366
- if shared.state.interrupted:
367
- shared.state.recover()
368
- return
369
- yield chatbot, history, status_text, all_token_counts
370
- else:
371
- logging.info("不使用流式传输")
372
- chatbot, history, status_text, all_token_counts = predict_all(
373
- openai_api_key,
374
- system_prompt,
375
- history,
376
- inputs,
377
- chatbot,
378
- all_token_counts,
379
- top_p,
380
- temperature,
381
- selected_model,
382
- fake_input=old_inputs,
383
- display_append=display_reference
384
- )
385
- yield chatbot, history, status_text, all_token_counts
386
-
387
- logging.info(f"传输完毕。当前token计数为{all_token_counts}")
388
- if len(history) > 1 and history[-1]["content"] != inputs:
389
- logging.info(
390
- "回答为:"
391
- + colorama.Fore.BLUE
392
- + f"{history[-1]['content']}"
393
- + colorama.Style.RESET_ALL
394
- )
395
-
396
- if limited_context:
397
- history = history[-4:]
398
- all_token_counts = all_token_counts[-2:]
399
- yield chatbot, history, status_text, all_token_counts
400
-
401
- if stream:
402
- max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["streaming"]
403
- else:
404
- max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
405
-
406
- if sum(all_token_counts) > max_token and should_check_token_count:
407
- print(all_token_counts)
408
- count = 0
409
- while sum(all_token_counts) > max_token - 500 and sum(all_token_counts) > 0:
410
- count += 1
411
- del all_token_counts[0]
412
- del history[:2]
413
- logging.info(status_text)
414
- status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
415
- yield chatbot, history, status_text, all_token_counts
416
-
417
-
418
- def retry(
419
- openai_api_key,
420
- system_prompt,
421
- history,
422
- chatbot,
423
- token_count,
424
- top_p,
425
- temperature,
426
- stream=False,
427
- selected_model=MODELS[0],
428
- reply_language="中文",
429
- ):
430
- logging.info("重试中……")
431
- if len(history) == 0:
432
- yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
433
- return
434
- history.pop()
435
- inputs = history.pop()["content"]
436
- token_count.pop()
437
- iter = predict(
438
- openai_api_key,
439
- system_prompt,
440
- history,
441
- inputs,
442
- chatbot,
443
- token_count,
444
- top_p,
445
- temperature,
446
- stream=stream,
447
- selected_model=selected_model,
448
- reply_language=reply_language,
449
- )
450
- logging.info("重试中……")
451
- for x in iter:
452
- yield x
453
- logging.info("重试完毕")
454
-
455
-
456
- def reduce_token_size(
457
- openai_api_key,
458
- system_prompt,
459
- history,
460
- chatbot,
461
- token_count,
462
- top_p,
463
- temperature,
464
- max_token_count,
465
- selected_model=MODELS[0],
466
- reply_language="中文",
467
- ):
468
- logging.info("开始减少token数量……")
469
- iter = predict(
470
- openai_api_key,
471
- system_prompt,
472
- history,
473
- summarize_prompt,
474
- chatbot,
475
- token_count,
476
- top_p,
477
- temperature,
478
- selected_model=selected_model,
479
- should_check_token_count=False,
480
- reply_language=reply_language,
481
- )
482
- logging.info(f"chatbot: {chatbot}")
483
- flag = False
484
- for chatbot, history, status_text, previous_token_count in iter:
485
- num_chat = find_n(previous_token_count, max_token_count)
486
- logging.info(f"previous_token_count: {previous_token_count}, keeping {num_chat} chats")
487
- if flag:
488
- chatbot = chatbot[:-1]
489
- flag = True
490
- history = history[-2*num_chat:] if num_chat > 0 else []
491
- token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
492
- msg = f"保留了最近{num_chat}轮对话"
493
- yield chatbot, history, msg + "," + construct_token_message(
494
- token_count if len(token_count) > 0 else [0],
495
- ), token_count
496
- logging.info(msg)
497
- logging.info("减少token数量完毕")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/config.py CHANGED
@@ -117,6 +117,8 @@ https_proxy = os.environ.get("HTTPS_PROXY", https_proxy)
117
  os.environ["HTTP_PROXY"] = ""
118
  os.environ["HTTPS_PROXY"] = ""
119
 
 
 
120
  @contextmanager
121
  def retrieve_proxy(proxy=None):
122
  """
 
117
  os.environ["HTTP_PROXY"] = ""
118
  os.environ["HTTPS_PROXY"] = ""
119
 
120
+ local_embedding = config.get("local_embedding", False) # 是否使用本地embedding
121
+
122
  @contextmanager
123
  def retrieve_proxy(proxy=None):
124
  """
modules/llama_func.py CHANGED
@@ -15,6 +15,8 @@ from tqdm import tqdm
15
 
16
  from modules.presets import *
17
  from modules.utils import *
 
 
18
 
19
  def get_index_name(file_src):
20
  file_paths = [x.name for x in file_src]
@@ -28,6 +30,7 @@ def get_index_name(file_src):
28
 
29
  return md5_hash.hexdigest()
30
 
 
31
  def block_split(text):
32
  blocks = []
33
  while len(text) > 0:
@@ -35,6 +38,7 @@ def block_split(text):
35
  text = text[1000:]
36
  return blocks
37
 
 
38
  def get_documents(file_src):
39
  documents = []
40
  logging.debug("Loading documents...")
@@ -44,40 +48,45 @@ def get_documents(file_src):
44
  filename = os.path.basename(filepath)
45
  file_type = os.path.splitext(filepath)[1]
46
  logging.info(f"loading file: {filename}")
47
- if file_type == ".pdf":
48
- logging.debug("Loading PDF...")
49
- try:
50
- from modules.pdf_func import parse_pdf
51
- from modules.config import advance_docs
52
- two_column = advance_docs["pdf"].get("two_column", False)
53
- pdftext = parse_pdf(filepath, two_column).text
54
- except:
55
- pdftext = ""
56
- with open(filepath, 'rb') as pdfFileObj:
57
- pdfReader = PyPDF2.PdfReader(pdfFileObj)
58
- for page in tqdm(pdfReader.pages):
59
- pdftext += page.extract_text()
60
- text_raw = pdftext
61
- elif file_type == ".docx":
62
- logging.debug("Loading Word...")
63
- DocxReader = download_loader("DocxReader")
64
- loader = DocxReader()
65
- text_raw = loader.load_data(file=filepath)[0].text
66
- elif file_type == ".epub":
67
- logging.debug("Loading EPUB...")
68
- EpubReader = download_loader("EpubReader")
69
- loader = EpubReader()
70
- text_raw = loader.load_data(file=filepath)[0].text
71
- elif file_type == ".xlsx":
72
- logging.debug("Loading Excel...")
73
- text_list = excel_to_string(filepath)
74
- for elem in text_list:
75
- documents.append(Document(elem))
76
- continue
77
- else:
78
- logging.debug("Loading text file...")
79
- with open(filepath, "r", encoding="utf-8") as f:
80
- text_raw = f.read()
 
 
 
 
 
81
  text = add_space(text_raw)
82
  # text = block_split(text)
83
  # documents += text
@@ -87,19 +96,21 @@ def get_documents(file_src):
87
 
88
 
89
  def construct_index(
90
- api_key,
91
- file_src,
92
- max_input_size=4096,
93
- num_outputs=5,
94
- max_chunk_overlap=20,
95
- chunk_size_limit=600,
96
- embedding_limit=None,
97
- separator=" "
98
  ):
99
  from langchain.chat_models import ChatOpenAI
100
- from llama_index import GPTSimpleVectorIndex, ServiceContext
 
101
 
102
- os.environ["OPENAI_API_KEY"] = api_key
 
103
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
104
  embedding_limit = None if embedding_limit == 0 else embedding_limit
105
  separator = " " if separator == "" else separator
@@ -107,7 +118,14 @@ def construct_index(
107
  llm_predictor = LLMPredictor(
108
  llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
109
  )
110
- prompt_helper = PromptHelper(max_input_size = max_input_size, num_output = num_outputs, max_chunk_overlap = max_chunk_overlap, embedding_limit=embedding_limit, chunk_size_limit=600, separator=separator)
 
 
 
 
 
 
 
111
  index_name = get_index_name(file_src)
112
  if os.path.exists(f"./index/{index_name}.json"):
113
  logging.info("找到了缓存的索引文件,加载中……")
@@ -115,11 +133,20 @@ def construct_index(
115
  else:
116
  try:
117
  documents = get_documents(file_src)
 
 
 
 
118
  logging.info("构建索引中……")
119
  with retrieve_proxy():
120
- service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
 
 
 
 
 
121
  index = GPTSimpleVectorIndex.from_documents(
122
- documents, service_context=service_context
123
  )
124
  logging.debug("索引构建完成!")
125
  os.makedirs("./index", exist_ok=True)
 
15
 
16
  from modules.presets import *
17
  from modules.utils import *
18
+ from modules.config import local_embedding
19
+
20
 
21
  def get_index_name(file_src):
22
  file_paths = [x.name for x in file_src]
 
30
 
31
  return md5_hash.hexdigest()
32
 
33
+
34
  def block_split(text):
35
  blocks = []
36
  while len(text) > 0:
 
38
  text = text[1000:]
39
  return blocks
40
 
41
+
42
  def get_documents(file_src):
43
  documents = []
44
  logging.debug("Loading documents...")
 
48
  filename = os.path.basename(filepath)
49
  file_type = os.path.splitext(filepath)[1]
50
  logging.info(f"loading file: {filename}")
51
+ try:
52
+ if file_type == ".pdf":
53
+ logging.debug("Loading PDF...")
54
+ try:
55
+ from modules.pdf_func import parse_pdf
56
+ from modules.config import advance_docs
57
+
58
+ two_column = advance_docs["pdf"].get("two_column", False)
59
+ pdftext = parse_pdf(filepath, two_column).text
60
+ except:
61
+ pdftext = ""
62
+ with open(filepath, "rb") as pdfFileObj:
63
+ pdfReader = PyPDF2.PdfReader(pdfFileObj)
64
+ for page in tqdm(pdfReader.pages):
65
+ pdftext += page.extract_text()
66
+ text_raw = pdftext
67
+ elif file_type == ".docx":
68
+ logging.debug("Loading Word...")
69
+ DocxReader = download_loader("DocxReader")
70
+ loader = DocxReader()
71
+ text_raw = loader.load_data(file=filepath)[0].text
72
+ elif file_type == ".epub":
73
+ logging.debug("Loading EPUB...")
74
+ EpubReader = download_loader("EpubReader")
75
+ loader = EpubReader()
76
+ text_raw = loader.load_data(file=filepath)[0].text
77
+ elif file_type == ".xlsx":
78
+ logging.debug("Loading Excel...")
79
+ text_list = excel_to_string(filepath)
80
+ for elem in text_list:
81
+ documents.append(Document(elem))
82
+ continue
83
+ else:
84
+ logging.debug("Loading text file...")
85
+ with open(filepath, "r", encoding="utf-8") as f:
86
+ text_raw = f.read()
87
+ except Exception as e:
88
+ logging.error(f"Error loading file: {filename}")
89
+ pass
90
  text = add_space(text_raw)
91
  # text = block_split(text)
92
  # documents += text
 
96
 
97
 
98
  def construct_index(
99
+ api_key,
100
+ file_src,
101
+ max_input_size=4096,
102
+ num_outputs=5,
103
+ max_chunk_overlap=20,
104
+ chunk_size_limit=600,
105
+ embedding_limit=None,
106
+ separator=" ",
107
  ):
108
  from langchain.chat_models import ChatOpenAI
109
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
110
+ from llama_index import GPTSimpleVectorIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding
111
 
112
+ if api_key:
113
+ os.environ["OPENAI_API_KEY"] = api_key
114
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
115
  embedding_limit = None if embedding_limit == 0 else embedding_limit
116
  separator = " " if separator == "" else separator
 
118
  llm_predictor = LLMPredictor(
119
  llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
120
  )
121
+ prompt_helper = PromptHelper(
122
+ max_input_size=max_input_size,
123
+ num_output=num_outputs,
124
+ max_chunk_overlap=max_chunk_overlap,
125
+ embedding_limit=embedding_limit,
126
+ chunk_size_limit=600,
127
+ separator=separator,
128
+ )
129
  index_name = get_index_name(file_src)
130
  if os.path.exists(f"./index/{index_name}.json"):
131
  logging.info("找到了缓存的索引文件,加载中……")
 
133
  else:
134
  try:
135
  documents = get_documents(file_src)
136
+ if local_embedding:
137
+ embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
138
+ else:
139
+ embed_model = OpenAIEmbedding()
140
  logging.info("构建索引中……")
141
  with retrieve_proxy():
142
+ service_context = ServiceContext.from_defaults(
143
+ llm_predictor=llm_predictor,
144
+ prompt_helper=prompt_helper,
145
+ chunk_size_limit=chunk_size_limit,
146
+ embed_model=embed_model,
147
+ )
148
  index = GPTSimpleVectorIndex.from_documents(
149
+ documents, service_context=service_context
150
  )
151
  logging.debug("索引构建完成!")
152
  os.makedirs("./index", exist_ok=True)
modules/models.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, List
3
+
4
+ import logging
5
+ import json
6
+ import commentjson as cjson
7
+ import os
8
+ import sys
9
+ import requests
10
+ import urllib3
11
+ import platform
12
+
13
+ from tqdm import tqdm
14
+ import colorama
15
+ from duckduckgo_search import ddg
16
+ import asyncio
17
+ import aiohttp
18
+ from enum import Enum
19
+
20
+ from .presets import *
21
+ from .llama_func import *
22
+ from .utils import *
23
+ from . import shared
24
+ from .config import retrieve_proxy
25
+ from modules import config
26
+ from .base_model import BaseLLMModel, ModelType
27
+
28
+
29
+ class OpenAIClient(BaseLLMModel):
30
+ def __init__(
31
+ self,
32
+ model_name,
33
+ api_key,
34
+ system_prompt=INITIAL_SYSTEM_PROMPT,
35
+ temperature=1.0,
36
+ top_p=1.0,
37
+ ) -> None:
38
+ super().__init__(
39
+ model_name=model_name,
40
+ temperature=temperature,
41
+ top_p=top_p,
42
+ system_prompt=system_prompt,
43
+ )
44
+ self.api_key = api_key
45
+ self.need_api_key = True
46
+ self._refresh_header()
47
+
48
+ def get_answer_stream_iter(self):
49
+ response = self._get_response(stream=True)
50
+ if response is not None:
51
+ iter = self._decode_chat_response(response)
52
+ partial_text = ""
53
+ for i in iter:
54
+ partial_text += i
55
+ yield partial_text
56
+ else:
57
+ yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
58
+
59
+ def get_answer_at_once(self):
60
+ response = self._get_response()
61
+ response = json.loads(response.text)
62
+ content = response["choices"][0]["message"]["content"]
63
+ total_token_count = response["usage"]["total_tokens"]
64
+ return content, total_token_count
65
+
66
+ def count_token(self, user_input):
67
+ input_token_count = count_token(construct_user(user_input))
68
+ if self.system_prompt is not None and len(self.all_token_counts) == 0:
69
+ system_prompt_token_count = count_token(
70
+ construct_system(self.system_prompt)
71
+ )
72
+ return input_token_count + system_prompt_token_count
73
+ return input_token_count
74
+
75
+ def billing_info(self):
76
+ try:
77
+ curr_time = datetime.datetime.now()
78
+ last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
79
+ first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
80
+ usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
81
+ try:
82
+ usage_data = self._get_billing_data(usage_url)
83
+ except Exception as e:
84
+ logging.error(f"获取API使用情况失败:" + str(e))
85
+ return f"**获取API使用情况失败**"
86
+ rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
87
+ return f"**本月使用金额** \u3000 ${rounded_usage}"
88
+ except requests.exceptions.ConnectTimeout:
89
+ status_text = (
90
+ STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
91
+ )
92
+ return status_text
93
+ except requests.exceptions.ReadTimeout:
94
+ status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
95
+ return status_text
96
+ except Exception as e:
97
+ logging.error(f"获取API使用情况失败:" + str(e))
98
+ return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
99
+
100
+ def set_token_upper_limit(self, new_upper_limit):
101
+ pass
102
+
103
+ def set_key(self, new_access_key):
104
+ self.api_key = new_access_key.strip()
105
+ self._refresh_header()
106
+ msg = f"API密钥更改为了{hide_middle_chars(self.api_key)}"
107
+ logging.info(msg)
108
+ return msg
109
+
110
+ @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
111
+ def _get_response(self, stream=False):
112
+ openai_api_key = self.api_key
113
+ system_prompt = self.system_prompt
114
+ history = self.history
115
+ logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
116
+ headers = {
117
+ "Content-Type": "application/json",
118
+ "Authorization": f"Bearer {openai_api_key}",
119
+ }
120
+
121
+ if system_prompt is not None:
122
+ history = [construct_system(system_prompt), *history]
123
+
124
+ payload = {
125
+ "model": self.model_name,
126
+ "messages": history,
127
+ "temperature": self.temperature,
128
+ "top_p": self.top_p,
129
+ "n": self.n_choices,
130
+ "stream": stream,
131
+ "presence_penalty": self.presence_penalty,
132
+ "frequency_penalty": self.frequency_penalty,
133
+ }
134
+
135
+ if self.max_generation_token is not None:
136
+ payload["max_tokens"] = self.max_generation_token
137
+ if self.stop_sequence is not None:
138
+ payload["stop"] = self.stop_sequence
139
+ if self.logit_bias is not None:
140
+ payload["logit_bias"] = self.logit_bias
141
+ if self.user_identifier is not None:
142
+ payload["user"] = self.user_identifier
143
+
144
+ if stream:
145
+ timeout = TIMEOUT_STREAMING
146
+ else:
147
+ timeout = TIMEOUT_ALL
148
+
149
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
150
+ if shared.state.completion_url != COMPLETION_URL:
151
+ logging.info(f"使用自定义API URL: {shared.state.completion_url}")
152
+
153
+ with retrieve_proxy():
154
+ try:
155
+ response = requests.post(
156
+ shared.state.completion_url,
157
+ headers=headers,
158
+ json=payload,
159
+ stream=stream,
160
+ timeout=timeout,
161
+ )
162
+ except:
163
+ return None
164
+ return response
165
+
166
+ def _refresh_header(self):
167
+ self.headers = {
168
+ "Content-Type": "application/json",
169
+ "Authorization": f"Bearer {self.api_key}",
170
+ }
171
+
172
+ def _get_billing_data(self, billing_url):
173
+ with retrieve_proxy():
174
+ response = requests.get(
175
+ billing_url,
176
+ headers=self.headers,
177
+ timeout=TIMEOUT_ALL,
178
+ )
179
+
180
+ if response.status_code == 200:
181
+ data = response.json()
182
+ return data
183
+ else:
184
+ raise Exception(
185
+ f"API request failed with status code {response.status_code}: {response.text}"
186
+ )
187
+
188
+ def _decode_chat_response(self, response):
189
+ error_msg = ""
190
+ for chunk in response.iter_lines():
191
+ if chunk:
192
+ chunk = chunk.decode()
193
+ chunk_length = len(chunk)
194
+ try:
195
+ chunk = json.loads(chunk[6:])
196
+ except json.JSONDecodeError:
197
+ print(f"JSON解析错误,收到的内容: {chunk}")
198
+ error_msg+=chunk
199
+ continue
200
+ if chunk_length > 6 and "delta" in chunk["choices"][0]:
201
+ if chunk["choices"][0]["finish_reason"] == "stop":
202
+ break
203
+ try:
204
+ yield chunk["choices"][0]["delta"]["content"]
205
+ except Exception as e:
206
+ # logging.error(f"Error: {e}")
207
+ continue
208
+ if error_msg:
209
+ raise Exception(error_msg)
210
+
211
+
212
+ class ChatGLM_Client(BaseLLMModel):
213
+ def __init__(self, model_name) -> None:
214
+ super().__init__(model_name=model_name)
215
+ from transformers import AutoTokenizer, AutoModel
216
+ import torch
217
+
218
+ system_name = platform.system()
219
+ model_path=None
220
+ if os.path.exists("models"):
221
+ model_dirs = os.listdir("models")
222
+ if model_name in model_dirs:
223
+ model_path = f"models/{model_name}"
224
+ if model_path is not None:
225
+ model_source = model_path
226
+ else:
227
+ model_source = f"THUDM/{model_name}"
228
+ self.tokenizer = AutoTokenizer.from_pretrained(
229
+ model_source, trust_remote_code=True
230
+ )
231
+ quantified = False
232
+ if "int4" in model_name:
233
+ quantified = True
234
+ if quantified:
235
+ model = AutoModel.from_pretrained(
236
+ model_source, trust_remote_code=True
237
+ ).float()
238
+ else:
239
+ model = AutoModel.from_pretrained(
240
+ model_source, trust_remote_code=True
241
+ ).half()
242
+ if torch.cuda.is_available():
243
+ # run on CUDA
244
+ logging.info("CUDA is available, using CUDA")
245
+ model = model.cuda()
246
+ # mps加速还存在一些问题,暂时不使用
247
+ elif system_name == "Darwin" and model_path is not None and not quantified:
248
+ logging.info("Running on macOS, using MPS")
249
+ # running on macOS and model already downloaded
250
+ model = model.to("mps")
251
+ else:
252
+ logging.info("GPU is not available, using CPU")
253
+ model = model.eval()
254
+ self.model = model
255
+
256
+ def _get_glm_style_input(self):
257
+ history = [x["content"] for x in self.history]
258
+ query = history.pop()
259
+ logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
260
+ assert (
261
+ len(history) % 2 == 0
262
+ ), f"History should be even length. current history is: {history}"
263
+ history = [[history[i], history[i + 1]] for i in range(0, len(history), 2)]
264
+ return history, query
265
+
266
+ def get_answer_at_once(self):
267
+ history, query = self._get_glm_style_input()
268
+ response, _ = self.model.chat(self.tokenizer, query, history=history)
269
+ return response, len(response)
270
+
271
+ def get_answer_stream_iter(self):
272
+ history, query = self._get_glm_style_input()
273
+ for response, history in self.model.stream_chat(
274
+ self.tokenizer,
275
+ query,
276
+ history,
277
+ max_length=self.token_upper_limit,
278
+ top_p=self.top_p,
279
+ temperature=self.temperature,
280
+ ):
281
+ yield response
282
+
283
+
284
+ class LLaMA_Client(BaseLLMModel):
285
+ def __init__(
286
+ self,
287
+ model_name,
288
+ lora_path=None,
289
+ ) -> None:
290
+ super().__init__(model_name=model_name)
291
+ from lmflow.datasets.dataset import Dataset
292
+ from lmflow.pipeline.auto_pipeline import AutoPipeline
293
+ from lmflow.models.auto_model import AutoModel
294
+ from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
295
+ model_path = None
296
+ if os.path.exists("models"):
297
+ model_dirs = os.listdir("models")
298
+ if model_name in model_dirs:
299
+ model_path = f"models/{model_name}"
300
+ if model_path is not None:
301
+ model_source = model_path
302
+ else:
303
+ raise Exception(f"models目录下没有这个模型: {model_name}")
304
+ if lora_path is not None:
305
+ lora_path = f"lora/{lora_path}"
306
+ self.max_generation_token = 1000
307
+ pipeline_name = "inferencer"
308
+ model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
309
+ pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
310
+
311
+ with open(pipeline_args.deepspeed, "r") as f:
312
+ ds_config = json.load(f)
313
+
314
+ self.model = AutoModel.get_model(
315
+ model_args,
316
+ tune_strategy="none",
317
+ ds_config=ds_config,
318
+ )
319
+
320
+ # We don't need input data
321
+ data_args = DatasetArguments(dataset_path=None)
322
+ self.dataset = Dataset(data_args)
323
+
324
+ self.inferencer = AutoPipeline.get_pipeline(
325
+ pipeline_name=pipeline_name,
326
+ model_args=model_args,
327
+ data_args=data_args,
328
+ pipeline_args=pipeline_args,
329
+ )
330
+
331
+ # Chats
332
+ model_name = model_args.model_name_or_path
333
+ if model_args.lora_model_path is not None:
334
+ model_name += f" + {model_args.lora_model_path}"
335
+
336
+ # context = (
337
+ # "You are a helpful assistant who follows the given instructions"
338
+ # " unconditionally."
339
+ # )
340
+ self.end_string = "\n\n"
341
+
342
+ def _get_llama_style_input(self):
343
+ history = [x["content"] for x in self.history]
344
+ context = "\n".join(history)
345
+ return context
346
+
347
+ def get_answer_at_once(self):
348
+ context = self._get_llama_style_input()
349
+
350
+ input_dataset = self.dataset.from_dict(
351
+ {"type": "text_only", "instances": [{"text": context}]}
352
+ )
353
+
354
+ output_dataset = self.inferencer.inference(
355
+ model=self.model,
356
+ dataset=input_dataset,
357
+ max_new_tokens=self.max_generation_token,
358
+ temperature=self.temperature,
359
+ )
360
+
361
+ response = output_dataset.to_dict()["instances"][0]["text"]
362
+
363
+ try:
364
+ index = response.index(self.end_string)
365
+ except ValueError:
366
+ response += self.end_string
367
+ index = response.index(self.end_string)
368
+
369
+ response = response[: index + 1]
370
+ return response, len(response)
371
+
372
+ def get_answer_stream_iter(self):
373
+ context = self._get_llama_style_input()
374
+
375
+ input_dataset = self.dataset.from_dict(
376
+ {"type": "text_only", "instances": [{"text": context}]}
377
+ )
378
+
379
+ output_dataset = self.inferencer.inference(
380
+ model=self.model,
381
+ dataset=input_dataset,
382
+ max_new_tokens=self.max_generation_token,
383
+ temperature=self.temperature,
384
+ )
385
+
386
+ response = output_dataset.to_dict()["instances"][0]["text"]
387
+
388
+ try:
389
+ index = response.index(self.end_string)
390
+ except ValueError:
391
+ response += self.end_string
392
+ index = response.index(self.end_string)
393
+
394
+ response = response[: index + 1]
395
+ yield response
396
+
397
+
398
+ class ModelManager:
399
+ def __init__(self, **kwargs) -> None:
400
+ self.get_model(**kwargs)
401
+
402
+ def get_model(
403
+ self,
404
+ model_name,
405
+ lora_model_path=None,
406
+ access_key=None,
407
+ temperature=None,
408
+ top_p=None,
409
+ system_prompt=None,
410
+ ) -> BaseLLMModel:
411
+ msg = f"模型设置为了: {model_name}"
412
+ model_type = ModelType.get_type(model_name)
413
+ lora_selector_visibility = False
414
+ lora_choices = []
415
+ dont_change_lora_selector = False
416
+ if model_type != ModelType.OpenAI:
417
+ config.local_embedding = True
418
+ model = None
419
+ try:
420
+ if model_type == ModelType.OpenAI:
421
+ model = OpenAIClient(
422
+ model_name=model_name,
423
+ api_key=access_key,
424
+ system_prompt=system_prompt,
425
+ temperature=temperature,
426
+ top_p=top_p,
427
+ )
428
+ elif model_type == ModelType.ChatGLM:
429
+ model = ChatGLM_Client(model_name)
430
+ elif model_type == ModelType.LLaMA and lora_model_path == "":
431
+ msg = "现在请选择LoRA模型"
432
+ logging.info(msg)
433
+ lora_selector_visibility = True
434
+ if os.path.isdir("lora"):
435
+ lora_choices = get_file_names("lora", plain=True, filetypes=[""])
436
+ lora_choices = ["No LoRA"] + lora_choices
437
+ elif model_type == ModelType.LLaMA and lora_model_path != "":
438
+ dont_change_lora_selector = True
439
+ if lora_model_path == "No LoRA":
440
+ lora_model_path = None
441
+ msg += " + No LoRA"
442
+ else:
443
+ msg += f" + {lora_model_path}"
444
+ model = LLaMA_Client(model_name, lora_model_path)
445
+ pass
446
+ elif model_type == ModelType.Unknown:
447
+ raise ValueError(f"未知模型: {model_name}")
448
+ logging.info(msg)
449
+ except Exception as e:
450
+ logging.error(e)
451
+ msg = f"{STANDARD_ERROR_MSG}: {e}"
452
+ if model is not None:
453
+ self.model = model
454
+ if dont_change_lora_selector:
455
+ return msg
456
+ else:
457
+ return msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility)
458
+
459
+ def predict(self, *args):
460
+ iter = self.model.predict(*args)
461
+ for i in iter:
462
+ yield i
463
+
464
+ def billing_info(self):
465
+ return self.model.billing_info()
466
+
467
+ def set_key(self, *args):
468
+ return self.model.set_key(*args)
469
+
470
+ def load_chat_history(self, *args):
471
+ return self.model.load_chat_history(*args)
472
+
473
+ def interrupt(self, *args):
474
+ return self.model.interrupt(*args)
475
+
476
+ def reset(self, *args):
477
+ return self.model.reset(*args)
478
+
479
+ def retry(self, *args):
480
+ iter = self.model.retry(*args)
481
+ for i in iter:
482
+ yield i
483
+
484
+ def delete_first_conversation(self, *args):
485
+ return self.model.delete_first_conversation(*args)
486
+
487
+ def delete_last_conversation(self, *args):
488
+ return self.model.delete_last_conversation(*args)
489
+
490
+ def set_system_prompt(self, *args):
491
+ return self.model.set_system_prompt(*args)
492
+
493
+ def save_chat_history(self, *args):
494
+ return self.model.save_chat_history(*args)
495
+
496
+ def export_markdown(self, *args):
497
+ return self.model.export_markdown(*args)
498
+
499
+ def load_chat_history(self, *args):
500
+ return self.model.load_chat_history(*args)
501
+
502
+ def set_token_upper_limit(self, *args):
503
+ return self.model.set_token_upper_limit(*args)
504
+
505
+ def set_temperature(self, *args):
506
+ self.model.set_temperature(*args)
507
+
508
+ def set_top_p(self, *args):
509
+ self.model.set_top_p(*args)
510
+
511
+ def set_n_choices(self, *args):
512
+ self.model.set_n_choices(*args)
513
+
514
+ def set_stop_sequence(self, *args):
515
+ self.model.set_stop_sequence(*args)
516
+
517
+ def set_max_tokens(self, *args):
518
+ self.model.set_max_tokens(*args)
519
+
520
+ def set_presence_penalty(self, *args):
521
+ self.model.set_presence_penalty(*args)
522
+
523
+ def set_frequency_penalty(self, *args):
524
+ self.model.set_frequency_penalty(*args)
525
+
526
+ def set_logit_bias(self, *args):
527
+ self.model.set_logit_bias(*args)
528
+
529
+ def set_user_identifier(self, *args):
530
+ self.model.set_user_identifier(*args)
531
+
532
+
533
+
534
+
535
+ if __name__ == "__main__":
536
+ with open("config.json", "r") as f:
537
+ openai_api_key = cjson.load(f)["openai_api_key"]
538
+ # set logging level to debug
539
+ logging.basicConfig(level=logging.DEBUG)
540
+ # client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
541
+ client = ModelManager(model_name="chatglm-6b-int4")
542
+ chatbot = []
543
+ stream = False
544
+ # 测试账单功能
545
+ logging.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
546
+ logging.info(client.billing_info())
547
+ # 测试问答
548
+ logging.info(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET)
549
+ question = "巴黎是中国的首都吗?"
550
+ for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
551
+ logging.info(i)
552
+ logging.info(f"测试问答后history : {client.history}")
553
+ # 测试记忆力
554
+ logging.info(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET)
555
+ question = "我刚刚问了你什么问题?"
556
+ for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
557
+ logging.info(i)
558
+ logging.info(f"测试记忆力后history : {client.history}")
559
+ # 测试重试功能
560
+ logging.info(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET)
561
+ for i in client.retry(chatbot=chatbot, stream=stream):
562
+ logging.info(i)
563
+ logging.info(f"重试后history : {client.history}")
564
+ # # 测试总结功能
565
+ # print(colorama.Back.GREEN + "测试总结功能" + colorama.Back.RESET)
566
+ # chatbot, msg = client.reduce_token_size(chatbot=chatbot)
567
+ # print(chatbot, msg)
568
+ # print(f"总结后history: {client.history}")
modules/openai_func.py DELETED
@@ -1,65 +0,0 @@
1
- import requests
2
- import logging
3
- from modules.presets import (
4
- timeout_all,
5
- USAGE_API_URL,
6
- BALANCE_API_URL,
7
- standard_error_msg,
8
- connection_timeout_prompt,
9
- error_retrieve_prompt,
10
- read_timeout_prompt
11
- )
12
-
13
- from . import shared
14
- from modules.config import retrieve_proxy
15
- import os, datetime
16
-
17
- def get_billing_data(openai_api_key, billing_url):
18
- headers = {
19
- "Content-Type": "application/json",
20
- "Authorization": f"Bearer {openai_api_key}"
21
- }
22
-
23
- timeout = timeout_all
24
- with retrieve_proxy():
25
- response = requests.get(
26
- billing_url,
27
- headers=headers,
28
- timeout=timeout,
29
- )
30
-
31
- if response.status_code == 200:
32
- data = response.json()
33
- return data
34
- else:
35
- raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
36
-
37
-
38
- def get_usage(openai_api_key):
39
- try:
40
- curr_time = datetime.datetime.now()
41
- last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
42
- first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
43
- usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
44
- try:
45
- usage_data = get_billing_data(openai_api_key, usage_url)
46
- except Exception as e:
47
- logging.error(f"获取API使用情况失败:"+str(e))
48
- return f"**获取API使用情况失败**"
49
- rounded_usage = "{:.5f}".format(usage_data['total_usage']/100)
50
- return f"**本月使用金额** \u3000 ${rounded_usage}"
51
- except requests.exceptions.ConnectTimeout:
52
- status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
53
- return status_text
54
- except requests.exceptions.ReadTimeout:
55
- status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
56
- return status_text
57
- except Exception as e:
58
- logging.error(f"获取API使用情况失败:"+str(e))
59
- return standard_error_msg + error_retrieve_prompt
60
-
61
- def get_last_day_of_month(any_day):
62
- # The day 28 exists in every month. 4 days later, it's always next month
63
- next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
64
- # subtracting the number of the current day brings us back one month
65
- return next_month - datetime.timedelta(days=next_month.day)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/presets.py CHANGED
@@ -3,48 +3,50 @@ import gradio as gr
3
  from pathlib import Path
4
 
5
  # ChatGPT 设置
6
- initial_prompt = "You are a helpful assistant."
7
  API_HOST = "api.openai.com"
8
  COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
9
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
10
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
11
  HISTORY_DIR = Path("history")
 
12
  TEMPLATES_DIR = "templates"
13
 
14
  # 错误信息
15
- standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
16
- error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
17
- connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
18
- read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
19
- proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
20
- ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
21
- no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
22
- no_input_msg = "请输入对话内容。" # 未输入对话内容
23
-
24
- timeout_streaming = 60 # 流式对话时的超时时间
25
- timeout_all = 200 # 非流式对话时的超时时间
26
- enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
 
 
27
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
28
  CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
29
 
30
  SIM_K = 5
31
  INDEX_QUERY_TEMPRATURE = 1.0
32
 
33
- title = """<h1 align="left">川虎ChatGPT 🚀</h1>"""
34
- description = """\
35
  <div align="center" style="margin:16px 0">
36
 
37
  由Bilibili [土川虎虎虎](https://space.bilibili.com/29125536) 和 [明昭MZhao](https://space.bilibili.com/24807452)开发
38
 
39
  访问川虎ChatGPT的 [GitHub项目](https://github.com/GaiZhenbiao/ChuanhuChatGPT) 下载最新版脚本
40
 
41
- 此App使用 `gpt-3.5-turbo` 大语言模型
42
  </div>
43
  """
44
 
45
- footer = """<div class="versions">{versions}</div>"""
46
 
47
- appearance_switcher = """
48
  <div style="display: flex; justify-content: space-between;">
49
  <span style="margin-top: 4px !important;">切换亮暗色主题</span>
50
  <span><label class="apSwitch" for="checkbox">
@@ -53,7 +55,8 @@ appearance_switcher = """
53
  </label></span>
54
  </div>
55
  """
56
- summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
 
57
 
58
  MODELS = [
59
  "gpt-3.5-turbo",
@@ -62,35 +65,34 @@ MODELS = [
62
  "gpt-4-0314",
63
  "gpt-4-32k",
64
  "gpt-4-32k-0314",
 
 
 
 
 
 
 
 
 
 
 
65
  ] # 可选的模型
66
 
67
- MODEL_SOFT_TOKEN_LIMIT = {
68
- "gpt-3.5-turbo": {
69
- "streaming": 3500,
70
- "all": 3500
71
- },
72
- "gpt-3.5-turbo-0301": {
73
- "streaming": 3500,
74
- "all": 3500
75
- },
76
- "gpt-4": {
77
- "streaming": 7500,
78
- "all": 7500
79
- },
80
- "gpt-4-0314": {
81
- "streaming": 7500,
82
- "all": 7500
83
- },
84
- "gpt-4-32k": {
85
- "streaming": 31000,
86
- "all": 31000
87
- },
88
- "gpt-4-32k-0314": {
89
- "streaming": 31000,
90
- "all": 31000
91
- }
92
  }
93
 
 
 
 
 
94
  REPLY_LANGUAGES = [
95
  "简体中文",
96
  "繁體中文",
 
3
  from pathlib import Path
4
 
5
  # ChatGPT 设置
6
+ INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
7
  API_HOST = "api.openai.com"
8
  COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
9
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
10
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
11
  HISTORY_DIR = Path("history")
12
+ HISTORY_DIR = "history"
13
  TEMPLATES_DIR = "templates"
14
 
15
  # 错误信息
16
+ STANDARD_ERROR_MSG = "☹️发生了错误:" # 错误信息的标准前缀
17
+ GENERAL_ERROR_MSG = "获取对话时发生错误,请查看后台日志"
18
+ ERROR_RETRIEVE_MSG = "请检查网络连接,或者API-Key是否有效。"
19
+ CONNECTION_TIMEOUT_MSG = "连接超时,无法获取对话。" # 连接超时
20
+ READ_TIMEOUT_MSG = "读取超时,无法获取对话。" # 读取超时
21
+ PROXY_ERROR_MSG = "代理错误,无法获取对话。" # 代理错误
22
+ SSL_ERROR_PROMPT = "SSL错误,无法获取对话。" # SSL 错误
23
+ NO_APIKEY_MSG = "API key为空,请检查是否输入正确。" # API key 长度不足 51 位
24
+ NO_INPUT_MSG = "请输入对话内容。" # 未输入对话内容
25
+ BILLING_NOT_APPLICABLE_MSG = "模型本地运行中" # 本地运行的模型返回的账单信息
26
+
27
+ TIMEOUT_STREAMING = 60 # 流式对话时的超时时间
28
+ TIMEOUT_ALL = 200 # 非流式对话时的超时时间
29
+ ENABLE_STREAMING_OPTION = True # 是否启用选择选择是否实时显示回答的勾选框
30
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
31
  CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
32
 
33
  SIM_K = 5
34
  INDEX_QUERY_TEMPRATURE = 1.0
35
 
36
+ CHUANHU_TITLE = """<h1 align="left">川虎ChatGPT 🚀</h1>"""
37
+ CHUANHU_DESCRIPTION = """\
38
  <div align="center" style="margin:16px 0">
39
 
40
  由Bilibili [土川虎虎虎](https://space.bilibili.com/29125536) 和 [明昭MZhao](https://space.bilibili.com/24807452)开发
41
 
42
  访问川虎ChatGPT的 [GitHub项目](https://github.com/GaiZhenbiao/ChuanhuChatGPT) 下载最新版脚本
43
 
 
44
  </div>
45
  """
46
 
47
+ FOOTER = """<div class="versions">{versions}</div>"""
48
 
49
+ APPEARANCE_SWITCHER = """
50
  <div style="display: flex; justify-content: space-between;">
51
  <span style="margin-top: 4px !important;">切换亮暗色主题</span>
52
  <span><label class="apSwitch" for="checkbox">
 
55
  </label></span>
56
  </div>
57
  """
58
+
59
+ SUMMARIZE_PROMPT = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
60
 
61
  MODELS = [
62
  "gpt-3.5-turbo",
 
65
  "gpt-4-0314",
66
  "gpt-4-32k",
67
  "gpt-4-32k-0314",
68
+ "chatglm-6b",
69
+ "chatglm-6b-int4",
70
+ "chatglm-6b-int4-qe",
71
+ "llama-7b-hf",
72
+ "llama-7b-hf-int4",
73
+ "llama-7b-hf-int8",
74
+ "llama-13b-hf",
75
+ "llama-13b-hf-int4",
76
+ "llama-30b-hf",
77
+ "llama-30b-hf-int4",
78
+ "llama-65b-hf",
79
  ] # 可选的模型
80
 
81
+ DEFAULT_MODEL = 0 # 默认的模型在MODELS中的序号,从0开始数
82
+
83
+ MODEL_TOKEN_LIMIT = {
84
+ "gpt-3.5-turbo": 4096,
85
+ "gpt-3.5-turbo-0301": 4096,
86
+ "gpt-4": 8192,
87
+ "gpt-4-0314": 8192,
88
+ "gpt-4-32k": 32768,
89
+ "gpt-4-32k-0314": 32768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  }
91
 
92
+ TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
93
+ DEFAULT_TOKEN_LIMIT = 3000 # 默认的token上限
94
+ REDUCE_TOKEN_FACTOR = 0.5 # 与模型token上限想乘,得到目标token数。减少token占用时,将token占用减少到目标token数以下。
95
+
96
  REPLY_LANGUAGES = [
97
  "简体中文",
98
  "繁體中文",
modules/utils.py CHANGED
@@ -153,107 +153,22 @@ def construct_assistant(text):
153
  return construct_text("assistant", text)
154
 
155
 
156
- def construct_token_message(tokens: List[int]):
157
- token_sum = 0
158
- for i in range(len(tokens)):
159
- token_sum += sum(tokens[: i + 1])
160
- return f"Token 计数: {sum(tokens)},本次对话累计消耗了 {token_sum} tokens"
161
-
162
-
163
- def delete_first_conversation(history, previous_token_count):
164
- if history:
165
- del history[:2]
166
- del previous_token_count[0]
167
- return (
168
- history,
169
- previous_token_count,
170
- construct_token_message(previous_token_count),
171
- )
172
-
173
-
174
- def delete_last_conversation(chatbot, history, previous_token_count):
175
- if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
176
- logging.info("由于包含报错信息,只删除chatbot记录")
177
- chatbot.pop()
178
- return chatbot, history
179
- if len(history) > 0:
180
- logging.info("删除了一组对话历史")
181
- history.pop()
182
- history.pop()
183
- if len(chatbot) > 0:
184
- logging.info("删除了一组chatbot对话")
185
- chatbot.pop()
186
- if len(previous_token_count) > 0:
187
- logging.info("删除了一组对话的token计数记录")
188
- previous_token_count.pop()
189
- return (
190
- chatbot,
191
- history,
192
- previous_token_count,
193
- construct_token_message(previous_token_count),
194
- )
195
-
196
-
197
  def save_file(filename, system, history, chatbot, user_name):
198
- logging.info(f"{user_name} 保存对话历史中……")
199
- os.makedirs(HISTORY_DIR / user_name, exist_ok=True)
200
  if filename.endswith(".json"):
201
  json_s = {"system": system, "history": history, "chatbot": chatbot}
202
  print(json_s)
203
- with open(os.path.join(HISTORY_DIR / user_name, filename), "w") as f:
204
  json.dump(json_s, f)
205
  elif filename.endswith(".md"):
206
  md_s = f"system: \n- {system} \n"
207
  for data in history:
208
  md_s += f"\n{data['role']}: \n- {data['content']} \n"
209
- with open(os.path.join(HISTORY_DIR / user_name, filename), "w", encoding="utf8") as f:
210
  f.write(md_s)
211
- logging.info(f"{user_name} 保存对话历史完毕")
212
- return os.path.join(HISTORY_DIR / user_name, filename)
213
-
214
-
215
- def save_chat_history(filename, system, history, chatbot, user_name):
216
- if filename == "":
217
- return
218
- if not filename.endswith(".json"):
219
- filename += ".json"
220
- return save_file(filename, system, history, chatbot, user_name)
221
-
222
-
223
- def export_markdown(filename, system, history, chatbot, user_name):
224
- if filename == "":
225
- return
226
- if not filename.endswith(".md"):
227
- filename += ".md"
228
- return save_file(filename, system, history, chatbot, user_name)
229
-
230
-
231
- def load_chat_history(filename, system, history, chatbot, user_name):
232
- logging.info(f"{user_name} 加载对话历史中……")
233
- if type(filename) != str:
234
- filename = filename.name
235
- try:
236
- with open(os.path.join(HISTORY_DIR / user_name, filename), "r") as f:
237
- json_s = json.load(f)
238
- try:
239
- if type(json_s["history"][0]) == str:
240
- logging.info("历史记录格式为旧版,正在转换……")
241
- new_history = []
242
- for index, item in enumerate(json_s["history"]):
243
- if index % 2 == 0:
244
- new_history.append(construct_user(item))
245
- else:
246
- new_history.append(construct_assistant(item))
247
- json_s["history"] = new_history
248
- logging.info(new_history)
249
- except:
250
- # 没有对话历史
251
- pass
252
- logging.info(f"{user_name} 加载对话历史完毕")
253
- return filename, json_s["system"], json_s["history"], json_s["chatbot"]
254
- except FileNotFoundError:
255
- logging.info(f"{user_name} 没有找到对话历史文件,不执行任何操作")
256
- return filename, system, history, chatbot
257
 
258
 
259
  def sorted_by_pinyin(list):
@@ -261,7 +176,7 @@ def sorted_by_pinyin(list):
261
 
262
 
263
  def get_file_names(dir, plain=False, filetypes=[".json"]):
264
- logging.info(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
265
  files = []
266
  try:
267
  for type in filetypes:
@@ -279,14 +194,13 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
279
 
280
 
281
  def get_history_names(plain=False, user_name=""):
282
- logging.info(f"从用户 {user_name} 中获取历史记录文件名列表")
283
- return get_file_names(HISTORY_DIR / user_name, plain)
284
 
285
 
286
  def load_template(filename, mode=0):
287
- logging.info(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
288
  lines = []
289
- logging.info("Loading template...")
290
  if filename.endswith(".json"):
291
  with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
292
  lines = json.load(f)
@@ -310,23 +224,18 @@ def load_template(filename, mode=0):
310
 
311
 
312
  def get_template_names(plain=False):
313
- logging.info("获取模板文件名列表")
314
  return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
315
 
316
 
317
  def get_template_content(templates, selection, original_system_prompt):
318
- logging.info(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
319
  try:
320
  return templates[selection]
321
  except:
322
  return original_system_prompt
323
 
324
 
325
- def reset_state():
326
- logging.info("重置状态")
327
- return [], [], [], construct_token_message([0])
328
-
329
-
330
  def reset_textbox():
331
  logging.debug("重置文本框")
332
  return gr.update(value="")
@@ -530,3 +439,13 @@ def excel_to_string(file_path):
530
 
531
 
532
  return result
 
 
 
 
 
 
 
 
 
 
 
153
  return construct_text("assistant", text)
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def save_file(filename, system, history, chatbot, user_name):
157
+ logging.debug(f"{user_name} 保存对话历史中……")
158
+ os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
159
  if filename.endswith(".json"):
160
  json_s = {"system": system, "history": history, "chatbot": chatbot}
161
  print(json_s)
162
+ with open(os.path.join(HISTORY_DIR, user_name, filename), "w") as f:
163
  json.dump(json_s, f)
164
  elif filename.endswith(".md"):
165
  md_s = f"system: \n- {system} \n"
166
  for data in history:
167
  md_s += f"\n{data['role']}: \n- {data['content']} \n"
168
+ with open(os.path.join(HISTORY_DIR, user_name, filename), "w", encoding="utf8") as f:
169
  f.write(md_s)
170
+ logging.debug(f"{user_name} 保存对话历史完毕")
171
+ return os.path.join(HISTORY_DIR, user_name, filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
 
174
  def sorted_by_pinyin(list):
 
176
 
177
 
178
  def get_file_names(dir, plain=False, filetypes=[".json"]):
179
+ logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
180
  files = []
181
  try:
182
  for type in filetypes:
 
194
 
195
 
196
  def get_history_names(plain=False, user_name=""):
197
+ logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
198
+ return get_file_names(os.path.join(HISTORY_DIR, user_name), plain)
199
 
200
 
201
  def load_template(filename, mode=0):
202
+ logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
203
  lines = []
 
204
  if filename.endswith(".json"):
205
  with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
206
  lines = json.load(f)
 
224
 
225
 
226
  def get_template_names(plain=False):
227
+ logging.debug("获取模板文件名列表")
228
  return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
229
 
230
 
231
  def get_template_content(templates, selection, original_system_prompt):
232
+ logging.debug(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
233
  try:
234
  return templates[selection]
235
  except:
236
  return original_system_prompt
237
 
238
 
 
 
 
 
 
239
  def reset_textbox():
240
  logging.debug("重置文本框")
241
  return gr.update(value="")
 
439
 
440
 
441
  return result
442
+
443
+ def get_last_day_of_month(any_day):
444
+ # The day 28 exists in every month. 4 days later, it's always next month
445
+ next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
446
+ # subtracting the number of the current day brings us back one month
447
+ return next_month - datetime.timedelta(days=next_month.day)
448
+
449
+ def get_model_source(model_name, alternative_source):
450
+ if model_name == "gpt2-medium":
451
+ return "https://huggingface.co/gpt2-medium"
requirements_advanced.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ icetk
4
+ protobuf==3.19.0
5
+ git+https://github.com/OptimalScale/LMFlow.git#egg=lmflow
6
+ cpm-kernels
7
+ sentence_transformers