akhaliq HF staff commited on
Commit
f8bd65e
1 Parent(s): c75877d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -76
app.py CHANGED
@@ -12,6 +12,11 @@ from threading import Lock
12
 
13
  @dataclass
14
  class AppState:
 
 
 
 
 
15
  conversation: list = field(default_factory=list)
16
  client: openai.OpenAI = None
17
 
@@ -24,17 +29,27 @@ def create_client(api_key):
24
  api_key=api_key
25
  )
26
 
27
- def process_audio_file(audio_file, state):
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  if state.client is None:
29
  raise gr.Error("Please enter a valid API key first.")
30
 
31
  format_ = "opus"
32
  bitrate = 16
33
-
34
- with open(audio_file.name, "rb") as f:
35
- audio_bytes = f.read()
36
  audio_data = base64.b64encode(audio_bytes).decode()
37
-
38
  try:
39
  stream = state.client.chat.completions.create(
40
  extra_body={
@@ -50,69 +65,42 @@ def process_audio_file(audio_file, state):
50
  stream=True,
51
  )
52
 
53
- transcript = ""
54
- audio_chunks = []
55
 
56
  for chunk in stream:
57
- if chunk.choices[0].delta.content:
58
- transcript += chunk.choices[0].delta.content
59
- if hasattr(chunk.choices[0], 'audio') and chunk.choices[0].audio:
60
- audio_chunks.extend(chunk.choices[0].audio)
61
-
62
- audio_data = b''.join([base64.b64decode(a) for a in audio_chunks])
63
-
64
- return transcript, audio_data, state
 
 
 
 
 
 
65
 
66
  except Exception as e:
67
- raise gr.Error(f"Error processing audio: {str(e)}")
68
 
69
- def generate_response_and_audio(message, state):
70
- if state.client is None:
71
- raise gr.Error("Please enter a valid API key first.")
 
 
 
 
 
 
 
 
 
72
 
73
- with state_lock:
74
- state.conversation.append({"role": "user", "content": message})
75
-
76
- try:
77
- completion = state.client.chat.completions.create(
78
- model="llama3-1-8b",
79
- messages=state.conversation,
80
- max_tokens=128,
81
- stream=True,
82
- extra_body={
83
- "require_audio": "true",
84
- "tts_preset_id": "jessica",
85
- }
86
- )
87
-
88
- full_response = ""
89
- audio_chunks = []
90
-
91
- for chunk in completion:
92
- if not chunk.choices:
93
- continue
94
-
95
- content = chunk.choices[0].delta.content
96
- audio = getattr(chunk.choices[0], 'audio', [])
97
-
98
- if content:
99
- full_response += content
100
- yield full_response, None, state
101
-
102
- if audio:
103
- audio_chunks.extend(audio)
104
- audio_data = b''.join([base64.b64decode(a) for a in audio_chunks])
105
- yield full_response, audio_data, state
106
-
107
- state.conversation.append({"role": "assistant", "content": full_response})
108
- except Exception as e:
109
- raise gr.Error(f"Error generating response: {str(e)}")
110
-
111
- def chat(message, state):
112
- if not message:
113
- return "", None, state
114
-
115
- return generate_response_and_audio(message, state)
116
 
117
  def set_api_key(api_key, state):
118
  if not api_key:
@@ -120,9 +108,11 @@ def set_api_key(api_key, state):
120
  state.client = create_client(api_key)
121
  return "API key set successfully!", state
122
 
 
 
 
 
123
  with gr.Blocks() as demo:
124
- state = gr.State(AppState())
125
-
126
  with gr.Row():
127
  api_key_input = gr.Textbox(type="password", label="Enter your Lepton API Key")
128
  set_key_button = gr.Button("Set API Key")
@@ -130,20 +120,42 @@ with gr.Blocks() as demo:
130
  api_key_status = gr.Textbox(label="API Key Status", interactive=False)
131
 
132
  with gr.Row():
133
- with gr.Column(scale=1):
134
- audio_file_input = gr.Audio(sources="microphone")
135
- with gr.Column(scale=2):
136
- chatbot = gr.Chatbot()
137
- text_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
138
- with gr.Column(scale=1):
139
- audio_output = gr.Audio(label="Generated Audio")
140
 
 
 
141
  set_key_button.click(set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, state])
142
- audio_file_input.change(
143
- process_audio_file,
144
- inputs=[audio_file_input, state],
145
- outputs=[text_input, audio_output, state]
 
 
 
146
  )
147
- text_input.submit(chat, inputs=[text_input, state], outputs=[chatbot, audio_output, state])
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  demo.launch()
 
