mrfakename commited on
Commit
d24a68b
1 Parent(s): cea02d8

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (3) hide show
  1. app.py +35 -271
  2. inference-cli.py +33 -287
  3. model/utils_infer.py +306 -0
app.py CHANGED
@@ -1,22 +1,25 @@
1
  import re
2
- import torch
3
- import torchaudio
 
4
  import gradio as gr
5
  import numpy as np
6
- import tempfile
7
- from vocos import Vocos
8
- from pydub import AudioSegment, silence
9
- from model import CFM, UNetT, DiT, MMDiT
10
  from cached_path import cached_path
 
 
 
11
  from model.utils import (
12
- load_checkpoint,
13
- get_tokenizer,
14
- convert_char_to_pinyin,
15
  save_spectrogram,
16
  )
17
- from transformers import pipeline
18
- import click
19
- import soundfile as sf
 
 
 
 
20
 
21
  try:
22
  import spaces
@@ -30,282 +33,47 @@ def gpu_decorator(func):
30
  else:
31
  return func
32
 
33
- device = (
34
- "cuda"
35
- if torch.cuda.is_available()
36
- else "mps" if torch.backends.mps.is_available() else "cpu"
37
- )
38
-
39
- print(f"Using {device} device")
40
-
41
- pipe = pipeline(
42
- "automatic-speech-recognition",
43
- model="openai/whisper-large-v3-turbo",
44
- torch_dtype=torch.float16,
45
- device=device,
46
- )
47
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
48
-
49
- # --------------------- Settings -------------------- #
50
-
51
- target_sample_rate = 24000
52
- n_mel_channels = 100
53
- hop_length = 256
54
- target_rms = 0.1
55
- nfe_step = 32 # 16, 32
56
- cfg_strength = 2.0
57
- ode_method = "euler"
58
- sway_sampling_coef = -1.0
59
- speed = 1.0
60
- fix_duration = None
61
-
62
-
63
- def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
64
- ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
65
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
66
- vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
67
- model = CFM(
68
- transformer=model_cls(
69
- **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
70
- ),
71
- mel_spec_kwargs=dict(
72
- target_sample_rate=target_sample_rate,
73
- n_mel_channels=n_mel_channels,
74
- hop_length=hop_length,
75
- ),
76
- odeint_kwargs=dict(
77
- method=ode_method,
78
- ),
79
- vocab_char_map=vocab_char_map,
80
- ).to(device)
81
-
82
- model = load_checkpoint(model, ckpt_path, device, use_ema = True)
83
-
84
- return model
85
 
86
 
87
  # load models
88
- F5TTS_model_cfg = dict(
89
- dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
90
- )
91
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
92
-
93
- F5TTS_ema_model = load_model(
94
- "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
95
- )
96
- E2TTS_ema_model = load_model(
97
- "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
98
- )
99
-
100
- def chunk_text(text, max_chars=135):
101
- """
102
- Splits the input text into chunks, each with a maximum number of characters.
103
-
104
- Args:
105
- text (str): The text to be split.
106
- max_chars (int): The maximum number of characters per chunk.
107
-
108
- Returns:
109
- List[str]: A list of text chunks.
110
- """
111
- chunks = []
112
- current_chunk = ""
113
- # Split the text into sentences based on punctuation followed by whitespace
114
- sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
115
-
116
- for sentence in sentences:
117
- if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
118
- current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
119
- else:
120
- if current_chunk:
121
- chunks.append(current_chunk.strip())
122
- current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
123
 
124
- if current_chunk:
125
- chunks.append(current_chunk.strip())
126
 
127
- return chunks
128
 
129
  @gpu_decorator
130
- def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration=0.15, progress=gr.Progress()):
131
- if exp_name == "F5-TTS":
 
 
 
132
  ema_model = F5TTS_ema_model
133
- elif exp_name == "E2-TTS":
134
  ema_model = E2TTS_ema_model
135
 
