mrfakename commited on
Commit
1d03890
1 Parent(s): 4064aae

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (8) hide show
  1. app.py +1 -0
  2. inference-cli.py +4 -3
  3. model/cfm.py +4 -2
  4. model/modules.py +1 -0
  5. model/trainer.py +7 -2
  6. model/utils.py +11 -11
  7. requirements.txt +1 -0
  8. speech_edit.py +4 -3
app.py CHANGED
@@ -173,6 +173,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
173
  sway_sampling_coef=sway_sampling_coef,
174
  )
175
 
 
176
  generated = generated[:, ref_audio_len:, :]
177
  generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
178
  generated_wave = vocos.decode(generated_mel_spec.cpu())
 
173
  sway_sampling_coef=sway_sampling_coef,
174
  )
175
 
176
+ generated = generated.to(torch.float32)
177
  generated = generated[:, ref_audio_len:, :]
178
  generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
179
  generated_wave = vocos.decode(generated_mel_spec.cpu())
inference-cli.py CHANGED
@@ -145,9 +145,9 @@ def load_model(model_cls, model_cfg, ckpt_path,file_vocab):
145
  else:
146
  tokenizer="custom"
147
 
148
- print("\nvocab : ",vocab_file,tokenizer)
149
- print("tokenizer : ",tokenizer)
150
- print("model : ",ckpt_path,"\n")
151
 
152
  vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
153
  model = CFM(
@@ -265,6 +265,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_voca
265
  sway_sampling_coef=sway_sampling_coef,
266
  )
267
 
 
268
  generated = generated[:, ref_audio_len:, :]
269
  generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
270
  generated_wave = vocos.decode(generated_mel_spec.cpu())
 
145
  else:
146
  tokenizer="custom"
147
 
148
+ print("\nvocab : ", vocab_file,tokenizer)
149
+ print("tokenizer : ", tokenizer)
150
+ print("model : ", ckpt_path,"\n")
151
 
152
  vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
153
  model = CFM(
 
265
  sway_sampling_coef=sway_sampling_coef,
266
  )
267
 
268
+ generated = generated.to(torch.float32)
269
  generated = generated[:, ref_audio_len:, :]
270
  generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
271
  generated_wave = vocos.decode(generated_mel_spec.cpu())
model/cfm.py CHANGED
@@ -99,6 +99,8 @@ class CFM(nn.Module):
99
  ):
100
  self.eval()
101
 
 
 
102
  # raw wave
103
 
104
  if cond.ndim == 2:
@@ -175,7 +177,7 @@ class CFM(nn.Module):
175
  for dur in duration:
176
  if exists(seed):
177
  torch.manual_seed(seed)
178
- y0.append(torch.randn(dur, self.num_channels, device = self.device))
179
  y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
180
 
181
  t_start = 0
@@ -186,7 +188,7 @@ class CFM(nn.Module):
186
  y0 = (1 - t_start) * y0 + t_start * test_cond
187
  steps = int(steps * (1 - t_start))
188
 
189
- t = torch.linspace(t_start, 1, steps, device = self.device)
190
  if sway_sampling_coef is not None:
191
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
192
 
 
99
  ):
100
  self.eval()
101
 
102
+ cond = cond.half()
103
+
104
  # raw wave
105
 
106
  if cond.ndim == 2:
 
177
  for dur in duration:
178
  if exists(seed):
179
  torch.manual_seed(seed)
180
+ y0.append(torch.randn(dur, self.num_channels, device = self.device, dtype=step_cond.dtype))
181
  y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
182
 
183
  t_start = 0
 
188
  y0 = (1 - t_start) * y0 + t_start * test_cond
189
  steps = int(steps * (1 - t_start))
190
 
191
+ t = torch.linspace(t_start, 1, steps, device = self.device, dtype=step_cond.dtype)
192
  if sway_sampling_coef is not None:
193
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
194
 
model/modules.py CHANGED
@@ -571,5 +571,6 @@ class TimestepEmbedding(nn.Module):
571
 
572
  def forward(self, timestep: float['b']):
573
  time_hidden = self.time_embed(timestep)
 
574
  time = self.time_mlp(time_hidden) # b d
575
  return time
 
571
 
572
  def forward(self, timestep: float['b']):
573
  time_hidden = self.time_embed(timestep)
574
+ time_hidden = time_hidden.to(timestep.dtype)
575
  time = self.time_mlp(time_hidden) # b d
576
  return time
model/trainer.py CHANGED
@@ -45,7 +45,8 @@ class Trainer:
45
  wandb_resume_id: str = None,
