akhaliq HF staff commited on
Commit
a0f34aa
1 Parent(s): e65a834

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -33
app.py CHANGED
@@ -3,8 +3,6 @@ import numpy as np
3
  import io
4
  from pydub import AudioSegment
5
  import tempfile
6
- import os
7
- import base64
8
  import openai
9
  import time
10
  from dataclasses import dataclass, field
@@ -14,11 +12,11 @@ from threading import Lock
14
  class AppState:
15
  stream: np.ndarray | None = None
16
  sampling_rate: int = 0
17
- pause_start: float | None = None
18
- last_speech: float = 0
19
  conversation: list = field(default_factory=list)
20
  client: openai.OpenAI = None
21
  output_format: str = "mp3"
 
22
 
23
  # Global lock for thread safety
24
  state_lock = Lock()
@@ -29,27 +27,36 @@ def create_client(api_key):
29
  api_key=api_key
30
  )
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def process_audio(audio: tuple, state: AppState):
33
  if state.stream is None:
34
  state.stream = audio[1]
35
  state.sampling_rate = audio[0]
36
- state.last_speech = time.time()
37
  else:
38
  state.stream = np.concatenate((state.stream, audio[1]))
39
 
40
- # Improved pause detection
41
- current_time = time.time()
42
- if np.max(np.abs(audio[1])) > 0.1: # Adjust this threshold as needed
43
- state.last_speech = current_time
44
- state.pause_start = None
45
- elif state.pause_start is None:
46
- state.pause_start = current_time
47
 
48
- # Check if pause is long enough to stop recording
49
- if state.pause_start and (current_time - state.pause_start > 2.0): # 2 seconds of silence
50
  return gr.Audio(recording=False), state
51
-
52
- return None, state
53
 
54
  def generate_response_and_audio(audio_bytes: bytes, state: AppState):
55
  if state.client is None:
@@ -58,7 +65,7 @@ def generate_response_and_audio(audio_bytes: bytes, state: AppState):
58
  format_ = state.output_format
59
  bitrate = 128 if format_ == "mp3" else 32 # Higher bitrate for MP3, lower for OPUS
60
  audio_data = base64.b64encode(audio_bytes).decode()
61
-
62
  try:
