Spaces:
Running
Running
ThreadAbort
commited on
Commit
•
fda8dc2
1
Parent(s):
781fa66
model change
Browse files
app.py
CHANGED
@@ -50,7 +50,7 @@ speed = 1.0
|
|
50 |
fix_duration = None
|
51 |
|
52 |
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
53 |
-
checkpoint = torch.load(str(cached_path(f"hf://SWivid/
|
54 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
55 |
model = CFM(
|
56 |
transformer=model_cls(
|
@@ -85,7 +85,7 @@ E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_
|
|
85 |
@spaces.GPU
|
86 |
def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
|
87 |
print(gen_text)
|
88 |
-
if model.predict(gen_text)['toxicity'] > 0.
|
89 |
print("Flagged for toxicity:", gen_text)
|
90 |
raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
|
91 |
gr.Info("Converting audio...")
|
|
|
50 |
fix_duration = None
|
51 |
|
52 |
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
53 |
+
checkpoint = torch.load(str(cached_path(f"hf://SWivid/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
|
54 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
55 |
model = CFM(
|
56 |
transformer=model_cls(
|
|
|
85 |
@spaces.GPU
|
86 |
def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
|
87 |
print(gen_text)
|
88 |
+
if model.predict(gen_text)['toxicity'] > 0.8:
|
89 |
print("Flagged for toxicity:", gen_text)
|
90 |
raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
|
91 |
gr.Info("Converting audio...")
|