46
  last_per_steps = None,
47
  accelerate_kwargs: dict = dict(),
48
- ema_kwargs: dict = dict()
 
49
  ):
50
 
51
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
@@ -107,7 +108,11 @@ class Trainer:
107
 
108
  self.duration_predictor = duration_predictor
109
 
110
- self.optimizer = AdamW(model.parameters(), lr=learning_rate)
 
 
 
 
111
  self.model, self.optimizer = self.accelerator.prepare(
112
  self.model, self.optimizer
113
  )
 
45
  wandb_resume_id: str = None,
46
  last_per_steps = None,
47
  accelerate_kwargs: dict = dict(),
48
+ ema_kwargs: dict = dict(),
49
+ bnb_optimizer: bool = False,
50
  ):
51
 
52
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
 
108
 
109
  self.duration_predictor = duration_predictor
110
 
111
+ if bnb_optimizer:
112
+ import bitsandbytes as bnb
113
+ self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
114
+ else:
115
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
116
  self.model, self.optimizer = self.accelerator.prepare(
117
  self.model, self.optimizer
118
  )
model/utils.py CHANGED
@@ -557,23 +557,23 @@ def repetition_found(text, length = 2, tolerance = 10):
557
  # load model checkpoint for inference
558
 
559
  def load_checkpoint(model, ckpt_path, device, use_ema = True):
560
- from ema_pytorch import EMA
561
 
562
  ckpt_type = ckpt_path.split(".")[-1]
563
  if ckpt_type == "safetensors":
564
  from safetensors.torch import load_file
565
- checkpoint = load_file(ckpt_path, device=device)
566
  else:
567
- checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
568
 
569
- if use_ema == True:
570
- ema_model = EMA(model, include_online_model = False).to(device)
571
  if ckpt_type == "safetensors":
572
- ema_model.load_state_dict(checkpoint)
573
- else:
574
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
575
- ema_model.copy_params_from_ema_to_model()
576
  else:
 
 
577
  model.load_state_dict(checkpoint['model_state_dict'])
578
-
579
- return model
 
557
  # load model checkpoint for inference
558
 
559
  def load_checkpoint(model, ckpt_path, device, use_ema = True):
560
+ model = model.half()
561
 
562
  ckpt_type = ckpt_path.split(".")[-1]
563
  if ckpt_type == "safetensors":
564
  from safetensors.torch import load_file
565
+ checkpoint = load_file(ckpt_path)
566
  else:
567
+ checkpoint = torch.load(ckpt_path, weights_only=True)
568
 
569
+ if use_ema:
 
570
  if ckpt_type == "safetensors":
571
+ checkpoint = {'ema_model_state_dict': checkpoint}
572
+ checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
573
+ model.load_state_dict(checkpoint['model_state_dict'])
 
574
  else:
575
+ if ckpt_type == "safetensors":
576
+ checkpoint = {'model_state_dict': checkpoint}
577
  model.load_state_dict(checkpoint['model_state_dict'])
578
+
579
+ return model.to(device)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  accelerate>=0.33.0
 
2
  cached_path
3
  click
4
  datasets
 
1
  accelerate>=0.33.0
2
+ bitsandbytes>0.37.0
3
  cached_path
4
  click
5
  datasets
speech_edit.py CHANGED
@@ -49,7 +49,7 @@ elif exp_name == "E2TTS_Base":
49
  model_cls = UNetT
50
  model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
51
 
52
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
53
  output_dir = "tests"
54
 
55
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
@@ -172,12 +172,13 @@ with torch.inference_mode():
172
  print(f"Generated mel: {generated.shape}")
173
 
174
  # Final result
 
175
  generated = generated[:, ref_audio_len:, :]
176
  generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
177
  generated_wave = vocos.decode(generated_mel_spec.cpu())
178
  if rms < target_rms:
179
  generated_wave = generated_wave * rms / target_rms
180
 
181
- save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png")
182
- torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate)
183
  print(f"Generated wav: {generated_wave.shape}")
 
49
  model_cls = UNetT
50
  model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
51
 
52
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
53
  output_dir = "tests"
54
 
55
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
 
172
  print(f"Generated mel: {generated.shape}")
173
 
174
  # Final result
175
+ generated = generated.to(torch.float32)
176
  generated = generated[:, ref_audio_len:, :]
177
  generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
178
  generated_wave = vocos.decode(generated_mel_spec.cpu())
179
  if rms < target_rms:
180
  generated_wave = generated_wave * rms / target_rms
181
 
182
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
183
+ torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
184
  print(f"Generated wav: {generated_wave.shape}")