Tuchuanhuhuhu commited on
Commit
77f2c42
1 Parent(s): 64eb375

去除chat_func文件,改用类控制模型

Browse files
ChuanhuChatbot.py CHANGED
@@ -10,8 +10,7 @@ from modules.config import *
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
- from modules.chat_func import *
14
- from modules.openai_func import get_usage
15
 
16
  gr.Chatbot.postprocess = postprocess
17
  PromptHelper.compact_text_chunks = compact_text_chunks
@@ -21,12 +20,11 @@ with open("assets/custom.css", "r", encoding="utf-8") as f:
21
 
22
  with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
23
  user_name = gr.State("")
24
- history = gr.State([])
25
- token_count = gr.State([])
26
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
27
  user_api_key = gr.State(my_api_key)
28
  user_question = gr.State("")
29
- outputing = gr.State(False)
 
30
  topic = gr.State("未命名对话历史记录")
31
 
32
  with gr.Row():
@@ -64,7 +62,6 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
64
  retryBtn = gr.Button("🔄 重新生成")
65
  delFirstBtn = gr.Button("🗑️ 删除最旧对话")
66
  delLastBtn = gr.Button("🗑️ 删除最新对话")
67
- reduceTokenBtn = gr.Button("♻️ 总结对话")
68
 
69
  with gr.Column():
70
  with gr.Column(min_width=50, scale=1):
@@ -94,7 +91,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
94
  multiselect=False,
95
  value=REPLY_LANGUAGES[0],
96
  )
97
- index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
98
  two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
99
  # TODO: 公式ocr
100
  # formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
@@ -104,7 +101,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
104
  show_label=True,
105
  placeholder=f"在这里输入System Prompt...",
106
  label="System prompt",
107
- value=initial_prompt,
108
  lines=10,
109
  ).style(container=False)
110
  with gr.Accordion(label="加载Prompt模板", open=True):
@@ -202,23 +199,16 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
202
  gr.Markdown(description)
203
  gr.HTML(footer.format(versions=versions_html()), elem_id="footer")
204
  chatgpt_predict_args = dict(
205
- fn=predict,
206
  inputs=[
207
- user_api_key,
208
- systemPromptTxt,
209
- history,
210
  user_question,
211
  chatbot,
212
- token_count,
213
- top_p,
214
- temperature,
215
  use_streaming_checkbox,
216
- model_select_dropdown,
217
  use_websearch_checkbox,
218
  index_files,
219
  language_select_dropdown,
220
  ],
221
- outputs=[chatbot, history, status_display, token_count],
222
  show_progress=True,
223
  )
224
 
@@ -242,12 +232,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
242
  )
243
 
244
  get_usage_args = dict(
245
- fn=get_usage, inputs=[user_api_key], outputs=[usageTxt], show_progress=False
 
 
 
 
 
 
246
  )
247
 
248
 
249
  # Chatbot
250
- cancelBtn.click(cancel_outputing, [], [])
251
 
252
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
253
  user_input.submit(**get_usage_args)
@@ -256,63 +252,39 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
256
  submitBtn.click(**get_usage_args)
257
 
258
  emptyBtn.click(
259
- reset_state,
260
- outputs=[chatbot, history, token_count, status_display],
261
  show_progress=True,
262
  )
263
  emptyBtn.click(**reset_textbox_args)
264
 
265
  retryBtn.click(**start_outputing_args).then(
266
- retry,
267
  [
268
- user_api_key,
269
- systemPromptTxt,
270
- history,
271
  chatbot,
272
- token_count,
273
- top_p,
274
- temperature,
275
  use_streaming_checkbox,
276
- model_select_dropdown,
 
277
  language_select_dropdown,
278
  ],
279
- [chatbot, history, status_display, token_count],
280
  show_progress=True,
281
  ).then(**end_outputing_args)
282
  retryBtn.click(**get_usage_args)
283
 
284
  delFirstBtn.click(
285
- delete_first_conversation,
286
- [history, token_count],
287
- [history, token_count, status_display],
288
  )
289
 
290
  delLastBtn.click(
291
- delete_last_conversation,
292
- [chatbot, history, token_count],
293
- [chatbot, history, token_count, status_display],
294
- show_progress=True,
295
  )
296
 
297
- reduceTokenBtn.click(
298
- reduce_token_size,
299
- [
300
- user_api_key,
301
- systemPromptTxt,
302
- history,
303
- chatbot,
304
- token_count,
305
- top_p,
306
- temperature,
307
- gr.State(sum(token_count.value[-4:])),
308
- model_select_dropdown,
309
- language_select_dropdown,
310
- ],
311
- [chatbot, history, status_display, token_count],
312
- show_progress=True,
313
- )
314
- reduceTokenBtn.click(**get_usage_args)
315
-
316
  two_column.change(update_doc_config, [two_column], None)
317
 
318
  # ChatGPT
@@ -336,30 +308,21 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
336
 
337
  # S&L
338
  saveHistoryBtn.click(
339
- save_chat_history,
340
- [saveFileName, systemPromptTxt, history, chatbot, user_name],
341
  downloadFile,
342
  show_progress=True,
343
  )
344
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
345
  exportMarkdownBtn.click(
346
- export_markdown,
347
- [saveFileName, systemPromptTxt, history, chatbot, user_name],
348
  downloadFile,
349
  show_progress=True,
350
  )
351
  historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
352
- historyFileSelectDropdown.change(
353
- load_chat_history,
354
- [historyFileSelectDropdown, systemPromptTxt, history, chatbot, user_name],
355
- [saveFileName, systemPromptTxt, history, chatbot],
356
- show_progress=True,
357
- )
358
- downloadFile.change(
359
- load_chat_history,
360
- [downloadFile, systemPromptTxt, history, chatbot, user_name],
361
- [saveFileName, systemPromptTxt, history, chatbot],
362
- )
363
 
364
  # Advanced
365
  default_btn.click(
 
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
+ from modules.models import get_model
 
14
 
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
 
20
 
21
  with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
22
  user_name = gr.State("")
 
 
23
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
24
  user_api_key = gr.State(my_api_key)
25
  user_question = gr.State("")
26
+ current_model = gr.State(get_model(MODELS[0], my_api_key))
27
+
28
  topic = gr.State("未命名对话历史记录")
29
 
30
  with gr.Row():
 
62
  retryBtn = gr.Button("🔄 重新生成")
63
  delFirstBtn = gr.Button("🗑️ 删除最旧对话")
64
  delLastBtn = gr.Button("🗑️ 删除最新对话")
 
65
 
66
  with gr.Column():
67
  with gr.Column(min_width=50, scale=1):
 
91
  multiselect=False,
92
  value=REPLY_LANGUAGES[0],
93
  )
94
+ index_files = gr.Files(label="上传索引文件", type="file")
95
  two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
96
  # TODO: 公式ocr
97
  # formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
 
101
  show_label=True,
102
  placeholder=f"在这里输入System Prompt...",
103
  label="System prompt",
104
+ value=INITIAL_SYSTEM_PROMPT,
105
  lines=10,
106
  ).style(container=False)
107
  with gr.Accordion(label="加载Prompt模板", open=True):
 
199
  gr.Markdown(description)
200
  gr.HTML(footer.format(versions=versions_html()), elem_id="footer")
201
  chatgpt_predict_args = dict(
202
+ fn=current_model.value.predict,
203
  inputs=[
 
 
 
204
  user_question,
205
  chatbot,
 
 
 
206
  use_streaming_checkbox,
 
207
  use_websearch_checkbox,
208
  index_files,
209
  language_select_dropdown,
210
  ],
211
+ outputs=[chatbot, status_display],
212
  show_progress=True,
213
  )
214
 
 
232
  )
233
 
234
  get_usage_args = dict(
235
+ fn=current_model.value.billing_info, inputs=None, outputs=[usageTxt], show_progress=False
236
+ )
237
+
238
+ load_history_from_file_args = dict(
239
+ fn=current_model.value.load_chat_history,
240
+ inputs=[historyFileSelectDropdown, chatbot, user_name],
241
+ outputs=[saveFileName, systemPromptTxt, chatbot]
242
  )
243
 
244
 
245
  # Chatbot
246
+ cancelBtn.click(current_model.value.interrupt, [], [])
247
 