136
- audio, sr = ref_audio
137
- if audio.shape[0] > 1:
138
- audio = torch.mean(audio, dim=0, keepdim=True)
139
-
140
- rms = torch.sqrt(torch.mean(torch.square(audio)))
141
- if rms < target_rms:
142
- audio = audio * target_rms / rms
143
- if sr != target_sample_rate:
144
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
145
- audio = resampler(audio)
146
- audio = audio.to(device)
147
-
148
- generated_waves = []
149
- spectrograms = []
150
-
151
- if len(ref_text[-1].encode('utf-8')) == 1:
152
- ref_text = ref_text + " "
153
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
154
- # Prepare the text
155
- text_list = [ref_text + gen_text]
156
- final_text_list = convert_char_to_pinyin(text_list)
157
-
158
- # Calculate duration
159
- ref_audio_len = audio.shape[-1] // hop_length
160
- ref_text_len = len(ref_text.encode('utf-8'))
161
- gen_text_len = len(gen_text.encode('utf-8'))
162
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
163
-
164
- # inference
165
- with torch.inference_mode():
166
- generated, _ = ema_model.sample(
167
- cond=audio,
168
- text=final_text_list,
169
- duration=duration,
170
- steps=nfe_step,
171
- cfg_strength=cfg_strength,
172
- sway_sampling_coef=sway_sampling_coef,
173
- )
174
-
175
- generated = generated.to(torch.float32)
176
- generated = generated[:, ref_audio_len:, :]
177
- generated_mel_spec = generated.permute(0, 2, 1)
178
- generated_wave = vocos.decode(generated_mel_spec.cpu())
179
- if rms < target_rms:
180
- generated_wave = generated_wave * rms / target_rms
181
-
182
- # wav -> numpy
183
- generated_wave = generated_wave.squeeze().cpu().numpy()
184
-
185
- generated_waves.append(generated_wave)
186
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
187
-
188
- # Combine all generated waves with cross-fading
189
- if cross_fade_duration <= 0:
190
- # Simply concatenate
191
- final_wave = np.concatenate(generated_waves)
192
- else:
193
- final_wave = generated_waves[0]
194
- for i in range(1, len(generated_waves)):
195
- prev_wave = final_wave
196
- next_wave = generated_waves[i]
197
-
198
- # Calculate cross-fade samples, ensuring it does not exceed wave lengths
199
- cross_fade_samples = int(cross_fade_duration * target_sample_rate)
200
- cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
201
-
202
- if cross_fade_samples <= 0:
203
- # No overlap possible, concatenate
204
- final_wave = np.concatenate([prev_wave, next_wave])
205
- continue
206
-
207
- # Overlapping parts
208
- prev_overlap = prev_wave[-cross_fade_samples:]
209
- next_overlap = next_wave[:cross_fade_samples]
210
-
211
- # Fade out and fade in
212
- fade_out = np.linspace(1, 0, cross_fade_samples)
213
- fade_in = np.linspace(0, 1, cross_fade_samples)
214
-
215
- # Cross-faded overlap
216
- cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
217
-
218
- # Combine
219
- new_wave = np.concatenate([
220
- prev_wave[:-cross_fade_samples],
221
- cross_faded_overlap,
222
- next_wave[cross_fade_samples:]
223
- ])
224
-
225
- final_wave = new_wave
226
 
227
  # Remove silence
228
  if remove_silence:
229
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
230
- sf.write(f.name, final_wave, target_sample_rate)
231
- aseg = AudioSegment.from_file(f.name)
232
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
233
- non_silent_wave = AudioSegment.silent(duration=0)
234
- for non_silent_seg in non_silent_segs:
235
- non_silent_wave += non_silent_seg
236
- aseg = non_silent_wave
237
- aseg.export(f.name, format="wav")
238
  final_wave, _ = torchaudio.load(f.name)
239
  final_wave = final_wave.squeeze().cpu().numpy()
240
 
241
- # Create a combined spectrogram
242
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
243
-
244
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
245
  spectrogram_path = tmp_spectrogram.name
246
  save_spectrogram(combined_spectrogram, spectrogram_path)
247
 
248
- return (target_sample_rate, final_wave), spectrogram_path
249
-
250
- @gpu_decorator
251
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, cross_fade_duration=0.15):
252
-
253
- print(gen_text)
254
-
255
- gr.Info("Converting audio...")
256
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
257
- aseg = AudioSegment.from_file(ref_audio_orig)
258
-
259
- non_silent_segs = silence.split_on_silence(
260
- aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
261
- )
262
- non_silent_wave = AudioSegment.silent(duration=0)
263
- for non_silent_seg in non_silent_segs:
264
- non_silent_wave += non_silent_seg
265
- aseg = non_silent_wave
266
-
267
- audio_duration = len(aseg)
268
- if audio_duration > 15000:
269
- gr.Warning("Audio is over 15s, clipping to only first 15s.")
270
- aseg = aseg[:15000]
271
- aseg.export(f.name, format="wav")
272
- ref_audio = f.name
273
-
274
- if not ref_text.strip():
275
- gr.Info("No reference text provided, transcribing reference audio...")
276
- ref_text = pipe(
277
- ref_audio,
278
- chunk_length_s=30,
279
- batch_size=128,
280
- generate_kwargs={"task": "transcribe"},
281
- return_timestamps=False,
282
- )["text"].strip()
283
- gr.Info("Finished transcription")
284
- else:
285
- gr.Info("Using custom reference text...")
286
-
287
- # Add the functionality to ensure it ends with ". "
288
- if not ref_text.endswith(". "):
289
- if ref_text.endswith("."):
290
- ref_text += " "
291
- else:
292
- ref_text += ". "
293
-
294
- audio, sr = torchaudio.load(ref_audio)
295
-
296
- # Use the new chunk_text function to split gen_text
297
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
298
- gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
299
- print('ref_text', ref_text)
300
- for i, batch_text in enumerate(gen_text_batches):
301
- print(f'gen_text {i}', batch_text)
302
-
303
- gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
304
- return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration)
305
 
