Spaces:
Running
Running
#!/usr/bin/env python3 | |
import gradio as gr | |
import numpy as np | |
import torch | |
import json | |
import io | |
import soundfile as sf | |
from PIL import Image | |
import matplotlib | |
import joblib | |
from sklearn.decomposition import PCA | |
from collections import OrderedDict | |
import nltk | |
matplotlib.use("Agg") # Use non-interactive backend | |
import matplotlib.pyplot as plt | |
from text2speech import tts_randomized, parse_speed, tts_with_style_vector | |
# Constants and Paths | |
VOICES_JSON_PATH = "voices.json" | |
PCA_MODEL_PATH = "pca_model.pkl" | |
ANNOTATED_FEATURES_PATH = "annotated_features.npy" | |
VECTOR_DIMENSION = 256 | |
ANNOTATED_FEATURES_NAMES = ["Gender", "Tone", "Quality", "Enunciation", "Pace", "Style"] | |
ANNOTATED_FEATURES_INFO = [ | |
"Male | Female", | |
"High | Low", | |
"Noisy | Clean", | |
"Clear | Unclear", | |
"Rapid | Slow", | |
"Colloquial | Formal", | |
] | |
# Download necessary NLTK data | |
nltk.download("punkt_tab") | |
############################################################################## | |
# DEVICE CONFIGURATION | |
############################################################################## | |
# Detect if CUDA is available and set the device accordingly | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
############################################################################## | |
# LOAD PCA MODEL AND ANNOTATED FEATURES | |
############################################################################## | |
# Load PCA model and annotated features | |
try: | |
pca = joblib.load(PCA_MODEL_PATH) | |
print("PCA model loaded successfully.") | |
except FileNotFoundError: | |
print(f"Error: PCA model file '{PCA_MODEL_PATH}' not found.") | |
pca = None | |
try: | |
annotated_features = np.load(ANNOTATED_FEATURES_PATH) | |
print("Annotated features loaded successfully.") | |
except FileNotFoundError: | |
print(f"Error: Annotated features file '{ANNOTATED_FEATURES_PATH}' not found.") | |
annotated_features = None | |
############################################################################## | |
# UTILITY FUNCTIONS | |
############################################################################## | |
def load_voices_json(): | |
"""Load the voices.json file.""" | |
try: | |
with open(VOICES_JSON_PATH, "r") as f: | |
return json.load(f, object_pairs_hook=OrderedDict) | |
except FileNotFoundError: | |
print(f"Warning: {VOICES_JSON_PATH} not found. Creating a new one.") | |
return OrderedDict() | |
except json.JSONDecodeError: | |
print(f"Warning: {VOICES_JSON_PATH} is not valid JSON.") | |
return OrderedDict() | |
def save_voices_json(data, path=VOICES_JSON_PATH): | |
"""Save to voices.json.""" | |
with open(path, "w") as f: | |
json.dump(data, f, indent=2) | |
print(f"Voices saved to '{path}'.") | |
def update_sliders(voice_name): | |
""" | |
Update slider values based on the selected predefined voice using reverse PCA. | |
Args: | |
voice_name (str): The name of the selected voice. | |
Returns: | |
list: A list of PCA component values to set the sliders. | |
""" | |
if not voice_name: | |
# Return default slider values (e.g., zeros) if no voice is selected | |
return [0.0] * len(ANNOTATED_FEATURES_NAMES) | |
voices_data = load_voices_json() | |
if voice_name not in voices_data: | |
print(f"Voice '{voice_name}' not found in {VOICES_JSON_PATH}.") | |
return [0.0] * len(ANNOTATED_FEATURES_NAMES) | |
style_vector = np.array(voices_data[voice_name], dtype=np.float32).reshape(1, -1) | |
if pca is None: | |
print("PCA model is not loaded.") | |
return [0.0] * len(ANNOTATED_FEATURES_NAMES) | |
try: | |
# Transform the style vector into PCA component values | |
pca_components = pca.transform(style_vector)[0] | |
return pca_components.tolist() | |
except Exception as e: | |
print(f"Error transforming style vector to PCA components: {e}") | |
return [0.0] * len(ANNOTATED_FEATURES_NAMES) | |
def generate_audio_with_voice(text, voice_key, speed_val): | |
""" | |
Generate audio using the style vector of the selected predefined voice. | |
Args: | |
text (str): The text to synthesize. | |
voice_key (str): The name of the selected voice. | |
speed_val (float): The speed multiplier. | |
Returns: | |
tuple: (audio_tuple, style_vector) | |
""" | |
try: | |
# Load voices data | |
voices_data = load_voices_json() | |
if voice_key not in voices_data: | |
print(f"Voice '{voice_key}' not found in {VOICES_JSON_PATH}.") | |
return None, None, "Selected voice not found." | |
# Retrieve the style vector for the selected voice | |
style_vector = np.array(voices_data[voice_key], dtype=np.float32).reshape(1, -1) | |
print(f"Selected Voice: {voice_key}") | |
print(f"Style Vector (First 6): {style_vector[0][:6]}") | |
# Convert to torch tensor and move to device | |
style_vec_torch = torch.from_numpy(style_vector).float().to(device) | |
# Generate audio using the TTS model | |
audio_np = tts_with_style_vector( | |
text, | |
style_vec=style_vec_torch, | |
speed=speed_val, | |
alpha=0.3, | |
beta=0.7, | |
diffusion_steps=7, | |
embedding_scale=1.0, | |
) | |
if audio_np is None: | |
print("Audio generation failed.") | |
return None, None, "Audio generation failed." | |
# Prepare audio for Gradio | |
sr = 24000 # Adjust based on your actual sampling rate | |
audio_tuple = (sr, audio_np) | |
# Return audio, image, and style vector | |
return audio_tuple, style_vector.tolist() | |
except Exception as e: | |
print(f"Error in generate_audio_with_voice: {e}") | |
return None, "An error occurred during audio generation." | |
def build_modified_vector(voice_key, top6_values): | |
"""Build a modified style vector by updating top 6 PCA components.""" | |
voices_data = load_voices_json() | |
if voice_key not in voices_data: | |
print(f"Voice '{voice_key}' not found in {VOICES_JSON_PATH}.") | |
return None | |
arr = np.array(voices_data[voice_key], dtype=np.float32).squeeze() | |
if arr.ndim != 1 or arr.shape[0] != VECTOR_DIMENSION: | |
print(f"Voice '{voice_key}' has invalid shape {arr.shape}. Expected (256,).") | |
return None | |
try: | |
# Reconstruct the style vector using inverse PCA | |
pca_components = np.array(top6_values).reshape(1, -1) | |
reconstructed_vec = pca.inverse_transform(pca_components)[0] | |
return reconstructed_vec | |
except Exception as e: | |
print(f"Error reconstructing style vector: {e}") | |
return None | |
def reconstruct_style_vector(pca_components): | |
""" | |
Reconstruct the 256-dimensional style vector from PCA components. | |
""" | |
if pca is None: | |
print("PCA model is not loaded.") | |
return None | |
try: | |
return pca.inverse_transform([pca_components])[0] | |
except Exception as e: | |
print(f"Error during inverse PCA transform: {e}") | |
return None | |
def generate_custom_audio(text, voice_key, randomize, speed_str, *slider_values): | |
""" | |
Generate audio and produce a matplotlib plot of the style vector. | |
Returns: | |
- audio tuple (sr, np_array) for Gradio's Audio | |
- a PIL Image representing the style vector plot | |
- the final style vector as a list for State | |
""" | |
try: | |
speed_val = parse_speed(speed_str) | |
print(f"Parsed speed: {speed_val}") | |
if randomize: | |
# Generate randomized style vector | |
audio_np, random_style_vec = tts_randomized(text, speed=speed_val) | |
if random_style_vec is None: | |
print("Failed to generate randomized style vector.") | |
return None, None, None | |
# Ensure the style vector is flat and on device | |
final_vec = ( | |
random_style_vec.cpu().numpy().flatten() | |
if isinstance(random_style_vec, torch.Tensor) | |
else np.array(random_style_vec).flatten() | |
) | |
print("Randomized Style Vector (First 6):", final_vec[:6]) | |
else: | |
# Reconstruct the style vector from slider values using inverse PCA | |
reconstructed_vec = build_modified_vector(voice_key, slider_values) | |
if reconstructed_vec is None: | |
print( | |
"No reconstructed vector could be constructed, skipping audio generation." | |
) | |
return None, None, None | |
# Convert to torch tensor and move to device | |
style_vec_torch = ( | |
torch.from_numpy(reconstructed_vec).float().unsqueeze(0).to(device) | |
) | |
# Generate audio with the reconstructed style vector | |
audio_np = tts_with_style_vector( | |
text, | |
style_vec=style_vec_torch, | |
speed=speed_val, | |
alpha=0.3, | |
beta=0.7, | |
diffusion_steps=7, | |
embedding_scale=1.0, | |
) | |
final_vec = reconstructed_vec | |
print("Reconstructed Style Vector (First 6):", final_vec[:6]) | |
if audio_np is None: | |
print("Audio generation failed.") | |
return None, None, None | |
# Prepare audio for Gradio | |
sr = 24000 # Adjust based on your actual sampling rate | |
audio_tuple = (sr, audio_np) | |
# Return audio, image, and style vector | |
return audio_tuple, final_vec.tolist() | |
except Exception as e: | |
print(f"Error generating audio and style plot: {e}") | |
return None, None, None | |
def save_style_to_json(style_data, style_name): | |
"""Saves the provided style_data (list of floats) into voices.json under style_name.""" | |
if not style_name.strip(): | |
return "Please enter a new style name before saving." | |
voices_data = load_voices_json() | |
if style_name in voices_data: | |
return ( | |
f"Style name '{style_name}' already exists. Please choose a different name." | |
) | |
# Ensure the style_data has the correct length | |
if len(style_data) != VECTOR_DIMENSION: | |
return f"Style vector length mismatch. Expected {VECTOR_DIMENSION}, got {len(style_data)}." | |
# Save the style vector | |
voices_data[style_name] = style_data | |
save_voices_json(voices_data) | |
return f"Saved style as '{style_name}' in {VOICES_JSON_PATH}." | |
# Gradio Interface Functions | |
def rearrange_voices(new_order): | |
"""Rearrange the voices based on the new_order list.""" | |
voices_data = load_voices_json() | |
new_order_list = [name.strip() for name in new_order.split(",")] | |
if not all(name in voices_data for name in new_order_list): | |
return "Error: New order contains invalid voice names.", list( | |
voices_data.keys() | |
) | |
ordered_data = OrderedDict() | |
for name in new_order_list: | |
ordered_data[name] = voices_data[name] | |
save_voices_json(ordered_data) | |
print(f"Voices rearranged: {list(ordered_data.keys())}") | |
return "Voices rearranged successfully.", list(ordered_data.keys()) | |
def delete_voice(selected): | |
"""Delete voices from the voices.json.""" | |
if not selected: | |
return "No voices selected for deletion.", list(load_voices_json().keys()) | |
voices_data = load_voices_json() | |
for voice_name in selected: | |
if voice_name in voices_data: | |
del voices_data[voice_name] | |
print(f"Voice '{voice_name}' deleted.") | |
save_voices_json(voices_data) | |
return "Deleted selected voices successfully.", list(voices_data.keys()) | |
def upload_new_voices(uploaded_file): | |
"""Upload new voices from a JSON file.""" | |
if uploaded_file is None: | |
return "No file uploaded.", list(load_voices_json().keys()) | |
try: | |
uploaded_data = json.load(uploaded_file) | |
if not isinstance(uploaded_data, dict): | |
return "Invalid JSON format. Expected a dictionary of voices.", list( | |
load_voices_json().keys() | |
) | |
voices_data = load_voices_json() | |
voices_data.update(uploaded_data) | |
save_voices_json(voices_data) | |
print(f"Voices uploaded: {list(uploaded_data.keys())}") | |
return "Voices uploaded successfully.", list(voices_data.keys()) | |
except json.JSONDecodeError: | |
return "Uploaded file is not valid JSON.", list(load_voices_json().keys()) | |
# Create Gradio Interface with Tabs | |
def create_combined_interface(): | |
voices_data = load_voices_json() | |
voice_choices = list(voices_data.keys()) | |
default_voice = voice_choices[0] if voice_choices else None | |
css = """ | |
h4 { | |
text-align: center; | |
display:block; | |
} | |
""" | |
def refresh_voices(): | |
"""Refresh the voices by reloading the JSON.""" | |
new_choices = list(load_voices_json().keys()) | |
print(f"Voices refreshed: {new_choices}") | |
return gr.Dropdown(choices=new_choices) | |
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: | |
gr.Markdown("# StyleTTS2 Studio - Build custom voices") | |
# ----------- Text-to-Speech Tab ----------- | |
with gr.Tab("Text-to-Speech"): | |
gr.Markdown("### Generate Speech with Predefined Voices") | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Text to Synthesize", | |
value="Did you know that you can just do stuff?", | |
lines=3, | |
) | |
voice_dropdown = gr.Dropdown( | |
choices=voice_choices, | |
label="Select Base Voice", | |
value=default_voice, | |
interactive=True, | |
) | |
speed_slider = gr.Slider( | |
minimum=50, | |
maximum=200, | |
step=1, | |
label="Speed (%)", | |
value=100, | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate Audio") | |
audio_output = gr.Audio(label="Synthesized Audio") | |
# Generate button functionality | |
def on_generate_tts(text, voice, speed): | |
if not voice: | |
return None, "No voice selected." | |
speed_val = speed / 100 # Convert percentage to multiplier | |
audio, style_vector = generate_audio_with_voice(text, voice, speed_val) | |
if audio is None: | |
return None, style_vector # style_vector contains the error message | |
return audio, "Audio generated successfully." | |
generate_btn.click( | |
fn=on_generate_tts, | |
inputs=[text_input, voice_dropdown, speed_slider], | |
outputs=[audio_output, gr.Textbox(label="Status", visible=False)], | |
) | |
# ----------- Voice Studio Tab ----------- | |
with gr.Tab("Voice Studio"): | |
gr.Markdown("### Customize and Create New Voices") | |
with gr.Column(): | |
text_input_studio = gr.Textbox( | |
label="Text to Synthesize", | |
value="Use the sliders to customize a voice!", | |
lines=3, | |
) | |
voice_dropdown_studio = gr.Dropdown( | |
choices=voice_choices, | |
label="Select Base Voice", | |
value=default_voice, | |
) | |
speed_slider_studio = gr.Slider( | |
minimum=50, | |
maximum=200, | |
step=1, | |
label="Speed (%)", | |
value=100, | |
) | |
# Sliders for PCA components (6 sliders) | |
pca_sliders = [ | |
gr.Slider( | |
minimum=-2.0, | |
maximum=2.0, | |
value=0.0, | |
step=0.1, | |
label=feature, | |
) | |
for feature in ANNOTATED_FEATURES_NAMES | |
] | |
generate_btn_studio = gr.Button("Generate Customized Audio") | |
audio_output_studio = gr.Audio(label="Customized Synthesized Audio") | |
new_style_name = gr.Textbox(label="New Style Name", value="") | |
save_btn_studio = gr.Button("Save Customized Voice") | |
status_text = gr.Textbox(label="Status", visible=True) | |
# State to hold the last style vector | |
style_vector_state_studio = gr.State() | |
# Generate button functionality | |
def on_generate_studio(text, voice, speed, *pca_values): | |
if not voice: | |
return None, "No voice selected.", None | |
speed_val = speed / 100 # Convert percentage to multiplier | |
result = generate_custom_audio( | |
text, voice, False, speed_val, *pca_values | |
) | |
if result is None: | |
return None, "Failed to generate audio.", None | |
audio_tuple, style_vector = result | |
style_vector_state_studio.value = style_vector | |
return audio_tuple, "Audio generated successfully.", style_vector | |
generate_btn_studio.click( | |
fn=on_generate_studio, | |
inputs=[text_input_studio, voice_dropdown_studio, speed_slider_studio] | |
+ pca_sliders, | |
outputs=[audio_output_studio, status_text, style_vector_state_studio], | |
) | |
# Save button functionality | |
def on_save_style_studio(style_vector, style_name): | |
if not style_name: | |
return ( | |
"Please enter a name for the new voice!", | |
gr.Dropdown.update(), | |
gr.Dropdown.update(), | |
) | |
result = save_style_to_json(style_vector, style_name) | |
new_choices = list(load_voices_json().keys()) | |
# Return multiple values to update both dropdowns and show status | |
return ( | |
gr.Dropdown.update(choices=new_choices), | |
gr.Dropdown.update(choices=new_choices), | |
result, # Status message | |
) | |
save_btn_studio.click( | |
fn=on_save_style_studio, | |
inputs=[style_vector_state_studio, new_style_name], | |
outputs=[voice_dropdown, voice_dropdown_studio, status_text], | |
) | |
# Add callback to update sliders when a voice is selected | |
voice_dropdown_studio.change( | |
fn=update_sliders, | |
inputs=voice_dropdown_studio, | |
outputs=pca_sliders, | |
) | |
gr.Markdown( | |
"#### Based on [StyleTTS2](https://github.com/yl4579/StyleTTS2) and [artificial StyleTTS2](https://huggingface.co/dkounadis/artificial-styletts2/tree/main)" | |
) | |
return demo | |
if __name__ == "__main__": | |
try: | |
interface = create_combined_interface() | |
interface.launch(share=False) | |
except Exception as e: | |
print(f"An error occurred while launching the interface: {e}") | |