Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import torch | |
from transformers import VitsModel, VitsTokenizer, set_seed | |
title = """ | |
<div style="text-align: center; max-width: 700px; margin: 0 auto;"> | |
<div | |
style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;" | |
> <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;"> | |
VITS TTS Demo | |
</h1> </div> | |
</div> | |
""" | |
description = """ | |
VITS is an end-to-end speech synthesis model that predicts a speech waveform conditional on an input text sequence. It is a conditional variational autoencoder (VAE) comprised of a posterior encoder, decoder, and conditional prior. | |
This demo showcases the official VITS checkpoints, trained on the [LJSpeech](https://huggingface.co/kakao-enterprise/vits-ljs) and [VCTK](https://huggingface.co/kakao-enterprise/vits-vctk) datasets. | |
""" | |
article = "Model by Jaehyeon Kim et al. from Kakao Enterprise. Code and demo by 🤗 Hugging Face." | |
ljs_model = VitsModel.from_pretrained("kakao-enterprise/vits-ljs") | |
ljs_tokenizer = VitsTokenizer.from_pretrained("kakao-enterprise/vits-ljs") | |
vctk_model = VitsModel.from_pretrained("kakao-enterprise/vits-vctk") | |
vctk_tokenizer = VitsTokenizer.from_pretrained("kakao-enterprise/vits-vctk") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
ljs_model.to(device) | |
vctk_model.to(device) | |
def ljs_forward(text, speaking_rate=1.0): | |
inputs = ljs_tokenizer(text, return_tensors="pt") | |
ljs_model.speaking_rate = speaking_rate | |
set_seed(555) | |
with torch.no_grad(): | |
outputs = ljs_model(**inputs)[0] | |
waveform = outputs[0].cpu().float().numpy() | |
return gr.make_waveform((22050, waveform)) | |
def vctk_forward(text, speaking_rate=1.0, speaker_id=1): | |
inputs = vctk_tokenizer(text, return_tensors="pt") | |
vctk_model.speaking_rate = speaking_rate | |
set_seed(555) | |
with torch.no_grad(): | |
outputs = vctk_model(**inputs, speaker_id=speaker_id - 1)[0] | |
waveform = outputs[0].cpu().float().numpy() | |
return gr.make_waveform((22050, waveform)) | |
ljs_inference = gr.Interface( | |
fn=ljs_forward, | |
inputs=[ | |
gr.Textbox( | |
value="Hey, it's Hugging Face on the phone", | |
max_lines=1, | |
label="Input text", | |
), | |
gr.Slider( | |
0.5, | |
1.5, | |
value=1, | |
step=0.1, | |
label="Speaking rate", | |
), | |
], | |
outputs=gr.Audio(), | |
) | |
vctk_inference = gr.Interface( | |
fn=vctk_forward, | |
inputs=[ | |
gr.Textbox( | |
value="Hey, it's Hugging Face on the phone", | |
max_lines=1, | |
label="Input text", | |
), | |
gr.Slider( | |
0.5, | |
1.5, | |
value=1, | |
step=0.1, | |
label="Speaking rate", | |
), | |
gr.Slider( | |
1, | |
vctk_model.config.num_speakers, | |
value=1, | |
step=1, | |
label="Speaker id", | |
info=f"The VCTK model is trained on {vctk_model.config.num_speakers} speakers. You can prompt the model using one of these speaker ids.", | |
), | |
], | |
outputs=gr.Audio(), | |
) | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.TabbedInterface([ljs_inference, vctk_inference], ["LJ Speech", "VCTK"]) | |
gr.Markdown(article) | |
demo.queue(max_size=10) | |
demo.launch() | |