mrfakename commited on
Commit
9cebe0a
1 Parent(s): 61075cd

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/model/trainer.py +4 -4
src/f5_tts/model/trainer.py CHANGED
@@ -61,7 +61,7 @@ class Trainer:
61
  gradient_accumulation_steps=grad_accumulation_steps,
62
  **accelerate_kwargs,
63
  )
64
-
65
  self.logger = logger
66
  if self.logger == "wandb":
67
  if exists(wandb_resume_id):
@@ -325,7 +325,7 @@ class Trainer:
325
 
326
  if self.log_samples and self.accelerator.is_local_main_process:
327
  ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
328
- torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
329
  with torch.inference_mode():
330
  generated, _ = self.accelerator.unwrap_model(self.model).sample(
331
  cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
@@ -336,8 +336,8 @@ class Trainer:
336
  sway_sampling_coef=sway_sampling_coef,
337
  )
338
  generated = generated.to(torch.float32)
339
- gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu())
340
- torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
341
 
342
  if global_step % self.last_per_steps == 0:
343
  self.save_checkpoint(global_step, last=True)
 
61
  gradient_accumulation_steps=grad_accumulation_steps,
62
  **accelerate_kwargs,
63
  )
64
+ self.device = self.accelerator.device
65
  self.logger = logger
66
  if self.logger == "wandb":
67
  if exists(wandb_resume_id):
 
325
 
326
  if self.log_samples and self.accelerator.is_local_main_process:
327
  ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
328
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate)
329
  with torch.inference_mode():
330
  generated, _ = self.accelerator.unwrap_model(self.model).sample(
331
  cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
 
336
  sway_sampling_coef=sway_sampling_coef,
337
  )
338
  generated = generated.to(torch.float32)
339
+ gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.device))
340
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate)
341
 
342
  if global_step % self.last_per_steps == 0:
343
  self.save_checkpoint(global_step, last=True)