rahulshah63's picture
Update app.py
be049c9
raw
history blame
2.69 kB
import torch
import os
import torchaudio
import gradio as gr
import matplotlib.pyplot as plt
device="cpu"
# Load Nvidia Tacotron2 from Hub
tacotron2 = torch.hub.load(
"NVIDIA/DeepLearningExamples:torchhub",
"nvidia_tacotron2",
model_math='fp32',
pretrained=False,
)
# Load Weights and bias of nepali text
checkpoint_path = os.path.join(os.getcwd(), 'model_E45.ckpt')
state_dict = torch.load(checkpoint_path, map_location=device)
tacotron2.load_state_dict(state_dict)
tacotron2 = tacotron2.to(device)
tacotron2.eval()
# Load Nvidia Waveglow from Hub
waveglow = torch.hub.load(
"NVIDIA/DeepLearningExamples:torchhub",
"nvidia_waveglow",
model_math="fp32",
pretrained=False,
)
checkpoint = torch.hub.load_state_dict_from_url(
"https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth", # noqa: E501
progress=False,
map_location=device,
)
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow = waveglow.to(device)
waveglow.eval()
# Load Nvidia Utils from Hub
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tts_utils')
# sequences, lengths = utils.prepare_input_sequence([text])
def inference(text):
with torch.inference_mode():
sequences, lengths = utils.prepare_input_sequence([text])
sequences = sequences.to(device)
lengths = lengths.to(device)
mel, _, _ = tacotron2.infer(sequences, lengths)
plt.imshow(mel[0].cpu().detach())
plt.axis('off')
plt.savefig("test.png", bbox_inches='tight')
with torch.no_grad():
audio = waveglow.infer(mel)
torchaudio.save("output.wav", audio[0:1].cpu(), sample_rate=22050)
return "output.wav","test.png"
title="TACOTRON 2"
description="Nepali Speech TACOTRON 2: The Tacotron 2 model for generating mel spectrograms from text. To use it, simply add you text or click on one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1712.05884' target='_blank'>Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions</a> | <a href='https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2' target='_blank'>Github Repo</a></p>"
examples=[["life is like a box of chocolates"]]
gr.Interface(inference,"text",[gr.outputs.Audio(type="file",label="Audio"),gr.outputs.Image(type="file",label="Spectrogram")],title=title,description=description,article=article,examples=examples).launch(enable_queue=True)