Spaces:
Running
on
Zero
Running
on
Zero
mrfakename
commited on
Commit
•
1d03890
1
Parent(s):
4064aae
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- app.py +1 -0
- inference-cli.py +4 -3
- model/cfm.py +4 -2
- model/modules.py +1 -0
- model/trainer.py +7 -2
- model/utils.py +11 -11
- requirements.txt +1 -0
- speech_edit.py +4 -3
app.py
CHANGED
@@ -173,6 +173,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
|
|
173 |
sway_sampling_coef=sway_sampling_coef,
|
174 |
)
|
175 |
|
|
|
176 |
generated = generated[:, ref_audio_len:, :]
|
177 |
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
178 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
|
|
173 |
sway_sampling_coef=sway_sampling_coef,
|
174 |
)
|
175 |
|
176 |
+
generated = generated.to(torch.float32)
|
177 |
generated = generated[:, ref_audio_len:, :]
|
178 |
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
179 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
inference-cli.py
CHANGED
@@ -145,9 +145,9 @@ def load_model(model_cls, model_cfg, ckpt_path,file_vocab):
|
|
145 |
else:
|
146 |
tokenizer="custom"
|
147 |
|
148 |
-
print("\nvocab : ",vocab_file,tokenizer)
|
149 |
-
print("tokenizer : ",tokenizer)
|
150 |
-
print("model : ",ckpt_path,"\n")
|
151 |
|
152 |
vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
|
153 |
model = CFM(
|
@@ -265,6 +265,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_voca
|
|
265 |
sway_sampling_coef=sway_sampling_coef,
|
266 |
)
|
267 |
|
|
|
268 |
generated = generated[:, ref_audio_len:, :]
|
269 |
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
270 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
|
|
145 |
else:
|
146 |
tokenizer="custom"
|
147 |
|
148 |
+
print("\nvocab : ", vocab_file,tokenizer)
|
149 |
+
print("tokenizer : ", tokenizer)
|
150 |
+
print("model : ", ckpt_path,"\n")
|
151 |
|
152 |
vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
|
153 |
model = CFM(
|
|
|
265 |
sway_sampling_coef=sway_sampling_coef,
|
266 |
)
|
267 |
|
268 |
+
generated = generated.to(torch.float32)
|
269 |
generated = generated[:, ref_audio_len:, :]
|
270 |
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
271 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
model/cfm.py
CHANGED
@@ -99,6 +99,8 @@ class CFM(nn.Module):
|
|
99 |
):
|
100 |
self.eval()
|
101 |
|
|
|
|
|
102 |
# raw wave
|
103 |
|
104 |
if cond.ndim == 2:
|
@@ -175,7 +177,7 @@ class CFM(nn.Module):
|
|
175 |
for dur in duration:
|
176 |
if exists(seed):
|
177 |
torch.manual_seed(seed)
|
178 |
-
y0.append(torch.randn(dur, self.num_channels, device = self.device))
|
179 |
y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
|
180 |
|
181 |
t_start = 0
|
@@ -186,7 +188,7 @@ class CFM(nn.Module):
|
|
186 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
187 |
steps = int(steps * (1 - t_start))
|
188 |
|
189 |
-
t = torch.linspace(t_start, 1, steps, device = self.device)
|
190 |
if sway_sampling_coef is not None:
|
191 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
192 |
|
|
|
99 |
):
|
100 |
self.eval()
|
101 |
|
102 |
+
cond = cond.half()
|
103 |
+
|
104 |
# raw wave
|
105 |
|
106 |
if cond.ndim == 2:
|
|
|
177 |
for dur in duration:
|
178 |
if exists(seed):
|
179 |
torch.manual_seed(seed)
|
180 |
+
y0.append(torch.randn(dur, self.num_channels, device = self.device, dtype=step_cond.dtype))
|
181 |
y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
|
182 |
|
183 |
t_start = 0
|
|
|
188 |
y0 = (1 - t_start) * y0 + t_start * test_cond
|
189 |
steps = int(steps * (1 - t_start))
|
190 |
|
191 |
+
t = torch.linspace(t_start, 1, steps, device = self.device, dtype=step_cond.dtype)
|
192 |
if sway_sampling_coef is not None:
|
193 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
194 |
|
model/modules.py
CHANGED
@@ -571,5 +571,6 @@ class TimestepEmbedding(nn.Module):
|
|
571 |
|
572 |
def forward(self, timestep: float['b']):
|
573 |
time_hidden = self.time_embed(timestep)
|
|
|
574 |
time = self.time_mlp(time_hidden) # b d
|
575 |
return time
|
|
|
571 |
|
572 |
def forward(self, timestep: float['b']):
|
573 |
time_hidden = self.time_embed(timestep)
|
574 |
+
time_hidden = time_hidden.to(timestep.dtype)
|
575 |
time = self.time_mlp(time_hidden) # b d
|
576 |
return time
|
model/trainer.py
CHANGED
@@ -45,7 +45,8 @@ class Trainer:
|
|
45 |
wandb_resume_id: str = None,
|
46 |
last_per_steps = None,
|
47 |
accelerate_kwargs: dict = dict(),
|
48 |
-
ema_kwargs: dict = dict()
|
|
|
49 |
):
|
50 |
|
51 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
|
@@ -107,7 +108,11 @@ class Trainer:
|
|
107 |
|
108 |
self.duration_predictor = duration_predictor
|
109 |
|
110 |
-
|
|
|
|
|
|
|
|
|
111 |
self.model, self.optimizer = self.accelerator.prepare(
|
112 |
self.model, self.optimizer
|
113 |
)
|
|
|
45 |
wandb_resume_id: str = None,
|
46 |
last_per_steps = None,
|
47 |
accelerate_kwargs: dict = dict(),
|
48 |
+
ema_kwargs: dict = dict(),
|
49 |
+
bnb_optimizer: bool = False,
|
50 |
):
|
51 |
|
52 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
|
|
|
108 |
|
109 |
self.duration_predictor = duration_predictor
|
110 |
|
111 |
+
if bnb_optimizer:
|
112 |
+
import bitsandbytes as bnb
|
113 |
+
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
114 |
+
else:
|
115 |
+
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
116 |
self.model, self.optimizer = self.accelerator.prepare(
|
117 |
self.model, self.optimizer
|
118 |
)
|
model/utils.py
CHANGED
@@ -557,23 +557,23 @@ def repetition_found(text, length = 2, tolerance = 10):
|
|
557 |
# load model checkpoint for inference
|
558 |
|
559 |
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
560 |
-
|
561 |
|
562 |
ckpt_type = ckpt_path.split(".")[-1]
|
563 |
if ckpt_type == "safetensors":
|
564 |
from safetensors.torch import load_file
|
565 |
-
checkpoint = load_file(ckpt_path
|
566 |
else:
|
567 |
-
checkpoint = torch.load(ckpt_path, weights_only=True
|
568 |
|
569 |
-
if use_ema
|
570 |
-
ema_model = EMA(model, include_online_model = False).to(device)
|
571 |
if ckpt_type == "safetensors":
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
ema_model.copy_params_from_ema_to_model()
|
576 |
else:
|
|
|
|
|
577 |
model.load_state_dict(checkpoint['model_state_dict'])
|
578 |
-
|
579 |
-
return model
|
|
|
557 |
# load model checkpoint for inference
|
558 |
|
559 |
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
560 |
+
model = model.half()
|
561 |
|
562 |
ckpt_type = ckpt_path.split(".")[-1]
|
563 |
if ckpt_type == "safetensors":
|
564 |
from safetensors.torch import load_file
|
565 |
+
checkpoint = load_file(ckpt_path)
|
566 |
else:
|
567 |
+
checkpoint = torch.load(ckpt_path, weights_only=True)
|
568 |
|
569 |
+
if use_ema:
|
|
|
570 |
if ckpt_type == "safetensors":
|
571 |
+
checkpoint = {'ema_model_state_dict': checkpoint}
|
572 |
+
checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
|
573 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
574 |
else:
|
575 |
+
if ckpt_type == "safetensors":
|
576 |
+
checkpoint = {'model_state_dict': checkpoint}
|
577 |
model.load_state_dict(checkpoint['model_state_dict'])
|
578 |
+
|
579 |
+
return model.to(device)
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
accelerate>=0.33.0
|
|
|
2 |
cached_path
|
3 |
click
|
4 |
datasets
|
|
|
1 |
accelerate>=0.33.0
|
2 |
+
bitsandbytes>0.37.0
|
3 |
cached_path
|
4 |
click
|
5 |
datasets
|
speech_edit.py
CHANGED
@@ -49,7 +49,7 @@ elif exp_name == "E2TTS_Base":
|
|
49 |
model_cls = UNetT
|
50 |
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
51 |
|
52 |
-
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.
|
53 |
output_dir = "tests"
|
54 |
|
55 |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
@@ -172,12 +172,13 @@ with torch.inference_mode():
|
|
172 |
print(f"Generated mel: {generated.shape}")
|
173 |
|
174 |
# Final result
|
|
|
175 |
generated = generated[:, ref_audio_len:, :]
|
176 |
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
|
177 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
178 |
if rms < target_rms:
|
179 |
generated_wave = generated_wave * rms / target_rms
|
180 |
|
181 |
-
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/
|
182 |
-
torchaudio.save(f"{output_dir}/
|
183 |
print(f"Generated wav: {generated_wave.shape}")
|
|
|
49 |
model_cls = UNetT
|
50 |
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
51 |
|
52 |
+
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
53 |
output_dir = "tests"
|
54 |
|
55 |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
|
|
172 |
print(f"Generated mel: {generated.shape}")
|
173 |
|
174 |
# Final result
|
175 |
+
generated = generated.to(torch.float32)
|
176 |
generated = generated[:, ref_audio_len:, :]
|
177 |
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
|
178 |
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
179 |
if rms < target_rms:
|
180 |
generated_wave = generated_wave * rms / target_rms
|
181 |
|
182 |
+
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
183 |
+
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
|
184 |
print(f"Generated wav: {generated_wave.shape}")
|