Corvius commited on
Commit
f7d27cc
β€’
1 Parent(s): 435c68c

OOOOOOHMYYYYYYYYGOOOOOOOOOOOOOOOOOOOOOO

Browse files
Files changed (1) hide show
  1. app.py +70 -119
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import gradio as gr
 
2
  import json
3
  import os
4
  import datetime
5
- import asyncio
6
- import aiohttp
7
- from aiohttp import ClientSession, ClientTimeout
8
 
9
  API_URL = os.environ.get('API_URL')
10
  API_KEY = os.environ.get('API_KEY')
@@ -27,7 +26,7 @@ DEFAULT_PARAMS = {
27
  def get_timestamp():
28
  return datetime.datetime.now().strftime("%H:%M:%S")
29
 
30
- async def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
31
  history_format = [{"role": "system", "content": system_prompt}]
32
  for human, assistant in history:
33
  history_format.append({"role": "user", "content": human})
@@ -69,36 +68,31 @@ async def predict(message, history, system_prompt, temperature, top_p, top_k, fr
69
  }
70
 
71
  try:
72
- timeout = ClientTimeout(total=60) # Set a 60-second timeout
73
- async with ClientSession(timeout=timeout) as session:
74
- async with session.post(API_URL, headers=headers, json=data) as response:
75
- partial_message = ""
76
- async for line in response.content:
77
- if asyncio.current_task().cancelled():
78
- print("Task cancelled during API request")
79
- break
80
- if line:
81
- line = line.decode('utf-8')
82
- if line.startswith("data: "):
83
- if line.strip() == "data: [DONE]":
84
- break
85
- try:
86
- json_data = json.loads(line[6:])
87
- if 'choices' in json_data and json_data['choices']:
88
- content = json_data['choices'][0]['delta'].get('content', '')
89
- if content:
90
- partial_message += content
91
- yield partial_message
92
- except json.JSONDecodeError:
93
- continue
94
 
95
  if partial_message:
96
  yield partial_message
97
 
98
- except asyncio.TimeoutError:
99
- print("Request timed out")
100
- yield "Request timed out. Please try again."
101
- except Exception as e:
102
  print(f"Request error: {e}")
103
  yield f"An error occurred: {str(e)}"
104
 
@@ -136,71 +130,12 @@ def export_chat(history, system_prompt):
136
  export_data += f"<|assistant|> {assistant_msg}\n\n"
137
  return export_data
138
 
139
- def sanitize_chatbot_history(history):
140
- """Ensure each entry in the chatbot history is a tuple of two items."""
141
- return [tuple(entry[:2]) if isinstance(entry, (list, tuple)) else (str(entry), None) for entry in history]
142
-
143
- async def user(user_message, history):
144
- history = sanitize_chatbot_history(history or [])
145
- return "", history + [(user_message, None)]
146
-
147
- async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
148
- history = sanitize_chatbot_history(history or [])
149
- if not history:
150
- yield history
151
- return
152
- user_message = history[-1][0]
153
- bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
154
- history[-1] = (history[-1][0], "")
155
- task = asyncio.current_task()
156
- task_info['task'] = task
157
- task_info['stop_requested'] = False
158
- try:
159
- async for chunk in bot_message:
160
- if task_info.get('stop_requested', False):
161
- print("Stop requested, breaking the loop")
162
- break
163
- history[-1] = (history[-1][0], chunk)
164
- yield history
165
- except asyncio.CancelledError:
166
- print("Bot generation cancelled")
167
- except GeneratorExit:
168
- print("Generator exited")
169
- except Exception as e:
170
- print(f"Error in bot generation: {e}")
171
- finally:
172
- if history[-1][1] == "":
173
- history[-1] = (history[-1][0], " [Generation stopped]")
174
- task_info['task'] = None
175
- task_info['stop_requested'] = False
176
- yield history
177
-
178
- async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
179
- if 'task' in task_info and task_info['task']:
180
- print("Cancelling previous task")
181
- task_info['stop_requested'] = True
182
- task_info['task'].cancel()
183
-
184
- await asyncio.sleep(0.1)
185
-
186
- history = sanitize_chatbot_history(history or [])
187
- if history:
188
- history[-1] = (history[-1][0], None)
189
- try:
190
- async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
191
- yield sanitize_chatbot_history(new_history)
192
- except Exception as e:
193
- print(f"Error in regenerate_response: {e}")
194
- yield history
195
- else:
196
- yield []
197
-
198
- def import_chat_wrapper(custom_format_string):
199
- imported_history, imported_system_prompt = import_chat(custom_format_string)
200
- return sanitize_chatbot_history(imported_history), imported_system_prompt
201
 
202
  with gr.Blocks(theme='gradio/monochrome') as demo:
203
- task_info = gr.State({'task': None, 'stop_requested': False})
204
 
205
  with gr.Row():
206
  with gr.Column(scale=2):
@@ -227,43 +162,59 @@ with gr.Blocks(theme='gradio/monochrome') as demo:
227
  repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
228
  max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
229
 
230
- submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
231
- bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info], chatbot,
232
- concurrency_limit=10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  )
234
 