12
 
13
  @dataclass
14
  class AppState:
15
+ stream: np.ndarray | None = None
16
+ sampling_rate: int = 0
17
+ pause_detected: bool = False
18
+ started_talking: bool = False
19
+ stopped: bool = False
20
  conversation: list = field(default_factory=list)
21
  client: openai.OpenAI = None
22
 
 
29
  api_key=api_key
30
  )
31
 
32
+ def process_audio(audio: tuple, state: AppState):
33
+ if state.stream is None:
34
+ state.stream = audio[1]
35
+ state.sampling_rate = audio[0]
36
+ else:
37
+ state.stream = np.concatenate((state.stream, audio[1]))
38
+
39
+ # Simple pause detection (you might want to implement a more sophisticated method)
40
+ if len(state.stream) > state.sampling_rate * 0.5: # 0.5 second of silence
41
+ state.pause_detected = True
42
+ return gr.Audio(recording=False), state
43
+ return None, state
44
+
45
+ def generate_response_and_audio(audio_bytes: bytes, state: AppState):
46
  if state.client is None:
47
  raise gr.Error("Please enter a valid API key first.")
48
 
49
  format_ = "opus"
50
  bitrate = 16
 
 
 
51
  audio_data = base64.b64encode(audio_bytes).decode()
52
+
53
  try:
54
  stream = state.client.chat.completions.create(
55
  extra_body={
 
65
  stream=True,
66
  )
67
 
68
+ full_response = ""
69
+ audios = []
70
 
71
  for chunk in stream:
72
+ if not chunk.choices:
73
+ continue
74
+ content = chunk.choices[0].delta.content
75
+ audio = getattr(chunk.choices[0], 'audio', [])
76
+ if content:
77
+ full_response += content
78
+ yield full_response, None, state
79
+ if audio:
80
+ audios.extend(audio)
81
+ audio_data = b''.join([base64.b64decode(a) for a in audios])
82
+ yield full_response, audio_data, state
83
+
84
+ state.conversation.append({"role": "user", "content": "Audio input"})
85
+ state.conversation.append({"role": "assistant", "content": full_response})
86
 
87
  except Exception as e:
88
+ raise gr.Error(f"Error during audio streaming: {e}")
89
 
90
+ def response(state: AppState):
91
+ if not state.pause_detected:
92
+ return None, None, AppState()
93
+
94
+ audio_buffer = io.BytesIO()
95
+ segment = AudioSegment(
96
+ state.stream.tobytes(),
97
+ frame_rate=state.sampling_rate,
98
+ sample_width=state.stream.dtype.itemsize,
99
+ channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
100
+ )
101
+ segment.export(audio_buffer, format="wav")
102
 
103
+ return generate_response_and_audio(audio_buffer.getvalue(), state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  def set_api_key(api_key, state):
106
  if not api_key:
 
108
  state.client = create_client(api_key)
109
  return "API key set successfully!", state
110
 
111
+ def start_recording_user(state: AppState):
112
+ if not state.stopped:
113
+ return gr.Audio(recording=True)
114
+
115
  with gr.Blocks() as demo:
 
 
116
  with gr.Row():
117
  api_key_input = gr.Textbox(type="password", label="Enter your Lepton API Key")
118
  set_key_button = gr.Button("Set API Key")
 
120
  api_key_status = gr.Textbox(label="API Key Status", interactive=False)
121
 
122
  with gr.Row():
123
+ with gr.Column():
124
+ input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy")
125
+ with gr.Column():
126
+ chatbot = gr.Chatbot(label="Conversation", type="messages")
127
+ output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True)
 
 
128
 
129
+ state = gr.State(AppState())
130
+
131
  set_key_button.click(set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, state])
132
+
133
+ stream = input_audio.stream(
134
+ process_audio,
135
+ [input_audio, state],
136
+ [input_audio, state],
137
+ stream_every=0.50,
138
+ time_limit=30,
139
  )
 
140
 
141
+ respond = input_audio.stop_recording(
142
+ response,
143
+ [state],
144
+ [chatbot, output_audio, state]
145
+ )
146
+
147
+ restart = output_audio.stop(
148
+ start_recording_user,
149
+ [state],
150
+ [input_audio]
151
+ )
152
+
153
+ cancel = gr.Button("Stop Conversation", variant="stop")
154
+ cancel.click(
155
+ lambda: (AppState(stopped=True), gr.Audio(recording=False)),
156
+ None,
157
+ [state, input_audio],
158
+ cancels=[respond, restart]
159
+ )
160
+
161
  demo.launch()