yadongxie commited on
Commit
88fe5dc
1 Parent(s): dec22aa

feat: support streaming & default api

Browse files
Files changed (1) hide show
  1. app.py +132 -90
app.py CHANGED
@@ -8,42 +8,37 @@ import time
8
  from dataclasses import dataclass, field
9
  from threading import Lock
10
  import base64
 
 
11
 
12
 
13
  @dataclass
14
  class AppState:
15
  stream: np.ndarray | None = None
16
  sampling_rate: int = 0
17
- pause_detected: bool = False
18
  conversation: list = field(default_factory=list)
19
  client: openai.OpenAI = None
20
  output_format: str = "mp3"
21
- stopped: bool = False
22
 
23
  # Global lock for thread safety
24
  state_lock = Lock()
25
 
 
26
  def create_client(api_key):
27
  return openai.OpenAI(
28
- base_url="https://llama3-1-8b.lepton.run/api/v1/",
29
- api_key=api_key
30
  )
31
 
32
- def determine_pause(audio, sampling_rate, state):
33
- # Take the last 1 second of audio
34
- pause_length = int(sampling_rate * 1) # 1 second
35
- if len(audio) < pause_length:
36
- return False
37
- last_audio = audio[-pause_length:]
38
- amplitude = np.abs(last_audio)
39
-
40
- # Calculate the average amplitude in the last 1 second
41
- avg_amplitude = np.mean(amplitude)
42
- silence_threshold = 0.01 # Adjust this threshold as needed
43
- if avg_amplitude < silence_threshold:
44
- return True
45
- else:
46
- return False
47
 
48
  def process_audio(audio: tuple, state: AppState):
49
  if state.stream is None:
@@ -52,13 +47,16 @@ def process_audio(audio: tuple, state: AppState):
52
  else:
53
  state.stream = np.concatenate((state.stream, audio[1]))
54
 
55
- pause_detected = determine_pause(state.stream, state.sampling_rate, state)
56
- state.pause_detected = pause_detected
 
 
 
 
 
 
 
57
 
58
- if state.pause_detected:
59
- return gr.Audio(recording=False), state
60
- else:
61
- return None, state
62
 
63
  def generate_response_and_audio(audio_bytes: bytes, state: AppState):
64
  if state.client is None:
@@ -67,6 +65,14 @@ def generate_response_and_audio(audio_bytes: bytes, state: AppState):
67
  format_ = state.output_format
68
  bitrate = 128 if format_ == "mp3" else 32 # Higher bitrate for MP3, lower for OPUS
69
  audio_data = base64.b64encode(audio_bytes).decode()
 
 
 
 
 
 
 
 
70
 
71
  try:
72
  stream = state.client.chat.completions.create(
@@ -74,36 +80,42 @@ def generate_response_and_audio(audio_bytes: bytes, state: AppState):
74
  "require_audio": True,
75
  "tts_preset_id": "jessica",
76
  "tts_audio_format": format_,
77
- "tts_audio_bitrate": bitrate
78
  },
79
  model="llama3.1-8b",
80
- messages=[{"role": "user", "content": [{"type": "audio", "data": audio_data}]}],
81
  temperature=0.7,
82
  max_tokens=256,
83
  stream=True,
84
  )
85
 
86
  full_response = ""
87
- audios = []
 
 
88
 
89
  for chunk in stream:
90
  if not chunk.choices:
91
  continue
92
  content = chunk.choices[0].delta.content
93
- audio = getattr(chunk.choices[0], 'audio', [])
 
 
 
 
94
  if content:
95
  full_response += content
96
- yield full_response, None, state
97
  if audio:
98
- audios.extend(audio)
99
-
100
- final_audio = b''.join([base64.b64decode(a) for a in audios])
101
 
102
- yield full_response, final_audio, state
103
 
104
  except Exception as e:
105
  raise gr.Error(f"Error during audio streaming: {e}")
106
 
 
107
  def response(state: AppState):
