import gradio as gr import numpy as np import io from pydub import AudioSegment import tempfile import openai import time from dataclasses import dataclass, field from threading import Lock import base64 import uuid import os print(os.getenv("API_KEY")) @dataclass class AppState: stream: np.ndarray | None = None sampling_rate: int = 0 pause_detected: bool = False conversation: list = field(default_factory=list) client: openai.OpenAI = None output_format: str = "mp3" stopped: bool = False # Global lock for thread safety state_lock = Lock() def create_client(api_key): return openai.OpenAI( base_url="https://llama3-1-8b.lepton.run/api/v1/", api_key=api_key ) def test_api_key(client): # Try making a simple request to check if the API key works try: # Attempt to retrieve available models as a test client.models.list() except Exception as e: raise e def determine_pause(audio, sampling_rate, state): # Take the last 1 second of audio pause_length = int(sampling_rate * 1) # 1 second if len(audio) < pause_length: return False last_audio = audio[-pause_length:] amplitude = np.abs(last_audio) # Calculate the average amplitude in the last 1 second avg_amplitude = np.mean(amplitude) silence_threshold = 0.01 # Adjust this threshold as needed if avg_amplitude < silence_threshold: return True else: return False def process_audio(audio: tuple, state: AppState): if state.stream is None: state.stream = audio[1] state.sampling_rate = audio[0] else: state.stream = np.concatenate((state.stream, audio[1])) pause_detected = determine_pause(state.stream, state.sampling_rate, state) state.pause_detected = pause_detected if state.pause_detected: return gr.Audio(recording=False), state else: return None, state def update_or_append_conversation(conversation, id, role, new_content): for entry in conversation: if entry["id"] == id and entry["role"] == role: entry["content"] = new_content return conversation.append({"id": id, "role": role, "content": new_content}) def generate_response_and_audio(audio_bytes: bytes, state: AppState): if state.client is None: raise gr.Error("Please enter a valid API key first.") format_ = state.output_format bitrate = 128 if format_ == "mp3" else 32 # Higher bitrate for MP3, lower for OPUS audio_data = base64.b64encode(audio_bytes).decode() old_messages = [] for item in state.conversation: old_messages.append({"role": item["role"], "content": item["content"]}) old_messages.append( {"role": "user", "content": [{"type": "audio", "data": audio_data}]} ) try: stream = state.client.chat.completions.create( extra_body={ "require_audio": True, "tts_preset_id": "jessica", "tts_audio_format": format_, "tts_audio_bitrate": bitrate, }, model="llama3.1-8b", messages=old_messages, temperature=0.7, max_tokens=256, stream=True, ) full_response = "" asr_result = "" audios = [] id = uuid.uuid4() for chunk in stream: if not chunk.choices: continue content = chunk.choices[0].delta.content audio = getattr(chunk.choices[0], "audio", []) asr_results = getattr(chunk.choices[0], "asr_results", []) if asr_results: asr_result += "".join(asr_results) yield id, full_response, asr_result, None, state if content: full_response += content yield id, full_response, asr_result, None, state if audio: audios.extend(audio) final_audio = b"".join([base64.b64decode(a) for a in audios]) yield id, full_response, asr_result, final_audio, state except Exception as e: raise gr.Error(f"Error during audio streaming: {e}") def response(state: AppState): if state.stream is None or len(state.stream) == 0: return None, None, state audio_buffer = io.BytesIO() segment = AudioSegment( state.stream.tobytes(), frame_rate=state.sampling_rate, sample_width=state.stream.dtype.itemsize, channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]), ) segment.export(audio_buffer, format="wav") generator = generate_response_and_audio(audio_buffer.getvalue(), state) for id, text, asr, audio, updated_state in generator: state = updated_state if asr: update_or_append_conversation(state.conversation, id, "user", asr) if text: update_or_append_conversation(state.conversation, id, "assistant", text) chatbot_output = state.conversation yield chatbot_output, audio, state # Reset the audio stream for the next interaction state.stream = None state.pause_detected = False def set_api_key(api_key, state): try: state.client = create_client(api_key) test_api_key(state.client) # Test the provided API key api_key_status = gr.update(value="API key set successfully!", visible=True) api_key_input = gr.update(visible=False) set_key_button = gr.update(visible=False) return api_key_status, api_key_input, set_key_button, state except Exception as e: api_key_status = gr.update( value="Invalid API key. Please try again.", visible=True ) return api_key_status, None, None, state def initial_setup(state): api_key = os.getenv("API_KEY") if api_key: try: state.client = create_client(api_key) test_api_key(state.client) # Test the API key from the environment variable api_key_status = gr.update(value="Use default API key", visible=True) api_key_input = gr.update(visible=False) set_key_button = gr.update(visible=False) return api_key_status, api_key_input, set_key_button, state except Exception as e: # Failed to use the api_key, show input box api_key_status = gr.update( value="Failed to use API key from environment variable. Please enter a valid API key.", visible=True, ) api_key_input = gr.update(visible=True) set_key_button = gr.update(visible=True) return api_key_status, api_key_input, set_key_button, state else: # No API key in environment variable api_key_status = gr.update(visible=False) api_key_input = gr.update(visible=True) set_key_button = gr.update(visible=True) return api_key_status, api_key_input, set_key_button, state with gr.Blocks() as demo: gr.Markdown("# Lepton AI LLM Voice Mode") gr.Markdown( "You can find Lepton AI LLM voice doc [here](https://www.lepton.ai/playground/chat/llama-3.2-3b)" ) with gr.Row(): with gr.Column(scale=3): api_key_input = gr.Textbox( type="password", placeholder="Enter your Lepton API Key", show_label=False, container=False, ) with gr.Column(scale=1): set_key_button = gr.Button("Set API Key", scale=2, variant="primary") api_key_status = gr.Textbox( show_label=False, container=False, interactive=False, visible=False ) with gr.Blocks(): with gr.Row(): input_audio = gr.Audio( label="Input Audio", sources="microphone", type="numpy" ) output_audio = gr.Audio(label="Output Audio", autoplay=True) chatbot = gr.Chatbot(label="Conversation", type="messages") cancel = gr.Button("Stop Conversation", variant="stop") state = gr.State(AppState()) # Initial setup to set API key from environment variable demo.load( initial_setup, inputs=state, outputs=[api_key_status, api_key_input, set_key_button, state], ) set_key_button.click( set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, api_key_input, set_key_button, state], ) stream = input_audio.stream( process_audio, [input_audio, state], [input_audio, state], stream_every=0.25, # Reduced to make it more responsive time_limit=60, # Increased to allow for longer messages ) respond = input_audio.stop_recording( response, [state], [chatbot, output_audio, state] ) # Update the chatbot with the final conversation respond.then(lambda s: s.conversation, [state], [chatbot]) # Add a "Stop Conversation" button cancel.click( lambda: (AppState(stopped=True), gr.Audio(recording=False)), None, [state, input_audio], cancels=[respond], ) demo.launch()