Tuchuanhuhuhu commited on
Commit
2cb5d92
1 Parent(s): 41d7a7a

加入在线搜索功能

Browse files
Files changed (4) hide show
  1. ChuanhuChatbot.py +8 -7
  2. presets.py +8 -0
  3. requirements.txt +1 -0
  4. utils.py +12 -1
ChuanhuChatbot.py CHANGED
@@ -70,8 +70,8 @@ with gr.Blocks(css=customCSS,) as demo:
70
  interactive=True, label="Top-p (nucleus sampling)",)
71
  temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
72
  step=0.1, interactive=True, label="Temperature",)
73
-
74
-
75
  with gr.Accordion(label="加载Prompt模板", open=False):
76
  with gr.Column():
77
  with gr.Row():
@@ -101,9 +101,10 @@ with gr.Blocks(css=customCSS,) as demo:
101
  historyReadBtn = gr.Button("📂 读入对话")
102
 
103
  use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
104
-
 
 
105
 
106
-
107
  with gr.Column(scale=5):
108
  with gr.Row(scale=1):
109
  chatbot = gr.Chatbot().style(height=700) # .style(color_map=("#1D51EE", "#585A5B"))
@@ -118,7 +119,7 @@ with gr.Blocks(css=customCSS,) as demo:
118
  retryBtn = gr.Button("🔄 重新生成")
119
  delLastBtn = gr.Button("🗑️ 删除最近一条对话")
120
  reduceTokenBtn = gr.Button("♻️ 总结对话")
121
-
122
 
123
  gr.HTML("""
124
  <div style="text-align: center; margin-top: 20px; margin-bottom: 20px;">
@@ -126,10 +127,10 @@ with gr.Blocks(css=customCSS,) as demo:
126
  gr.Markdown(description)
127
 
128
 
129
- user_input.submit(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
130
  user_input.submit(reset_textbox, [], [user_input])
131
 
132
- submitBtn.click(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
133
  submitBtn.click(reset_textbox, [], [user_input])
134
 
135
  emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
 
70
  interactive=True, label="Top-p (nucleus sampling)",)
71
  temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
72
  step=0.1, interactive=True, label="Temperature",)
73
+
74
+
75
  with gr.Accordion(label="加载Prompt模板", open=False):
76
  with gr.Column():
77
  with gr.Row():
 
101
  historyReadBtn = gr.Button("📂 读入对话")
102
 
103
  use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
104
+ use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
105
+
106
+
107
 
 
108
  with gr.Column(scale=5):
109
  with gr.Row(scale=1):
110
  chatbot = gr.Chatbot().style(height=700) # .style(color_map=("#1D51EE", "#585A5B"))
 
119
  retryBtn = gr.Button("🔄 重新生成")
120
  delLastBtn = gr.Button("🗑️ 删除最近一条对话")
121
  reduceTokenBtn = gr.Button("♻️ 总结对话")
122
+
123
 
124
  gr.HTML("""
125
  <div style="text-align: center; margin-top: 20px; margin-bottom: 20px;">
 
127
  gr.Markdown(description)
128
 
129
 
130
+ user_input.submit(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown, use_websearch_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
131
  user_input.submit(reset_textbox, [], [user_input])
132
 
133
+ submitBtn.click(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown, use_websearch_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
134
  submitBtn.click(reset_textbox, [], [user_input])
135
 
136
  emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
presets.py CHANGED
@@ -40,6 +40,14 @@ pre code {
40
 
41
  summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
42
  MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-4","gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"] # 可选的模型
 
 
 
 
 
 
 
 
43
 
44
  # 错误信息
45
  standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
 
40
 
41
  summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
42
  MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-4","gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"] # 可选的模型
43
+ websearch_prompt = """Web search results:
44
+
45
+ {web_results}
46
+ Current date: {current_date}
47
+
48
+ Instructions: Using the provided web search results, write a comprehensive reply to the given query. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.
49
+ Query: {query}
50
+ Reply in 中文"""
51
 
52
  # 错误信息
53
  standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
requirements.txt CHANGED
@@ -5,3 +5,4 @@ tiktoken
5
  socksio
6
  tqdm
7
  colorama
 
 
5
  socksio
6
  tqdm
7
  colorama
8
+ duckduckgo_search
utils.py CHANGED
@@ -16,6 +16,8 @@ from presets import *
16
  import tiktoken
17
  from tqdm import tqdm
18
  import colorama
 
 
19
 
20
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
21
 
@@ -224,8 +226,17 @@ def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_tok
224
  return chatbot, history, status_text, all_token_counts
225
 
226
 
227
- def predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, stream=False, selected_model = MODELS[0], should_check_token_count = True): # repetition_penalty, top_k
228
  logging.info("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
 
 
 
 
 
 
 
 
 
229
  if len(openai_api_key) != 51:
230
  status_text = standard_error_msg + no_apikey_msg
231
  logging.info(status_text)
 
16
  import tiktoken
17
  from tqdm import tqdm
18
  import colorama
19
+ from duckduckgo_search import ddg
20
+ import datetime
21
 
22
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
23
 
 
226
  return chatbot, history, status_text, all_token_counts
227
 
228
 
229
+ def predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, stream=False, selected_model = MODELS[0], use_websearch_checkbox = False, should_check_token_count = True): # repetition_penalty, top_k
230
  logging.info("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
231
+ if use_websearch_checkbox:
232
+ results = ddg(inputs, max_results=3)
233
+ web_results = []
234
+ for idx, result in enumerate(results):
235
+ logging.info(f"搜索结果{idx + 1}:{result}")
236
+ web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
237
+ web_results = "\n\n".join(web_results)
238
+ today = datetime.datetime.today().strftime("%Y-%m-%d")
239
+ inputs = websearch_prompt.replace("{current_date}", today).replace("{query}", inputs).replace("{web_results}", web_results)
240
  if len(openai_api_key) != 51:
241
  status_text = standard_error_msg + no_apikey_msg
242
  logging.info(status_text)