fastspeech2-TTS / app.py
StevenLimcorn's picture
Update app.py
a752f69
raw
history blame
4.58 kB
from matplotlib.pyplot import text
import numpy as np
import soundfile as sf
import yaml
import tensorflow as tf
from tensorflow_tts.inference import TFAutoModel
from tensorflow_tts.inference import AutoProcessor
from tensorflow_tts.inference import AutoConfig
import gradio as gr
MODEL_NAMES = [
"Fastspeech2 + Melgan",
"Tacotron2 + Melgan",
"Tacotron2 + MB-Melgan",
"Fastspeech2 + MB-Melgan"
]
fastspeech = TFAutoModel.from_pretrained("tensorspeech/tts-fastspeech-ljspeech-en", name="fastspeech")
fastspeech2 = TFAutoModel.from_pretrained("tensorspeech/tts-fastspeech2-ljspeech-en", name="fastspeech2")
tacotron2 = TFAutoModel.from_pretrained("tensorspeech/tts-tacotron2-ljspeech-en", name="tacotron2")
melgan = TFAutoModel.from_pretrained("tensorspeech/tts-melgan-ljspeech-en", name="melgan")
mb_melgan = TFAutoModel.from_pretrained("tensorspeech/tts-mb_melgan-ljspeech-en", name="mb_melgan")
MODEL_DICT = {
"Fastspeech2" : fastspeech2,
"Tacotron2" : tacotron2,
"Melgan": melgan,
"MB-Melgan": mb_melgan,
}
def inference(input_text, model_type):
text2mel_name, vocoder_name = model_type.split(" + ")
text2mel_model, vocoder_model = MODEL_DICT[text2mel_name], MODEL_DICT[vocoder_name]
processor = AutoProcessor.from_pretrained("tensorspeech/tts-tacotron2-ljspeech-en")
input_ids = processor.text_to_sequence(input_text)
if text2mel_name == "Tacotron2":
_, mel_outputs, stop_token_prediction, alignment_history = text2mel_model.inference(
tf.expand_dims(tf.convert_to_tensor(input_ids, dtype=tf.int32), 0),
tf.convert_to_tensor([len(input_ids)], tf.int32),
tf.convert_to_tensor([0], dtype=tf.int32)
)
elif text2mel_name == "Fastspeech":
mel_before, mel_outputs, duration_outputs = text2mel_model.inference(
input_ids=tf.expand_dims(tf.convert_to_tensor(input_ids, dtype=tf.int32), 0),
speaker_ids=tf.convert_to_tensor([0], dtype=tf.int32),
speed_ratios=tf.convert_to_tensor([1.0], dtype=tf.float32),
)
elif text2mel_name == "Fastspeech2":
mel_before, mel_outputs, duration_outputs, _, _ = text2mel_model.inference(
tf.expand_dims(tf.convert_to_tensor(input_ids, dtype=tf.int32), 0),
speaker_ids=tf.convert_to_tensor([0], dtype=tf.int32),
speed_ratios=tf.convert_to_tensor([1.0], dtype=tf.float32),
f0_ratios=tf.convert_to_tensor([1.0], dtype=tf.float32),
energy_ratios=tf.convert_to_tensor([1.0], dtype=tf.float32),
)
else:
raise ValueError("Only TACOTRON, FASTSPEECH, FASTSPEECH2 are supported on text2mel_name")
# vocoder part
if vocoder_name == "Melgan":
audio = vocoder_model(mel_outputs)[0, :, 0]
elif vocoder_name == "MB-Melgan":
audio = vocoder_model(mel_outputs)[0, :, 0]
else:
raise ValueError("Only MELGAN, MELGAN-STFT and MB_MELGAN are supported on vocoder_name")
# if text2mel_name == "TACOTRON":
# return mel_outputs.numpy(), alignment_history.numpy(), audio.numpy()
# else:
# return mel_outputs.numpy(), audio.numpy()
sf.write('./audio_after.wav', audio, 22050, "PCM_16")
return './audio_after.wav'
inputs = [
gr.inputs.Textbox(lines=5, label="Input Text"),
gr.inputs.Radio(label="Pick a TTS Model",choices=MODEL_NAMES,)
]
outputs = gr.outputs.Audio(type="file", label="Output Audio")
title = "Tensorflow TTS"
description = "Gradio demo for TensorFlowTTS: Real-Time State-of-the-art Speech Synthesis for Tensorflow 2. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://tensorspeech.github.io/TensorFlowTTS/'>TensorFlowTTS: Real-Time State-of-the-art Speech Synthesis for Tensorflow 2</a> | <a href='https://github.com/TensorSpeech/TensorFlowTTS'>Github Repo</a></p><p>An extension to akhaliq's implementation <a href='https://huggingface.co/spaces/akhaliq/TensorFlowTTS'></p>"
examples = [
["Once upon a time there was an old mother pig who had three little pigs and not enough food to feed them. So when they were old enough, she sent them out into the world to seek their fortunes."],
["How much wood would a woodchuck chuck if a woodchuck could chuck wood?"]
]
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples).launch()