63
  stream = state.client.chat.completions.create(
64
  extra_body={
@@ -90,9 +97,6 @@ def generate_response_and_audio(audio_bytes: bytes, state: AppState):
90
 
91
  final_audio = b''.join([base64.b64decode(a) for a in audios])
92
 
93
- state.conversation.append({"role": "user", "content": "Audio input"})
94
- state.conversation.append({"role": "assistant", "content": full_response})
95
-
96
  yield full_response, final_audio, state
97
 
98
  except Exception as e:
@@ -101,7 +105,7 @@ def generate_response_and_audio(audio_bytes: bytes, state: AppState):
101
  def response(state: AppState):
102
  if state.stream is None or len(state.stream) == 0:
103
  return None, None, state
104
-
105
  audio_buffer = io.BytesIO()
106
  segment = AudioSegment(
107
  state.stream.tobytes(),
@@ -112,7 +116,7 @@ def response(state: AppState):
112
  segment.export(audio_buffer, format="wav")
113
 
114
  generator = generate_response_and_audio(audio_buffer.getvalue(), state)
115
-
116
  # Process the generator to get the final results
117
  final_text = ""
118
  final_audio = None
@@ -122,15 +126,23 @@ def response(state: AppState):
122
  state = updated_state
123
 
124
  # Update the chatbot with the final conversation
125
- chatbot_output = state.conversation[-2:] # Get the last two messages (user input and AI response)
126
-
 
127
  # Reset the audio stream for the next interaction
128
  state.stream = None
129
- state.pause_start = None
130
- state.last_speech = 0
131
-
 
132
  return chatbot_output, final_audio, state
133
 
 
 
 
 
 
 
134
  def set_api_key(api_key, state):
135
  if not api_key:
136
  raise gr.Error("Please enter a valid API key.")
@@ -145,19 +157,19 @@ with gr.Blocks() as demo:
145
  with gr.Row():
146
  api_key_input = gr.Textbox(type="password", label="Enter your Lepton API Key")
147
  set_key_button = gr.Button("Set API Key")
148
-
149
  api_key_status = gr.Textbox(label="API Key Status", interactive=False)
150
-
151
  with gr.Row():
152
  format_dropdown = gr.Dropdown(choices=["mp3", "opus"], value="mp3", label="Output Audio Format")
153
-
154
  with gr.Row():
155
  with gr.Column():
156
  input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy")
157
  with gr.Column():
158
  chatbot = gr.Chatbot(label="Conversation", type="messages")
159
  output_audio = gr.Audio(label="Output Audio", autoplay=True)
160
-
161
  state = gr.State(AppState())
162
 
163
  set_key_button.click(set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, state])
@@ -170,11 +182,25 @@ with gr.Blocks() as demo:
170
  stream_every=0.25, # Reduced to make it more responsive
171
  time_limit=60, # Increased to allow for longer messages
172
  )
173
-
174
  respond = input_audio.stop_recording(
175
  response,
176
  [state],
177
  [chatbot, output_audio, state]
178
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- demo.launch()
 
3
  import io
4
  from pydub import AudioSegment
5
  import tempfile
 
 
6
  import openai
7
  import time
8
  from dataclasses import dataclass, field
 
12
  class AppState:
13
  stream: np.ndarray | None = None
14
  sampling_rate: int = 0
15
+ pause_detected: bool = False
 
16
  conversation: list = field(default_factory=list)
17
  client: openai.OpenAI = None
18
  output_format: str = "mp3"
19
+ stopped: bool = False
20
 
21
  # Global lock for thread safety
22
  state_lock = Lock()
 
27
  api_key=api_key
28
  )
29
 
30
+ def determine_pause(audio, sampling_rate, state):
31
+ # Take the last 1 second of audio
32
+ pause_length = int(sampling_rate * 1) # 1 second
33
+ if len(audio) < pause_length:
34
+ return False
35
+ last_audio = audio[-pause_length:]
36
+ amplitude = np.abs(last_audio)
37
+
38
+ # Calculate the average amplitude in the last 1 second
39
+ avg_amplitude = np.mean(amplitude)
40
+ silence_threshold = 0.01 # Adjust this threshold as needed
41
+ if avg_amplitude < silence_threshold:
42
+ return True
43
+ else:
44
+ return False
45
+
46
  def process_audio(audio: tuple, state: AppState):
47
  if state.stream is None:
48
  state.stream = audio[1]
49
  state.sampling_rate = audio[0]
 
50
  else:
51
  state.stream = np.concatenate((state.stream, audio[1]))
52
 
53
+ pause_detected = determine_pause(state.stream, state.sampling_rate, state)
54
+ state.pause_detected = pause_detected
 
 
 
 
 
55
 
56
+ if state.pause_detected:
 
57
  return gr.Audio(recording=False), state
58
+ else:
59
+ return None, state
60
 
61
  def generate_response_and_audio(audio_bytes: bytes, state: AppState):
62
  if state.client is None:
 
65
  format_ = state.output_format
66
  bitrate = 128 if format_ == "mp3" else 32 # Higher bitrate for MP3, lower for OPUS
67
  audio_data = base64.b64encode(audio_bytes).decode()
68
+
69
  try:
70
  stream = state.client.chat.completions.create(
71
  extra_body={
 
97
 
98
  final_audio = b''.join([base64.b64decode(a) for a in audios])
99
 
 
 
 
100
  yield full_response, final_audio, state
101
 
102
  except Exception as e:
 
105
  def response(state: AppState):
106
  if state.stream is None or len(state.stream) == 0:
107
  return None, None, state
108
+
109
  audio_buffer = io.BytesIO()
110
  segment = AudioSegment(
111
  state.stream.tobytes(),
 
116
  segment.export(audio_buffer, format="wav")
117
 
118
  generator = generate_response_and_audio(audio_buffer.getvalue(), state)
119
+
120
  # Process the generator to get the final results
121
  final_text = ""
122
  final_audio = None
 
126
  state = updated_state
127
 
128
  # Update the chatbot with the final conversation
129
+ state.conversation.append({"role": "user", "content": "Audio input"})
130
+ state.conversation.append({"role": "assistant", "content": final_text})
131
+
132
  # Reset the audio stream for the next interaction
133
  state.stream = None
134
+ state.pause_detected = False
135
+
136
+ chatbot_output = state.conversation[-2:] # Get the last two messages
137
+
138
  return chatbot_output, final_audio, state
139
 
140
+ def start_recording_user(state: AppState):
141
+ if not state.stopped:
142
+ return gr.Audio(recording=True)
143
+ else:
144
+ return gr.Audio(recording=False)
145
+
146
  def set_api_key(api_key, state):
147
  if not api_key:
148
  raise gr.Error("Please enter a valid API key.")
 
157
  with gr.Row():
158
  api_key_input = gr.Textbox(type="password", label="Enter your Lepton API Key")
159
  set_key_button = gr.Button("Set API Key")
160
+
161
  api_key_status = gr.Textbox(label="API Key Status", interactive=False)
162
+
163
  with gr.Row():
164
  format_dropdown = gr.Dropdown(choices=["mp3", "opus"], value="mp3", label="Output Audio Format")
165
+
166
  with gr.Row():
167
  with gr.Column():
168
  input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy")
169
  with gr.Column():
170
  chatbot = gr.Chatbot(label="Conversation", type="messages")
171
  output_audio = gr.Audio(label="Output Audio", autoplay=True)
172
+
173
  state = gr.State(AppState())
174
 
175
  set_key_button.click(set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, state])
 
182
  stream_every=0.25, # Reduced to make it more responsive
183
  time_limit=60, # Increased to allow for longer messages
184
  )
185
+
186
  respond = input_audio.stop_recording(
187
  response,
188
  [state],
189
  [chatbot, output_audio, state]
190
  )
191
+ # Update the chatbot with the final conversation
192
+ respond.then(lambda s: s.conversation, [state], [chatbot])
193
+
194
+ # Automatically restart recording after the assistant's response
195
+ restart = output_audio.stop(
196
+ start_recording_user,
197
+ [state],
198
+ [input_audio]
199
+ )
200
+
201
+ # Add a "Stop Conversation" button
202
+ cancel = gr.Button("Stop Conversation", variant="stop")
203
+ cancel.click(lambda: (AppState(stopped=True), gr.Audio(recording=False)), None,
204
+ [state, input_audio], cancels=[respond, restart])
205
 
206
+ demo.launch()