chansung commited on
Commit
e5b0cf0
1 Parent(s): 602e36b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -55
app.py CHANGED
@@ -5,7 +5,7 @@ import copy
5
  import gradio as gr
6
 
7
  from llama2 import GradioLLaMA2ChatPPManager
8
- from llama2 import gen_text, gen_text_none_stream
9
 
10
  from styles import MODEL_SELECTION_CSS
11
  from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS
@@ -64,38 +64,38 @@ def fill_up_placeholders(txt):
64
  )
65
 
66
 
67
- def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cuda"):
68
- internet_search_ppm = copy.deepcopy(ppmanager)
69
- user_msg = internet_search_ppm.pingpongs[-1].ping
70
- internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query."
71
 
72
- internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
73
- internet_search_prompt = build_prompts(internet_search_ppm, "", win_size=ctx_num_lconv)
74
 
75
- search_query = gen_text_none_stream(internet_search_prompt, hf_model=MODEL_ID, hf_token=TOKEN)
76
- ###
77
 
78
- searcher = SimilaritySearcher.from_pretrained(device=device)
79
- iss = InternetSearchStrategy(
80
- searcher,
81
- serper_api_key=serper_api_key
82
- )(ppmanager, search_query=search_query)
83
 
84
- step_ppm = None
85
- while True:
86
- try:
87
- step_ppm, _ = next(iss)
88
- yield "", step_ppm.build_uis()
89
- except StopIteration:
90
- break
91
 
92
- search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv)
93
- yield search_prompt, ppmanager.build_uis()
94
 
95
  async def rollback_last(
96
  idx, local_data, chat_state,
97
  global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
98
- internet_option, serper_api_key
99
  ):
100
  internet_option = True if internet_option == "on" else False
101
 
@@ -137,7 +137,7 @@ async def rollback_last(
137
 
138
  yield prompt, ppm.build_uis(), str(res), gr.update(interactive=True), "off"
139
 
140
- def reset_chat(idx, ld, state):
141
  res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
142
  res[idx].pingpongs = []
143
 
@@ -152,7 +152,7 @@ def reset_chat(idx, ld, state):
152
  async def chat_stream(
153
  idx, local_data, instruction_txtbox, chat_state,
154
  global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
155
- internet_option, serper_api_key
156
  ):
157
  internet_option = True if internet_option == "on" else False
158
 
@@ -161,35 +161,23 @@ async def chat_stream(
161
  for ppm in local_data
162
  ]
163
 
164
- ppm = res[idx]
165
- ppm.add_pingpong(
166
- PingPong(instruction_txtbox, "")
167
- )
168
  prompt = build_prompts(ppm, global_context, ctx_num_lconv)
169
 
170
  #######
171
- if internet_option:
172
- search_prompt = None
173
- for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
174
- search_prompt = tmp_prompt
175
- yield "", prompt, uis, str(res), gr.update(interactive=False), "off"
176
-
177
- async for result in gen_text(
178
- search_prompt if internet_option else prompt,
179
- hf_model=MODEL_ID, hf_token=TOKEN,
180
- parameters={
181
- 'max_new_tokens': res_mnts,
182
- 'do_sample': res_sample,
183
- 'return_full_text': False,
184
- 'temperature': res_temp,
185
- 'top_k': res_topk,
186
- 'repetition_penalty': res_rpen
187
- }
188
- ):
189
- ppm.append_pong(result)
190
- yield "", prompt, ppm.build_uis(), str(res), gr.update(interactive=False), "off"
191
 
192
- yield "", prompt, ppm.build_uis(), str(res), gr.update(interactive=True), "off"
 
 
 
 
 
 
 
193
 
194
  def channel_num(btn_title):
195
  choice = 0
@@ -234,10 +222,12 @@ def get_final_template(
234
  )
235
 
236
  with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
 
 
237
  with gr.Column() as chat_view:
238
  idx = gr.State(0)
239
  chat_state = gr.State({
240
- "ppmanager_type": GradioLLaMA2ChatPPManager
241
  })
242
  local_data = gr.JSON({}, visible=False)
243
 
@@ -377,7 +367,7 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
377
  chat_stream,
378
  [idx, local_data, instruction_txtbox, chat_state,
379
  global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
380
- internet_option, serper_api_key],
381
  [instruction_txtbox, context_inspector, chatbot, local_data, regenerate, internet_option]
382
  ).then(
383
  None, local_data, None,
@@ -409,7 +399,7 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
409
  rollback_last,
410
  [idx, local_data, chat_state,
411
  global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
412
- internet_option, serper_api_key],
413
  [context_inspector, chatbot, local_data, regenerate, internet_option]
414
  ).then(
415
  None, local_data, None,
@@ -440,7 +430,7 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
440
 
441
  clean.click(
442
  reset_chat,
443
- [idx, local_data, chat_state],
444
  [instruction_txtbox, chatbot, local_data, example_block, regenerate]
445
  ).then(
446
  None, local_data, None,
 
5
  import gradio as gr
6
 
7
  from llama2 import GradioLLaMA2ChatPPManager
8
+ from llama2 import gen_text
9
 
10
  from styles import MODEL_SELECTION_CSS
11
  from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS
 
64
  )
65
 
66
 
67
+ # def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cuda"):
68
+ # internet_search_ppm = copy.deepcopy(ppmanager)
69
+ # user_msg = internet_search_ppm.pingpongs[-1].ping
70
+ # internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query."
71
 
72
+ # internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
73
+ # internet_search_prompt = build_prompts(internet_search_ppm, "", win_size=ctx_num_lconv)
74
 
75
+ # search_query = gen_text_none_stream(internet_search_prompt, hf_model=MODEL_ID, hf_token=TOKEN)
76
+ # ###
77
 
78
+ # searcher = SimilaritySearcher.from_pretrained(device=device)
79
+ # iss = InternetSearchStrategy(
80
+ # searcher,
81
+ # serper_api_key=serper_api_key
82
+ # )(ppmanager, search_query=search_query)
83
 
84
+ # step_ppm = None
85
+ # while True:
86
+ # try:
87
+ # step_ppm, _ = next(iss)
88
+ # yield "", step_ppm.build_uis()
89
+ # except StopIteration:
90
+ # break
91
 
92
+ # search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv)
93
+ # yield search_prompt, ppmanager.build_uis()
94
 
95
  async def rollback_last(
96
  idx, local_data, chat_state,
97
  global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
98
+ internet_option, serper_api_key, palm_if
99
  ):
100
  internet_option = True if internet_option == "on" else False
101
 
 
137
 
138
  yield prompt, ppm.build_uis(), str(res), gr.update(interactive=True), "off"
139
 
140
+ def reset_chat(idx, ld, state, palm_if):
141
  res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
142
  res[idx].pingpongs = []
143
 
 
152
  async def chat_stream(
153
  idx, local_data, instruction_txtbox, chat_state,
154
  global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
155
+ internet_option, serper_api_key, palm_if
156
  ):
157
  internet_option = True if internet_option == "on" else False
158
 
 
161
  for ppm in local_data
162
  ]
