mrfakename commited on
Commit
831ba2e
1 Parent(s): b4752cf

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (2) hide show
  1. api.py +1 -0
  2. model/utils_infer.py +7 -23
api.py CHANGED
@@ -105,6 +105,7 @@ class F5TTS:
105
  sway_sampling_coef=sway_sampling_coef,
106
  speed=speed,
107
  fix_duration=fix_duration,
 
108
  )
109
 
110
  if file_wave is not None:
 
105
  sway_sampling_coef=sway_sampling_coef,
106
  speed=speed,
107
  fix_duration=fix_duration,
108
+ device=self.device,
109
  )
110
 
111
  if file_wave is not None:
model/utils_infer.py CHANGED
@@ -19,13 +19,8 @@ from model.utils import (
19
  convert_char_to_pinyin,
20
  )
21
 
22
- # get device
23
-
24
-
25
- def get_device():
26
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
- return device
28
 
 
29
 
30
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
31
 
@@ -81,9 +76,7 @@ def chunk_text(text, max_chars=135):
81
 
82
 
83
  # load vocoder
84
- def load_vocoder(is_local=False, local_path="", device=None):
85
- if device is None:
86
- device = get_device()
87
  if is_local:
88
  print(f"Load vocos from local path {local_path}")
89
  vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
@@ -101,11 +94,8 @@ def load_vocoder(is_local=False, local_path="", device=None):
101
  asr_pipe = None
102
 
103
 
104
- def initialize_asr_pipeline(device=None):
105
  global asr_pipe
106
- if device is None:
107
- device = get_device()
108
-
109
  asr_pipe = pipeline(
110
  "automatic-speech-recognition",
111
  model="openai/whisper-large-v3-turbo",
@@ -117,9 +107,7 @@ def initialize_asr_pipeline(device=None):
117
  # load model for inference
118
 
119
 
120
- def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=None):
121
- if device is None:
122
- device = get_device()
123
  if vocab_file == "":
124
  vocab_file = "Emilia_ZH_EN"
125
  tokenizer = "pinyin"
@@ -152,10 +140,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
152
  # preprocess reference audio and text
153
 
154
 
155
- def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=None):
156
- if device is None:
157
- device = get_device()
158
-
159
  show_info("Converting audio...")
160
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
161
  aseg = AudioSegment.from_file(ref_audio_orig)
@@ -216,6 +201,7 @@ def infer_process(
216
  sway_sampling_coef=sway_sampling_coef,
217
  speed=speed,
218
  fix_duration=fix_duration,
 
219
  ):
220
  # Split the input text into batches
221
  audio, sr = torchaudio.load(ref_audio)
@@ -238,6 +224,7 @@ def infer_process(
238
  sway_sampling_coef=sway_sampling_coef,
239
  speed=speed,
240
  fix_duration=fix_duration,
 
241
  )
242
 
243
 
@@ -259,9 +246,6 @@ def infer_batch_process(
259
  fix_duration=None,
260
  device=None,
261
  ):
262
- if device is None:
263
- device = get_device()
264
-
265
  audio, sr = ref_audio
266
  if audio.shape[0] > 1:
267
  audio = torch.mean(audio, dim=0, keepdim=True)
 
19
  convert_char_to_pinyin,
20
  )
21
 
 
 
 
 
 
 
22
 
23
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
24
 
25
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
26
 
 
76
 
77
 
78
  # load vocoder
79
+ def load_vocoder(is_local=False, local_path="", device=device):
 
 
80
  if is_local:
81
  print(f"Load vocos from local path {local_path}")
82
  vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
 
94
  asr_pipe = None
95
 
96
 
97
+ def initialize_asr_pipeline(device=device):
98
  global asr_pipe
 
 
 
99
  asr_pipe = pipeline(
100
  "automatic-speech-recognition",
101
  model="openai/whisper-large-v3-turbo",
 
107
  # load model for inference
108
 
109
 
110
+ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
 
 
111
  if vocab_file == "":
112
  vocab_file = "Emilia_ZH_EN"
113
  tokenizer = "pinyin"
 
140
  # preprocess reference audio and text
141
 
142
 
143
+ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
 
 
 
144
  show_info("Converting audio...")
145
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
146
  aseg = AudioSegment.from_file(ref_audio_orig)
 
201
  sway_sampling_coef=sway_sampling_coef,
202
  speed=speed,
203
  fix_duration=fix_duration,
204
+ device=device,
205
  ):
206
  # Split the input text into batches
207
  audio, sr = torchaudio.load(ref_audio)
 
224
  sway_sampling_coef=sway_sampling_coef,
225
  speed=speed,
226
  fix_duration=fix_duration,
227
+ device=device,
228
  )
229
 
230
 
 
246
  fix_duration=None,
247
  device=None,
248
  ):
 
 
 
249
  audio, sr = ref_audio
250
  if audio.shape[0] > 1:
251
  audio = torch.mean(audio, dim=0, keepdim=True)