306
 
307
  @gpu_decorator
308
- def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
309
  # Split the script into speaker blocks
310
  speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
311
  speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
@@ -327,7 +95,7 @@ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name
327
  continue # Skip if the speaker is neither speaker1 nor speaker2
328
 
329
  # Generate audio for this block
330
- audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
331
 
332
  # Convert the generated audio to a numpy array
333
  sr, audio_data = audio
@@ -377,10 +145,6 @@ def parse_speechtypes_text(gen_text):
377
 
378
  return segments
379
 
380
- def update_speed(new_speed):
381
- global speed
382
- speed = new_speed
383
- return f"Speed set to: {speed}"
384
 
385
  with gr.Blocks() as app_credits:
386
  gr.Markdown("""
@@ -413,7 +177,7 @@ with gr.Blocks() as app_tts:
413
  label="Speed",
414
  minimum=0.3,
415
  maximum=2.0,
416
- value=speed,
417
  step=0.1,
418
  info="Adjust the speed of the audio.",
419
  )
@@ -425,7 +189,6 @@ with gr.Blocks() as app_tts:
425
  step=0.01,
426
  info="Set the duration of the cross-fade between audio clips.",
427
  )
428
- speed_slider.change(update_speed, inputs=speed_slider)
429
 
430
  audio_output = gr.Audio(label="Synthesized Audio")
431
  spectrogram_output = gr.Image(label="Spectrogram")
@@ -439,6 +202,7 @@ with gr.Blocks() as app_tts:
439
  model_choice,
440
  remove_silence,
441
  cross_fade_duration_slider,
 
442
  ],
443
  outputs=[audio_output, spectrogram_output],
444
  )
 
1
  import re
2
+ import tempfile
3
+
4
+ import click
5
  import gradio as gr
6
  import numpy as np
7
+ import soundfile as sf
8
+ import torchaudio
 
 
9
  from cached_path import cached_path
10
+ from pydub import AudioSegment
11
+
12
+ from model import DiT, UNetT
13
  from model.utils import (
 
 
 
14
  save_spectrogram,
15
  )
16
+ from model.utils_infer import (
17
+ load_vocoder,
18
+ load_model,
19
+ preprocess_ref_audio_text,
20
+ infer_process,
21
+ remove_silence_for_generated_wav,
22
+ )
23
 
24
  try:
25
  import spaces
 
33
  else:
34
  return func
35
 
36
+ vocos = load_vocoder()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  # load models
40
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
41
+ F5TTS_ema_model = load_model(DiT, F5TTS_model_cfg, str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
44
+ E2TTS_ema_model = load_model(UNetT, E2TTS_model_cfg, str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")))
45
 
 
46
 
47
  @gpu_decorator
48
+ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
49
+
50
+ ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=gr.Info)
51
+
52
+ if model == "F5-TTS":
53
  ema_model = F5TTS_ema_model
54
+ elif model == "E2-TTS":
55
  ema_model = E2TTS_ema_model
56
 
57
+ final_wave, final_sample_rate, combined_spectrogram = infer_process(ref_audio, ref_text, gen_text, ema_model, cross_fade_duration=cross_fade_duration, speed=speed, show_info=gr.Info, progress=gr.Progress())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Remove silence
60
  if remove_silence:
61
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
62
+ sf.write(f.name, final_wave, final_sample_rate)
63
+ remove_silence_for_generated_wav(f.name)
 
 
 
 
 
 
64
  final_wave, _ = torchaudio.load(f.name)
65
  final_wave = final_wave.squeeze().cpu().numpy()
66
 
67
+ # Save the spectrogram
 
 
68
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
69
  spectrogram_path = tmp_spectrogram.name
70
  save_spectrogram(combined_spectrogram, spectrogram_path)
71
 
72
+ return (final_sample_rate, final_wave), spectrogram_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  @gpu_decorator
76
+ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence):
77
  # Split the script into speaker blocks
78
  speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
79
  speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
 
95
  continue # Skip if the speaker is neither speaker1 nor speaker2
96
 
97
  # Generate audio for this block
98
+ audio, _ = infer(ref_audio, ref_text, text, model, remove_silence)
99
 
100
  # Convert the generated audio to a numpy array
101
  sr, audio_data = audio
 
145
 
146
  return segments
147
 
 
 
 
 
148
 
149
  with gr.Blocks() as app_credits:
150
  gr.Markdown("""
 
177
  label="Speed",
178
  minimum=0.3,
179
  maximum=2.0,
180
+ value=1.0,
181
  step=0.1,
182
  info="Adjust the speed of the audio.",
183
  )
 
189
  step=0.01,
190
  info="Set the duration of the cross-fade between audio clips.",