163
 
 
 
 
 
164
  prompt = build_prompts(ppm, global_context, ctx_num_lconv)
165
 
166
  #######
167
+ # if internet_option:
168
+ # search_prompt = None
169
+ # for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
170
+ # search_prompt = tmp_prompt
171
+ # yield "", prompt, uis, str(res), gr.update(interactive=False), "off"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ palm_if, response_txt = gen_text(instruction_txtbox, palm_if)
174
+
175
+ ppm = res[idx]
176
+ ppm.add_pingpong(
177
+ PingPong(instruction_txtbox, response_txt)
178
+ )
179
+
180
+ return "", "", ppm.build_uis(), str(res), gr.update(interactive=True), "off"
181
 
182
  def channel_num(btn_title):
183
  choice = 0
 
222
  )
223
 
224
  with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
225
+ palm_if = gr.State()
226
+
227
  with gr.Column() as chat_view:
228
  idx = gr.State(0)
229
  chat_state = gr.State({
230
+ "ppmanager_type": GradioPaLMChatPPManager
231
  })
232
  local_data = gr.JSON({}, visible=False)
233
 
 
367
  chat_stream,
368
  [idx, local_data, instruction_txtbox, chat_state,
369
  global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
370
+ internet_option, serper_api_key, palm_if],
371
  [instruction_txtbox, context_inspector, chatbot, local_data, regenerate, internet_option]
372
  ).then(
373
  None, local_data, None,
 
399
  rollback_last,
400
  [idx, local_data, chat_state,
401
  global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
402
+ internet_option, serper_api_key, palm_if],
403
  [context_inspector, chatbot, local_data, regenerate, internet_option]
404
  ).then(
405
  None, local_data, None,
 
430
 
431
  clean.click(
432
  reset_chat,
433
+ [idx, local_data, chat_state, palm_if],
434
  [instruction_txtbox, chatbot, local_data, example_block, regenerate]
435
  ).then(
436
  None, local_data, None,