mrfakename commited on
Commit
d430de8
1 Parent(s): dca07a4

Add LJSpeech model

Browse files
Files changed (2) hide show
  1. app.py +28 -9
  2. ljspeechimportable.py +226 -0
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import gradio as gr
2
  import styletts2importable
 
 
3
  theme = gr.themes.Base(
4
  font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
5
  )
@@ -20,7 +22,32 @@ def synthesize(text, voice):
20
  raise gr.Error("Text must be under 500 characters")
21
  v = voice.lower()
22
  return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1))
 
 
 
 
 
 
 
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  with gr.Blocks(title="StyleTTS 2", css="footer{display:none !important}", theme=theme) as demo:
25
  gr.Markdown("""# StyleTTS 2
26
 
@@ -34,15 +61,7 @@ This space does NOT allow voice cloning. We use some default voice from Tortoise
34
 
35
  Is there a long queue on this space? Duplicate it and add a GPU to skip the wait!""")
36
  gr.DuplicateButton("Duplicate Space")
37
- with gr.Row():
38
- with gr.Column(scale=1):
39
- inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
40
- voice = gr.Dropdown(['Angie', 'Daniel', 'Tom', 'LJ', 'Pat', 'Tom', 'Dotrice', 'Mouse', 'William'], label="Voice", info="Select a voice. We use some voices from Tortoise TTS.", value='Tom', interactive=True)
41
- with gr.Column(scale=1):
42
- btn = gr.Button("Synthesize", variant="primary")
43
- audio = gr.Audio(interactive=False, label="Synthesized Audio")
44
- btn.click(synthesize, inputs=[inp, voice], outputs=[audio], concurrency_limit=4)
45
-
46
  if __name__ == "__main__":
47
  demo.queue(api_open=False, max_size=15).launch(show_api=False)
48
 
 
1
  import gradio as gr
2
  import styletts2importable
3
+ import ljspeechimportable
4
+ import torch
5
  theme = gr.themes.Base(
6
  font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
7
  )
 
22
  raise gr.Error("Text must be under 500 characters")
23
  v = voice.lower()
24
  return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1))
25
+ def ljsynthesize(text):
26
+ if text.strip() == "":
27
+ raise gr.Error("You must enter some text")
28
+ if len(text) > 500:
29
+ raise gr.Error("Text must be under 500 characters")
30
+ noise = torch.randn(1,1,256).to('cuda' if torch.cuda.is_available() else 'cpu')
31
+ return (24000, ljspeechimportable.inference(text, noise, diffusion_steps=7, embedding_scale=1))
32
+
33
 
34
+ with gr.Blocks() as vctk:
35
+ with gr.Row():
36
+ with gr.Column(scale=1):
37
+ inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
38
+ voice = gr.Dropdown(['Angie', 'Daniel', 'Tom', 'LJ', 'Pat', 'Tom', 'Dotrice', 'Mouse', 'William'], label="Voice", info="Select a voice. We use some voices from Tortoise TTS.", value='Tom', interactive=True)
39
+ with gr.Column(scale=1):
40
+ btn = gr.Button("Synthesize", variant="primary")
41
+ audio = gr.Audio(interactive=False, label="Synthesized Audio")
42
+ btn.click(synthesize, inputs=[inp, voice], outputs=[audio], concurrency_limit=4)
43
+ with gr.Blocks() as lj:
44
+ with gr.Row():
45
+ with gr.Column(scale=1):
46
+ ljinp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
47
+ with gr.Column(scale=1):
48
+ ljbtn = gr.Button("Synthesize", variant="primary")
49
+ ljaudio = gr.Audio(interactive=False, label="Synthesized Audio")
50
+ ljbtn.click(ljsynthesize, inputs=[ljinp], outputs=[ljaudio], concurrency_limit=4)
51
  with gr.Blocks(title="StyleTTS 2", css="footer{display:none !important}", theme=theme) as demo:
52
  gr.Markdown("""# StyleTTS 2
53
 
 
61
 
62
  Is there a long queue on this space? Duplicate it and add a GPU to skip the wait!""")
63
  gr.DuplicateButton("Duplicate Space")
64
+ gr.TabbedInterface([vctk, lj], ['Multi-Voice', 'LJSpeech'])
 
 
 
 
 
 
 
 
65
  if __name__ == "__main__":
66
  demo.queue(api_open=False, max_size=15).launch(show_api=False)
67
 
