ntt123 commited on
Commit
7591e94
1 Parent(s): ffe1f9e

Support generating long clips

Browse files
Files changed (1) hide show
  1. app.py +36 -115
app.py CHANGED
@@ -30,102 +30,11 @@ assert phone_set[0][1:-1] == "SEP"
30
  assert "sil" in phone_set
31
  sil_idx = phone_set.index("sil")
32
 
33
- vietnamese_characters = [
34
- "a",
35
- "à",
36
- "á",
37
- "ả",
38
- "ã",
39
- "ạ",
40
- "ă",
41
- "ằ",
42
- "ắ",
43
- "ẳ",
44
- "ẵ",
45
- "ặ",
46
- "â",
47
- "ầ",
48
- "ấ",
49
- "ẩ",
50
- "ẫ",
51
- "ậ",
52
- "e",
53
- "è",
54
- "é",
55
- "ẻ",
56
- "ẽ",
57
- "ẹ",
58
- "ê",
59
- "ề",
60
- "ế",
61
- "ể",
62
- "ễ",
63
- "ệ",
64
- "i",
65
- "ì",
66
- "í",
67
- "ỉ",
68
- "ĩ",
69
- "ị",
70
- "o",
71
- "ò",
72
- "ó",
73
- "ỏ",
74
- "õ",
75
- "ọ",
76
- "ô",
77
- "ồ",
78
- "ố",
79
- "ổ",
80
- "ỗ",
81
- "ộ",
82
- "ơ",
83
- "ờ",
84
- "ớ",
85
- "ở",
86
- "ỡ",
87
- "ợ",
88
- "u",
89
- "ù",
90
- "ú",
91
- "ủ",
92
- "ũ",
93
- "ụ",
94
- "ư",
95
- "ừ",
96
- "ứ",
97
- "ử",
98
- "ữ",
99
- "ự",
100
- "y",
101
- "ỳ",
102
- "ý",
103
- "ỷ",
104
- "ỹ",
105
- "ỵ",
106
- "b",
107
- "c",
108
- "d",
109
- "đ",
110
- "g",
111
- "h",
112
- "k",
113
- "l",
114
- "m",
115
- "n",
116
- "p",
117
- "q",
118
- "r",
119
- "s",
120
- "t",
121
- "v",
122
- "x",
123
- ]
124
- alphabet = "".join(vietnamese_characters)
125
  space_re = regex.compile(r"\s+")
126
  number_re = regex.compile("([0-9]+)")
127
  digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
128
  num_re = regex.compile(r"([0-9.,]*[0-9])")
 
129
  keep_text_and_num_re = regex.compile(rf"[^\s{alphabet}.,0-9]")
130
  keep_text_re = regex.compile(rf"[^\s{alphabet}]")
131
 
@@ -225,7 +134,7 @@ def text_to_phone_idx(text):
225
  return tokens
226
 
227
 
228
- def text_to_speech(text):
229
  # prevent too long text
230
  if len(text) > 500:
231
  text = text[:500]
@@ -237,9 +146,6 @@ def text_to_speech(text):
237
  }
238
 
239
  # predict phoneme duration
240
- duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device)
241
- duration_net.load_state_dict(torch.load(duration_model_path, map_location=device))
242
- duration_net = duration_net.eval()
243
  phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device)
244
  phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device)
245
  with torch.inference_mode():
@@ -249,24 +155,7 @@ def text_to_speech(text):
249
  )
250
  phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
251
 
252
- generator = SynthesizerTrn(
253
- hps.data.vocab_size,
254
- hps.data.filter_length // 2 + 1,
255
- hps.train.segment_size // hps.data.hop_length,
256
- **vars(hps.model),
257
- ).to(device)
258
- del generator.enc_q
259
- ckpt = torch.load(lightspeed_model_path, map_location=device)
260
- params = {}
261
- for k, v in ckpt["net_g"].items():
262
- k = k[7:] if k.startswith("module.") else k
263
- params[k] = v
264
- generator.load_state_dict(params, strict=False)
265
- del ckpt, params
266
- generator = generator.eval()
267
- # mininum 1 frame for each phone
268
- # phone_duration = torch.clamp_min(phone_duration, hps.data.hop_length * 1000 / hps.data.sampling_rate)
269
- # phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
270
  end_time = torch.cumsum(phone_duration, dim=-1)
271
  start_time = end_time - phone_duration
272
  start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
@@ -285,8 +174,40 @@ def text_to_speech(text):
285
  return (wave * (2**15)).astype(np.int16)
286
 
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  def speak(text):
289
- y = text_to_speech(text)
 
 
 
 
 
 
 
 
 
 
290
  return hps.data.sampling_rate, y
291
 
292
 
 
30
  assert "sil" in phone_set
31
  sil_idx = phone_set.index("sil")
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  space_re = regex.compile(r"\s+")
34
  number_re = regex.compile("([0-9]+)")
35
  digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
36
  num_re = regex.compile(r"([0-9.,]*[0-9])")
37
+ alphabet = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx"
38
  keep_text_and_num_re = regex.compile(rf"[^\s{alphabet}.,0-9]")
39
  keep_text_re = regex.compile(rf"[^\s{alphabet}]")
40
 
 
134
  return tokens
135
 
136
 
137
+ def text_to_speech(duration_net, generator, text):
138
  # prevent too long text
139
  if len(text) > 500:
140
  text = text[:500]
 
146
  }
147
 
148
  # predict phoneme duration
 
 
 
149
  phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device)
150
  phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device)
151
  with torch.inference_mode():
 
155
  )
156
  phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
157
 
158
+ # generate waveform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  end_time = torch.cumsum(phone_duration, dim=-1)
160
  start_time = end_time - phone_duration
161
  start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
 
174
  return (wave * (2**15)).astype(np.int16)
175
 
176
 
177
+ def load_models():
178
+ duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device)
179
+ duration_net.load_state_dict(torch.load(duration_model_path, map_location=device))
180
+ duration_net = duration_net.eval()
181
+ generator = SynthesizerTrn(
182
+ hps.data.vocab_size,
183
+ hps.data.filter_length // 2 + 1,
184
+ hps.train.segment_size // hps.data.hop_length,
185
+ **vars(hps.model),
186
+ ).to(device)
187
+ del generator.enc_q
188
+ ckpt = torch.load(lightspeed_model_path, map_location=device)
189
+ params = {}
190
+ for k, v in ckpt["net_g"].items():
191
+ k = k[7:] if k.startswith("module.") else k
192
+ params[k] = v
193
+ generator.load_state_dict(params, strict=False)
194
+ del ckpt, params
195
+ generator = generator.eval()
196
+ return duration_net, generator
197
+
198
+
199
  def speak(text):
200
+ duration_net, generator = load_models()
201
+ paragraphs = text.split("\n")
202
+ clips = [] # list of audio clips
203
+ # silence = np.zeros(hps.data.sampling_rate // 4)
204
+ for paragraph in paragraphs:
205
+ paragraph = paragraph.strip()
206
+ if paragraph == "":
207
+ continue
208
+ clips.append(text_to_speech(duration_net, generator, paragraph))
209
+ # clips.append(silence)
210
+ y = np.concatenate(clips)
211
  return hps.data.sampling_rate, y
212
 
213