Chuan Hu commited on
Commit
41d7759
1 Parent(s): 943240e

重大更新:支持像官方网页那样实时传输了;改进的保存/加载机制

Browse files
Files changed (1) hide show
  1. ChuanhuChatbot.py +182 -169
ChuanhuChatbot.py CHANGED
@@ -1,14 +1,15 @@
1
  import json
2
  import gradio as gr
3
- import openai
4
  import os
5
  import sys
6
  import traceback
7
- # import markdown
8
 
9
- my_api_key = "" # 在这里输入你的 API 密钥
10
  initial_prompt = "You are a helpful assistant."
11
 
 
 
12
  if my_api_key == "":
13
  my_api_key = os.environ.get('my_api_key')
14
 
@@ -16,17 +17,11 @@ if my_api_key == "empty":
16
  print("Please give a api key!")
17
  sys.exit(1)
18
 
19
- if my_api_key == "":
20
- initial_keytxt = None
21
- elif len(str(my_api_key)) == 51:
22
- initial_keytxt = "默认api-key(未验证):" + str(my_api_key[:4] + "..." + my_api_key[-4:])
23
- else:
24
- initial_keytxt = "默认api-key无效,请重新输入"
25
 
26
  def parse_text(text):
27
  lines = text.split("\n")
28
  count = 0
29
- for i,line in enumerate(lines):
30
  if "```" in line:
31
  count += 1
32
  items = line.split('`')
@@ -46,190 +41,208 @@ def parse_text(text):
46
  lines[i] = '<br/>'+line
47
  return "".join(lines)
48
 
49
- def get_response(system, context, myKey, raw = False):
50
- openai.api_key = myKey
51
- response = openai.ChatCompletion.create(
52
- model="gpt-3.5-turbo",
53
- messages=[system, *context],
54
- )
55
- openai.api_key = ""
56
- if raw:
57
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  else:
59
- statistics = f'本次对话Tokens用量【{response["usage"]["total_tokens"]} / 4096】 ( 提问+上文 {response["usage"]["prompt_tokens"]},回答 {response["usage"]["completion_tokens"]} )'
60
- message = response["choices"][0]["message"]["content"]
61
-
62
- message_with_stats = f'{message}\n\n================\n\n{statistics}'
63
- # message_with_stats = markdown.markdown(message_with_stats)
64
-
65
- return message, parse_text(message_with_stats)
66
-
67
- def predict(chatbot, input_sentence, system, context, myKey):
68
- if len(input_sentence) == 0:
69
- return []
70
- context.append({"role": "user", "content": f"{input_sentence}"})
71
-
72
- try:
73
- message, message_with_stats = get_response(system, context, myKey)
74
- except openai.error.AuthenticationError:
75
- chatbot.append((input_sentence, "请求失败,请检查API-key是否正确。"))
76
- return chatbot, context
77
- except openai.error.Timeout:
78
- chatbot.append((input_sentence, "请求超时,请检查网络连接。"))
79
- return chatbot, context
80
- except openai.error.APIConnectionError:
81
- chatbot.append((input_sentence, "连接失败,请检查网络连接。"))
82
- return chatbot, context
83
- except openai.error.RateLimitError:
84
- chatbot.append((input_sentence, "请求过于频繁,请5s后再试。"))
85
- return chatbot, context
86
- except:
87
- chatbot.append((input_sentence, "发生了未知错误Orz"))
88
- return chatbot, context
89
-
90
- context.append({"role": "assistant", "content": message})
91
-
92
- chatbot.append((input_sentence, message_with_stats))
93
-
94
- return chatbot, context
95
-
96
- def retry(chatbot, system, context, myKey):
97
- if len(context) == 0:
98
- return [], []
99
-
100
- try:
101
- message, message_with_stats = get_response(system, context[:-1], myKey)
102
- except openai.error.AuthenticationError:
103
- chatbot.append(("重试请求", "请求失败,请检查API-key是否正确。"))
104
- return chatbot, context
105
- except openai.error.Timeout:
106
- chatbot.append(("重试请求", "请求超时,请检查网络连接。"))
107
- return chatbot, context
108
- except openai.error.APIConnectionError:
109
- chatbot.append(("重试请求", "连接失败,请检查网络连接。"))
110
- return chatbot, context
111
- except openai.error.RateLimitError:
112
- chatbot.append(("重���请求", "请求过于频繁,请5s后再试。"))
113
- return chatbot, context
114
- except:
115
- chatbot.append(("重试请求", "发生了未知错误Orz"))
116
- return chatbot, context
117
-
118
- context[-1] = {"role": "assistant", "content": message}
119
-
120
- chatbot[-1] = (context[-2]["content"], message_with_stats)
121
- return chatbot, context
122
-
123
- def delete_last_conversation(chatbot, context):
124
- if len(context) == 0:
125
- return [], []
126
- chatbot = chatbot[:-1]
127
- context = context[:-2]
128
- return chatbot, context
129
-
130
- def reduce_token(chatbot, system, context, myKey):
131
- context.append({"role": "user", "content": "请帮我总结一下上述对话的内容,实现减少tokens的同时,保证对话的质量。在总结中不要加入这一句话。"})
132
-
133
- response = get_response(system, context, myKey, raw=True)
134
-
135
- statistics = f'本次对话Tokens用量【{response["usage"]["completion_tokens"]+12+12+8} / 4096】'
136
- optmz_str = parse_text( f'好的,我们之前聊了:{response["choices"][0]["message"]["content"]}\n\n================\n\n{statistics}' )
137
- chatbot.append(("请帮我总结一下上述对话的内容,实现减少tokens的同时,保证对话的质量。", optmz_str))
138
-
139
- context = []
140
- context.append({"role": "user", "content": "我们之前聊了什么?"})
141
- context.append({"role": "assistant", "content": f'我们之前聊了:{response["choices"][0]["message"]["content"]}'})
142
- return chatbot, context
143
-
144
- def save_chat_history(filepath, system, context):
145
  if filepath == "":