108
  if state.stream is None or len(state.stream) == 0:
109
  return None, None, state
@@ -119,63 +131,106 @@ def response(state: AppState):
119
 
120
  generator = generate_response_and_audio(audio_buffer.getvalue(), state)
121
 
122
- # Process the generator to get the final results
123
- final_text = ""
124
- final_audio = None
125
- for text, audio, updated_state in generator:
126
- final_text = text if text else final_text
127
- final_audio = audio if audio else final_audio
128
  state = updated_state
129
-
130
- # Update the chatbot with the final conversation
131
- state.conversation.append({"role": "user", "content": "Audio input"})
132
- state.conversation.append({"role": "assistant", "content": final_text})
 
 
133
 
134
  # Reset the audio stream for the next interaction
135
  state.stream = None
136
- state.pause_detected = False
137
-
138
- chatbot_output = state.conversation[-2:] # Get the last two messages
139
 
140
- return chatbot_output, final_audio, state
141
-
142
- def start_recording_user(state: AppState):
143
- if not state.stopped:
144
- return gr.Audio(recording=True)
145
- else:
146
- return gr.Audio(recording=False)
147
 
148
  def set_api_key(api_key, state):
149
- if not api_key:
150
- raise gr.Error("Please enter a valid API key.")
151
- state.client = create_client(api_key)
152
- return "API key set successfully!", state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- def update_format(format, state):
155
- state.output_format = format
156
- return state
157
 
158
  with gr.Blocks() as demo:
 
 
 
 
159
  with gr.Row():
160
- api_key_input = gr.Textbox(type="password", label="Enter your Lepton API Key")
161
- set_key_button = gr.Button("Set API Key")
162
-
163
- api_key_status = gr.Textbox(label="API Key Status", interactive=False)
164
-
165
- with gr.Row():
166
- format_dropdown = gr.Dropdown(choices=["mp3", "opus"], value="mp3", label="Output Audio Format")
 
 
 
 
 
 
167
 
168
- with gr.Row():
169
- with gr.Column():
170
- input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy")
171
- with gr.Column():
172
- chatbot = gr.Chatbot(label="Conversation", type="messages")
173
- output_audio = gr.Audio(label="Output Audio", autoplay=True)
 
174
 
175
  state = gr.State(AppState())
176
 
177
- set_key_button.click(set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, state])
178
- format_dropdown.change(update_format, inputs=[format_dropdown, state], outputs=[state])
 
 
 
 
 
 
 
 
 
 
179
 
180
  stream = input_audio.stream(
181
  process_audio,
@@ -186,23 +241,10 @@ with gr.Blocks() as demo:
186
  )
187
 
188
  respond = input_audio.stop_recording(
189
- response,
190
- [state],
191
- [chatbot, output_audio, state]
192
  )
193
  # Update the chatbot with the final conversation
194
  respond.then(lambda s: s.conversation, [state], [chatbot])
195
 
196
- # Automatically restart recording after the assistant's response
197
- restart = output_audio.stop(
198
- start_recording_user,
199
- [state],
200
- [input_audio]
201
- )
202
-
203
- # Add a "Stop Conversation" button
204
- cancel = gr.Button("Stop Conversation", variant="stop")
205
- cancel.click(lambda: (AppState(stopped=True), gr.Audio(recording=False)), None,
206
- [state, input_audio], cancels=[respond, restart])
207
 
208
  demo.launch()
 
8
  from dataclasses import dataclass, field
9
  from threading import Lock
10
  import base64
11
+ import uuid
12
+ import os
13
 
14
 
15
  @dataclass
16
  class AppState:
17
  stream: np.ndarray | None = None
18
  sampling_rate: int = 0
 
19
  conversation: list = field(default_factory=list)
20
  client: openai.OpenAI = None
21
  output_format: str = "mp3"
22
+
23
 
24
  # Global lock for thread safety
25
  state_lock = Lock()
26
 
27
+
28
  def create_client(api_key):
