Tuchuanhuhuhu commited on
Commit
ab74909
1 Parent(s): 3b0ad60

改进代码质量

Browse files
Files changed (3) hide show
  1. ChuanhuChatbot.py +12 -12
  2. presets.py +9 -1
  3. utils.py +52 -35
ChuanhuChatbot.py CHANGED
@@ -42,13 +42,6 @@ else:
42
  gr.Chatbot.postprocess = postprocess
43
 
44
  with gr.Blocks(css=customCSS) as demo:
45
- gr.HTML(title)
46
- with gr.Row():
47
- with gr.Column(scale=4):
48
- keyTxt = gr.Textbox(show_label=False, placeholder=f"在这里输入你的OpenAI API-key...",value=my_api_key, type="password", visible=not HIDE_MY_KEY).style(container=True)
49
- with gr.Column(scale=1):
50
- use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
51
- chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
52
  history = gr.State([])
53
  token_count = gr.State([])
54
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
@@ -56,6 +49,13 @@ with gr.Blocks(css=customCSS) as demo:
56
  FALSECONSTANT = gr.State(False)
57
  topic = gr.State("未命名对话历史记录")
58
 
 
 
 
 
 
 
 
59
  with gr.Row():
60
  with gr.Column(scale=12):
61
  user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
@@ -68,8 +68,9 @@ with gr.Blocks(css=customCSS) as demo:
68
  delLastBtn = gr.Button("🗑️ 删除最近一条对话")
69
  reduceTokenBtn = gr.Button("♻️ 总结对话")
70
  status_display = gr.Markdown("status: ready")
71
- systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...",
72
- label="System prompt", value=initial_prompt).style(container=True)
 
73
  with gr.Accordion(label="加载Prompt模板", open=False):
74
  with gr.Column():
75
  with gr.Row():
@@ -100,11 +101,10 @@ with gr.Blocks(css=customCSS) as demo:
100
  #inputs, top_p, temperature, top_k, repetition_penalty
101
  with gr.Accordion("参数", open=False):
102
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
103
- interactive=True, label="Top-p (nucleus sampling)",)
104
  temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
105
  step=0.1, interactive=True, label="Temperature",)
106
- #top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
107
- #repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
108
  gr.Markdown(description)
109
 
110
 
 
42
  gr.Chatbot.postprocess = postprocess
43
 
44
  with gr.Blocks(css=customCSS) as demo:
 
 
 
 
 
 
 
45
  history = gr.State([])
46
  token_count = gr.State([])
47
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
 
49
  FALSECONSTANT = gr.State(False)
50
  topic = gr.State("未命名对话历史记录")
51
 
52
+ gr.HTML(title)
53
+ with gr.Row():
54
+ with gr.Column(scale=4):
55
+ keyTxt = gr.Textbox(show_label=False, placeholder=f"在这里输入你的OpenAI API-key...",value=my_api_key, type="password", visible=not HIDE_MY_KEY).style(container=False)
56
+ with gr.Column(scale=1):
57
+ use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
58
+ chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
59
  with gr.Row():
60
  with gr.Column(scale=12):
61
  user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
 
68
  delLastBtn = gr.Button("🗑️ 删除最近一条对话")
69
  reduceTokenBtn = gr.Button("♻️ 总结对话")
70
  status_display = gr.Markdown("status: ready")
71
+
72
+ systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...", label="System prompt", value=initial_prompt).style(container=True)
73
+
74
  with gr.Accordion(label="加载Prompt模板", open=False):
75
  with gr.Column():
76
  with gr.Row():
 
101
  #inputs, top_p, temperature, top_k, repetition_penalty
102
  with gr.Accordion("参数", open=False):
103
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
104
+ interactive=True, label="Top-p (nucleus sampling)",)
105
  temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
106
  step=0.1, interactive=True, label="Temperature",)
107
+
 
108
  gr.Markdown(description)
109
 
110
 
presets.py CHANGED
@@ -31,9 +31,17 @@ pre code {
31
  }
32
  """
33
 
 
 
 
34
  standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
35
  error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
36
- summarize_prompt = "请总结以上对话,不超过100字。" # 总结对话时的 prompt
 
 
 
 
 
37
  max_token_streaming = 3500 # 流式对话时的最大 token 数
38
  timeout_streaming = 15 # 流式对话时的超时时间
39
  max_token_all = 3500 # 非流式对话时的最大 token 数
 
31
  }
32
  """
33
 
34
+ summarize_prompt = "请总结以上对话,不超过100字。" # 总结对话时的 prompt
35
+
36
+ # 错误信息
37
  standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
38
  error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
39
+ connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
40
+ read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
41
+ proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
42
+ ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
43
+ no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
44
+
45
  max_token_streaming = 3500 # 流式对话时的最大 token 数
46
  timeout_streaming = 15 # 流式对话时的超时时间
47
  max_token_all = 3500 # 非流式对话时的最大 token 数