146
  return
147
- history = {"system": system, "context": context}
148
- with open(f"{filepath}.json", "w") as f:
149
- json.dump(history, f)
 
 
150
 
151
- def load_chat_history(fileobj):
152
- with open(fileobj.name, "r") as f:
153
- history = json.load(f)
154
- context = history["context"]
155
- chathistory = []
156
- for i in range(0, len(context), 2):
157
- chathistory.append((parse_text(context[i]["content"]), parse_text(context[i+1]["content"])))
158
- return chathistory , history["system"], context, history["system"]["content"]
159
 
160
- def get_history_names():
161
- with open("history.json", "r") as f:
162
- history = json.load(f)
163
- return list(history.keys())
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  def reset_state():
167
  return [], []
168
 
169
- def update_system(new_system_prompt):
170
- return {"role": "system", "content": new_system_prompt}
171
-
172
- def set_apikey(new_api_key, myKey):
173
- old_api_key = myKey
174
-
175
- try:
176
- get_response(update_system(initial_prompt), [{"role": "user", "content": "test"}], new_api_key)
177
- except openai.error.AuthenticationError:
178
- return "无效的api-key", myKey
179
- except openai.error.Timeout:
180
- return "请求超时,请检查网络设置", myKey
181
- except openai.error.APIConnectionError:
182
- return "网络错误", myKey
183
- except:
184
- return "发生了未知错误Orz", myKey
185
-
186
- encryption_str = "验证成功,api-key已做遮挡处理:" + new_api_key[:4] + "..." + new_api_key[-4:]
187
- return encryption_str, new_api_key
188
 
189
 
190
  with gr.Blocks() as demo:
191
- keyTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入你的OpenAI API-key...", value=initial_keytxt, label="API Key").style(container=True)
192
- chatbot = gr.Chatbot().style(color_map=("#1D51EE", "#585A5B"))
193
- context = gr.State([])
194
- systemPrompt = gr.State(update_system(initial_prompt))
195
- myKey = gr.State(my_api_key)
 
196
  topic = gr.State("未命名对话历史记录")
197
 
198
  with gr.Row():
199
  with gr.Column(scale=12):
200
- txt = gr.Textbox(show_label=False, placeholder="在这里输入").style(container=False)
 
201
  with gr.Column(min_width=50, scale=1):
202
  submitBtn = gr.Button("🚀", variant="primary")