191
  )
 
192
 
193
  audio_output = gr.Audio(label="Synthesized Audio")
194
  spectrogram_output = gr.Image(label="Spectrogram")
 
202
  model_choice,
203
  remove_silence,
204
  cross_fade_duration_slider,
205
+ speed_slider,
206
  ],
207
  outputs=[audio_output, spectrogram_output],
208
  )
inference-cli.py CHANGED
@@ -1,23 +1,22 @@
1
  import argparse
2
  import codecs
3
  import re
4
- import tempfile
5
  from pathlib import Path
6
 
7
  import numpy as np
8
  import soundfile as sf
9
  import tomli
10
- import torch
11
- import torchaudio
12
- import tqdm
13
  from cached_path import cached_path
14
- from pydub import AudioSegment, silence
15
- from transformers import pipeline
16
- from vocos import Vocos
17
 
18
- from model import CFM, DiT, MMDiT, UNetT
19
- from model.utils import (convert_char_to_pinyin, get_tokenizer,
20
- load_checkpoint, save_spectrogram)
 
 
 
 
 
 
21
 
22
  parser = argparse.ArgumentParser(
23
  prog="python3 inference-cli.py",
@@ -104,282 +103,35 @@ wave_path = Path(output_dir)/"out.wav"
104
  spectrogram_path = Path(output_dir)/"out.png"
105
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
106
 
107
- device = (
108
- "cuda"
109
- if torch.cuda.is_available()
110
- else "mps" if torch.backends.mps.is_available() else "cpu"
111
- )
112
-
113
- if args.load_vocoder_from_local:
114
- print(f"Load vocos from local path {vocos_local_path}")
115
- vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
116
- state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
117
- vocos.load_state_dict(state_dict)
118
- vocos.eval()
119
- else:
120
- print("Download Vocos from huggingface charactr/vocos-mel-24khz")
121
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
122
-
123
- print(f"Using {device} device")
124
-
125
- # --------------------- Settings -------------------- #
126
-
127
- target_sample_rate = 24000
128
- n_mel_channels = 100
129
- hop_length = 256
130
- target_rms = 0.1
131
- nfe_step = 32 # 16, 32
132
- cfg_strength = 2.0
133
- ode_method = "euler"
134
- sway_sampling_coef = -1.0
135
- speed = 1.0
136
- # fix_duration = 27 # None or float (duration in seconds)
137
- fix_duration = None
138
 
139
- def load_model(model_cls, model_cfg, ckpt_path,file_vocab):
140
-
141
- if file_vocab=="":
142
- file_vocab="Emilia_ZH_EN"
143
- tokenizer="pinyin"
144
- else:
145
- tokenizer="custom"
146
-
147
- print("\nvocab : ", vocab_file,tokenizer)
148
- print("tokenizer : ", tokenizer)
149
- print("model : ", ckpt_path,"\n")
150
-
151
- vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
152
- model = CFM(
153
- transformer=model_cls(
154
- **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
155
- ),
156
- mel_spec_kwargs=dict(
157
- target_sample_rate=target_sample_rate,
158
- n_mel_channels=n_mel_channels,
159
- hop_length=hop_length,
160
- ),
161
- odeint_kwargs=dict(
162
- method=ode_method,
163
- ),
164
- vocab_char_map=vocab_char_map,
165
- ).to(device)
166
-
167
- model = load_checkpoint(model, ckpt_path, device, use_ema = True)
168
-
169
- return model
170
 
171
  # load models
172
- F5TTS_model_cfg = dict(
173
- dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
174
- )
175
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
176
-
177
  if model == "F5-TTS":
178
-
 
179
  if ckpt_file == "":
180
- repo_name= "F5-TTS"
181
- exp_name = "F5TTS_Base"
182
- ckpt_step= 1200000
183
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
184
-
185
- ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,vocab_file)
186
 
187
  elif model == "E2-TTS":
 
 
188
  if ckpt_file == "":