utils.py CHANGED
@@ -124,37 +124,37 @@ def get_response(openai_api_key, system_prompt, history, temperature, top_p, str
124
  response = requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=timeout)
125
  return response
126
 
127
- def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, previous_token_count, top_p, temperature):
128
  def get_return_value():
129
- return chatbot, history, status_text, [*previous_token_count, token_counter]
130
 
131
  print("实时回答模式")
132
- token_counter = 0
133
  partial_words = ""
134
  counter = 0
135
  status_text = "开始实时传输回答……"
136
  history.append(construct_user(inputs))
 
 
137
  user_token_count = 0
138
- if len(previous_token_count) == 0:
139
  system_prompt_token_count = count_token(system_prompt)
140
  user_token_count = count_token(inputs) + system_prompt_token_count
141
  else:
142
  user_token_count = count_token(inputs)
 
143
  print(f"输入token计数: {user_token_count}")
 
144
  try:
145
  response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
146
  except requests.exceptions.ConnectTimeout:
147
- history.pop()
148
- status_text = standard_error_msg + "连接超时,无法获取对话。" + error_retrieve_prompt
149
  yield get_return_value()
150
  return
151
  except requests.exceptions.ReadTimeout:
152
- history.pop()
153
- status_text = standard_error_msg + "读取超时,无法获取对话。" + error_retrieve_prompt
154
  yield get_return_value()
155
  return
156
 
157
- chatbot.append((parse_text(inputs), ""))
158
  yield get_return_value()
159
 
160
  for chunk in tqdm(response.iter_lines()):
@@ -169,13 +169,14 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
169
  try:
170
  chunk = json.loads(chunk[6:])
171
  except json.JSONDecodeError:
 
172
  status_text = f"JSON解析错误。请重置对话。收到的内容: {chunk}"
173
  yield get_return_value()
174
- break
175
  # decode each line as response data is in bytes
176
  if chunklength > 6 and "delta" in chunk['choices'][0]:
177
  finish_reason = chunk['choices'][0]['finish_reason']
178
- status_text = construct_token_message(sum(previous_token_count)+token_counter+user_token_count, stream=True)
179
  if finish_reason == "stop":
180
  print("生成完毕")
181
  yield get_return_value()
@@ -183,60 +184,76 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
183
  try:
184
  partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
185
  except KeyError:
186
- status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(previous_token_count)+token_counter+user_token_count)
187
  yield get_return_value()
188
  break
189
- if token_counter == 0:
190
- history.append(construct_assistant(" " + partial_words))
191
- else:
192
- history[-1] = construct_assistant(partial_words)
193
  chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
194
- token_counter += 1
195
  yield get_return_value()
196
 
197
 
198
- def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, previous_token_count, top_p, temperature):
199
  print("一次性回答模式")
200
  history.append(construct_user(inputs))
 
 
 
201
  try:
202
  response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False)
203
  except requests.exceptions.ConnectTimeout:
204
  status_text = standard_error_msg + error_retrieve_prompt
205
- return chatbot, history, status_text, previous_token_count
 
 
 
 
 
206
  response = json.loads(response.text)
207
  content = response["choices"][0]["message"]["content"]
208
- history.append(construct_assistant(content))
209
  chatbot.append((parse_text(inputs), parse_text(content)))
210
  total_token_count = response["usage"]["total_tokens"]
211
- previous_token_count.append(total_token_count - sum(previous_token_count))
212
  status_text = construct_token_message(total_token_count)
213
  print("生成一次性回答完毕")
214
- return chatbot, history, status_text, previous_token_count
215
 
216
 
217
- def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=False, should_check_token_count = True): # repetition_penalty, top_k
218
  print("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
 
 
 
 
 
 
 
 
 
 
219
  if stream:
220
  print("使用流式传输")
221
- iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
222
- for chatbot, history, status_text, token_count in iter:
223
- yield chatbot, history, status_text, token_count
224
  else:
225
  print("不使用流式传输")
226
- chatbot, history, status_text, token_count = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
227
- yield chatbot, history, status_text, token_count
228
- print(f"传输完毕。当前token计数为{token_count}")
229
- print("回答为:" +colorama.Fore.BLUE + f"{history[-1]['content']}" + colorama.Style.RESET_ALL)
 
230
  if stream:
231
  max_token = max_token_streaming
232
  else:
233
  max_token = max_token_all
234
- if sum(token_count) > max_token and should_check_token_count:
235
- print(f"精简token中{token_count}/{max_token}")
236
- iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=True)
237
- for chatbot, history, status_text, token_count in iter:
238
  status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
239
- yield chatbot, history, status_text, token_count
240
 
241
 
242
  def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False):
 
124
  response = requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=timeout)
125
  return response
126
 
127
+ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature):
128
  def get_return_value():
129
+ return chatbot, history, status_text, all_token_counts
130
 
131
  print("实时回答模式")
 
132
  partial_words = ""
