KevinGeng commited on
Commit
211fff4
1 Parent(s): b27c1ab

Update app.py

Browse files

Support multi channel
better ASR model

Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from random import sample
3
  import gradio as gr
4
  import torchaudio
@@ -10,8 +9,12 @@ import jiwer
10
 
11
  # ASR part
12
  from transformers import pipeline
13
- p = pipeline("automatic-speech-recognition")
14
-
 
 
 
 
15
  # WER part
16
  transformation = jiwer.Compose([
17
  jiwer.ToLowerCase(),
@@ -21,10 +24,10 @@ transformation = jiwer.Compose([
21
  ])
22
 
23
  # WPM part
24
- from transformers import Wav2Vec2PhonemeCTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC
25
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
26
  phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
27
-
28
  class ChangeSampleRate(nn.Module):
29
  def __init__(self, input_rate: int, output_rate: int):
30
  super().__init__()
@@ -44,7 +47,9 @@ class ChangeSampleRate(nn.Module):
44
  model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
45
 
46
  def calc_mos(audio_path, ref):
47
- wav, sr = torchaudio.load(audio_path, channels_first=True) # Mono channel
 
 
48
  osr = 16_000
49
  batch = wav.unsqueeze(0).repeat(10, 1, 1)
50
  csr = ChangeSampleRate(sr, osr)
@@ -73,6 +78,7 @@ def calc_mos(audio_path, ref):
73
 
74
  return predic_mos, trans, wer, phone_transcription, ppm
75
 
 
76
  description ="""
77
  MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
78
  This demo only accepts .wav format. Best at 16 kHz sampling rate.
@@ -83,15 +89,16 @@ Add ASR based on wav2vec-960, currently only English available.
83
  Add WER interface.
84
  """
85
 
 
86
  iface = gr.Interface(
87
  fn=calc_mos,
88
  inputs=[gr.Audio(type='filepath', label="Audio to evaluate"),
89
- gr.Textbox(placeholder="Input referance here", label="Referance")],
90
- outputs=[gr.Textbox(placeholder="Predicted MOS", label="Predicted MOS"),
91
  gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
92
- gr.Textbox(placeholder="Word Error Rate", label = "WER"),
93
  gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"),
94
- gr.Textbox(placeholder="Phonemes per minutes", label="PPM")],
95
  title="Laronix's Voice Quality Checking System Demo",
96
  description=description,
97
  allow_flagging="auto",
 
 
1
  from random import sample
2
  import gradio as gr
3
  import torchaudio
 
9
 
10
  # ASR part
11
  from transformers import pipeline
12
+ # p = pipeline("automatic-speech-recognition")
13
+ p = pipeline(
14
+ "automatic-speech-recognition",
15
+ model="KevinGeng/whipser_medium_en_PAL300_step25",
16
+ device=0,
17
+ )
18
  # WER part
19
  transformation = jiwer.Compose([
20
  jiwer.ToLowerCase(),
 
24
  ])
25
 
26
  # WPM part
27
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
28
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
29
  phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
30
+ # phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
31
  class ChangeSampleRate(nn.Module):
32
  def __init__(self, input_rate: int, output_rate: int):
33
  super().__init__()
 
47
  model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
48
 
49
  def calc_mos(audio_path, ref):
50
+ wav, sr = torchaudio.load(audio_path, channels_first=True)
51
+ if wav.shape[0] > 1:
52
+ wav = wav.mean(dim=0, keepdim=True) # Mono channel
53
  osr = 16_000
54
  batch = wav.unsqueeze(0).repeat(10, 1, 1)
55
  csr = ChangeSampleRate(sr, osr)
 
78
 
79
  return predic_mos, trans, wer, phone_transcription, ppm
80
 
81
+
82
  description ="""
83
  MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
84
  This demo only accepts .wav format. Best at 16 kHz sampling rate.
 
89
  Add WER interface.
90
  """
91
 
92
+
93
  iface = gr.Interface(
94
  fn=calc_mos,
95
  inputs=[gr.Audio(type='filepath', label="Audio to evaluate"),
96
+ gr.Textbox(placeholder="Input reference here (Don't keep this empty)", label="Reference")],
97
+ outputs=[gr.Textbox(placeholder="Naturalness evaluation, ranged 1 to 5, the higher the better.", label="Predicted MOS"),
98
  gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
99
+ gr.Textbox(placeholder="Word Error Rate: Only valid when Reference is given", label = "WER"),
100
  gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"),
101
+ gr.Textbox(placeholder="Speaking Rate, Phonemes per minutes", label="PPM")],
102
  title="Laronix's Voice Quality Checking System Demo",
103
  description=description,
104
  allow_flagging="auto",