189
- repo_name= "E2-TTS"
190
- exp_name = "E2TTS_Base"
191
- ckpt_step= 1200000
192
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
193
-
194
- ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,vocab_file)
195
-
196
- asr_pipe = pipeline(
197
- "automatic-speech-recognition",
198
- model="openai/whisper-large-v3-turbo",
199
- torch_dtype=torch.float16,
200
- device=device,
201
- )
202
-
203
- def chunk_text(text, max_chars=135):
204
- """
205
- Splits the input text into chunks, each with a maximum number of characters.
206
- Args:
207
- text (str): The text to be split.
208
- max_chars (int): The maximum number of characters per chunk.
209
- Returns:
210
- List[str]: A list of text chunks.
211
- """
212
- chunks = []
213
- current_chunk = ""
214
- # Split the text into sentences based on punctuation followed by whitespace
215
- sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
216
-
217
- for sentence in sentences:
218
- if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
219
- current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
220
- else:
221
- if current_chunk:
222
- chunks.append(current_chunk.strip())
223
- current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
224
-
225
- if current_chunk:
226
- chunks.append(current_chunk.strip())
227
-
228
- return chunks
229
-
230
- #ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
231
- #if not Path(ckpt_path).exists():
232
- #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
233
-
234
- def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
235
- audio, sr = ref_audio
236
- if audio.shape[0] > 1:
237
- audio = torch.mean(audio, dim=0, keepdim=True)
238
-
239
- rms = torch.sqrt(torch.mean(torch.square(audio)))
240
- if rms < target_rms:
241
- audio = audio * target_rms / rms
242
- if sr != target_sample_rate:
243
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
244
- audio = resampler(audio)
245
- audio = audio.to(device)
246
-
247
- generated_waves = []
248
- spectrograms = []
249
-
250
- if len(ref_text[-1].encode('utf-8')) == 1:
251
- ref_text = ref_text + " "
252
- for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)):
253
- # Prepare the text
254
- text_list = [ref_text + gen_text]
255
- final_text_list = convert_char_to_pinyin(text_list)
256
-
257
- # Calculate duration
258
- ref_audio_len = audio.shape[-1] // hop_length
259
- ref_text_len = len(ref_text.encode('utf-8'))
260
- gen_text_len = len(gen_text.encode('utf-8'))
261
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
262
-
263
- # inference
264
- with torch.inference_mode():
265
- generated, _ = ema_model.sample(
266
- cond=audio,
267
- text=final_text_list,
268
- duration=duration,
269
- steps=nfe_step,
270
- cfg_strength=cfg_strength,
271
- sway_sampling_coef=sway_sampling_coef,
272
- )
273
-
274
- generated = generated.to(torch.float32)
275
- generated = generated[:, ref_audio_len:, :]
276
- generated_mel_spec = generated.permute(0, 2, 1)
277
- generated_wave = vocos.decode(generated_mel_spec.cpu())
278
- if rms < target_rms:
279
- generated_wave = generated_wave * rms / target_rms
280
-
281
- # wav -> numpy
282
- generated_wave = generated_wave.squeeze().cpu().numpy()
283
-
284
- generated_waves.append(generated_wave)
285
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
286
-
287
- # Combine all generated waves with cross-fading
288
- if cross_fade_duration <= 0:
289
- # Simply concatenate
290
- final_wave = np.concatenate(generated_waves)
291
- else:
292
- final_wave = generated_waves[0]
293
- for i in range(1, len(generated_waves)):
294
- prev_wave = final_wave
295
- next_wave = generated_waves[i]
296
-
297
- # Calculate cross-fade samples, ensuring it does not exceed wave lengths
298
- cross_fade_samples = int(cross_fade_duration * target_sample_rate)
299
- cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
300
-
301
- if cross_fade_samples <= 0:
302
- # No overlap possible, concatenate
303
- final_wave = np.concatenate([prev_wave, next_wave])
304
- continue
305
-
306
- # Overlapping parts
307
- prev_overlap = prev_wave[-cross_fade_samples:]
308
- next_overlap = next_wave[:cross_fade_samples]
309
-
310
- # Fade out and fade in
311
- fade_out = np.linspace(1, 0, cross_fade_samples)
312
- fade_in = np.linspace(0, 1, cross_fade_samples)
313
-
314
- # Cross-faded overlap
315
- cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
316
-
317
- # Combine
318
- new_wave = np.concatenate([
319
- prev_wave[:-cross_fade_samples],
320
- cross_faded_overlap,
321
- next_wave[cross_fade_samples:]
322
- ])
323
-
324
- final_wave = new_wave
325
-
326
- # Create a combined spectrogram
327
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
328
-
329
- return final_wave, combined_spectrogram
330
-
331
- def process_voice(ref_audio_orig, ref_text):
332
- print("Converting", ref_audio_orig)
333
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
334
- aseg = AudioSegment.from_file(ref_audio_orig)
335
-
336
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
337
- non_silent_wave = AudioSegment.silent(duration=0)
338
- for non_silent_seg in non_silent_segs:
339
- non_silent_wave += non_silent_seg
340
- aseg = non_silent_wave
341
-
342
- audio_duration = len(aseg)
343
- if audio_duration > 15000:
344
- print("Audio is over 15s, clipping to only first 15s.")
345
- aseg = aseg[:15000]
346
- aseg.export(f.name, format="wav")
347
- ref_audio = f.name
348
-
349
- if not ref_text.strip():
350
- print("No reference text provided, transcribing reference audio...")
351
- ref_text = asr_pipe(
352
- ref_audio,
353
- chunk_length_s=30,
354
- batch_size=128,
355
- generate_kwargs={"task": "transcribe"},
356
- return_timestamps=False,
357
- )["text"].strip()
358
- print("Finished transcription")
359
- else:
360
- print("Using custom reference text...")
361
- return ref_audio, ref_text
362
-
363
- def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
364
- # Add the functionality to ensure it ends with ". "
365
- if not ref_text.endswith(". ") and not ref_text.endswith("。"):
366
- if ref_text.endswith("."):
367
- ref_text += " "
368
- else:
369
- ref_text += ". "
370
-
371
- # Split the input text into batches
372
- audio, sr = torchaudio.load(ref_audio)
373
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
374
- gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
375
- for i, gen_text in enumerate(gen_text_batches):
376
- print(f'gen_text {i}', gen_text)
377
-
378
- print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
379
- return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
380
 
