Spaces:
Running
Running
mrfakename
commited on
Commit
•
831ba2e
1
Parent(s):
b4752cf
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- api.py +1 -0
- model/utils_infer.py +7 -23
api.py
CHANGED
@@ -105,6 +105,7 @@ class F5TTS:
|
|
105 |
sway_sampling_coef=sway_sampling_coef,
|
106 |
speed=speed,
|
107 |
fix_duration=fix_duration,
|
|
|
108 |
)
|
109 |
|
110 |
if file_wave is not None:
|
|
|
105 |
sway_sampling_coef=sway_sampling_coef,
|
106 |
speed=speed,
|
107 |
fix_duration=fix_duration,
|
108 |
+
device=self.device,
|
109 |
)
|
110 |
|
111 |
if file_wave is not None:
|
model/utils_infer.py
CHANGED
@@ -19,13 +19,8 @@ from model.utils import (
|
|
19 |
convert_char_to_pinyin,
|
20 |
)
|
21 |
|
22 |
-
# get device
|
23 |
-
|
24 |
-
|
25 |
-
def get_device():
|
26 |
-
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
27 |
-
return device
|
28 |
|
|
|
29 |
|
30 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
31 |
|
@@ -81,9 +76,7 @@ def chunk_text(text, max_chars=135):
|
|
81 |
|
82 |
|
83 |
# load vocoder
|
84 |
-
def load_vocoder(is_local=False, local_path="", device=
|
85 |
-
if device is None:
|
86 |
-
device = get_device()
|
87 |
if is_local:
|
88 |
print(f"Load vocos from local path {local_path}")
|
89 |
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
|
@@ -101,11 +94,8 @@ def load_vocoder(is_local=False, local_path="", device=None):
|
|
101 |
asr_pipe = None
|
102 |
|
103 |
|
104 |
-
def initialize_asr_pipeline(device=
|
105 |
global asr_pipe
|
106 |
-
if device is None:
|
107 |
-
device = get_device()
|
108 |
-
|
109 |
asr_pipe = pipeline(
|
110 |
"automatic-speech-recognition",
|
111 |
model="openai/whisper-large-v3-turbo",
|
@@ -117,9 +107,7 @@ def initialize_asr_pipeline(device=None):
|
|
117 |
# load model for inference
|
118 |
|
119 |
|
120 |
-
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=
|
121 |
-
if device is None:
|
122 |
-
device = get_device()
|
123 |
if vocab_file == "":
|
124 |
vocab_file = "Emilia_ZH_EN"
|
125 |
tokenizer = "pinyin"
|
@@ -152,10 +140,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
|
|
152 |
# preprocess reference audio and text
|
153 |
|
154 |
|
155 |
-
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=
|
156 |
-
if device is None:
|
157 |
-
device = get_device()
|
158 |
-
|
159 |
show_info("Converting audio...")
|
160 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
161 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
@@ -216,6 +201,7 @@ def infer_process(
|
|
216 |
sway_sampling_coef=sway_sampling_coef,
|
217 |
speed=speed,
|
218 |
fix_duration=fix_duration,
|
|
|
219 |
):
|
220 |
# Split the input text into batches
|
221 |
audio, sr = torchaudio.load(ref_audio)
|
@@ -238,6 +224,7 @@ def infer_process(
|
|
238 |
sway_sampling_coef=sway_sampling_coef,
|
239 |
speed=speed,
|
240 |
fix_duration=fix_duration,
|
|
|
241 |
)
|
242 |
|
243 |
|
@@ -259,9 +246,6 @@ def infer_batch_process(
|
|
259 |
fix_duration=None,
|
260 |
device=None,
|
261 |
):
|
262 |
-
if device is None:
|
263 |
-
device = get_device()
|
264 |
-
|
265 |
audio, sr = ref_audio
|
266 |
if audio.shape[0] > 1:
|
267 |
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
|
19 |
convert_char_to_pinyin,
|
20 |
)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
24 |
|
25 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
26 |
|
|
|
76 |
|
77 |
|
78 |
# load vocoder
|
79 |
+
def load_vocoder(is_local=False, local_path="", device=device):
|
|
|
|
|
80 |
if is_local:
|
81 |
print(f"Load vocos from local path {local_path}")
|
82 |
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
|
|
|
94 |
asr_pipe = None
|
95 |
|
96 |
|
97 |
+
def initialize_asr_pipeline(device=device):
|
98 |
global asr_pipe
|
|
|
|
|
|
|
99 |
asr_pipe = pipeline(
|
100 |
"automatic-speech-recognition",
|
101 |
model="openai/whisper-large-v3-turbo",
|
|
|
107 |
# load model for inference
|
108 |
|
109 |
|
110 |
+
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
|
|
|
|
|
111 |
if vocab_file == "":
|
112 |
vocab_file = "Emilia_ZH_EN"
|
113 |
tokenizer = "pinyin"
|
|
|
140 |
# preprocess reference audio and text
|
141 |
|
142 |
|
143 |
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
|
|
|
|
|
|
|
144 |
show_info("Converting audio...")
|
145 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
146 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
|
201 |
sway_sampling_coef=sway_sampling_coef,
|
202 |
speed=speed,
|
203 |
fix_duration=fix_duration,
|
204 |
+
device=device,
|
205 |
):
|
206 |
# Split the input text into batches
|
207 |
audio, sr = torchaudio.load(ref_audio)
|
|
|
224 |
sway_sampling_coef=sway_sampling_coef,
|
225 |
speed=speed,
|
226 |
fix_duration=fix_duration,
|
227 |
+
device=device,
|
228 |
)
|
229 |
|
230 |
|
|
|
246 |
fix_duration=None,
|
247 |
device=None,
|
248 |
):
|
|
|
|
|
|
|
249 |
audio, sr = ref_audio
|
250 |
if audio.shape[0] > 1:
|
251 |
audio = torch.mean(audio, dim=0, keepdim=True)
|