235
- clear.click(lambda: [], None, chatbot, queue=False)
236
 
237
- regenerate_event = regenerate.click(
238
  regenerate_response,
239
- [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info],
240
- chatbot,
241
- concurrency_limit=10
242
  )
243
 
244
- import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], concurrency_limit=10)
245
 
246
  export_button.click(
247
  export_chat,
248
  inputs=[chatbot, system_prompt],
249
- outputs=[import_textbox],
250
- concurrency_limit=10
251
  )
252
 
253
- def stop_generation(task_info):
254
- if 'task' in task_info and task_info['task']:
255
- print("Stop requested")
256
- task_info['stop_requested'] = True
257
- task_info['task'].cancel()
258
- return task_info
259
-
260
- stop_btn.click(
261
- stop_generation,
262
- inputs=[task_info],
263
- outputs=[task_info],
264
- cancels=[submit_event, regenerate_event],
265
- queue=False
266
- )
267
 
268
  if __name__ == "__main__":
269
- demo.launch(debug=True, max_threads=40)
 
1
  import gradio as gr
2
+ import requests
3
  import json
4
  import os
5
  import datetime
6
+ from requests.exceptions import RequestException
 
 
7
 
8
  API_URL = os.environ.get('API_URL')
9
  API_KEY = os.environ.get('API_KEY')
 
26
  def get_timestamp():
27
  return datetime.datetime.now().strftime("%H:%M:%S")
28
 
29
+ def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
30
  history_format = [{"role": "system", "content": system_prompt}]
31
  for human, assistant in history:
32
  history_format.append({"role": "user", "content": human})
 
68
  }
69
 
70
  try:
71
+ with requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response:
72
+ partial_message = ""
73
+ for line in response.iter_lines():
74
+ if stop_flag[0]:
75
+ response.close()
76
+ break
77
+ if line:
78
+ line = line.decode('utf-8')
79
+ if line.startswith("data: "):
80
+ if line.strip() == "data: [DONE]":
81
+ break
82
+ try:
83
+ json_data = json.loads(line[6:])
84
+ if 'choices' in json_data and json_data['choices']:
85
+ content = json_data['choices'][0]['delta'].get('content', '')
86
+ if content:
87
+ partial_message += content
88
+ yield partial_message
89
+ except json.JSONDecodeError:
90
+ continue
 
 
91
 
92
  if partial_message:
93
  yield partial_message
94
 
95
+ except RequestException as e:
 
 
 
96
  print(f"Request error: {e}")
97
  yield f"An error occurred: {str(e)}"
98
 
 
130
  export_data += f"<|assistant|> {assistant_msg}\n\n"
131
  return export_data
132
 
133
+ def stop_generation_func(stop_flag):
134
+ stop_flag[0] = True
135
+ return stop_flag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  with gr.Blocks(theme='gradio/monochrome') as demo:
138
+ stop_flag = gr.State([False])
139
 
140
  with gr.Row():
141
  with gr.Column(scale=2):
 
162
  repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
163
  max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
164
 
165
+ def user(user_message, history):
166
+ history = history or []
167
+ return "", history + [[user_message, None]]
168
+
169
+ def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
170
+ stop_flag[0] = False
171
+ history = history or []
172
+ if not history:
173
+ return history
174
+ user_message = history[-1][0]
175
+ bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag)
176
+ history[-1][1] = ""
177
+ for chunk in bot_message:
178
+ if stop_flag[0]:
179
+ history[-1][1] += " [Generation stopped]"
180
+ break
181
+ history[-1][1] = chunk
182
+ yield history
183
+
184
+ def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
185
+ if history and len(history) > 0:
186
+ last_user_message = history[-1][0]
187
+ history[-1][1] = None
188
+ for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag):
189
+ yield new_history
190
+ else:
191
+ yield []
192
+
193
+ def import_chat_wrapper(custom_format_string):
194
+ imported_history, imported_system_prompt = import_chat(custom_format_string)
195
+ return imported_history, imported_system_prompt
196
+
197
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
198
+ bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag], chatbot
199
  )
200
 
201
+ clear.click(lambda: None, None, chatbot, queue=False)
202
 
203
+ regenerate.click(
204
  regenerate_response,
205
+ [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, stop_flag],
206
+ chatbot
 
207
  )
208
 
209
+ import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
210
 
211
  export_button.click(
212
  export_chat,
213
  inputs=[chatbot, system_prompt],
214
+ outputs=[import_textbox]
 
215
  )
216
 
217
+ stop_btn.click(stop_generation_func, inputs=[stop_flag], outputs=[stop_flag])
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  if __name__ == "__main__":
220
+ demo.queue(max_size=20, default_concurrency_limit=20).launch(debug=True)