ljspeechimportable.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cached_path import cached_path
2
+ from dp.phonemizer import Phonemizer
3
+
4
+
5
+ import torch
6
+ torch.manual_seed(0)
7
+ torch.backends.cudnn.benchmark = False
8
+ torch.backends.cudnn.deterministic = True
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+ import numpy as np
14
+ np.random.seed(0)
15
+
16
+ import nltk
17
+ nltk.download('punkt')
18
+
19
+ # load packages
20
+ import time
21
+ import random
22
+ import yaml
23
+ from munch import Munch
24
+ import numpy as np
25
+ import torch
26
+ from torch import nn
27
+ import torch.nn.functional as F
28
+ import torchaudio
29
+ import librosa
30
+ from nltk.tokenize import word_tokenize
31
+
32
+ from models import *
33
+ from utils import *
34
+ from text_utils import TextCleaner
35
+ textclenaer = TextCleaner()
36
+
37
+
38
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
39
+
40
+ to_mel = torchaudio.transforms.MelSpectrogram(
41
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
42
+ mean, std = -4, 4
43
+
44
+ def length_to_mask(lengths):
45
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
46
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
47
+ return mask
48
+
49
+ def preprocess(wave):
50
+ wave_tensor = torch.from_numpy(wave).float()
51
+ mel_tensor = to_mel(wave_tensor)
52
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
53
+ return mel_tensor
54
+
55
+ def compute_style(ref_dicts):
56
+ reference_embeddings = {}
57
+ for key, path in ref_dicts.items():
58
+ wave, sr = librosa.load(path, sr=24000)
59
+ audio, index = librosa.effects.trim(wave, top_db=30)
60
+ if sr != 24000:
61
+ audio = librosa.resample(audio, sr, 24000)
62
+ mel_tensor = preprocess(audio).to(device)
63
+
64
+ with torch.no_grad():
65
+ ref = model.style_encoder(mel_tensor.unsqueeze(1))
66
+ reference_embeddings[key] = (ref.squeeze(1), audio)
67
+
68
+ return reference_embeddings
69
+
70
+ # load phonemizer
71
+ # import phonemizer
72
+ # global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore')
73
+
74
+ phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
75
+
76
+
77
+ config = yaml.safe_load(open(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/config.yml'))))
78
+
79
+ # load pretrained ASR model
80
+ ASR_config = config.get('ASR_config', False)
81
+ ASR_path = config.get('ASR_path', False)
82
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
83
+
84
+ # load pretrained F0 model
85
+ F0_path = config.get('F0_path', False)
86
+ pitch_extractor = load_F0_models(F0_path)
87
+
88
+ # load BERT model
89
+ from Utils.PLBERT.util import load_plbert
90
+ BERT_path = config.get('PLBERT_dir', False)
91
+ plbert = load_plbert(BERT_path)
92
+
93
+ model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
94
+ _ = [model[key].eval() for key in model]
95
+ _ = [model[key].to(device) for key in model]
96
+
97
+ # params_whole = torch.load("Models/LJSpeech/epoch_2nd_00100.pth", map_location='cpu')
98
+ params_whole = torch.load(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/epoch_2nd_00100.pth')), map_location='cpu')
99
+ params = params_whole['net']
100
+
101
+ for key in model:
102
+ if key in params:
103
+ print('%s loaded' % key)
104
+ try:
105
+ model[key].load_state_dict(params[key])
106
+ except:
107
+ from collections import OrderedDict
108
+ state_dict = params[key]
109
+ new_state_dict = OrderedDict()
110
+ for k, v in state_dict.items():
111
+ name = k[7:] # remove `module.`
112
+ new_state_dict[name] = v
113
+ # load params
114
+ model[key].load_state_dict(new_state_dict, strict=False)
115
+ # except:
116
+ # _load(params[key], model[key])
117
+ _ = [model[key].eval() for key in model]
118
+
119
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
120
+
121
+ sampler = DiffusionSampler(
122
+ model.diffusion.diffusion,
123
+ sampler=ADPM2Sampler(),
124
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
125
+ clamp=False
126
+ )
127
+
128
+ def inference(text, noise, diffusion_steps=5, embedding_scale=1):
129
+ text = text.strip()
130
+ text = text.replace('"', '')
131
+ ps = phonemizer([text], lang='en_us')
132
+ ps = word_tokenize(ps[0])
133
+ ps = ' '.join(ps)
134
+
135
+ tokens = textclenaer(ps)
136
+ tokens.insert(0, 0)
137
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
138
+
139
+ with torch.no_grad():
140
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
141
+ text_mask = length_to_mask(input_lengths).to(tokens.device)
142
+
143
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
144
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
145
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
146
+
147
+ s_pred = sampler(noise,
148
+ embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
149
+ embedding_scale=embedding_scale).squeeze(0)
150
+
151
+ s = s_pred[:, 128:]
152
+ ref = s_pred[:, :128]
153
+
154
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
155
+
156
+ x, _ = model.predictor.lstm(d)
157
+ duration = model.predictor.duration_proj(x)
158
+ duration = torch.sigmoid(duration).sum(axis=-1)
159
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
160
+
161
+ pred_dur[-1] += 5
162
+
163
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
164
+ c_frame = 0
165
+ for i in range(pred_aln_trg.size(0)):
166
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
167
+ c_frame += int(pred_dur[i].data)
168
+
169
+ # encode prosody
170
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
171
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
172
+ out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),
173
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
174
+
175
+ return out.squeeze().cpu().numpy()
176
+
177
+ def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):
178
+ text = text.strip()
179
+ text = text.replace('"', '')
180
+ ps = phonemizer([text], lang='en_us')
181
+ ps = word_tokenize(ps[0])
182
+ ps = ' '.join(ps)
183
+
184
+ tokens = textclenaer(ps)
185
+ tokens.insert(0, 0)
186
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
187
+
188
+ with torch.no_grad():
189
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
190
+ text_mask = length_to_mask(input_lengths).to(tokens.device)
191
+
192
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
193
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
194
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
195
+
196
+ s_pred = sampler(noise,
197
+ embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
198
+ embedding_scale=embedding_scale).squeeze(0)
199
+
200
+ if s_prev is not None:
201
+ # convex combination of previous and current style
202
+ s_pred = alpha * s_prev + (1 - alpha) * s_pred
203
+
204
+ s = s_pred[:, 128:]
205
+ ref = s_pred[:, :128]
206
+
207
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
208
+
209
+ x, _ = model.predictor.lstm(d)
210
+ duration = model.predictor.duration_proj(x)
211
+ duration = torch.sigmoid(duration).sum(axis=-1)
212
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
213
+
214
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
215
+ c_frame = 0
216
+ for i in range(pred_aln_trg.size(0)):
217
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
218
+ c_frame += int(pred_dur[i].data)
219
+
220
+ # encode prosody
221
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
222
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
223
+ out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),
224
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
225
+
226
+ return out.squeeze().cpu().numpy(), s_pred