203
  with gr.Row():
204
  emptyBtn = gr.Button("🧹 新的对话")
205
  retryBtn = gr.Button("🔄 重新生成")
206
  delLastBtn = gr.Button("🗑️ 删除上条对话")
207
- reduceTokenBtn = gr.Button("♻️ 优化Tokens")
208
- newSystemPrompt = gr.Textbox(show_label=True, placeholder=f"在这里输入新的System Prompt...", label="更改 System prompt").style(container=True)
209
- systemPromptDisplay = gr.Textbox(show_label=True, value=initial_prompt, interactive=False, label="目前的 System prompt").style(container=True)
210
- with gr.Accordion(label="保存/加载对话历史记录(在文本框中输入文件名,点击“保存对话”按钮,历史记录文件会被存储到本地)", open=False):
211
  with gr.Column():
212
  with gr.Row():
213
  with gr.Column(scale=6):
214
- saveFileName = gr.Textbox(show_label=True, placeholder=f"在这里输入保存的文件名...", label="保存对话", value="对话历史记录").style(container=True)
 
215
  with gr.Column(scale=1):
216
  saveBtn = gr.Button("💾 保存对话")
217
- uploadBtn = gr.UploadButton("📂 读取对话", file_count="single", file_types=["json"])
218
-
219
- txt.submit(predict, [chatbot, txt, systemPrompt, context, myKey], [chatbot, context], show_progress=True)
220
- txt.submit(lambda :"", None, txt)
221
- submitBtn.click(predict, [chatbot, txt, systemPrompt, context, myKey], [chatbot, context], show_progress=True)
222
- submitBtn.click(lambda :"", None, txt)
223
- emptyBtn.click(reset_state, outputs=[chatbot, context])
224
- newSystemPrompt.submit(update_system, newSystemPrompt, systemPrompt)
225
- newSystemPrompt.submit(lambda x: x, newSystemPrompt, systemPromptDisplay)
226
- newSystemPrompt.submit(lambda :"", None, newSystemPrompt)
227
- retryBtn.click(retry, [chatbot, systemPrompt, context, myKey], [chatbot, context], show_progress=True)
228
- delLastBtn.click(delete_last_conversation, [chatbot, context], [chatbot, context], show_progress=True)
229
- reduceTokenBtn.click(reduce_token, [chatbot, systemPrompt, context, myKey], [chatbot, context], show_progress=True)
230
- keyTxt.submit(set_apikey, [keyTxt, myKey], [keyTxt, myKey], show_progress=True)
231
- uploadBtn.upload(load_chat_history, uploadBtn, [chatbot, systemPrompt, context, systemPromptDisplay], show_progress=True)
232
- saveBtn.click(save_chat_history, [saveFileName, systemPrompt, context], None, show_progress=True)
233
-
234
-
235
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import gradio as gr
 
3
  import os
4
  import sys
5
  import traceback
6
+ import requests
7
 
8
+ my_api_key = "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" # 在这里输入你的 API 密钥
9
  initial_prompt = "You are a helpful assistant."
10
 
11
+ API_URL = "https://api.openai.com/v1/chat/completions"
12
+
13
  if my_api_key == "":
14
  my_api_key = os.environ.get('my_api_key')
15
 
 
17
  print("Please give a api key!")
18
  sys.exit(1)
19
 
 
 
 
 
 
 
20
 
21
  def parse_text(text):
22
  lines = text.split("\n")
23
  count = 0
24
+ for i, line in enumerate(lines):
25
  if "```" in line:
26
  count += 1
27
  items = line.split('`')
 
41
  lines[i] = '<br/>'+line
42
  return "".join(lines)
43
 
