mrfakename
commited on
Commit
·
d430de8
1
Parent(s):
dca07a4
Add LJSpeech model
Browse files- app.py +28 -9
- ljspeechimportable.py +226 -0
app.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import styletts2importable
|
|
|
|
|
3 |
theme = gr.themes.Base(
|
4 |
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
|
5 |
)
|
@@ -20,7 +22,32 @@ def synthesize(text, voice):
|
|
20 |
raise gr.Error("Text must be under 500 characters")
|
21 |
v = voice.lower()
|
22 |
return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
with gr.Blocks(title="StyleTTS 2", css="footer{display:none !important}", theme=theme) as demo:
|
25 |
gr.Markdown("""# StyleTTS 2
|
26 |
|
@@ -34,15 +61,7 @@ This space does NOT allow voice cloning. We use some default voice from Tortoise
|
|
34 |
|
35 |
Is there a long queue on this space? Duplicate it and add a GPU to skip the wait!""")
|
36 |
gr.DuplicateButton("Duplicate Space")
|
37 |
-
|
38 |
-
with gr.Column(scale=1):
|
39 |
-
inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
|
40 |
-
voice = gr.Dropdown(['Angie', 'Daniel', 'Tom', 'LJ', 'Pat', 'Tom', 'Dotrice', 'Mouse', 'William'], label="Voice", info="Select a voice. We use some voices from Tortoise TTS.", value='Tom', interactive=True)
|
41 |
-
with gr.Column(scale=1):
|
42 |
-
btn = gr.Button("Synthesize", variant="primary")
|
43 |
-
audio = gr.Audio(interactive=False, label="Synthesized Audio")
|
44 |
-
btn.click(synthesize, inputs=[inp, voice], outputs=[audio], concurrency_limit=4)
|
45 |
-
|
46 |
if __name__ == "__main__":
|
47 |
demo.queue(api_open=False, max_size=15).launch(show_api=False)
|
48 |
|
|
|
1 |
import gradio as gr
|
2 |
import styletts2importable
|
3 |
+
import ljspeechimportable
|
4 |
+
import torch
|
5 |
theme = gr.themes.Base(
|
6 |
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
|
7 |
)
|
|
|
22 |
raise gr.Error("Text must be under 500 characters")
|
23 |
v = voice.lower()
|
24 |
return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1))
|
25 |
+
def ljsynthesize(text):
|
26 |
+
if text.strip() == "":
|
27 |
+
raise gr.Error("You must enter some text")
|
28 |
+
if len(text) > 500:
|
29 |
+
raise gr.Error("Text must be under 500 characters")
|
30 |
+
noise = torch.randn(1,1,256).to('cuda' if torch.cuda.is_available() else 'cpu')
|
31 |
+
return (24000, ljspeechimportable.inference(text, noise, diffusion_steps=7, embedding_scale=1))
|
32 |
+
|
33 |
|
34 |
+
with gr.Blocks() as vctk:
|
35 |
+
with gr.Row():
|
36 |
+
with gr.Column(scale=1):
|
37 |
+
inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
|
38 |
+
voice = gr.Dropdown(['Angie', 'Daniel', 'Tom', 'LJ', 'Pat', 'Tom', 'Dotrice', 'Mouse', 'William'], label="Voice", info="Select a voice. We use some voices from Tortoise TTS.", value='Tom', interactive=True)
|
39 |
+
with gr.Column(scale=1):
|
40 |
+
btn = gr.Button("Synthesize", variant="primary")
|
41 |
+
audio = gr.Audio(interactive=False, label="Synthesized Audio")
|
42 |
+
btn.click(synthesize, inputs=[inp, voice], outputs=[audio], concurrency_limit=4)
|
43 |
+
with gr.Blocks() as lj:
|
44 |
+
with gr.Row():
|
45 |
+
with gr.Column(scale=1):
|
46 |
+
ljinp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
|
47 |
+
with gr.Column(scale=1):
|
48 |
+
ljbtn = gr.Button("Synthesize", variant="primary")
|
49 |
+
ljaudio = gr.Audio(interactive=False, label="Synthesized Audio")
|
50 |
+
ljbtn.click(ljsynthesize, inputs=[ljinp], outputs=[ljaudio], concurrency_limit=4)
|
51 |
with gr.Blocks(title="StyleTTS 2", css="footer{display:none !important}", theme=theme) as demo:
|
52 |
gr.Markdown("""# StyleTTS 2
|
53 |
|
|
|
61 |
|
62 |
Is there a long queue on this space? Duplicate it and add a GPU to skip the wait!""")
|
63 |
gr.DuplicateButton("Duplicate Space")
|
64 |
+
gr.TabbedInterface([vctk, lj], ['Multi-Voice', 'LJSpeech'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
if __name__ == "__main__":
|
66 |
demo.queue(api_open=False, max_size=15).launch(show_api=False)
|
67 |
|
ljspeechimportable.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cached_path import cached_path
|
2 |
+
from dp.phonemizer import Phonemizer
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
torch.manual_seed(0)
|
7 |
+
torch.backends.cudnn.benchmark = False
|
8 |
+
torch.backends.cudnn.deterministic = True
|
9 |
+
|
10 |
+
import random
|
11 |
+
random.seed(0)
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
np.random.seed(0)
|
15 |
+
|
16 |
+
import nltk
|
17 |
+
nltk.download('punkt')
|
18 |
+
|
19 |
+
# load packages
|
20 |
+
import time
|
21 |
+
import random
|
22 |
+
import yaml
|
23 |
+
from munch import Munch
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
from torch import nn
|
27 |
+
import torch.nn.functional as F
|
28 |
+
import torchaudio
|
29 |
+
import librosa
|
30 |
+
from nltk.tokenize import word_tokenize
|
31 |
+
|
32 |
+
from models import *
|
33 |
+
from utils import *
|
34 |
+
from text_utils import TextCleaner
|
35 |
+
textclenaer = TextCleaner()
|
36 |
+
|
37 |
+
|
38 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
39 |
+
|
40 |
+
to_mel = torchaudio.transforms.MelSpectrogram(
|
41 |
+
n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
|
42 |
+
mean, std = -4, 4
|
43 |
+
|
44 |
+
def length_to_mask(lengths):
|
45 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
46 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
47 |
+
return mask
|
48 |
+
|
49 |
+
def preprocess(wave):
|
50 |
+
wave_tensor = torch.from_numpy(wave).float()
|
51 |
+
mel_tensor = to_mel(wave_tensor)
|
52 |
+
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
53 |
+
return mel_tensor
|
54 |
+
|
55 |
+
def compute_style(ref_dicts):
|
56 |
+
reference_embeddings = {}
|
57 |
+
for key, path in ref_dicts.items():
|
58 |
+
wave, sr = librosa.load(path, sr=24000)
|
59 |
+
audio, index = librosa.effects.trim(wave, top_db=30)
|
60 |
+
if sr != 24000:
|
61 |
+
audio = librosa.resample(audio, sr, 24000)
|
62 |
+
mel_tensor = preprocess(audio).to(device)
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
ref = model.style_encoder(mel_tensor.unsqueeze(1))
|
66 |
+
reference_embeddings[key] = (ref.squeeze(1), audio)
|
67 |
+
|
68 |
+
return reference_embeddings
|
69 |
+
|
70 |
+
# load phonemizer
|
71 |
+
# import phonemizer
|
72 |
+
# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore')
|
73 |
+
|
74 |
+
phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
|
75 |
+
|
76 |
+
|
77 |
+
config = yaml.safe_load(open(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/config.yml'))))
|
78 |
+
|
79 |
+
# load pretrained ASR model
|
80 |
+
ASR_config = config.get('ASR_config', False)
|
81 |
+
ASR_path = config.get('ASR_path', False)
|
82 |
+
text_aligner = load_ASR_models(ASR_path, ASR_config)
|
83 |
+
|
84 |
+
# load pretrained F0 model
|
85 |
+
F0_path = config.get('F0_path', False)
|
86 |
+
pitch_extractor = load_F0_models(F0_path)
|
87 |
+
|
88 |
+
# load BERT model
|
89 |
+
from Utils.PLBERT.util import load_plbert
|
90 |
+
BERT_path = config.get('PLBERT_dir', False)
|
91 |
+
plbert = load_plbert(BERT_path)
|
92 |
+
|
93 |
+
model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
|
94 |
+
_ = [model[key].eval() for key in model]
|
95 |
+
_ = [model[key].to(device) for key in model]
|
96 |
+
|
97 |
+
# params_whole = torch.load("Models/LJSpeech/epoch_2nd_00100.pth", map_location='cpu')
|
98 |
+
params_whole = torch.load(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/epoch_2nd_00100.pth')), map_location='cpu')
|
99 |
+
params = params_whole['net']
|
100 |
+
|
101 |
+
for key in model:
|
102 |
+
if key in params:
|
103 |
+
print('%s loaded' % key)
|
104 |
+
try:
|
105 |
+
model[key].load_state_dict(params[key])
|
106 |
+
except:
|
107 |
+
from collections import OrderedDict
|
108 |
+
state_dict = params[key]
|
109 |
+
new_state_dict = OrderedDict()
|
110 |
+
for k, v in state_dict.items():
|
111 |
+
name = k[7:] # remove `module.`
|
112 |
+
new_state_dict[name] = v
|
113 |
+
# load params
|
114 |
+
model[key].load_state_dict(new_state_dict, strict=False)
|
115 |
+
# except:
|
116 |
+
# _load(params[key], model[key])
|
117 |
+
_ = [model[key].eval() for key in model]
|
118 |
+
|
119 |
+
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
|
120 |
+
|
121 |
+
sampler = DiffusionSampler(
|
122 |
+
model.diffusion.diffusion,
|
123 |
+
sampler=ADPM2Sampler(),
|
124 |
+
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
|
125 |
+
clamp=False
|
126 |
+
)
|
127 |
+
|
128 |
+
def inference(text, noise, diffusion_steps=5, embedding_scale=1):
|
129 |
+
text = text.strip()
|
130 |
+
text = text.replace('"', '')
|
131 |
+
ps = phonemizer([text], lang='en_us')
|
132 |
+
ps = word_tokenize(ps[0])
|
133 |
+
ps = ' '.join(ps)
|
134 |
+
|
135 |
+
tokens = textclenaer(ps)
|
136 |
+
tokens.insert(0, 0)
|
137 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
138 |
+
|
139 |
+
with torch.no_grad():
|
140 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
|
141 |
+
text_mask = length_to_mask(input_lengths).to(tokens.device)
|
142 |
+
|
143 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
144 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
145 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
146 |
+
|
147 |
+
s_pred = sampler(noise,
|
148 |
+
embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
|
149 |
+
embedding_scale=embedding_scale).squeeze(0)
|
150 |
+
|
151 |
+
s = s_pred[:, 128:]
|
152 |
+
ref = s_pred[:, :128]
|
153 |
+
|
154 |
+
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
155 |
+
|
156 |
+
x, _ = model.predictor.lstm(d)
|
157 |
+
duration = model.predictor.duration_proj(x)
|
158 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
159 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
160 |
+
|
161 |
+
pred_dur[-1] += 5
|
162 |
+
|
163 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
164 |
+
c_frame = 0
|
165 |
+
for i in range(pred_aln_trg.size(0)):
|
166 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
167 |
+
c_frame += int(pred_dur[i].data)
|
168 |
+
|
169 |
+
# encode prosody
|
170 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
171 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
172 |
+
out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),
|
173 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
174 |
+
|
175 |
+
return out.squeeze().cpu().numpy()
|
176 |
+
|
177 |
+
def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):
|
178 |
+
text = text.strip()
|
179 |
+
text = text.replace('"', '')
|
180 |
+
ps = phonemizer([text], lang='en_us')
|
181 |
+
ps = word_tokenize(ps[0])
|
182 |
+
ps = ' '.join(ps)
|
183 |
+
|
184 |
+
tokens = textclenaer(ps)
|
185 |
+
tokens.insert(0, 0)
|
186 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
187 |
+
|
188 |
+
with torch.no_grad():
|
189 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
|
190 |
+
text_mask = length_to_mask(input_lengths).to(tokens.device)
|
191 |
+
|
192 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
193 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
194 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
195 |
+
|
196 |
+
s_pred = sampler(noise,
|
197 |
+
embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
|
198 |
+
embedding_scale=embedding_scale).squeeze(0)
|
199 |
+
|
200 |
+
if s_prev is not None:
|
201 |
+
# convex combination of previous and current style
|
202 |
+
s_pred = alpha * s_prev + (1 - alpha) * s_pred
|
203 |
+
|
204 |
+
s = s_pred[:, 128:]
|
205 |
+
ref = s_pred[:, :128]
|
206 |
+
|
207 |
+
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
208 |
+
|
209 |
+
x, _ = model.predictor.lstm(d)
|
210 |
+
duration = model.predictor.duration_proj(x)
|
211 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
212 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
213 |
+
|
214 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
215 |
+
c_frame = 0
|
216 |
+
for i in range(pred_aln_trg.size(0)):
|
217 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
218 |
+
c_frame += int(pred_dur[i].data)
|
219 |
+
|
220 |
+
# encode prosody
|
221 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
222 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
223 |
+
out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),
|
224 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
225 |
+
|
226 |
+
return out.squeeze().cpu().numpy(), s_pred
|