248
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
249
  user_input.submit(**get_usage_args)
 
252
  submitBtn.click(**get_usage_args)
253
 
254
  emptyBtn.click(
255
+ current_model.value.reset,
256
+ outputs=[chatbot, status_display],
257
  show_progress=True,
258
  )
259
  emptyBtn.click(**reset_textbox_args)
260
 
261
  retryBtn.click(**start_outputing_args).then(
262
+ current_model.value.retry,
263
  [
 
 
 
264
  chatbot,
 
 
 
265
  use_streaming_checkbox,
266
+ use_websearch_checkbox,
267
+ index_files,
268
  language_select_dropdown,
269
  ],
270
+ [chatbot, status_display],
271
  show_progress=True,
272
  ).then(**end_outputing_args)
273
  retryBtn.click(**get_usage_args)
274
 
275
  delFirstBtn.click(
276
+ current_model.value.delete_first_conversation,
277
+ None,
278
+ [status_display],
279
  )
280
 
281
  delLastBtn.click(
282
+ current_model.value.delete_last_conversation,
283
+ [chatbot],
284
+ [chatbot, status_display],
285
+ show_progress=False
286
  )
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  two_column.change(update_doc_config, [two_column], None)
289
 
290
  # ChatGPT
 
308
 
309
  # S&L
310
  saveHistoryBtn.click(
311
+ current_model.value.save_chat_history,
312
+ [saveFileName, chatbot, user_name],
313
  downloadFile,
314
  show_progress=True,
315
  )
316
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
317
  exportMarkdownBtn.click(
318
+ current_model.value.export_markdown,
319
+ [saveFileName, chatbot, user_name],
320
  downloadFile,
321
  show_progress=True,
322
  )
323
  historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
324
+ historyFileSelectDropdown.change(**load_history_from_file_args)
325
+ downloadFile.change(**load_history_from_file_args)
 
 
 
 
 
 
 
 
 
326
 
327
  # Advanced