44
+ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[], system_prompt=initial_prompt, retry=False, summary=False): # repetition_penalty, top_k
45
+
46
+ headers = {
47
+ "Content-Type": "application/json",
48
+ "Authorization": f"Bearer {openai_api_key}"
49
+ }
50
+
51
+ chat_counter = len(history) // 2
52
+
53
+ print(f"chat_counter - {chat_counter}")
54
+
55
+ messages = [compose_system(system_prompt)]
56
+ if chat_counter:
57
+ for data in chatbot:
58
+ temp1 = {}
59
+ temp1["role"] = "user"
60
+ temp1["content"] = data[0]
61
+ temp2 = {}
62
+ temp2["role"] = "assistant"
63
+ temp2["content"] = data[1]
64
+ if temp1["content"] != "":
65
+ messages.append(temp1)
66
+ messages.append(temp2)
67
+ else:
68
+ messages[-1]['content'] = temp2['content']
69
+ if retry and chat_counter:
70
+ messages.pop()
71
+ elif summary and chat_counter:
72
+ messages.append(compose_user(
73
+ "请帮我总结一下上述对话的内容,实现减少字数的同时,保证对话的质量。在总结中不要加入这一句话。"))
74
+ history = ["我们刚刚聊了什么?"]
75
  else:
76
+ temp3 = {}
77
+ temp3["role"] = "user"
78
+ temp3["content"] = inputs
79
+ messages.append(temp3)
80
+ chat_counter += 1
81
+ # messages
82
+ payload = {
83
+ "model": "gpt-3.5-turbo",
84
+ "messages": messages, # [{"role": "user", "content": f"{inputs}"}],
85
+ "temperature": temperature, # 1.0,
86
+ "top_p": top_p, # 1.0,
87
+ "n": 1,
88
+ "stream": True,
89
+ "presence_penalty": 0,
90
+ "frequency_penalty": 0,
91
+ }
92
+
93
+ if not summary:
94
+ history.append(inputs)
95
+ print(f"payload is - {payload}")
96
+ # make a POST request to the API endpoint using the requests.post method, passing in stream=True
97
+ response = requests.post(API_URL, headers=headers,
98
+ json=payload, stream=True)
99
+ #response = requests.post(API_URL, headers=headers, json=payload, stream=True)
100
+
101
+ token_counter = 0
102
+ partial_words = ""
103
+
104
+ counter = 0
105
+ chatbot.append((history[-1], ""))
106
+ for chunk in response.iter_lines():
107
+ if counter == 0:
108
+ counter += 1
109
+ continue
110
+ counter += 1
111
+ # check whether each line is non-empty
112
+ if chunk:
113
+ # decode each line as response data is in bytes
114
+ if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0:
115
+ break
116
+ #print(json.loads(chunk.decode()[6:])['choices'][0]["delta"] ["content"])
117
+ partial_words = partial_words + \
118
+ json.loads(chunk.decode()[6:])[
119
+ 'choices'][0]["delta"]["content"]
120
+ if token_counter == 0:
121
+ history.append(" " + partial_words)
122
+ else:
123
+ history[-1] = partial_words
124
+ chatbot[-1] = (history[-2], history[-1])
125
+ # chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list
126
+ token_counter += 1
127
+ # resembles {chatbot: chat, state: history}
128
+ yield chatbot, history
129
+
130
+
131
+
132
+ def delete_last_conversation(chatbot, history):
133
+ if chat_counter > 0:
134
+ chat_counter -= 1
135
+ chatbot.pop()
136
+ history.pop()
137
+ history.pop()
138
+ return chatbot, history
139
+
140
+ def save_chat_history(filepath, system, history, chatbot):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  if filepath == "":
142
  return
143
+ if not filepath.endswith(".json"):
144
+ filepath += ".json"
145
+ json_s = {"system": system, "history": history, "chatbot": chatbot}
146
+ with open(filepath, "w") as f:
147
+ json.dump(json_s, f)
148
 
 
 
 
 
 
 
 
 
149
 
150
+ def load_chat_history(filename):
151
+ with open(filename, "r") as f:
152
+ json_s = json.load(f)
153
+ return filename, json_s["system"], json_s["history"], json_s["chatbot"]
154
+
155
+
156
+ def get_history_names(plain=False):
157
+ # find all json files in the current directory and return their names
158
+ files = [f for f in os.listdir() if f.endswith(".json")]
159
+ if plain:
160
+ return files
161
+ else:
162
+ return gr.Dropdown.update(choices=files)
163
 
164
 
165
  def reset_state():
166
  return [], []
167
 
