mrfakename commited on
Commit
cea02d8
1 Parent(s): 9eac142

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (2) hide show
  1. model/cfm.py +2 -1
  2. model/utils.py +2 -1
model/cfm.py CHANGED
@@ -96,7 +96,8 @@ class CFM(nn.Module):
96
  ):
97
  self.eval()
98
 
99
- cond = cond.half()
 
100
 
101
  # raw wave
102
 
 
96
  ):
97
  self.eval()
98
 
99
+ if cond.device != torch.device('cpu'):
100
+ cond = cond.half()
101
 
102
  # raw wave
103
 
model/utils.py CHANGED
@@ -555,7 +555,8 @@ def repetition_found(text, length = 2, tolerance = 10):
555
  # load model checkpoint for inference
556
 
557
  def load_checkpoint(model, ckpt_path, device, use_ema = True):
558
- model = model.half()
 
559
 
560
  ckpt_type = ckpt_path.split(".")[-1]
561
  if ckpt_type == "safetensors":
 
555
  # load model checkpoint for inference
556
 
557
  def load_checkpoint(model, ckpt_path, device, use_ema = True):
558
+ if device != "cpu":
559
+ model = model.half()
560
 
561
  ckpt_type = ckpt_path.split(".")[-1]
562
  if ckpt_type == "safetensors":