import gradio as gr import numpy as np import io from pydub import AudioSegment import openai from dataclasses import dataclass, field from threading import Lock import base64 @dataclass class AppState: stream: np.ndarray | None = None sampling_rate: int = 0 pause_detected: bool = False conversation: list = field(default_factory=list) client: openai.OpenAI = None output_format: str = "mp3" stopped: bool = False # Global lock for thread safety state_lock = Lock() def create_client(api_key): return openai.OpenAI( base_url="https://llama3-1-8b.lepton.run/api/v1/", api_key=api_key ) def determine_pause(audio, sampling_rate, state): # Take the last 1 second of audio pause_length = int(sampling_rate * 1) # 1 second if len(audio) < pause_length: return False last_audio = audio[-pause_length:] amplitude = np.abs(last_audio) # Calculate the average amplitude in the last 1 second avg_amplitude = np.mean(amplitude) silence_threshold = 0.01 # Adjust this threshold as needed if avg_amplitude < silence_threshold: return True else: return False def process_audio(audio: tuple, state: AppState): if state.stream is None: state.stream = audio[1] state.sampling_rate = audio[0] else: state.stream = np.concatenate((state.stream, audio[1])) pause_detected = determine_pause(state.stream, state.sampling_rate, state) state.pause_detected = pause_detected if state.pause_detected: return gr.Audio(recording=False), state else: return None, state def generate_response_and_audio(audio_bytes: bytes, state: AppState): if state.client is None: raise gr.Error("Please enter a valid API key first.") format_ = state.output_format bitrate = 128 if format_ == "mp3" else 32 # Higher bitrate for MP3, lower for OPUS audio_data = base64.b64encode(audio_bytes).decode() try: stream = state.client.chat.completions.create( extra_body={ "require_audio": True, "tts_preset_id": "jessica", "tts_audio_format": format_, "tts_audio_bitrate": bitrate, }, model="llama3.1-8b", messages=[ {"role": "user", "content": [{"type": "audio", "data": audio_data}]} ], temperature=0.7, max_tokens=256, stream=True, ) full_response = "" audios = [] for chunk in stream: if not chunk.choices: continue content = chunk.choices[0].delta.content audio = getattr(chunk.choices[0], "audio", []) if content: full_response += content yield full_response, None, state if audio: audios.extend(audio) final_audio = b"".join([base64.b64decode(a) for a in audios]) yield full_response, final_audio, state except Exception as e: raise gr.Error(f"Error during audio streaming: {e}") def response(state: AppState): if state.stream is None or len(state.stream) == 0: return None, None, state audio_buffer = io.BytesIO() segment = AudioSegment( state.stream.tobytes(), frame_rate=state.sampling_rate, sample_width=state.stream.dtype.itemsize, channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]), ) segment.export(audio_buffer, format="wav") generator = generate_response_and_audio(audio_buffer.getvalue(), state) # Process the generator to get the final results final_text = "" final_audio = None for text, audio, updated_state in generator: final_text = text if text else final_text final_audio = audio if audio else final_audio state = updated_state # Update the chatbot with the final conversation state.conversation.append({"role": "user", "content": "Audio input"}) state.conversation.append({"role": "assistant", "content": final_text}) # Reset the audio stream for the next interaction state.stream = None state.pause_detected = False chatbot_output = state.conversation[-2:] # Get the last two messages return chatbot_output, final_audio, state def start_recording_user(state: AppState): if not state.stopped: return gr.Audio(recording=True) else: return gr.Audio(recording=False) def set_api_key(api_key, state): if not api_key: raise gr.Error("Please enter a valid API key.") try: state.client = create_client(api_key) return gr.update(value="API key set successfully!", visible=True), state except Exception as e: return gr.update(value="Connection error", visible=True), state with gr.Blocks() as demo: gr.Markdown("# Lepton LLM Voice Mode") gr.Markdown( "You can find Lepton serverless endpoint API Key at [here](https://dashboard.lepton.ai/workspace-redirect/settings/api-tokens)" ) with gr.Row(): with gr.Column(scale=3): api_key_input = gr.Textbox( type="password", label="Enter your Lepton API Key" ) with gr.Column(scale=1): set_key_button = gr.Button("Set API Key", scale=2, variant="secondary") api_key_status = gr.Textbox( show_label=False, container=False, interactive=False, visible=False ) with gr.Blocks(): input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy") output_audio = gr.Audio(label="Output Audio", autoplay=True) chatbot = gr.Chatbot(label="Conversation", type="messages") cancel = gr.Button("Stop Conversation", variant="stop") state = gr.State(AppState()) set_key_button.click( set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, state], ) stream = input_audio.stream( process_audio, [input_audio, state], [input_audio, state], stream_every=0.25, # Reduced to make it more responsive time_limit=60, # Increased to allow for longer messages ) respond = input_audio.stop_recording( response, [state], [chatbot, output_audio, state] ) # Update the chatbot with the final conversation respond.then(lambda s: s.conversation, [state], [chatbot]) # Automatically restart recording after the assistant's response restart = output_audio.stop(start_recording_user, [state], [input_audio]) # Add a "Stop Conversation" button cancel.click( lambda: (AppState(stopped=True), gr.Audio(recording=False)), None, [state, input_audio], cancels=[respond, restart], ) demo.launch(share=True)