29
  return openai.OpenAI(
30
+ base_url="https://llama3-1-8b.lepton.run/api/v1/", api_key=api_key
 
31
  )
32
 
33
+
34
+ def test_api_key(client):
35
+ # Try making a simple request to check if the API key works
36
+ try:
37
+ # Attempt to retrieve available models as a test
38
+ client.models.list()
39
+ except Exception as e:
40
+ raise e
41
+
 
 
 
 
 
 
42
 
43
  def process_audio(audio: tuple, state: AppState):
44
  if state.stream is None:
 
47
  else:
48
  state.stream = np.concatenate((state.stream, audio[1]))
49
 
50
+ return None, state
51
+
52
+
53
+ def update_or_append_conversation(conversation, id, role, new_content):
54
+ for entry in conversation:
55
+ if entry["id"] == id and entry["role"] == role:
56
+ entry["content"] = new_content
57
+ return
58
+ conversation.append({"id": id, "role": role, "content": new_content})
59
 
 
 
 
 
60
 
61
  def generate_response_and_audio(audio_bytes: bytes, state: AppState):
62
  if state.client is None:
 
65
  format_ = state.output_format
66
  bitrate = 128 if format_ == "mp3" else 32 # Higher bitrate for MP3, lower for OPUS
67
  audio_data = base64.b64encode(audio_bytes).decode()
68
+ old_messages = []
69
+
70
+ for item in state.conversation:
71
+ old_messages.append({"role": item["role"], "content": item["content"]})
72
+
73
+ old_messages.append(
74
+ {"role": "user", "content": [{"type": "audio", "data": audio_data}]}
75
+ )
76
 
77
  try:
78
  stream = state.client.chat.completions.create(
 
80
  "require_audio": True,
81
  "tts_preset_id": "jessica",
82
  "tts_audio_format": format_,
83
+ "tts_audio_bitrate": bitrate,
84
  },
85
  model="llama3.1-8b",
86
+ messages=old_messages,
87
  temperature=0.7,
88
  max_tokens=256,
89
  stream=True,
90
  )
91
 
92
  full_response = ""
93
+ asr_result = ""
94
+ final_audio = b""
95
+ id = uuid.uuid4()
96
 
97
  for chunk in stream:
98
  if not chunk.choices:
99
  continue
100
  content = chunk.choices[0].delta.content
101
+ audio = getattr(chunk.choices[0], "audio", [])
102
+ asr_results = getattr(chunk.choices[0], "asr_results", [])
103
+ if asr_results:
104
+ asr_result += "".join(asr_results)
105
+ yield id, full_response, asr_result, None, state
106
  if content:
107
  full_response += content
108
+ yield id, full_response, asr_result, None, state
109
  if audio:
110
+ final_audio = b"".join([base64.b64decode(a) for a in audio])
111
+ yield id, full_response, asr_result, final_audio, state
 
112
 
113
+ yield id, full_response, asr_result, final_audio, state
114
 
115
  except Exception as e:
116
  raise gr.Error(f"Error during audio streaming: {e}")
117
 
118
+
119
  def response(state: AppState):
120
  if state.stream is None or len(state.stream) == 0:
121
  return None, None, state
 
131
 
132
  generator = generate_response_and_audio(audio_buffer.getvalue(), state)
133
 
134
+ for id, text, asr, audio, updated_state in generator:
 
 
 
 
 
135
  state = updated_state
136
+ if asr:
137
+ update_or_append_conversation(state.conversation, id, "user", asr)
138
+ if text:
139
+ update_or_append_conversation(state.conversation, id, "assistant", text)
140
+ chatbot_output = state.conversation
141
+ yield chatbot_output, audio, state
142
 
143
  # Reset the audio stream for the next interaction
144
  state.stream = None
 
 
 
145
 
 
 
 
 
 
 
 
146
 
147
  def set_api_key(api_key, state):