328
  default_btn.click(
modules/__init__.py ADDED
File without changes
modules/base_model.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, List
3
+
4
+ import logging
5
+ import json
6
+ import commentjson as cjson
7
+ import os
8
+ import sys
9
+ import requests
10
+ import urllib3
11
+
12
+ from tqdm import tqdm
13
+ import colorama
14
+ from duckduckgo_search import ddg
15
+ import asyncio
16
+ import aiohttp
17
+ from enum import Enum
18
+
19
+ from .presets import *
20
+ from .llama_func import *
21
+ from .utils import *
22
+ from . import shared
23
+ from .config import retrieve_proxy
24
+
25
+
26
+ class ModelType(Enum):
27
+ OpenAI = 0
28
+ ChatGLM = 1
29
+ LLaMA = 2
30
+
31
+ @classmethod
32
+ def get_type(cls, model_name: str):
33
+ model_type = None
34
+ if "gpt" in model_name.lower():
35
+ model_type = ModelType.OpenAI
36
+ elif "chatglm" in model_name.upper():
37
+ model_type = ModelType.ChatGLM
38
+ else:
39
+ model_type = ModelType.LLaMA
40
+ return model_type
41
+
42
+
43
+ class BaseLLMModel:
44
+ def __init__(self, model_name, temperature=1.0, top_p=1.0, max_generation_token=None, system_prompt="") -> None:
45
+ self.history = []
46
+ self.all_token_counts = []
47
+ self.model_name = model_name
48
+ self.model_type = ModelType.get_type(model_name)
49
+ self.api_key = None
50
+ self.token_upper_limit = MODEL_TOKEN_LIMIT[model_name]
51
+ self.max_generation_token = max_generation_token if max_generation_token is not None else self.token_upper_limit
52
+ self.interrupted = False
53
+ self.temperature = temperature
54
+ self.top_p = top_p
55
+ self.system_prompt = system_prompt
56
+
57
+
58
+ def get_answer_stream_iter(self):
59
+ """stream predict, need to be implemented
60
+ conversations are stored in self.history, with the most recent question, in OpenAI format
61
+ should return a generator, each time give the next word (str) in the answer
62
+ """
63
+ pass
64
+
65
+ def get_answer_at_once(self):
66
+ """predict at once, need to be implemented
67
+ conversations are stored in self.history, with the most recent question, in OpenAI format
68
+ Should return:
69
+ the answer (str)
70
+ total token count (int)
71
+ """
72
+ pass
73
+
74
+ def billing_info(self):
75
+ """get billing infomation, inplement if needed"""
76
+ return billing_not_applicable_msg
77
+
78
+
79
+ def count_token(self, user_input):
80
+ """get token count from input, implement if needed
81
+ """
82
+ return 0
83
+
84
+ def stream_next_chatbot(
85
+ self, inputs, chatbot, fake_input=None, display_append=""
86
+ ):
87
+ def get_return_value():
88
+ return chatbot, status_text
89
+
90
+ status_text = "开始实时传输回答……"
91
+ if fake_input:
92
+ chatbot.append((fake_input, ""))
93
+ else:
94
+ chatbot.append((inputs, ""))
95
+
96
+ user_token_count = self.count_token(inputs)
97
+ self.all_token_counts.append(user_token_count)
98
+ logging.debug(f"输入token计数: {user_token_count}")
99
+
100
+ stream_iter = self.get_answer_stream_iter()
101
+
102
+ for partial_text in stream_iter:
103
+ self.history[-1] = construct_assistant(partial_text)
104
+ chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
105
+ self.all_token_counts[-1] += 1
106
+ status_text = self.token_message()
107
+ yield get_return_value()
108
+
109
+ def next_chatbot_at_once(
110
+ self, inputs, chatbot, fake_input=None, display_append=""
111
+ ):
112
+ if fake_input:
113
+ chatbot.append((fake_input, ""))
114
+ else:
115
+ chatbot.append((inputs, ""))
116
+ if fake_input is not None:
117
+ user_token_count = self.count_token(fake_input)
118
+ else:
119
+ user_token_count = self.count_token(inputs)
120
+ self.all_token_counts.append(user_token_count)
121
+ ai_reply, total_token_count = self.get_answer_at_once()
122
+ if fake_input is not None:
123
+ self.history[-2] = construct_user(fake_input)
124
+ self.history[-1] = construct_assistant(ai_reply)
125
+ chatbot[-1] = (chatbot[-1][0], ai_reply+display_append)
126
+ if fake_input is not None:
127
+ self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
128
+ else:
129
+ self.all_token_counts[-1] = total_token_count - sum(self.all_token_counts)
130
+ status_text = self.token_message()
131
+ return chatbot, status_text
132
+
133
+ def predict(
134
+ self,
135
+ inputs,
136
+ chatbot,
137
+ stream=False,
138
+ use_websearch=False,
139
+ files=None,
140
+ reply_language="中文",
141
+ should_check_token_count=True,
142
+ ): # repetition_penalty, top_k
143
+ from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
144
+ from llama_index.indices.query.schema import QueryBundle
145
+ from langchain.llms import OpenAIChat
146
+
147
+ logging.info(
148
+ "输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
149
+ )
150
+ if should_check_token_count:
151
+ yield chatbot + [(inputs, "")], "开始生成回答……"
152
+ if reply_language == "跟随问题语言(不稳定)":
153
+ reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
154
+ old_inputs = None
155
+ display_reference = []
156
+ limited_context = False
157
+ if files and self.api_key:
158
+ limited_context = True
159
+ old_inputs = inputs
160
+ msg = "加载索引中……(这可能需要几分钟)"
161
+ logging.info(msg)
162
+ yield chatbot + [(inputs, "")], msg
163
+ index = construct_index(self.api_key, file_src=files)
164
+ msg = "索引构建完成,获取回答中……"
165
+ logging.info(msg)
166
+ yield chatbot + [(inputs, "")], msg
167
+ with retrieve_proxy():
168
+ llm_predictor = LLMPredictor(
169
+ llm=OpenAIChat(temperature=0, model_name=self.model_name)
170
+ )
171
+ prompt_helper = PromptHelper(
172
+ max_input_size=4096,
173
+ num_output=5,
174
+ max_chunk_overlap=20,
175
+ chunk_size_limit=600,
176
+ )
177
+ from llama_index import ServiceContext
178
+
179
+ service_context = ServiceContext.from_defaults(
180
+ llm_predictor=llm_predictor, prompt_helper=prompt_helper
181
+ )
182
+ query_object = GPTVectorStoreIndexQuery(
183
+ index.index_struct,
184
+ service_context=service_context,
185
+ similarity_top_k=5,
186
+ vector_store=index._vector_store,
187
+ docstore=index._docstore,
188
+ )
189
+ query_bundle = QueryBundle(inputs)
190
+ nodes = query_object.retrieve(query_bundle)
191
+ reference_results = [n.node.text for n in nodes]
192
+ reference_results = add_source_numbers(reference_results, use_source=False)
193
+ display_reference = add_details(reference_results)
194
+ display_reference = "\n\n" + "".join(display_reference)
195
+ inputs = (
196
+ replace_today(PROMPT_TEMPLATE)
197
+ .replace("{query_str}", inputs)
198
+ .replace("{context_str}", "\n\n".join(reference_results))
199
+ .replace("{reply_language}", reply_language)
200
+ )
201
+ elif use_websearch:
202
+ limited_context = True
203
+ search_results = ddg(inputs, max_results=5)
204
+ old_inputs = inputs
205
+ reference_results = []
206
+ for idx, result in enumerate(search_results):
207
+ logging.debug(f"搜索结果{idx + 1}:{result}")
208
+ domain_name = urllib3.util.parse_url(result["href"]).host
209
+ reference_results.append([result["body"], result["href"]])
210
+ display_reference.append(
211
+ f"{idx+1}. [{domain_name}]({result['href']})\n"
212
+ )
213
+ reference_results = add_source_numbers(reference_results)
214
+ display_reference = "\n\n" + "".join(display_reference)
215
+ inputs = (
216
+ replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
217
+ .replace("{query}", inputs)
218
+ .replace("{web_results}", "\n\n".join(reference_results))
219
+ .replace("{reply_language}", reply_language)
220
+ )
221
+ else:
222
+ display_reference = ""
223
+
224
+ if len(self.api_key) == 0 and not shared.state.multi_api_key:
225
+ status_text = standard_error_msg + no_apikey_msg
226
+ logging.info(status_text)
227
+ chatbot.append((inputs, ""))
228
+ if len(self.history) == 0:
229
+ self.history.append(construct_user(inputs))
230
+ self.history.append("")
231
+ self.all_token_counts.append(0)
232
+ else:
233
+ self.history[-2] = construct_user(inputs)
234
+ yield chatbot + [(inputs, "")], status_text
235
+ return
236
+ elif len(inputs.strip()) == 0:
237
+ status_text = standard_error_msg + no_input_msg
238
+ logging.info(status_text)
239
+ yield chatbot + [(inputs, "")], status_text
240
+ return
241
+
242
+ self.history.append(construct_user(inputs))
243
+ self.history.append(construct_assistant(""))
244
+
245
+ if stream:
246
+ logging.debug("使用流式传输")
247
+ iter = self.stream_next_chatbot(
248
+ inputs,
249
+ chatbot,
250
+ fake_input=old_inputs,
251
+ display_append=display_reference,
252
+ )
253
+ for chatbot, status_text in iter:
254
+ yield chatbot, status_text
255
+ if self.interrupted:
256
+ self.recover()
257
+ break
258
+ else:
259
+ logging.debug("不使用流式传输")
260
+ chatbot, status_text = self.next_chatbot_at_once(
261
+ inputs,
262
+ chatbot,
263
+ fake_input=old_inputs,
264
+ display_append=display_reference,
265
+ )
266
+ yield chatbot, status_text
267
+
268
+ if len(self.history) > 1 and self.history[-1]["content"] != inputs:
269
+ logging.info(
270
+ "回答为:"
271
+ + colorama.Fore.BLUE
272
+ + f"{self.history[-1]['content']}"
273
+ + colorama.Style.RESET_ALL
274
+ )
275
+
276
+ if limited_context:
277
+ self.history = self.history[-4:]
278
+ self.all_token_counts = self.all_token_counts[-2:]
279
+
280
+
281
+ max_token = self.token_upper_limit - TOKEN_OFFSET
282
+
283
+ if sum(self.all_token_counts) > max_token and should_check_token_count:
284
+ count = 0
285
+ while sum(self.all_token_counts) > self.token_upper_limit * REDUCE_TOKEN_FACTOR and sum(self.all_token_counts) > 0:
286
+ count += 1
287
+ del self.all_token_counts[0]
288
+ del self.history[:2]
289
+ logging.info(status_text)
290
+ status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
291
+ yield chatbot, status_text
292
+
293
+ def retry(
294
+ self,
295
+ chatbot,
296
+ stream=False,
297
+ use_websearch=False,
298
+ files=None,
299
+ reply_language="中文",
300
+ ):
301
+ logging.info("重试中……")
302
+ if len(self.history) == 0:
303
+ yield chatbot, f"{standard_error_msg}上下文是空的"
304
+ return
305
+
306
+ del self.history[-2:]
307
+ inputs = chatbot[-1][0]
308
+ self.all_token_counts.pop()
309
+ iter = self.predict(
310
+ inputs,
311
+ chatbot,
312
+ stream=stream,
313
+ use_websearch=use_websearch,
314
+ files=files,
315
+ reply_language=reply_language,
316
+ )
317
+ for x in iter:
318
+ yield x
319
+ logging.info("重试完毕")
320
+
321
+ # def reduce_token_size(self, chatbot):
322
+ # logging.info("开始减少token数量……")
323
+ # chatbot, status_text = self.next_chatbot_at_once(
324
+ # summarize_prompt,
325
+ # chatbot
326
+ # )
327
+ # max_token_count = self.token_upper_limit * REDUCE_TOKEN_FACTOR
328
+ # num_chat = find_n(self.all_token_counts, max_token_count)
329
+ # logging.info(f"previous_token_count: {self.all_token_counts}, keeping {num_chat} chats")
330
+ # chatbot = chatbot[:-1]
331
+ # self.history = self.history[-2*num_chat:] if num_chat > 0 else []
332
+ # self.all_token_counts = self.all_token_counts[-num_chat:] if num_chat > 0 else []
333
+ # msg = f"保留了最近{num_chat}轮对话"
334
+ # logging.info(msg)
335
+ # logging.info("减少token数量完毕")
336
+ # return chatbot, msg + "," + self.token_message(self.all_token_counts if len(self.all_token_counts) > 0 else [0])
337
+
338
+ def interrupt(self):
339
+ self.interrupted = True
340
+
341
+ def recover(self):
342
+ self.interrupted = False
343
+
344
+ def set_temprature(self, new_temprature):
345
+ self.temperature = new_temprature
346
+
347
+ def set_top_p(self, new_top_p):
348
+ self.top_p = new_top_p
349
+
350
+ def reset(self):
351
+ self.history = []
352
+ self.all_token_counts = []
353
+ self.interrupted = False
354
+ return [], self.token_message([0])
355
+
356
+ def delete_first_conversation(self):
357
+ if self.history:
358
+ del self.history[:2]
359
+ del self.all_token_counts[0]
360
+ return self.token_message()
361
+
362
+ def delete_last_conversation(self, chatbot):
363
+ if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
364
+ msg = "由于包含报错信息,只删除chatbot记录"
365
+ chatbot.pop()
366
+ return chatbot, self.history
367
+ if len(self.history) > 0:
368
+ self.history.pop()
369
+ self.history.pop()
370
+ if len(chatbot) > 0:
371
+ msg = "删除了一组chatbot对话"
372
+ chatbot.pop()
373
+ if len(self.all_token_counts) > 0:
374
+ msg = "删除了一组对话的token计数记录"
375
+ self.all_token_counts.pop()
376
+ msg = "删除了一组对话"
377
+ return chatbot, msg
378
+
379
+ def token_message(self, token_lst = None):
380
+ if token_lst is None:
381
+ token_lst = self.all_token_counts
382
+ token_sum = 0
383
+ for i in range(len(token_lst)):
384
+ token_sum += sum(token_lst[: i + 1])
385
+ return f"Token 计数: {sum(token_lst)},本次对话累计消耗了 {token_sum} tokens"
386
+
387
+ def save_chat_history(self, filename, chatbot, user_name):
388
+ if filename == "":
389
+ return
390
+ if not filename.endswith(".json"):
391
+ filename += ".json"
392
+ return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
393
+
394
+ def export_markdown(self, filename, chatbot, user_name):
395
+ if filename == "":
396
+ return
397
+ if not filename.endswith(".md"):
398
+ filename += ".md"
399
+ return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
400
+
401
+ def load_chat_history(self, filename, chatbot, user_name):
402
+ logging.info(f"{user_name} 加载对话历史中……")
403
+ if type(filename) != str:
404
+ filename = filename.name
405
+ try:
406
+ with open(os.path.join(HISTORY_DIR / user_name, filename), "r") as f:
407
+ json_s = json.load(f)
408
+ try:
409
+ if type(json_s["history"][0]) == str:
410
+ logging.info("历史记录格式为旧版,正在转换……")
411
+ new_history = []
412
+ for index, item in enumerate(json_s["history"]):
413
+ if index % 2 == 0:
414
+ new_history.append(construct_user(item))
415
+ else:
416
+ new_history.append(construct_assistant(item))
417
+ json_s["history"] = new_history
418
+ logging.info(new_history)
419
+ except:
420
+ # 没有对话历史
421
+ pass
422
+ logging.info(f"{user_name} 加载对话历史完毕")
423
+ self.history = json_s["history"]
424
+ return filename, json_s["system"], json_s["chatbot"]
425
+ except FileNotFoundError:
426
+ logging.info(f"{user_name} 没有找到对话历史文件,不执行任何操作")
427
+ return filename, self.system_prompt, chatbot
modules/chat_func.py DELETED
@@ -1,497 +0,0 @@
1
- # -*- coding:utf-8 -*-
2
- from __future__ import annotations
3
- from typing import TYPE_CHECKING, List
4
-
5
- import logging
6
- import json
7
- import os
8
- import requests
9
- import urllib3
10
-
11
- from tqdm import tqdm
12
- import colorama
13
- from duckduckgo_search import ddg
14
- import asyncio
15
- import aiohttp
16
-
17
-
18
- from modules.presets import *
19
- from modules.llama_func import *
20
- from modules.utils import *
21
- from . import shared
22
- from modules.config import retrieve_proxy
23
-
24
- # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
25
-
26
- if TYPE_CHECKING:
27
- from typing import TypedDict
28
-
29
- class DataframeData(TypedDict):
30
- headers: List[str]
31
- data: List[List[str | int | bool]]
32
-
33
-
34
- initial_prompt = "You are a helpful assistant."
35
- HISTORY_DIR = "history"
36
- TEMPLATES_DIR = "templates"
37
-
38
- @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
39
- def get_response(
40
- openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
41
- ):
42
- headers = {
43
- "Content-Type": "application/json",
44
- "Authorization": f"Bearer {openai_api_key}",
45
- }
46
-
47
- history = [construct_system(system_prompt), *history]
48
-
49
- payload = {
50
- "model": selected_model,
51
- "messages": history, # [{"role": "user", "content": f"{inputs}"}],
52
- "temperature": temperature, # 1.0,
53
- "top_p": top_p, # 1.0,
54
- "n": 1,
55
- "stream": stream,
56
- "presence_penalty": 0,
57
- "frequency_penalty": 0,
58
- }
59
- if stream:
60
- timeout = timeout_streaming
61
- else:
62
- timeout = timeout_all
63
-
64
-
65
- # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
66
- if shared.state.completion_url != COMPLETION_URL:
67
- logging.info(f"使用自定义API URL: {shared.state.completion_url}")
68
-
69
- with retrieve_proxy():
70
- response = requests.post(
71
- shared.state.completion_url,
72
- headers=headers,
73
- json=payload,
74
- stream=True,
75
- timeout=timeout,
76
- )
77
-
78
- return response
79
-
80
-
81
- def stream_predict(
82
- openai_api_key,
83
- system_prompt,
84
- history,
85
- inputs,
86
- chatbot,
87
- all_token_counts,
88
- top_p,
89
- temperature,
90
- selected_model,
91
- fake_input=None,
92
- display_append=""
93
- ):
94
- def get_return_value():
95
- return chatbot, history, status_text, all_token_counts
96
-
97
- logging.info("实时回答模式")
98
- partial_words = ""
99
- counter = 0
100
- status_text = "开始实时传输回答……"
101
- history.append(construct_user(inputs))
102
- history.append(construct_assistant(""))
103
- if fake_input:
104
- chatbot.append((fake_input, ""))
105
- else:
106
- chatbot.append((inputs, ""))
107
- user_token_count = 0
108
- if fake_input is not None:
109
- input_token_count = count_token(construct_user(fake_input))
110
- else:
111
- input_token_count = count_token(construct_user(inputs))
112
- if len(all_token_counts) == 0:
113
- system_prompt_token_count = count_token(construct_system(system_prompt))
114
- user_token_count = (
115
- input_token_count + system_prompt_token_count
116
- )
117
- else:
118
- user_token_count = input_token_count
119
- all_token_counts.append(user_token_count)
120
- logging.info(f"输入token计数: {user_token_count}")
121
- yield get_return_value()
122
- try:
123
- response = get_response(
124
- openai_api_key,
125
- system_prompt,
126
- history,
127
- temperature,
128
- top_p,
129
- True,
130
- selected_model,
131
- )
132
- except requests.exceptions.ConnectTimeout:
133
- status_text = (
134
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
135
- )
136
- yield get_return_value()
137
- return
138
- except requests.exceptions.ReadTimeout:
139
- status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
140
- yield get_return_value()
141
- return
142
-
143
- yield get_return_value()
144
- error_json_str = ""
145
-
146
- if fake_input is not None:
147
- history[-2] = construct_user(fake_input)
148
- for chunk in tqdm(response.iter_lines()):
149
- if counter == 0:
150
- counter += 1
151
- continue
152
- counter += 1
153
- # check whether each line is non-empty
154
- if chunk:
155
- chunk = chunk.decode()
156
- chunklength = len(chunk)
157
- try:
158
- chunk = json.loads(chunk[6:])
159
- except json.JSONDecodeError:
160
- logging.info(chunk)
161
- error_json_str += chunk
162
- status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
163
- yield get_return_value()
164
- continue
165
- # decode each line as response data is in bytes
166
- if chunklength > 6 and "delta" in chunk["choices"][0]:
167
- finish_reason = chunk["choices"][0]["finish_reason"]
168
- status_text = construct_token_message(all_token_counts)
169
- if finish_reason == "stop":
170
- yield get_return_value()
171
- break
172
- try:
173
- partial_words = (
174
- partial_words + chunk["choices"][0]["delta"]["content"]
175
- )
176
- except KeyError:
177
- status_text = (
178
- standard_error_msg
179
- + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
180
- + str(sum(all_token_counts))
181
- )
182
- yield get_return_value()
183
- break
184
- history[-1] = construct_assistant(partial_words)
185
- chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
186
- all_token_counts[-1] += 1
187
- yield get_return_value()
188
-
189
-
190
- def predict_all(
191
- openai_api_key,
192
- system_prompt,
193
- history,
194
- inputs,
195
- chatbot,
196
- all_token_counts,
197
- top_p,
198
- temperature,
199
- selected_model,
200
- fake_input=None,
201
- display_append=""
202
- ):
203
- logging.info("一次性回答模式")
204
- history.append(construct_user(inputs))
205
- history.append(construct_assistant(""))
206
- if fake_input:
207
- chatbot.append((fake_input, ""))
208
- else:
209
- chatbot.append((inputs, ""))
210
- if fake_input is not None:
211
- all_token_counts.append(count_token(construct_user(fake_input)))
212
- else:
213
- all_token_counts.append(count_token(construct_user(inputs)))
214
- try:
215
- response = get_response(
216
- openai_api_key,
217
- system_prompt,
218
- history,
219
- temperature,
220
- top_p,
221
- False,
222
- selected_model,
223
- )
224
- except requests.exceptions.ConnectTimeout:
225
- status_text = (
226
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
227
- )
228
- return chatbot, history, status_text, all_token_counts
229
- except requests.exceptions.ProxyError:
230
- status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
231
- return chatbot, history, status_text, all_token_counts
232
- except requests.exceptions.SSLError:
233
- status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
234
- return chatbot, history, status_text, all_token_counts
235
- response = json.loads(response.text)
236
- if fake_input is not None:
237
- history[-2] = construct_user(fake_input)
238
- try:
239
- content = response["choices"][0]["message"]["content"]
240
- history[-1] = construct_assistant(content)
241
- chatbot[-1] = (chatbot[-1][0], content+display_append)
242
- total_token_count = response["usage"]["total_tokens"]
243
- if fake_input is not None:
244
- all_token_counts[-1] += count_token(construct_assistant(content))
245
- else:
246
- all_token_counts[-1] = total_token_count - sum(all_token_counts)
247
- status_text = construct_token_message(total_token_count)
248
- return chatbot, history, status_text, all_token_counts
249
- except KeyError:
250
- status_text = standard_error_msg + str(response)
251
- return chatbot, history, status_text, all_token_counts
252
-
253
-
254
- def predict(
255
- openai_api_key,
256
- system_prompt,
257
- history,
258
- inputs,
259
- chatbot,
260
- all_token_counts,
261
- top_p,
262
- temperature,
263
- stream=False,
264
- selected_model=MODELS[0],
265
- use_websearch=False,
266
- files = None,
267
- reply_language="中文",
268
- should_check_token_count=True,
269
- ): # repetition_penalty, top_k
270
- from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
271
- from llama_index.indices.query.schema import QueryBundle
272
- from langchain.llms import OpenAIChat
273
-
274
-
275
- logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
276
- if should_check_token_count:
277
- yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
278
- if reply_language == "跟随问题语言(不稳定)":
279
- reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
280
- old_inputs = None
281
- display_reference = []
282
- limited_context = False
283
- if files:
284
- limited_context = True
285
- old_inputs = inputs
286
- msg = "加载索引中……(这可能需要几分钟)"
287
- logging.info(msg)
288
- yield chatbot+[(inputs, "")], history, msg, all_token_counts
289
- index = construct_index(openai_api_key, file_src=files)
290
- msg = "索引构建完成,获取回答中……"
291
- logging.info(msg)
292
- yield chatbot+[(inputs, "")], history, msg, all_token_counts
293
- with retrieve_proxy():
294
- llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
295
- prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
296
- from llama_index import ServiceContext
297
- service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
298
- query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
299
- query_bundle = QueryBundle(inputs)
300
- nodes = query_object.retrieve(query_bundle)
301
- reference_results = [n.node.text for n in nodes]
302
- reference_results = add_source_numbers(reference_results, use_source=False)
303
- display_reference = add_details(reference_results)
304
- display_reference = "\n\n" + "".join(display_reference)
305
- inputs = (
306
- replace_today(PROMPT_TEMPLATE)
307
- .replace("{query_str}", inputs)
308
- .replace("{context_str}", "\n\n".join(reference_results))
309
- .replace("{reply_language}", reply_language )
310
- )
311
- elif use_websearch:
312
- limited_context = True
313
- search_results = ddg(inputs, max_results=5)
314
- old_inputs = inputs
315
- reference_results = []
316
- for idx, result in enumerate(search_results):
317
- logging.info(f"搜索结果{idx + 1}:{result}")
318
- domain_name = urllib3.util.parse_url(result["href"]).host
319
- reference_results.append([result["body"], result["href"]])
320
- display_reference.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
321
- reference_results = add_source_numbers(reference_results)
322
- display_reference = "\n\n" + "".join(display_reference)
323
- inputs = (
324
- replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
325
- .replace("{query}", inputs)
326
- .replace("{web_results}", "\n\n".join(reference_results))
327
- .replace("{reply_language}", reply_language )
328
- )
329
- else:
330
- display_reference = ""
331
-
332
- if len(openai_api_key) == 0 and not shared.state.multi_api_key:
333
- status_text = standard_error_msg + no_apikey_msg
334
- logging.info(status_text)
335
- chatbot.append((inputs, ""))
336
- if len(history) == 0:
337
- history.append(construct_user(inputs))
338
- history.append("")
339
- all_token_counts.append(0)
340
- else:
341
- history[-2] = construct_user(inputs)
342
- yield chatbot+[(inputs, "")], history, status_text, all_token_counts
343
- return
344
- elif len(inputs.strip()) == 0:
345
- status_text = standard_error_msg + no_input_msg
346
- logging.info(status_text)
347
- yield chatbot+[(inputs, "")], history, status_text, all_token_counts
348
- return
349
-
350
- if stream:
351
- logging.info("使用流式传输")
352
- iter = stream_predict(
353
- openai_api_key,
354
- system_prompt,
355
- history,
356
- inputs,
357
- chatbot,
358
- all_token_counts,
359
- top_p,
360
- temperature,
361
- selected_model,
362
- fake_input=old_inputs,
363
- display_append=display_reference
364
- )
365
- for chatbot, history, status_text, all_token_counts in iter:
366
- if shared.state.interrupted:
367
- shared.state.recover()
368
- return
369
- yield chatbot, history, status_text, all_token_counts
370
- else:
371
- logging.info("不使用流式传输")
372
- chatbot, history, status_text, all_token_counts = predict_all(
373
- openai_api_key,
374
- system_prompt,
375
- history,
376
- inputs,
377
- chatbot,
378
- all_token_counts,
379
- top_p,
380
- temperature,
381
- selected_model,
382
- fake_input=old_inputs,
383
- display_append=display_reference
384
- )
385
- yield chatbot, history, status_text, all_token_counts
386
-
387
- logging.info(f"传输完毕。当前token计数为{all_token_counts}")
388
- if len(history) > 1 and history[-1]["content"] != inputs:
389
- logging.info(
390
- "回答为:"
391
- + colorama.Fore.BLUE
392
- + f"{history[-1]['content']}"
393
- + colorama.Style.RESET_ALL
394
- )
395
-
396
- if limited_context:
397
- history = history[-4:]
398
- all_token_counts = all_token_counts[-2:]
399
- yield chatbot, history, status_text, all_token_counts
400
-
401
- if stream:
402
- max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["streaming"]
403
- else:
404
- max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
405
-
406
- if sum(all_token_counts) > max_token and should_check_token_count:
407
- print(all_token_counts)
408
- count = 0
409
- while sum(all_token_counts) > max_token - 500 and sum(all_token_counts) > 0:
410
- count += 1
411
- del all_token_counts[0]
412
- del history[:2]
413
- logging.info(status_text)
414
- status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
415
- yield chatbot, history, status_text, all_token_counts
416
-
417
-
418
- def retry(
419
- openai_api_key,
420
- system_prompt,
421
- history,
422
- chatbot,
423
- token_count,
424
- top_p,
425
- temperature,
426
- stream=False,
427
- selected_model=MODELS[0],
428
- reply_language="中文",
429
- ):
430
- logging.info("重试中……")
431
- if len(history) == 0:
432
- yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
433
- return
434
- history.pop()
435
- inputs = history.pop()["content"]
436
- token_count.pop()
437
- iter = predict(
438
- openai_api_key,
439
- system_prompt,
440
- history,
441
- inputs,
442
- chatbot,
443
- token_count,
444
- top_p,
445
- temperature,
446
- stream=stream,
447
- selected_model=selected_model,
448
- reply_language=reply_language,
449
- )
450
- logging.info("重试中……")
451
- for x in iter:
452
- yield x
453
- logging.info("重试完毕")
454
-
455
-
456
- def reduce_token_size(
457
- openai_api_key,
458
- system_prompt,
459
- history,
460
- chatbot,
461
- token_count,
462
- top_p,
463
- temperature,
464
- max_token_count,
465
- selected_model=MODELS[0],
466
- reply_language="中文",
467
- ):
468
- logging.info("开始减少token数量……")
469
- iter = predict(
470
- openai_api_key,
471
- system_prompt,
472
- history,
473
- summarize_prompt,
474
- chatbot,
475
- token_count,
476
- top_p,
477
- temperature,
478
- selected_model=selected_model,
479
- should_check_token_count=False,
480
- reply_language=reply_language,
481
- )
482
- logging.info(f"chatbot: {chatbot}")
483
- flag = False
484
- for chatbot, history, status_text, previous_token_count in iter:
485
- num_chat = find_n(previous_token_count, max_token_count)
486
- logging.info(f"previous_token_count: {previous_token_count}, keeping {num_chat} chats")
487
- if flag:
488
- chatbot = chatbot[:-1]
489
- flag = True
490
- history = history[-2*num_chat:] if num_chat > 0 else []
491
- token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
492
- msg = f"保留了最近{num_chat}轮对话"
493
- yield chatbot, history, msg + "," + construct_token_message(
494
- token_count if len(token_count) > 0 else [0],
495
- ), token_count
496
- logging.info(msg)
497
- logging.info("减少token数量完毕")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/config.py CHANGED
@@ -3,7 +3,7 @@ from contextlib import contextmanager
3
  import os
4
  import logging
5
  import sys
6
- import json
7
 
8
  from . import shared
9
 
 
3
  import os
4
  import logging
5
  import sys
6
+ import commentjson as json
7
 
8
  from . import shared
9
 
modules/llama_func.py CHANGED
@@ -44,40 +44,44 @@ def get_documents(file_src):
44
  filename = os.path.basename(filepath)
45
  file_type = os.path.splitext(filepath)[1]
46
  logging.info(f"loading file: {filename}")
47
- if file_type == ".pdf":
48
- logging.debug("Loading PDF...")
49
- try:
50
- from modules.pdf_func import parse_pdf
51
- from modules.config import advance_docs
52
- two_column = advance_docs["pdf"].get("two_column", False)
53
- pdftext = parse_pdf(filepath, two_column).text
54
- except:
55
- pdftext = ""
56
- with open(filepath, 'rb') as pdfFileObj:
57
- pdfReader = PyPDF2.PdfReader(pdfFileObj)
58
- for page in tqdm(pdfReader.pages):
59
- pdftext += page.extract_text()
60
- text_raw = pdftext
61
- elif file_type == ".docx":
62
- logging.debug("Loading Word...")
63
- DocxReader = download_loader("DocxReader")
64
- loader = DocxReader()
65
- text_raw = loader.load_data(file=filepath)[0].text
66
- elif file_type == ".epub":
67
- logging.debug("Loading EPUB...")
68
- EpubReader = download_loader("EpubReader")
69
- loader = EpubReader()
70
- text_raw = loader.load_data(file=filepath)[0].text
71
- elif file_type == ".xlsx":
72
- logging.debug("Loading Excel...")
73
- text_list = excel_to_string(filepath)
74
- for elem in text_list:
75
- documents.append(Document(elem))
76
- continue
77
- else:
78
- logging.debug("Loading text file...")
79
- with open(filepath, "r", encoding="utf-8") as f:
80
- text_raw = f.read()
 
 
 
 
81
  text = add_space(text_raw)
82
  # text = block_split(text)
83
  # documents += text
 
44
  filename = os.path.basename(filepath)
45
  file_type = os.path.splitext(filepath)[1]
46
  logging.info(f"loading file: {filename}")
47
+ try:
48
+ if file_type == ".pdf":
49
+ logging.debug("Loading PDF...")
50
+ try:
51
+ from modules.pdf_func import parse_pdf
52
+ from modules.config import advance_docs
53
+ two_column = advance_docs["pdf"].get("two_column", False)
54
+ pdftext = parse_pdf(filepath, two_column).text
55
+ except:
56
+ pdftext = ""
57
+ with open(filepath, 'rb') as pdfFileObj:
58
+ pdfReader = PyPDF2.PdfReader(pdfFileObj)
59
+ for page in tqdm(pdfReader.pages):
60
+ pdftext += page.extract_text()
61
+ text_raw = pdftext
62
+ elif file_type == ".docx":
63
+ logging.debug("Loading Word...")
64
+ DocxReader = download_loader("DocxReader")
65
+ loader = DocxReader()
66
+ text_raw = loader.load_data(file=filepath)[0].text
67
+ elif file_type == ".epub":
68
+ logging.debug("Loading EPUB...")
69
+ EpubReader = download_loader("EpubReader")
70
+ loader = EpubReader()
71
+ text_raw = loader.load_data(file=filepath)[0].text
72
+ elif file_type == ".xlsx":
73
+ logging.debug("Loading Excel...")
74
+ text_list = excel_to_string(filepath)
75
+ for elem in text_list:
76
+ documents.append(Document(elem))
77
+ continue
78
+ else:
79
+ logging.debug("Loading text file...")
80
+ with open(filepath, "r", encoding="utf-8") as f:
81
+ text_raw = f.read()
82
+ except Exception as e:
83
+ logging.error(f"Error loading file: {filename}")
84
+ pass
85
  text = add_space(text_raw)
86
  # text = block_split(text)
87
  # documents += text
modules/models.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, List
3
+
4
+ import logging
5
+ import json
6
+ import commentjson as cjson
7
+ import os
8
+ import sys
9
+ import requests
10
+ import urllib3
11
+
12
+ from tqdm import tqdm
13
+ import colorama
14
+ from duckduckgo_search import ddg
15
+ import asyncio
16
+ import aiohttp
17
+ from enum import Enum
18
+
19
+ from .presets import *
20
+ from .llama_func import *
21
+ from .utils import *
22
+ from . import shared
23
+ from .config import retrieve_proxy
24
+ from .base_model import BaseLLMModel, ModelType
25
+
26
+
27
+ class OpenAIClient(BaseLLMModel):
28
+ def __init__(
29
+ self, model_name, api_key, system_prompt=INITIAL_SYSTEM_PROMPT, temperature=1.0, top_p=1.0
30
+ ) -> None:
31
+ super().__init__(model_name=model_name, temperature=temperature, top_p=top_p, system_prompt=system_prompt)
32
+ self.api_key = api_key
33
+ self.completion_url = shared.state.completion_url
34
+ self.usage_api_url = shared.state.usage_api_url
35
+ self.headers = {
36
+ "Content-Type": "application/json",
37
+ "Authorization": f"Bearer {self.api_key}",
38
+ }
39
+
40
+
41
+ def get_answer_stream_iter(self):
42
+ response = self._get_response(stream=True)
43
+ if response is not None:
44
+ iter = self._decode_chat_response(response)
45
+ partial_text = ""
46
+ for i in iter:
47
+ partial_text += i
48
+ yield partial_text
49
+ else:
50
+ yield standard_error_msg + general_error_msg
51
+
52
+ def get_answer_at_once(self):
53
+ response = self._get_response()
54
+ response = json.loads(response.text)
55
+ content = response["choices"][0]["message"]["content"]
56
+ total_token_count = response["usage"]["total_tokens"]
57
+ return content, total_token_count
58
+
59
+ def count_token(self, user_input):
60
+ input_token_count = count_token(construct_user(user_input))
61
+ if self.system_prompt is not None and len(self.all_token_counts) == 0:
62
+ system_prompt_token_count = count_token(construct_system(self.system_prompt))
63
+ return input_token_count + system_prompt_token_count
64
+ return input_token_count
65
+
66
+ def set_system_prompt(self, new_system_prompt):
67
+ self.system_prompt = new_system_prompt
68
+
69
+ def billing_info(self):
70
+ try:
71
+ curr_time = datetime.datetime.now()
72
+ last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
73
+ first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
74
+ usage_url = f"{self.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
75
+ try:
76
+ usage_data = self._get_billing_data(usage_url)
77
+ except Exception as e:
78
+ logging.error(f"获取API使用情况失败:"+str(e))
79
+ return f"**获取API使用情况失败**"
80
+ rounded_usage = "{:.5f}".format(usage_data['total_usage']/100)
81
+ return f"**本月使用金额** \u3000 ${rounded_usage}"
82
+ except requests.exceptions.ConnectTimeout:
83
+ status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
84
+ return status_text
85
+ except requests.exceptions.ReadTimeout:
86
+ status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
87
+ return status_text
88
+ except Exception as e:
89
+ logging.error(f"获取API使用情况失败:"+str(e))
90
+ return standard_error_msg + error_retrieve_prompt
91
+
92
+ @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
93
+ def _get_response(self, stream=False):
94
+ openai_api_key = self.api_key
95
+ system_prompt = self.system_prompt
96
+ history = self.history
97
+ logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
98
+ temperature = self.temperature
99
+ top_p = self.top_p
100
+ selected_model = self.model_name
101
+ headers = {
102
+ "Content-Type": "application/json",
103
+ "Authorization": f"Bearer {openai_api_key}",
104
+ }
105
+
106
+ if system_prompt is not None:
107
+ history = [construct_system(system_prompt), *history]
108
+
109
+ payload = {
110
+ "model": selected_model,
111
+ "messages": history, # [{"role": "user", "content": f"{inputs}"}],
112
+ "temperature": temperature, # 1.0,
113
+ "top_p": top_p, # 1.0,
114
+ "n": 1,
115
+ "stream": stream,
116
+ "presence_penalty": 0,
117
+ "frequency_penalty": 0,
118
+ }
119
+ if stream:
120
+ timeout = timeout_streaming
121
+ else:
122
+ timeout = TIMEOUT_ALL
123
+
124
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
125
+ if shared.state.completion_url != COMPLETION_URL:
126
+ logging.info(f"使用自定义API URL: {shared.state.completion_url}")
127
+
128
+ with retrieve_proxy():
129
+ try:
130
+ response = requests.post(
131
+ shared.state.completion_url,
132
+ headers=headers,
133
+ json=payload,
134
+ stream=stream,
135
+ timeout=timeout,
136
+ )
137
+ except:
138
+ return None
139
+ return response
140
+
141
+ def _get_billing_data(self, usage_url):
142
+ with retrieve_proxy():
143
+ response = requests.get(
144
+ usage_url,
145
+ headers=self.headers,
146
+ timeout=TIMEOUT_ALL,
147
+ )
148
+
149
+ if response.status_code == 200:
150
+ data = response.json()
151
+ return data
152
+ else:
153
+ raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
154
+
155
+ def _decode_chat_response(self, response):
156
+ for chunk in response.iter_lines():
157
+ if chunk:
158
+ chunk = chunk.decode()
159
+ chunk_length = len(chunk)
160
+ try:
161
+ chunk = json.loads(chunk[6:])
162
+ except json.JSONDecodeError:
163
+ print(f"JSON解析错误,收到的内容: {chunk}")
164
+ continue
165
+ if chunk_length > 6 and "delta" in chunk["choices"][0]:
166
+ if chunk["choices"][0]["finish_reason"] == "stop":
167
+ break
168
+ try:
169
+ yield chunk["choices"][0]["delta"]["content"]
170
+ except Exception as e:
171
+ # logging.error(f"Error: {e}")
172
+ continue
173
+
174
+ def get_model(model_name, access_key=None, temprature=None, top_p=None, system_prompt = None) -> BaseLLMModel:
175
+ model_type = ModelType.get_type(model_name)
176
+ if model_type == ModelType.OpenAI:
177
+ model = OpenAIClient(model_name, access_key, system_prompt, temprature, top_p)
178
+ return model
179
+
180
+ if __name__=="__main__":
181
+ with open("config.json", "r") as f:
182
+ openai_api_key = cjson.load(f)["openai_api_key"]
183
+ client = OpenAIClient("gpt-3.5-turbo", openai_api_key)
184
+ chatbot = []
185
+ stream = False
186
+ # 测试账单功能
187
+ print(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
188
+ print(client.billing_info())
189
+ # 测试问答
190
+ print(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET)
191
+ question = "巴黎是中国的首都吗?"
192
+ for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
193
+ print(i)
194
+ print(f"测试问答后history : {client.history}")
195
+ # 测试记忆力
196
+ print(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET)
197
+ question = "我刚刚问了你什么问题?"
198
+ for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
199
+ print(i)
200
+ print(f"测试记忆力后history : {client.history}")
201
+ # 测试重试功能
202
+ print(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET)
203
+ for i in client.retry(chatbot=chatbot, stream=stream):
204
+ print(i)
205
+ print(f"重试后history : {client.history}")
206
+ # # 测试总结功能
207
+ # print(colorama.Back.GREEN + "测试总结功能" + colorama.Back.RESET)
208
+ # chatbot, msg = client.reduce_token_size(chatbot=chatbot)
209
+ # print(chatbot, msg)
210
+ # print(f"总结后history: {client.history}")
modules/openai_func.py DELETED
@@ -1,65 +0,0 @@
1
- import requests
2
- import logging
3
- from modules.presets import (
4
- timeout_all,
5
- USAGE_API_URL,
6
- BALANCE_API_URL,
7
- standard_error_msg,
8
- connection_timeout_prompt,
9
- error_retrieve_prompt,
10
- read_timeout_prompt
11
- )
12
-
13
- from . import shared
14
- from modules.config import retrieve_proxy
15
- import os, datetime
16
-
17
- def get_billing_data(openai_api_key, billing_url):
18
- headers = {
19
- "Content-Type": "application/json",
20
- "Authorization": f"Bearer {openai_api_key}"
21
- }
22
-
23
- timeout = timeout_all
24
- with retrieve_proxy():
25
- response = requests.get(
26
- billing_url,
27
- headers=headers,
28
- timeout=timeout,
29
- )
30
-
31
- if response.status_code == 200:
32
- data = response.json()
33
- return data
34
- else:
35
- raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
36
-
37
-
38
- def get_usage(openai_api_key):
39
- try:
40
- curr_time = datetime.datetime.now()
41
- last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
42
- first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
43
- usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
44
- try:
45
- usage_data = get_billing_data(openai_api_key, usage_url)
46
- except Exception as e:
47
- logging.error(f"获取API使用情况失败:"+str(e))
48
- return f"**获取API使用情况失败**"
49
- rounded_usage = "{:.5f}".format(usage_data['total_usage']/100)
50
- return f"**本月使用金额** \u3000 ${rounded_usage}"
51
- except requests.exceptions.ConnectTimeout:
52
- status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
53
- return status_text
54
- except requests.exceptions.ReadTimeout:
55
- status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
56
- return status_text
57
- except Exception as e:
58
- logging.error(f"获取API使用情况失败:"+str(e))
59
- return standard_error_msg + error_retrieve_prompt
60
-
61
- def get_last_day_of_month(any_day):
62
- # The day 28 exists in every month. 4 days later, it's always next month
63
- next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
64
- # subtracting the number of the current day brings us back one month
65
- return next_month - datetime.timedelta(days=next_month.day)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/presets.py CHANGED
@@ -3,26 +3,29 @@ import gradio as gr
3
  from pathlib import Path
4
 
5
  # ChatGPT 设置
6
- initial_prompt = "You are a helpful assistant."
7
  API_HOST = "api.openai.com"
8
  COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
9
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
10
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
11
  HISTORY_DIR = Path("history")
 
12
  TEMPLATES_DIR = "templates"
13
 
14
  # 错误信息
15
  standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
16
- error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
 
17
  connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
18
  read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
19
  proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
20
  ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
21
  no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
22
  no_input_msg = "请输入对话内容。" # 未输入对话内容
 
23
 
24
  timeout_streaming = 10 # 流式对话时的超时时间
25
- timeout_all = 200 # 非流式对话时的超时时间
26
  enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
27
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
28
  CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
@@ -57,33 +60,18 @@ MODELS = [
57
  "gpt-4-32k-0314",
58
  ] # 可选的模型
59
 
60
- MODEL_SOFT_TOKEN_LIMIT = {
61
- "gpt-3.5-turbo": {
62
- "streaming": 3500,
63
- "all": 3500
64
- },
65
- "gpt-3.5-turbo-0301": {
66
- "streaming": 3500,
67
- "all": 3500
68
- },
69
- "gpt-4": {
70
- "streaming": 7500,
71
- "all": 7500
72
- },
73
- "gpt-4-0314": {
74
- "streaming": 7500,
75
- "all": 7500
76
- },
77
- "gpt-4-32k": {
78
- "streaming": 31000,
79
- "all": 31000
80
- },
81
- "gpt-4-32k-0314": {
82
- "streaming": 31000,
83
- "all": 31000
84
- }
85
  }
86
 
 
 
 
87
  REPLY_LANGUAGES = [
88
  "简体中文",
89
  "繁體中文",
 
3
  from pathlib import Path
4
 
5
  # ChatGPT 设置
6
+ INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
7
  API_HOST = "api.openai.com"
8
  COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
9
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
10
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
11
  HISTORY_DIR = Path("history")
12
+ HISTORY_DIR = "history"
13
  TEMPLATES_DIR = "templates"
14
 
15
  # 错误信息
16
  standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
17
+ general_error_msg = "获取对话时发生错误,请查看后台日志"
18
+ error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。"
19
  connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
20
  read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
21
  proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
22
  ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
23
  no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
24
  no_input_msg = "请输入对话内容。" # 未输入对话内容
25
+ billing_not_applicable_msg = "模型本地运行中" # 本地运行的模型返回的账单信息
26
 
27
  timeout_streaming = 10 # 流式对话时的超时时间
28
+ TIMEOUT_ALL = 200 # 非流式对话时的超时时间
29
  enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
30
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
31
  CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
 
60
  "gpt-4-32k-0314",
61
  ] # 可选的模型
62
 
63
+ MODEL_TOKEN_LIMIT = {
64
+ "gpt-3.5-turbo": 4096,
65
+ "gpt-3.5-turbo-0301": 4096,
66
+ "gpt-4": 8192,
67
+ "gpt-4-0314": 8192,
68
+ "gpt-4-32k": 32768,
69
+ "gpt-4-32k-0314": 32768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  }
71
 
72
+ TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
73
+ REDUCE_TOKEN_FACTOR = 0.5 # 与模型token上限想乘,得到目标token数。减少token占用时,将token占用减少到目标token数以下。
74
+
75
  REPLY_LANGUAGES = [
76
  "简体中文",
77
  "繁體中文",
modules/utils.py CHANGED
@@ -153,47 +153,6 @@ def construct_assistant(text):
153
  return construct_text("assistant", text)
154
 
155
 
156
- def construct_token_message(tokens: List[int]):
157
- token_sum = 0
158
- for i in range(len(tokens)):
159
- token_sum += sum(tokens[: i + 1])
160
- return f"Token 计数: {sum(tokens)},本次对话累计消耗了 {token_sum} tokens"
161
-
162
-
163
- def delete_first_conversation(history, previous_token_count):
164
- if history:
165
- del history[:2]
166
- del previous_token_count[0]
167
- return (
168
- history,
169
- previous_token_count,
170
- construct_token_message(previous_token_count),
171
- )
172
-
173
-
174
- def delete_last_conversation(chatbot, history, previous_token_count):
175
- if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
176
- logging.info("由于包含报错信息,只删除chatbot记录")
177
- chatbot.pop()
178
- return chatbot, history
179
- if len(history) > 0:
180
- logging.info("删除了一组对话历史")
181
- history.pop()
182
- history.pop()
183
- if len(chatbot) > 0:
184
- logging.info("删除了一组chatbot对话")
185
- chatbot.pop()
186
- if len(previous_token_count) > 0:
187
- logging.info("删除了一组对话的token计数记录")
188
- previous_token_count.pop()
189
- return (
190
- chatbot,
191
- history,
192
- previous_token_count,
193
- construct_token_message(previous_token_count),
194
- )
195
-
196
-
197
  def save_file(filename, system, history, chatbot, user_name):
198
  logging.info(f"{user_name} 保存对话历史中……")
199
  os.makedirs(HISTORY_DIR / user_name, exist_ok=True)
@@ -212,56 +171,12 @@ def save_file(filename, system, history, chatbot, user_name):
212
  return os.path.join(HISTORY_DIR / user_name, filename)
213
 
214
 
215
- def save_chat_history(filename, system, history, chatbot, user_name):
216
- if filename == "":
217
- return
218
- if not filename.endswith(".json"):
219
- filename += ".json"
220
- return save_file(filename, system, history, chatbot, user_name)
221
-
222
-
223
- def export_markdown(filename, system, history, chatbot, user_name):
224
- if filename == "":
225
- return
226
- if not filename.endswith(".md"):
227
- filename += ".md"
228
- return save_file(filename, system, history, chatbot, user_name)
229
-
230
-
231
- def load_chat_history(filename, system, history, chatbot, user_name):
232
- logging.info(f"{user_name} 加载对话历史中……")
233
- if type(filename) != str:
234
- filename = filename.name
235
- try:
236
- with open(os.path.join(HISTORY_DIR / user_name, filename), "r") as f:
237
- json_s = json.load(f)
238
- try:
239
- if type(json_s["history"][0]) == str:
240
- logging.info("历史记录格式为旧版,正在转换……")
241
- new_history = []
242
- for index, item in enumerate(json_s["history"]):
243
- if index % 2 == 0:
244
- new_history.append(construct_user(item))
245
- else:
246
- new_history.append(construct_assistant(item))
247
- json_s["history"] = new_history
248
- logging.info(new_history)
249
- except:
250
- # 没有对话历史
251
- pass
252
- logging.info(f"{user_name} 加载对话历史完毕")
253
- return filename, json_s["system"], json_s["history"], json_s["chatbot"]
254
- except FileNotFoundError:
255
- logging.info(f"{user_name} 没有找到对话历史文件,不执行任何操作")
256
- return filename, system, history, chatbot
257
-
258
-
259
  def sorted_by_pinyin(list):
260
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
261
 
262
 
263
  def get_file_names(dir, plain=False, filetypes=[".json"]):
264
- logging.info(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
265
  files = []
266
  try:
267
  for type in filetypes:
@@ -279,14 +194,13 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
279
 
280
 
281
  def get_history_names(plain=False, user_name=""):
282
- logging.info(f"从用户 {user_name} 中获取历史记录文件名列表")
283
- return get_file_names(HISTORY_DIR / user_name, plain)
284
 
285
 
286
  def load_template(filename, mode=0):
287
- logging.info(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
288
  lines = []
289
- logging.info("Loading template...")
290
  if filename.endswith(".json"):
291
  with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
292
  lines = json.load(f)
@@ -310,7 +224,7 @@ def load_template(filename, mode=0):
310
 
311
 
312
  def get_template_names(plain=False):
313
- logging.info("获取模板文件名列表")
314
  return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
315
 
316
 
@@ -322,11 +236,6 @@ def get_template_content(templates, selection, original_system_prompt):
322
  return original_system_prompt
323
 
324
 
325
- def reset_state():
326
- logging.info("重置状态")
327
- return [], [], [], construct_token_message([0])
328
-
329
-
330
  def reset_textbox():
331
  logging.debug("重置文本框")
332
  return gr.update(value="")
@@ -530,3 +439,9 @@ def excel_to_string(file_path):
530
 
531
 
532
  return result
 
 
 
 
 
 
 
153
  return construct_text("assistant", text)
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def save_file(filename, system, history, chatbot, user_name):
157
  logging.info(f"{user_name} 保存对话历史中……")
158
  os.makedirs(HISTORY_DIR / user_name, exist_ok=True)
 
171
  return os.path.join(HISTORY_DIR / user_name, filename)
172
 
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def sorted_by_pinyin(list):
175
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
176
 
177
 
178
  def get_file_names(dir, plain=False, filetypes=[".json"]):
179
+ logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
180
  files = []
181
  try:
182
  for type in filetypes:
 
194
 
195
 
196
  def get_history_names(plain=False, user_name=""):
197
+ logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
198
+ return get_file_names(os.path.join(HISTORY_DIR, user_name), plain)
199
 
200
 
201
  def load_template(filename, mode=0):
202
+ logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
203
  lines = []
 
204
  if filename.endswith(".json"):
205
  with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
206
  lines = json.load(f)
 
224
 
225
 
226
  def get_template_names(plain=False):
227
+ logging.debug("获取模���文件名列表")
228
  return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
229
 
230
 
 
236
  return original_system_prompt
237
 
238
 
 
 
 
 
 
239
  def reset_textbox():
240
  logging.debug("重置文本框")
241
  return gr.update(value="")
 
439
 
440
 
441
  return result
442
+
443
+ def get_last_day_of_month(any_day):
444
+ # The day 28 exists in every month. 4 days later, it's always next month
445
+ next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
446
+ # subtracting the number of the current day brings us back one month
447
+ return next_month - datetime.timedelta(days=next_month.day)