381
 
382
- def process(ref_audio, ref_text, text_gen, model, remove_silence):
383
  main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
384
  if "voices" not in config:
385
  voices = {"main": main_voice}
@@ -387,7 +139,7 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
387
  voices = config["voices"]
388
  voices["main"] = main_voice
389
  for voice in voices:
390
- voices[voice]['ref_audio'], voices[voice]['ref_text'] = process_voice(voices[voice]['ref_audio'], voices[voice]['ref_text'])
391
  print("Voice:", voice)
392
  print("Ref_audio:", voices[voice]['ref_audio'])
393
  print("Ref_text:", voices[voice]['ref_text'])
@@ -407,23 +159,17 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
407
  ref_audio = voices[voice]['ref_audio']
408
  ref_text = voices[voice]['ref_text']
409
  print(f"Voice: {voice}")
410
- audio, spectragram = infer(ref_audio, ref_text, gen_text, model,remove_silence)
411
  generated_audio_segments.append(audio)
412
 
413
  if generated_audio_segments:
414
  final_wave = np.concatenate(generated_audio_segments)
415
  with open(wave_path, "wb") as f:
416
- sf.write(f.name, final_wave, target_sample_rate)
417
  # Remove silence
418
  if remove_silence:
419
- aseg = AudioSegment.from_file(f.name)
420
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
421
- non_silent_wave = AudioSegment.silent(duration=0)
422
- for non_silent_seg in non_silent_segs:
423
- non_silent_wave += non_silent_seg
424
- aseg = non_silent_wave
425
- aseg.export(f.name, format="wav")
426
  print(f.name)
427
 
428
 
429
- process(ref_audio, ref_text, gen_text, model, remove_silence)
 
1
  import argparse
2
  import codecs
3
  import re
 
4
  from pathlib import Path
5
 
6
  import numpy as np
7
  import soundfile as sf
8
  import tomli
 
 
 
9
  from cached_path import cached_path
 
 
 
10
 
11
+ from model import DiT, UNetT
12
+ from model.utils_infer import (
13
+ load_vocoder,
14
+ load_model,
15
+ preprocess_ref_audio_text,
16
+ infer_process,
17
+ remove_silence_for_generated_wav,
18
+ )
19
+
20
 
