""" utils.py Functions: - generate_script: Get the dialogue from the LLM. - call_llm: Call the LLM with the given prompt and dialogue format. - parse_url: Parse the given URL and return the text content. - generate_podcast_audio: Generate audio for podcast using TTS or advanced audio models. - _use_suno_model: Generate advanced audio using Bark. - _use_melotts_api: Generate audio using TTS model. - _get_melo_tts_params: Get TTS parameters based on speaker and language. """ # Standard library imports import time from typing import Any, Union # Third-party imports import instructor import requests from bark import SAMPLE_RATE, generate_audio, preload_models from fireworks.client import Fireworks from gradio_client import Client from scipy.io.wavfile import write as write_wav # Local imports from constants import ( FIREWORKS_API_KEY, FIREWORKS_MODEL_ID, FIREWORKS_MAX_TOKENS, FIREWORKS_TEMPERATURE, MELO_API_NAME, MELO_TTS_SPACES_ID, MELO_RETRY_ATTEMPTS, MELO_RETRY_DELAY, JINA_READER_URL, JINA_RETRY_ATTEMPTS, JINA_RETRY_DELAY, ) from schema import ShortDialogue, MediumDialogue # Initialize Fireworks client, with Instructor patch fw_client = Fireworks(api_key=FIREWORKS_API_KEY) fw_client = instructor.from_fireworks(fw_client) # Initialize Hugging Face client hf_client = Client(MELO_TTS_SPACES_ID) # Download and load all models for Bark preload_models() def generate_script( system_prompt: str, input_text: str, output_model: Union[ShortDialogue, MediumDialogue], ) -> Union[ShortDialogue, MediumDialogue]: """Get the dialogue from the LLM.""" # Call the LLM for the first time first_draft_dialogue = call_llm(system_prompt, input_text, output_model) # Call the LLM a second time to improve the dialogue system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue.model_dump_json()}." final_dialogue = call_llm(system_prompt_with_dialogue, "Please improve the dialogue. Make it more natural and engaging.", output_model) return final_dialogue def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any: """Call the LLM with the given prompt and dialogue format.""" response = fw_client.chat.completions.create( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": text}, ], model=FIREWORKS_MODEL_ID, max_tokens=FIREWORKS_MAX_TOKENS, temperature=FIREWORKS_TEMPERATURE, response_model=dialogue_format, ) return response def parse_url(url: str) -> str: """Parse the given URL and return the text content.""" for attempt in range(JINA_RETRY_ATTEMPTS): try: full_url = f"{JINA_READER_URL}{url}" response = requests.get(full_url, timeout=60) response.raise_for_status() # Raise an exception for bad status codes break except requests.RequestException as e: if attempt == JINA_RETRY_ATTEMPTS - 1: # Last attempt raise ValueError( f"Failed to fetch URL after {JINA_RETRY_ATTEMPTS} attempts: {e}" ) from e time.sleep(JINA_RETRY_DELAY) # Wait for X second before retrying return response.text def generate_podcast_audio( text: str, speaker: str, language: str, use_advanced_audio: bool, random_voice_number: int ) -> str: """Generate audio for podcast using TTS or advanced audio models.""" if use_advanced_audio: return _use_suno_model(text, speaker, language, random_voice_number) else: return _use_melotts_api(text, speaker, language) def _use_suno_model(text: str, speaker: str, language: str, random_voice_number: int) -> str: """Generate advanced audio using Bark.""" host_voice_num = str(random_voice_number) guest_voice_num = str(random_voice_number + 1) audio_array = generate_audio( text, history_prompt=f"v2/{language}_speaker_{host_voice_num if speaker == 'Host (Jane)' else guest_voice_num}", ) file_path = f"audio_{language}_{speaker}.mp3" write_wav(file_path, SAMPLE_RATE, audio_array) return file_path def _use_melotts_api(text: str, speaker: str, language: str) -> str: """Generate audio using TTS model.""" accent, speed = _get_melo_tts_params(speaker, language) for attempt in range(MELO_RETRY_ATTEMPTS): try: return hf_client.predict( text=text, language=language, speaker=accent, speed=speed, api_name=MELO_API_NAME, ) except Exception as e: if attempt == MELO_RETRY_ATTEMPTS - 1: # Last attempt raise # Re-raise the last exception if all attempts fail time.sleep(MELO_RETRY_DELAY) # Wait for X second before retrying def _get_melo_tts_params(speaker: str, language: str) -> tuple[str, float]: """Get TTS parameters based on speaker and language.""" if speaker == "Guest": accent = "EN-US" if language == "EN" else language speed = 0.9 else: # host accent = "EN-Default" if language == "EN" else language speed = ( 1.1 if language != "EN" else 1 ) # if the language is not English, try speeding up so it'll sound different from the host # for non-English, there is only one voice return accent, speed