ThreadAbort commited on
Commit
fda8dc2
1 Parent(s): 781fa66

model change

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -50,7 +50,7 @@ speed = 1.0
50
  fix_duration = None
51
 
52
  def load_model(exp_name, model_cls, model_cfg, ckpt_step):
53
- checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
54
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
55
  model = CFM(
56
  transformer=model_cls(
@@ -85,7 +85,7 @@ E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_
85
  @spaces.GPU
86
  def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
87
  print(gen_text)
88
- if model.predict(gen_text)['toxicity'] > 0.2:
89
  print("Flagged for toxicity:", gen_text)
90
  raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
91
  gr.Info("Converting audio...")
 
50
  fix_duration = None
51
 
52
  def load_model(exp_name, model_cls, model_cfg, ckpt_step):
53
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
54
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
55
  model = CFM(
56
  transformer=model_cls(
 
85
  @spaces.GPU
86
  def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
87
  print(gen_text)
88
+ if model.predict(gen_text)['toxicity'] > 0.8:
89
  print("Flagged for toxicity:", gen_text)
90
  raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
91
  gr.Info("Converting audio...")