168
+
169
+ def compose_system(system_prompt):
170
+ return {"role": "system", "content": system_prompt}
171
+
172
+
173
+ def compose_user(user_input):
174
+ return {"role": "user", "content": user_input}
175
+
176
+
177
+ def reset_textbox():
178
+ return gr.update(value='')
 
 
 
 
 
 
 
 
179
 
180
 
181
  with gr.Blocks() as demo:
182
+ keyTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入你的OpenAI API-key...",
183
+ value=my_api_key, label="API Key", type="password").style(container=True)
184
+ chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
185
+ history = gr.State([])
186
+ TRUECOMSTANT = gr.State(True)
187
+ FALSECONSTANT = gr.State(False)
188
  topic = gr.State("未命名对话历史记录")
189
 
190
  with gr.Row():
191
  with gr.Column(scale=12):
192
+ txt = gr.Textbox(show_label=False, placeholder="在这里输入").style(
193
+ container=False)
194
  with gr.Column(min_width=50, scale=1):
195
  submitBtn = gr.Button("🚀", variant="primary")
196
  with gr.Row():
197
  emptyBtn = gr.Button("🧹 新的对话")
198
  retryBtn = gr.Button("🔄 重新生成")
199
  delLastBtn = gr.Button("🗑️ 删除上条对话")
200
+ reduceTokenBtn = gr.Button("♻️ 总结对话")
201
+ systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...",
202
+ label="System prompt", value=initial_prompt).style(container=True)
203
+ with gr.Accordion(label="保存/加载对话历史记录(在文本框中输入文件名,点击“保存对话”按钮,历史记录文件会被存储到Python文件旁边)", open=False):
204
  with gr.Column():
205
  with gr.Row():
206
  with gr.Column(scale=6):
207
+ saveFileName = gr.Textbox(
208
+ show_label=True, placeholder=f"在这里输入保存的文件名...", label="设置保存文件名", value="对话历史记录").style(container=True)
209
  with gr.Column(scale=1):
210
  saveBtn = gr.Button("💾 保存对话")
211
+ with gr.Row():
212
+ with gr.Column(scale=6):
213
+ uploadDropdown = gr.Dropdown(label="从列表中加载对话", choices=get_history_names(plain=True), multiselect=False)
214
+ with gr.Column(scale=1):
215
+ refreshBtn = gr.Button("🔄 刷新")
216
+ uploadBtn = gr.Button("📂 读取对话")
217
+ #inputs, top_p, temperature, top_k, repetition_penalty
218
+ with gr.Accordion("参数", open=False):
219
+ top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
220
+ interactive=True, label="Top-p (nucleus sampling)",)
221
+ temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
222
+ step=0.1, interactive=True, label="Temperature",)
223
+ #top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
224
+ #repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
225
+
226
+ txt.submit(predict, [txt, top_p, temperature, keyTxt,
227
+ chatbot, history, systemPromptTxt], [chatbot, history])
228
+ txt.submit(reset_textbox, [], [txt])
229
+ submitBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot,
230
+ history, systemPromptTxt], [chatbot, history], show_progress=True)
231
+ submitBtn.click(reset_textbox, [], [txt])
232
+ emptyBtn.click(reset_state, outputs=[chatbot, history])
233
+ retryBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot, history,
234
+ systemPromptTxt, TRUECOMSTANT], [chatbot, history], show_progress=True)
235
+ delLastBtn.click(delete_last_conversation, [chatbot, history], [
236
+ chatbot, history], show_progress=True)
237
+ reduceTokenBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot, history,
238
+ systemPromptTxt, FALSECONSTANT, TRUECOMSTANT], [chatbot, history], show_progress=True)
239
+ saveBtn.click(save_chat_history, [
240
+ saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
241
+ saveBtn.click(get_history_names, None, [uploadDropdown])
242
+ refreshBtn.click(get_history_names, None, [uploadDropdown])
243
+ # uploadBtn.upload(load_chat_history, uploadBtn, [
244
+ # saveFileName, systemPromptTxt, history, chatbot], show_progress=True)
245
+ uploadBtn.click(load_chat_history, [uploadDropdown], [saveFileName, systemPromptTxt, history, chatbot], show_progress=True)
246
+
247
+
248
+ demo.queue().launch(debug=True)