import os import torch import numpy as np import time import matplotlib.pyplot as plt from typing import Tuple, List from statistics import mean, median, stdev from lib import ( normalize_text, chunk_text, count_tokens, load_module_from_file, download_model_files, list_voice_files, download_voice_files, ensure_dir, concatenate_audio_chunks ) import spaces class TTSModel: """GPU-accelerated TTS model manager""" def __init__(self): self.model = None self.voices_dir = "voices" self.model_repo = "hexgrad/Kokoro-82M" ensure_dir(self.voices_dir) self.model_path = None # Load required modules py_modules = ["istftnet", "plbert", "models", "kokoro"] module_files = download_model_files(self.model_repo, [f"{m}.py" for m in py_modules]) for module_name, file_path in zip(py_modules, module_files): load_module_from_file(module_name, file_path) # Import required functions from kokoro module kokoro = __import__("kokoro") self.generate = kokoro.generate self.build_model = __import__("models").build_model def initialize(self) -> bool: """Initialize model and download voices""" try: print("Initializing model...") # Download model files model_files = download_model_files( self.model_repo, ["kokoro-v0_19.pth", "config.json"] ) self.model_path = model_files[0] # kokoro-v0_19.pth # Download voice files download_voice_files(self.model_repo, "voices", self.voices_dir) # Get list of available voices available_voices = self.list_voices() print("Model initialization complete") return True except Exception as e: print(f"Error initializing model: {str(e)}") return False def ensure_voice_downloaded(self, voice_name: str) -> bool: """Ensure specific voice is downloaded""" try: voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt") if not os.path.exists(voice_path): print(f"Downloading voice {voice_name}.pt...") download_voice_files(self.model_repo, [f"{voice_name}.pt"], self.voices_dir) return True except Exception as e: print(f"Error downloading voice {voice_name}: {str(e)}") return False def list_voices(self) -> List[str]: """List available voices""" voices = [] voices_subdir = os.path.join(self.voices_dir, "voices") if os.path.exists(voices_subdir): for file in os.listdir(voices_subdir): if file.endswith(".pt"): voice_name = file[:-3] voices.append(voice_name) return voices # def _ensure_model_on_gpu(self) -> None: # """Ensure model is on GPU and stays there""" # if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu: # print("Moving model to GPU...") # with torch.cuda.device(0): # torch.cuda.set_device(0) # if hasattr(self.model, 'to'): # self.model.to('cuda') # else: # for name in self.model: # if isinstance(self.model[name], torch.Tensor): # self.model[name] = self.model[name].cuda() # self._model_on_gpu = True def _generate_audio(self, text: str, voicepack: torch.Tensor, lang: str, speed: float) -> np.ndarray: """GPU-accelerated audio generation""" try: with torch.cuda.device(0): torch.cuda.set_device(0) try: # Build model if needed if self.model is None: print("Building model...") device = torch.device('cuda') self.model = self.build_model(self.model_path, device=device) if self.model is None: raise ValueError("Failed to build model") print("Model built successfully") # Move model to GPU if needed if not hasattr(self.model, '_on_gpu'): print("Moving model to GPU...") if hasattr(self.model, 'to'): self.model = self.model.to('cuda') else: for name in self.model: if isinstance(self.model[name], torch.Tensor): self.model[name] = self.model[name].cuda() self.model._on_gpu = True except Exception as e: print(f"Error building model: {str(e)}") print("Attempting to continue") raise e # Move voicepack to GPU voicepack = voicepack.cuda() # Run generation with everything on GPU audio, _ = self.generate( self.model, text, voicepack, lang=lang, speed=speed ) return audio except Exception as e: print(f"Error in audio generation: {str(e)}") raise e @spaces.GPU(duration=None) # Duration will be set by the UI def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]: """Generate speech from text. Returns (audio_array, duration) Args: text: Input text to convert to speech voice_name: Name of voice to use speed: Speech speed multiplier progress_callback: Optional callback function(chunk_num, total_chunks, tokens_per_sec, rtf, progress_state, start_time, gpu_timeout, progress) progress_state: Dictionary tracking generation progress metrics progress: Progress callback from Gradio """ try: start_time = time.time() with torch.cuda.device(0): torch.cuda.set_device(0) if not text or not voice_names: raise ValueError("Text and voice name are required") # Build model directly on GPU # Build model if needed if self.model is None: print("Building model...") self.model = self.build_model(self.model_path, device='cuda') if self.model is None: raise ValueError("Failed to build model") print("Model built successfully") # Move model to GPU if needed if not hasattr(self.model, '_on_gpu'): print("Moving model to GPU...") if hasattr(self.model, 'to'): self.model = self.model.to('cuda') else: for name in self.model: if isinstance(self.model[name], torch.Tensor): self.model[name] = self.model[name].cuda() self.model._on_gpu = True t_voices = [] if isinstance(voice_names, list) and len(voice_names) > 1: for voice in voice_names: try: voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt") voicepack = torch.load(voice_path, weights_only=True) t_voices.append(voicepack) except Exception as e: print(f"Warning: Failed to load voice {voice}: {str(e)}") # Combine voices by taking mean voicepack = torch.mean(torch.stack(t_voices), dim=0) voice_name = "_".join(voice_names) else: voice_name = voice_names[0] voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt") voicepack = torch.load(voice_path, weights_only=True) # Count tokens and normalize text total_tokens = count_tokens(text) text = normalize_text(text) if not text: raise ValueError("Text is empty after normalization") # Break text into chunks for better memory management chunks = chunk_text(text) print(f"Processing {len(chunks)} chunks...") # Process all chunks within same GPU context audio_chunks = [] chunk_times = [] chunk_sizes = [] # Store chunk lengths total_processed_tokens = 0 total_processed_time = 0 for i, chunk in enumerate(chunks): chunk_start = time.time() chunk_audio = self._generate_audio( text=chunk, voicepack=voicepack, lang=voice_name[0], speed=speed ) chunk_time = time.time() - chunk_start # Calculate per-chunk metrics chunk_tokens = count_tokens(chunk) chunk_tokens_per_sec = chunk_tokens / chunk_time # Update totals for overall stats total_processed_tokens += chunk_tokens total_processed_time += chunk_time # Calculate processing speed metrics chunk_duration = len(chunk_audio) / 24000 # audio duration in seconds rtf = chunk_time / chunk_duration times_faster = 1 / rtf chunk_times.append(chunk_time) chunk_sizes.append(len(chunk)) print(f"Chunk {i+1}/{len(chunks)} processed in {chunk_time:.2f}s") print(f"Current tokens/sec: {chunk_tokens_per_sec:.2f}") print(f"Real-time factor: {rtf:.2f}x") print(f"{times_faster:.1f}x faster than real-time") audio_chunks.append(chunk_audio) # Call progress callback if provided if progress_callback: progress_callback( i + 1, # chunk_num len(chunks), # total_chunks chunk_tokens_per_sec, # Pass per-chunk rate instead of cumulative rtf, progress_state, # Added start_time, # Added gpu_timeout, # Use the timeout value from UI progress # Added ) # Concatenate audio chunks audio = concatenate_audio_chunks(audio_chunks) # Return audio and metrics return ( audio, # Audio array len(audio) / 24000, # Duration { "chunk_times": chunk_times, "chunk_sizes": chunk_sizes, "tokens_per_sec": [float(x) for x in progress_state["tokens_per_sec"]], "rtf": [float(x) for x in progress_state["rtf"]], "total_tokens": total_tokens, "total_time": time.time() - start_time } ) except Exception as e: print(f"Error generating speech: {str(e)}") raise