21
  parser = argparse.ArgumentParser(
22
  prog="python3 inference-cli.py",
 
103
  spectrogram_path = Path(output_dir)/"out.png"
104
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
105
 
106
+ vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # load models
 
 
 
 
 
110
  if model == "F5-TTS":
111
+ model_cls = DiT
112
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
113
  if ckpt_file == "":
114
+ repo_name= "F5-TTS"
115
+ exp_name = "F5TTS_Base"
116
+ ckpt_step= 1200000
117
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
118
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
 
119
 
120
  elif model == "E2-TTS":
121
+ model_cls = UNetT
122
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
123
  if ckpt_file == "":
124
+ repo_name= "E2-TTS"
125
+ exp_name = "E2TTS_Base"
126
+ ckpt_step= 1200000
127
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
128
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
129
+
130
+ print(f"Using {model}...")
131
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
+ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
135
  main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
136
  if "voices" not in config:
137
  voices = {"main": main_voice}
 
139
  voices = config["voices"]
140
  voices["main"] = main_voice
141
  for voice in voices:
142
+ voices[voice]['ref_audio'], voices[voice]['ref_text'] = preprocess_ref_audio_text(voices[voice]['ref_audio'], voices[voice]['ref_text'])
143
  print("Voice:", voice)
144
  print("Ref_audio:", voices[voice]['ref_audio'])
145
  print("Ref_text:", voices[voice]['ref_text'])
 
159
  ref_audio = voices[voice]['ref_audio']
160
  ref_text = voices[voice]['ref_text']
161
  print(f"Voice: {voice}")
162
+ audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
163
  generated_audio_segments.append(audio)
164
 
165
  if generated_audio_segments:
166
  final_wave = np.concatenate(generated_audio_segments)
167
  with open(wave_path, "wb") as f:
168
+ sf.write(f.name, final_wave, final_sample_rate)
169
  # Remove silence
170
  if remove_silence:
171
+ remove_silence_for_generated_wav(f.name)
 
 
 
 
 
 
172
  print(f.name)
173
 
174
 
175
+ main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
model/utils_infer.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A unified script for inference process
2
+ # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
+
4
+ import re
5
+ import tempfile
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torchaudio
10
+ import tqdm
11
+ from pydub import AudioSegment, silence
12
+ from transformers import pipeline
13
+ from vocos import Vocos
14
+
15
+ from model import CFM
16
+ from model.utils import (
17
+ load_checkpoint,
18
+ get_tokenizer,
19
+ convert_char_to_pinyin,
20
+ )
21
+
22
+ device = (
23
+ "cuda"
24
+ if torch.cuda.is_available()
25
+ else "mps" if torch.backends.mps.is_available() else "cpu"
26
+ )
27
+ print(f"Using {device} device")
28
+
29
+ asr_pipe = pipeline(
30
+ "automatic-speech-recognition",
31
+ model="openai/whisper-large-v3-turbo",
32
+ torch_dtype=torch.float16,
33
+ device=device,
34
+ )
35
+
36
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
37
+
38
+
39
+ # -----------------------------------------
40
+
41
+ target_sample_rate = 24000
42
+ n_mel_channels = 100
43
+ hop_length = 256
44
+ target_rms = 0.1
45
+ nfe_step = 32 # 16, 32
46
+ cfg_strength = 2.0
47
+ ode_method = "euler"
48
+ sway_sampling_coef = -1.0
49
+ speed = 1.0
50
+ fix_duration = None
51
+
52
+ # -----------------------------------------
53
+
54
+
55
+ # chunk text into smaller pieces
56
+
57
+ def chunk_text(text, max_chars=135):
58
+ """
59
+ Splits the input text into chunks, each with a maximum number of characters.
60
+
61
+ Args:
62
+ text (str): The text to be split.
63
+ max_chars (int): The maximum number of characters per chunk.
64
+
65
+ Returns:
66
+ List[str]: A list of text chunks.
67
+ """
68
+ chunks = []
69
+ current_chunk = ""
70
+ # Split the text into sentences based on punctuation followed by whitespace
71
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
72
+
73
+ for sentence in sentences:
74
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
75
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
76
+ else:
77
+ if current_chunk:
78
+ chunks.append(current_chunk.strip())
79
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
80
+
81
+ if current_chunk:
82
+ chunks.append(current_chunk.strip())
83
+
84
+ return chunks
85
+
86
+
87
+ # load vocoder
88
+
89
+ def load_vocoder(is_local=False, local_path=""):
90
+ if is_local:
91
+ print(f"Load vocos from local path {local_path}")
92
+ vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
93
+ state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location=device)
94
+ vocos.load_state_dict(state_dict)
95
+ vocos.eval()
96
+ else:
97
+ print("Download Vocos from huggingface charactr/vocos-mel-24khz")
98
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
99
+ return vocos
100
+
101
+
102
+ # load model for inference
103
+
104
+ def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
105
+
106
+ if vocab_file == "":
107
+ vocab_file = "Emilia_ZH_EN"
108
+ tokenizer = "pinyin"
109
+ else:
110
+ tokenizer = "custom"
111
+
112
+ print("\nvocab : ", vocab_file, tokenizer)
113
+ print("tokenizer : ", tokenizer)
114
+ print("model : ", ckpt_path,"\n")
115
+
116
+ vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
117
+ model = CFM(
118
+ transformer=model_cls(
119
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
120
+ ),
121
+ mel_spec_kwargs=dict(
122
+ target_sample_rate=target_sample_rate,
123
+ n_mel_channels=n_mel_channels,
124
+ hop_length=hop_length,
125
+ ),
126
+ odeint_kwargs=dict(
127
+ method=ode_method,
128
+ ),
129
+ vocab_char_map=vocab_char_map,
130
+ ).to(device)
131
+
132
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
133
+
134
+ return model
135
+
136
+
137
+ # preprocess reference audio and text
138
+
139
+ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
140
+ show_info("Converting audio...")
141
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
142
+ aseg = AudioSegment.from_file(ref_audio_orig)
143
+
144
+ non_silent_segs = silence.split_on_silence(
145
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
146
+ )
147
+ non_silent_wave = AudioSegment.silent(duration=0)
148
+ for non_silent_seg in non_silent_segs:
149
+ non_silent_wave += non_silent_seg
150
+ aseg = non_silent_wave
151
+
152
+ audio_duration = len(aseg)
153
+ if audio_duration > 15000:
154
+ show_info("Audio is over 15s, clipping to only first 15s.")
155
+ aseg = aseg[:15000]
156
+ aseg.export(f.name, format="wav")
157
+ ref_audio = f.name
158
+
159
+ if not ref_text.strip():
160
+ show_info("No reference text provided, transcribing reference audio...")
161
+ ref_text = asr_pipe(
162
+ ref_audio,
163
+ chunk_length_s=30,
164
+ batch_size=128,
165
+ generate_kwargs={"task": "transcribe"},
166
+ return_timestamps=False,
167
+ )["text"].strip()
168
+ show_info("Finished transcription")
169
+ else:
170
+ show_info("Using custom reference text...")
171
+
172
+ # Add the functionality to ensure it ends with ". "
173
+ if not ref_text.endswith(". ") and not ref_text.endswith("。"):
174
+ if ref_text.endswith("."):
175
+ ref_text += " "
176
+ else:
177
+ ref_text += ". "
178
+
179
+ return ref_audio, ref_text
180
+
181
+
182
+ # infer process: chunk text -> infer batches [i.e. infer_batch_process()]
183
+
184
+ def infer_process(ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm):
185
+
186
+ # Split the input text into batches
187
+ audio, sr = torchaudio.load(ref_audio)
188
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
189
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
190
+ for i, gen_text in enumerate(gen_text_batches):
191
+ print(f'gen_text {i}', gen_text)
192
+
193
+ show_info(f"Generating audio in {len(gen_text_batches)} batches...")
194
+ return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress)
195
+
196
+
197
+ # infer batches
198
+
199
+ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm):
200
+ audio, sr = ref_audio
201
+ if audio.shape[0] > 1:
202
+ audio = torch.mean(audio, dim=0, keepdim=True)
203
+
204
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
205
+ if rms < target_rms:
206
+ audio = audio * target_rms / rms
207
+ if sr != target_sample_rate:
208
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
209
+ audio = resampler(audio)
210
+ audio = audio.to(device)
211
+
212
+ generated_waves = []
213
+ spectrograms = []
214
+
215
+ if len(ref_text[-1].encode('utf-8')) == 1:
216
+ ref_text = ref_text + " "
217
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
218
+ # Prepare the text
219
+ text_list = [ref_text + gen_text]
220
+ final_text_list = convert_char_to_pinyin(text_list)
221
+
222
+ # Calculate duration
223
+ ref_audio_len = audio.shape[-1] // hop_length
224
+ ref_text_len = len(ref_text.encode('utf-8'))
225
+ gen_text_len = len(gen_text.encode('utf-8'))
226
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
227
+
228
+ # inference
229
+ with torch.inference_mode():
230
+ generated, _ = model_obj.sample(
231
+ cond=audio,
232
+ text=final_text_list,
233
+ duration=duration,
234
+ steps=nfe_step,
235
+ cfg_strength=cfg_strength,
236
+ sway_sampling_coef=sway_sampling_coef,
237
+ )
238
+
239
+ generated = generated.to(torch.float32)
240
+ generated = generated[:, ref_audio_len:, :]
241
+ generated_mel_spec = generated.permute(0, 2, 1)
242
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
243
+ if rms < target_rms:
244
+ generated_wave = generated_wave * rms / target_rms
245
+
246
+ # wav -> numpy
247
+ generated_wave = generated_wave.squeeze().cpu().numpy()
248
+
249
+ generated_waves.append(generated_wave)
250
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
251
+
252
+ # Combine all generated waves with cross-fading
253
+ if cross_fade_duration <= 0:
254
+ # Simply concatenate
255
+ final_wave = np.concatenate(generated_waves)
256
+ else:
257
+ final_wave = generated_waves[0]
258
+ for i in range(1, len(generated_waves)):
259
+ prev_wave = final_wave
260
+ next_wave = generated_waves[i]
261
+
262
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
263
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
264
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
265
+
266
+ if cross_fade_samples <= 0:
267
+ # No overlap possible, concatenate
268
+ final_wave = np.concatenate([prev_wave, next_wave])
269
+ continue
270
+
271
+ # Overlapping parts
272
+ prev_overlap = prev_wave[-cross_fade_samples:]
273
+ next_overlap = next_wave[:cross_fade_samples]
274
+
275
+ # Fade out and fade in
276
+ fade_out = np.linspace(1, 0, cross_fade_samples)
277
+ fade_in = np.linspace(0, 1, cross_fade_samples)
278
+
279
+ # Cross-faded overlap
280
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
281
+
282
+ # Combine
283
+ new_wave = np.concatenate([
284
+ prev_wave[:-cross_fade_samples],
285
+ cross_faded_overlap,
286
+ next_wave[cross_fade_samples:]
287
+ ])
288
+
289
+ final_wave = new_wave
290
+
291
+ # Create a combined spectrogram
292
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
293
+
294
+ return final_wave, target_sample_rate, combined_spectrogram
295
+
296
+
297
+ # remove silence from generated wav
298
+
299
+ def remove_silence_for_generated_wav(filename):
300
+ aseg = AudioSegment.from_file(filename)
301
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
302
+ non_silent_wave = AudioSegment.silent(duration=0)
303
+ for non_silent_seg in non_silent_segs:
304
+ non_silent_wave += non_silent_seg
305
+ aseg = non_silent_wave
306
+ aseg.export(filename, format="wav")