148
+ try:
149
+ state.client = create_client(api_key)
150
+ test_api_key(state.client) # Test the provided API key
151
+ api_key_status = gr.update(value="API key set successfully!", visible=True)
152
+ api_key_input = gr.update(visible=False)
153
+ set_key_button = gr.update(visible=False)
154
+ return api_key_status, api_key_input, set_key_button, state
155
+ except Exception as e:
156
+ api_key_status = gr.update(
157
+ value="Invalid API key. Please try again.", visible=True
158
+ )
159
+ return api_key_status, None, None, state
160
+
161
+
162
+ def initial_setup(state):
163
+ api_key = os.getenv("API_KEY")
164
+ if api_key:
165
+ try:
166
+ state.client = create_client(api_key)
167
+ test_api_key(state.client) # Test the API key from the environment variable
168
+ api_key_status = gr.update(
169
+ value="You are using default Lepton API key, which have 10 requests/min limit",
170
+ visible=True,
171
+ )
172
+ api_key_input = gr.update(visible=False)
173
+ set_key_button = gr.update(visible=False)
174
+ return api_key_status, api_key_input, set_key_button, state
175
+ except Exception as e:
176
+ # Failed to use the api_key, show input box
177
+ api_key_status = gr.update(
178
+ value="Failed to use default API key. Please enter a valid API key.",
179
+ visible=True,
180
+ )
181
+ api_key_input = gr.update(visible=True)
182
+ set_key_button = gr.update(visible=True)
183
+ return api_key_status, api_key_input, set_key_button, state
184
+ else:
185
+ # No API key in environment variable
186
+ api_key_status = gr.update(visible=False)
187
+ api_key_input = gr.update(visible=True)
188
+ set_key_button = gr.update(visible=True)
189
+ return api_key_status, api_key_input, set_key_button, state
190
 
 
 
 
191
 
192
  with gr.Blocks() as demo:
193
+ gr.Markdown("# Lepton AI LLM Voice Mode")
194
+ gr.Markdown(
195
+ "You can find Lepton AI LLM voice doc [here](https://www.lepton.ai/playground/chat/llama-3.2-3b) and serverless endpoint API Key [here](https://dashboard.lepton.ai/workspace-redirect/settings/api-tokens)"
196
+ )
197
  with gr.Row():
198
+ with gr.Column(scale=3):
199
+ api_key_input = gr.Textbox(
200
+ type="password",
201
+ placeholder="Enter your Lepton API Key",
202
+ show_label=False,
203
+ container=False,
204
+ )
205
+ with gr.Column(scale=1):
206
+ set_key_button = gr.Button("Set API Key", scale=2, variant="primary")
207
+
208
+ api_key_status = gr.Textbox(
209
+ show_label=False, container=False, interactive=False, visible=False
210
+ )
211
 
212
+ with gr.Blocks():
213
+ with gr.Row():
214
+ input_audio = gr.Audio(
215
+ label="Input Audio", sources="microphone", type="numpy"
216
+ )
217
+ output_audio = gr.Audio(label="Output Audio", autoplay=True, streaming=True)
218
+ chatbot = gr.Chatbot(label="Conversation", type="messages")
219
 
220
  state = gr.State(AppState())
221
 
222
+ # Initial setup to set API key from environment variable
223
+ demo.load(
224
+ initial_setup,
225
+ inputs=state,
226
+ outputs=[api_key_status, api_key_input, set_key_button, state],
227
+ )
228
+
229
+ set_key_button.click(
230
+ set_api_key,
231
+ inputs=[api_key_input, state],
232
+ outputs=[api_key_status, api_key_input, set_key_button, state],
233
+ )
234
 
235
  stream = input_audio.stream(
236
  process_audio,
 
241
  )
242
 
243
  respond = input_audio.stop_recording(
244
+ response, [state], [chatbot, output_audio, state]
 
 
245
  )
246
  # Update the chatbot with the final conversation
247
  respond.then(lambda s: s.conversation, [state], [chatbot])
248
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  demo.launch()