Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import copy
|
|
5 |
import gradio as gr
|
6 |
|
7 |
from llama2 import GradioLLaMA2ChatPPManager
|
8 |
-
from llama2 import gen_text
|
9 |
|
10 |
from styles import MODEL_SELECTION_CSS
|
11 |
from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS
|
@@ -64,38 +64,38 @@ def fill_up_placeholders(txt):
|
|
64 |
)
|
65 |
|
66 |
|
67 |
-
def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cuda"):
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
|
72 |
-
|
73 |
-
|
74 |
|
75 |
-
|
76 |
-
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
|
92 |
-
|
93 |
-
|
94 |
|
95 |
async def rollback_last(
|
96 |
idx, local_data, chat_state,
|
97 |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
98 |
-
internet_option, serper_api_key
|
99 |
):
|
100 |
internet_option = True if internet_option == "on" else False
|
101 |
|
@@ -137,7 +137,7 @@ async def rollback_last(
|
|
137 |
|
138 |
yield prompt, ppm.build_uis(), str(res), gr.update(interactive=True), "off"
|
139 |
|
140 |
-
def reset_chat(idx, ld, state):
|
141 |
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
|
142 |
res[idx].pingpongs = []
|
143 |
|
@@ -152,7 +152,7 @@ def reset_chat(idx, ld, state):
|
|
152 |
async def chat_stream(
|
153 |
idx, local_data, instruction_txtbox, chat_state,
|
154 |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
155 |
-
internet_option, serper_api_key
|
156 |
):
|
157 |
internet_option = True if internet_option == "on" else False
|
158 |
|
@@ -161,35 +161,23 @@ async def chat_stream(
|
|
161 |
for ppm in local_data
|
162 |
]
|
163 |
|
164 |
-
ppm = res[idx]
|
165 |
-
ppm.add_pingpong(
|
166 |
-
PingPong(instruction_txtbox, "")
|
167 |
-
)
|
168 |
prompt = build_prompts(ppm, global_context, ctx_num_lconv)
|
169 |
|
170 |
#######
|
171 |
-
if internet_option:
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
async for result in gen_text(
|
178 |
-
search_prompt if internet_option else prompt,
|
179 |
-
hf_model=MODEL_ID, hf_token=TOKEN,
|
180 |
-
parameters={
|
181 |
-
'max_new_tokens': res_mnts,
|
182 |
-
'do_sample': res_sample,
|
183 |
-
'return_full_text': False,
|
184 |
-
'temperature': res_temp,
|
185 |
-
'top_k': res_topk,
|
186 |
-
'repetition_penalty': res_rpen
|
187 |
-
}
|
188 |
-
):
|
189 |
-
ppm.append_pong(result)
|
190 |
-
yield "", prompt, ppm.build_uis(), str(res), gr.update(interactive=False), "off"
|
191 |
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
def channel_num(btn_title):
|
195 |
choice = 0
|
@@ -234,10 +222,12 @@ def get_final_template(
|
|
234 |
)
|
235 |
|
236 |
with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
|
|
|
|
237 |
with gr.Column() as chat_view:
|
238 |
idx = gr.State(0)
|
239 |
chat_state = gr.State({
|
240 |
-
"ppmanager_type":
|
241 |
})
|
242 |
local_data = gr.JSON({}, visible=False)
|
243 |
|
@@ -377,7 +367,7 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
|
377 |
chat_stream,
|
378 |
[idx, local_data, instruction_txtbox, chat_state,
|
379 |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
380 |
-
internet_option, serper_api_key],
|
381 |
[instruction_txtbox, context_inspector, chatbot, local_data, regenerate, internet_option]
|
382 |
).then(
|
383 |
None, local_data, None,
|
@@ -409,7 +399,7 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
|
409 |
rollback_last,
|
410 |
[idx, local_data, chat_state,
|
411 |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
412 |
-
internet_option, serper_api_key],
|
413 |
[context_inspector, chatbot, local_data, regenerate, internet_option]
|
414 |
).then(
|
415 |
None, local_data, None,
|
@@ -440,7 +430,7 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
|
440 |
|
441 |
clean.click(
|
442 |
reset_chat,
|
443 |
-
[idx, local_data, chat_state],
|
444 |
[instruction_txtbox, chatbot, local_data, example_block, regenerate]
|
445 |
).then(
|
446 |
None, local_data, None,
|
|
|
5 |
import gradio as gr
|
6 |
|
7 |
from llama2 import GradioLLaMA2ChatPPManager
|
8 |
+
from llama2 import gen_text
|
9 |
|
10 |
from styles import MODEL_SELECTION_CSS
|
11 |
from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS
|
|
|
64 |
)
|
65 |
|
66 |
|
67 |
+
# def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cuda"):
|
68 |
+
# internet_search_ppm = copy.deepcopy(ppmanager)
|
69 |
+
# user_msg = internet_search_ppm.pingpongs[-1].ping
|
70 |
+
# internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query."
|
71 |
|
72 |
+
# internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
|
73 |
+
# internet_search_prompt = build_prompts(internet_search_ppm, "", win_size=ctx_num_lconv)
|
74 |
|
75 |
+
# search_query = gen_text_none_stream(internet_search_prompt, hf_model=MODEL_ID, hf_token=TOKEN)
|
76 |
+
# ###
|
77 |
|
78 |
+
# searcher = SimilaritySearcher.from_pretrained(device=device)
|
79 |
+
# iss = InternetSearchStrategy(
|
80 |
+
# searcher,
|
81 |
+
# serper_api_key=serper_api_key
|
82 |
+
# )(ppmanager, search_query=search_query)
|
83 |
|
84 |
+
# step_ppm = None
|
85 |
+
# while True:
|
86 |
+
# try:
|
87 |
+
# step_ppm, _ = next(iss)
|
88 |
+
# yield "", step_ppm.build_uis()
|
89 |
+
# except StopIteration:
|
90 |
+
# break
|
91 |
|
92 |
+
# search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv)
|
93 |
+
# yield search_prompt, ppmanager.build_uis()
|
94 |
|
95 |
async def rollback_last(
|
96 |
idx, local_data, chat_state,
|
97 |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
98 |
+
internet_option, serper_api_key, palm_if
|
99 |
):
|
100 |
internet_option = True if internet_option == "on" else False
|
101 |
|
|
|
137 |
|
138 |
yield prompt, ppm.build_uis(), str(res), gr.update(interactive=True), "off"
|
139 |
|
140 |
+
def reset_chat(idx, ld, state, palm_if):
|
141 |
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
|
142 |
res[idx].pingpongs = []
|
143 |
|
|
|
152 |
async def chat_stream(
|
153 |
idx, local_data, instruction_txtbox, chat_state,
|
154 |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
155 |
+
internet_option, serper_api_key, palm_if
|
156 |
):
|
157 |
internet_option = True if internet_option == "on" else False
|
158 |
|
|
|
161 |
for ppm in local_data
|
162 |
]
|
163 |
|
|
|
|
|
|
|
|
|
164 |
prompt = build_prompts(ppm, global_context, ctx_num_lconv)
|
165 |
|
166 |
#######
|
167 |
+
# if internet_option:
|
168 |
+
# search_prompt = None
|
169 |
+
# for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
|
170 |
+
# search_prompt = tmp_prompt
|
171 |
+
# yield "", prompt, uis, str(res), gr.update(interactive=False), "off"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
+
palm_if, response_txt = gen_text(instruction_txtbox, palm_if)
|
174 |
+
|
175 |
+
ppm = res[idx]
|
176 |
+
ppm.add_pingpong(
|
177 |
+
PingPong(instruction_txtbox, response_txt)
|
178 |
+
)
|
179 |
+
|
180 |
+
return "", "", ppm.build_uis(), str(res), gr.update(interactive=True), "off"
|
181 |
|
182 |
def channel_num(btn_title):
|
183 |
choice = 0
|
|
|
222 |
)
|
223 |
|
224 |
with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
225 |
+
palm_if = gr.State()
|
226 |
+
|
227 |
with gr.Column() as chat_view:
|
228 |
idx = gr.State(0)
|
229 |
chat_state = gr.State({
|
230 |
+
"ppmanager_type": GradioPaLMChatPPManager
|
231 |
})
|
232 |
local_data = gr.JSON({}, visible=False)
|
233 |
|
|
|
367 |
chat_stream,
|
368 |
[idx, local_data, instruction_txtbox, chat_state,
|
369 |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
370 |
+
internet_option, serper_api_key, palm_if],
|
371 |
[instruction_txtbox, context_inspector, chatbot, local_data, regenerate, internet_option]
|
372 |
).then(
|
373 |
None, local_data, None,
|
|
|
399 |
rollback_last,
|
400 |
[idx, local_data, chat_state,
|
401 |
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
402 |
+
internet_option, serper_api_key, palm_if],
|
403 |
[context_inspector, chatbot, local_data, regenerate, internet_option]
|
404 |
).then(
|
405 |
None, local_data, None,
|
|
|
430 |
|
431 |
clean.click(
|
432 |
reset_chat,
|
433 |
+
[idx, local_data, chat_state, palm_if],
|
434 |
[instruction_txtbox, chatbot, local_data, example_block, regenerate]
|
435 |
).then(
|
436 |
None, local_data, None,
|