Spaces:
Runtime error
Runtime error
OOOOOOHMYYYYYYYYGOOOOOOOOOOOOOOOOOOOOOO
Browse files
app.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import json
|
3 |
import os
|
4 |
import datetime
|
5 |
-
import
|
6 |
-
import aiohttp
|
7 |
-
from aiohttp import ClientSession, ClientTimeout
|
8 |
|
9 |
API_URL = os.environ.get('API_URL')
|
10 |
API_KEY = os.environ.get('API_KEY')
|
@@ -27,7 +26,7 @@ DEFAULT_PARAMS = {
|
|
27 |
def get_timestamp():
|
28 |
return datetime.datetime.now().strftime("%H:%M:%S")
|
29 |
|
30 |
-
|
31 |
history_format = [{"role": "system", "content": system_prompt}]
|
32 |
for human, assistant in history:
|
33 |
history_format.append({"role": "user", "content": human})
|
@@ -69,36 +68,31 @@ async def predict(message, history, system_prompt, temperature, top_p, top_k, fr
|
|
69 |
}
|
70 |
|
71 |
try:
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
if line:
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
if
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
except json.JSONDecodeError:
|
93 |
-
continue
|
94 |
|
95 |
if partial_message:
|
96 |
yield partial_message
|
97 |
|
98 |
-
except
|
99 |
-
print("Request timed out")
|
100 |
-
yield "Request timed out. Please try again."
|
101 |
-
except Exception as e:
|
102 |
print(f"Request error: {e}")
|
103 |
yield f"An error occurred: {str(e)}"
|
104 |
|
@@ -136,71 +130,12 @@ def export_chat(history, system_prompt):
|
|
136 |
export_data += f"<|assistant|> {assistant_msg}\n\n"
|
137 |
return export_data
|
138 |
|
139 |
-
def
|
140 |
-
|
141 |
-
return
|
142 |
-
|
143 |
-
async def user(user_message, history):
|
144 |
-
history = sanitize_chatbot_history(history or [])
|
145 |
-
return "", history + [(user_message, None)]
|
146 |
-
|
147 |
-
async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
|
148 |
-
history = sanitize_chatbot_history(history or [])
|
149 |
-
if not history:
|
150 |
-
yield history
|
151 |
-
return
|
152 |
-
user_message = history[-1][0]
|
153 |
-
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
|
154 |
-
history[-1] = (history[-1][0], "")
|
155 |
-
task = asyncio.current_task()
|
156 |
-
task_info['task'] = task
|
157 |
-
task_info['stop_requested'] = False
|
158 |
-
try:
|
159 |
-
async for chunk in bot_message:
|
160 |
-
if task_info.get('stop_requested', False):
|
161 |
-
print("Stop requested, breaking the loop")
|
162 |
-
break
|
163 |
-
history[-1] = (history[-1][0], chunk)
|
164 |
-
yield history
|
165 |
-
except asyncio.CancelledError:
|
166 |
-
print("Bot generation cancelled")
|
167 |
-
except GeneratorExit:
|
168 |
-
print("Generator exited")
|
169 |
-
except Exception as e:
|
170 |
-
print(f"Error in bot generation: {e}")
|
171 |
-
finally:
|
172 |
-
if history[-1][1] == "":
|
173 |
-
history[-1] = (history[-1][0], " [Generation stopped]")
|
174 |
-
task_info['task'] = None
|
175 |
-
task_info['stop_requested'] = False
|
176 |
-
yield history
|
177 |
-
|
178 |
-
async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
|
179 |
-
if 'task' in task_info and task_info['task']:
|
180 |
-
print("Cancelling previous task")
|
181 |
-
task_info['stop_requested'] = True
|
182 |
-
task_info['task'].cancel()
|
183 |
-
|
184 |
-
await asyncio.sleep(0.1)
|
185 |
-
|
186 |
-
history = sanitize_chatbot_history(history or [])
|
187 |
-
if history:
|
188 |
-
history[-1] = (history[-1][0], None)
|
189 |
-
try:
|
190 |
-
async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
|
191 |
-
yield sanitize_chatbot_history(new_history)
|
192 |
-
except Exception as e:
|
193 |
-
print(f"Error in regenerate_response: {e}")
|
194 |
-
yield history
|
195 |
-
else:
|
196 |
-
yield []
|
197 |
-
|
198 |
-
def import_chat_wrapper(custom_format_string):
|
199 |
-
imported_history, imported_system_prompt = import_chat(custom_format_string)
|
200 |
-
return sanitize_chatbot_history(imported_history), imported_system_prompt
|
201 |
|
202 |
with gr.Blocks(theme='gradio/monochrome') as demo:
|
203 |
-
|
204 |
|
205 |
with gr.Row():
|
206 |
with gr.Column(scale=2):
|
@@ -227,43 +162,59 @@ with gr.Blocks(theme='gradio/monochrome') as demo:
|
|
227 |
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
|
228 |
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
|
229 |
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
)
|
234 |
|
235 |
-
clear.click(lambda:
|
236 |
|
237 |
-
|
238 |
regenerate_response,
|
239 |
-
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens,
|
240 |
-
chatbot
|
241 |
-
concurrency_limit=10
|
242 |
)
|
243 |
|
244 |
-
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt]
|
245 |
|
246 |
export_button.click(
|
247 |
export_chat,
|
248 |
inputs=[chatbot, system_prompt],
|
249 |
-
outputs=[import_textbox]
|
250 |
-
concurrency_limit=10
|
251 |
)
|
252 |
|
253 |
-
|
254 |
-
if 'task' in task_info and task_info['task']:
|
255 |
-
print("Stop requested")
|
256 |
-
task_info['stop_requested'] = True
|
257 |
-
task_info['task'].cancel()
|
258 |
-
return task_info
|
259 |
-
|
260 |
-
stop_btn.click(
|
261 |
-
stop_generation,
|
262 |
-
inputs=[task_info],
|
263 |
-
outputs=[task_info],
|
264 |
-
cancels=[submit_event, regenerate_event],
|
265 |
-
queue=False
|
266 |
-
)
|
267 |
|
268 |
if __name__ == "__main__":
|
269 |
-
demo.
|
|
|
1 |
import gradio as gr
|
2 |
+
import requests
|
3 |
import json
|
4 |
import os
|
5 |
import datetime
|
6 |
+
from requests.exceptions import RequestException
|
|
|
|
|
7 |
|
8 |
API_URL = os.environ.get('API_URL')
|
9 |
API_KEY = os.environ.get('API_KEY')
|
|
|
26 |
def get_timestamp():
|
27 |
return datetime.datetime.now().strftime("%H:%M:%S")
|
28 |
|
29 |
+
def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
|
30 |
history_format = [{"role": "system", "content": system_prompt}]
|
31 |
for human, assistant in history:
|
32 |
history_format.append({"role": "user", "content": human})
|
|
|
68 |
}
|
69 |
|
70 |
try:
|
71 |
+
with requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response:
|
72 |
+
partial_message = ""
|
73 |
+
for line in response.iter_lines():
|
74 |
+
if stop_flag[0]:
|
75 |
+
response.close()
|
76 |
+
break
|
77 |
+
if line:
|
78 |
+
line = line.decode('utf-8')
|
79 |
+
if line.startswith("data: "):
|
80 |
+
if line.strip() == "data: [DONE]":
|
81 |
+
break
|
82 |
+
try:
|
83 |
+
json_data = json.loads(line[6:])
|
84 |
+
if 'choices' in json_data and json_data['choices']:
|
85 |
+
content = json_data['choices'][0]['delta'].get('content', '')
|
86 |
+
if content:
|
87 |
+
partial_message += content
|
88 |
+
yield partial_message
|
89 |
+
except json.JSONDecodeError:
|
90 |
+
continue
|
|
|
|
|
91 |
|
92 |
if partial_message:
|
93 |
yield partial_message
|
94 |
|
95 |
+
except RequestException as e:
|
|
|
|
|
|
|
96 |
print(f"Request error: {e}")
|
97 |
yield f"An error occurred: {str(e)}"
|
98 |
|
|
|
130 |
export_data += f"<|assistant|> {assistant_msg}\n\n"
|
131 |
return export_data
|
132 |
|
133 |
+
def stop_generation_func(stop_flag):
|
134 |
+
stop_flag[0] = True
|
135 |
+
return stop_flag
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
with gr.Blocks(theme='gradio/monochrome') as demo:
|
138 |
+
stop_flag = gr.State([False])
|
139 |
|
140 |
with gr.Row():
|
141 |
with gr.Column(scale=2):
|
|
|
162 |
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
|
163 |
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
|
164 |
|
165 |
+
def user(user_message, history):
|
166 |
+
history = history or []
|
167 |
+
return "", history + [[user_message, None]]
|
168 |
+
|
169 |
+
def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
|
170 |
+
stop_flag[0] = False
|
171 |
+
history = history or []
|
172 |
+
if not history:
|
173 |
+
return history
|
174 |
+
user_message = history[-1][0]
|
175 |
+
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag)
|
176 |
+
history[-1][1] = ""
|
177 |
+
for chunk in bot_message:
|
178 |
+
if stop_flag[0]:
|
179 |
+
history[-1][1] += " [Generation stopped]"
|
180 |
+
break
|
181 |
+
history[-1][1] = chunk
|
182 |
+
yield history
|
183 |
+
|
184 |
+
def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
|
185 |
+
if history and len(history) > 0:
|
186 |
+
last_user_message = history[-1][0]
|
187 |
+
history[-1][1] = None
|
188 |
+
for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
|
189 |
+
yield new_history
|
190 |
+
else:
|
191 |
+
yield []
|
192 |
+
|
193 |
+
def import_chat_wrapper(custom_format_string):
|
194 |
+
imported_history, imported_system_prompt = import_chat(custom_format_string)
|
195 |
+
return imported_history, imported_system_prompt
|
196 |
+
|
197 |
+
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
198 |
+
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag], chatbot
|
199 |
)
|
200 |
|
201 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
202 |
|
203 |
+
regenerate.click(
|
204 |
regenerate_response,
|
205 |
+
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag],
|
206 |
+
chatbot
|
|
|
207 |
)
|
208 |
|
209 |
+
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
|
210 |
|
211 |
export_button.click(
|
212 |
export_chat,
|
213 |
inputs=[chatbot, system_prompt],
|
214 |
+
outputs=[import_textbox]
|
|
|
215 |
)
|
216 |
|
217 |
+
stop_btn.click(stop_generation_func, inputs=[stop_flag], outputs=[stop_flag])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
if __name__ == "__main__":
|
220 |
+
demo.queue(max_size=20, default_concurrency_limit=20).launch(debug=True)
|