133
  counter = 0
134
  status_text = "开始实时传输回答……"
135
  history.append(construct_user(inputs))
136
+ history.append(construct_assistant(""))
137
+ chatbot.append((parse_text(inputs), ""))
138
  user_token_count = 0
139
+ if len(all_token_counts) == 0:
140
  system_prompt_token_count = count_token(system_prompt)
141
  user_token_count = count_token(inputs) + system_prompt_token_count
142
  else:
143
  user_token_count = count_token(inputs)
144
+ all_token_counts.append(user_token_count)
145
  print(f"输入token计数: {user_token_count}")
146
+ yield get_return_value()
147
  try:
148
  response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
149
  except requests.exceptions.ConnectTimeout:
150
+ status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
 
151
  yield get_return_value()
152
  return
153
  except requests.exceptions.ReadTimeout:
154
+ status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
 
155
  yield get_return_value()
156
  return
157
 
 
158
  yield get_return_value()
159
 
160
  for chunk in tqdm(response.iter_lines()):
 
169
  try:
170
  chunk = json.loads(chunk[6:])
171
  except json.JSONDecodeError:
172
+ print(chunk)
173
  status_text = f"JSON解析错误。请重置对话。收到的内容: {chunk}"
174
  yield get_return_value()
175
+ continue
176
  # decode each line as response data is in bytes
177
  if chunklength > 6 and "delta" in chunk['choices'][0]:
178
  finish_reason = chunk['choices'][0]['finish_reason']
179
+ status_text = construct_token_message(sum(all_token_counts), stream=True)
180
  if finish_reason == "stop":
181
  print("生成完毕")
182
  yield get_return_value()
 
184
  try:
185
  partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
186
  except KeyError:
187
+ status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(all_token_counts))
188
  yield get_return_value()
189
  break
190
+ history[-1] = construct_assistant(partial_words)
 
 
 
191
  chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
192
+ all_token_counts[-1] += 1
193
  yield get_return_value()
194
 
195
 
196
+ def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature):
197
  print("一次性回答模式")
198
  history.append(construct_user(inputs))
199
+ history.append(construct_assistant(""))
200
+ chatbot.append((parse_text(inputs), ""))
201
+ all_token_counts.append(count_token(inputs))
202
  try:
203
  response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False)
204
  except requests.exceptions.ConnectTimeout:
205
  status_text = standard_error_msg + error_retrieve_prompt
206
+ return chatbot, history, status_text, all_token_counts
207
+ except requests.exceptions.ProxyError:
208
+ status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
209
+ except requests.exceptions.SSLError:
210
+ status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
211
+ return chatbot, history, status_text, all_token_counts
212
  response = json.loads(response.text)
213
  content = response["choices"][0]["message"]["content"]
214
+ history[-1] = construct_assistant(content)
215
  chatbot.append((parse_text(inputs), parse_text(content)))
216
  total_token_count = response["usage"]["total_tokens"]
217
+ all_token_counts[-1] = total_token_count - sum(all_token_counts)
218
  status_text = construct_token_message(total_token_count)
219
  print("生成一次性回答完毕")
220
+ return chatbot, history, status_text, all_token_counts
221
 
222
 
223
+ def predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, stream=False, should_check_token_count = True): # repetition_penalty, top_k
224
  print("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
225
+ if len(openai_api_key) != 51:
226
+ status_text = standard_error_msg + no_apikey_msg
227
+ print(status_text)
228
+ history.append(construct_user(inputs))
229
+ history.append("")
230
+ chatbot.append((parse_text(inputs), ""))
231
+ all_token_counts.append(0)
232
+ yield chatbot, history, status_text, all_token_counts
233
+ return
234
+ yield chatbot, history, "开始生成回答……", all_token_counts
235
  if stream:
236
  print("使用流式传输")
237
+ iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature)
238
+ for chatbot, history, status_text, all_token_counts in iter:
239
+ yield chatbot, history, status_text, all_token_counts
240
  else:
241
  print("不使用流式传输")
242
+ chatbot, history, status_text, all_token_counts = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature)
243
+ yield chatbot, history, status_text, all_token_counts
244
+ print(f"传输完毕。当前token计数为{all_token_counts}")
245
+ if len(history) > 1 and history[-1]['content'] != inputs:
246
+ print("回答为:" +colorama.Fore.BLUE + f"{history[-1]['content']}" + colorama.Style.RESET_ALL)
247
  if stream:
248
  max_token = max_token_streaming
249
  else:
250
  max_token = max_token_all
251
+ if sum(all_token_counts) > max_token and should_check_token_count:
252
+ print(f"精简token中{all_token_counts}/{max_token}")
253
+ iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot, all_token_counts, top_p, temperature, stream=False, hidden=True)
254
+ for chatbot, history, status_text, all_token_counts in iter:
255
  status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
256
+ yield chatbot, history, status_text, all_token_counts
257
 
258
 
259
  def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False):