KevinGeng commited on
Commit
5918e9e
1 Parent(s): eb2441e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -7
app.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
7
  import lightning_module
8
  import pdb
9
  import jiwer
 
10
  # ASR part
11
  from transformers import pipeline
12
  p = pipeline("automatic-speech-recognition")
@@ -19,6 +20,11 @@ transformation = jiwer.Compose([
19
  jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
20
  ])
21
 
 
 
 
 
 
22
  class ChangeSampleRate(nn.Module):
23
  def __init__(self, input_rate: int, output_rate: int):
24
  super().__init__()
@@ -35,7 +41,8 @@ class ChangeSampleRate(nn.Module):
35
  output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
36
  return output
37
 
38
- model = lightning_module.BaselineLightningModule.load_from_checkpoint("./epoch=3-step=7459.ckpt").eval()
 
39
  def calc_mos(audio_path, ref):
40
  wav, sr = torchaudio.load(audio_path)
41
  osr = 16_000
@@ -46,7 +53,7 @@ def calc_mos(audio_path, ref):
46
  trans = p(audio_path)["text"]
47
  # WER
48
  wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
49
-
50
  batch = {
51
  'wav': out_wavs,
52
  'domains': torch.tensor([0]),
@@ -54,10 +61,17 @@ def calc_mos(audio_path, ref):
54
  }
55
  with torch.no_grad():
56
  output = model(batch)
57
-
58
  predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
 
 
 
 
 
 
 
 
59
 
60
- return predic_mos, trans, wer
61
 
62
  description ="""
63
  MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
@@ -71,9 +85,14 @@ Add WER interface.
71
 
72
  iface = gr.Interface(
73
  fn=calc_mos,
74
- inputs=[gr.Audio(type='filepath'), gr.Textbox(placeholder="Insert referance here", label="Referance")],
75
- outputs=[gr.Textbox("Predicted MOS"), gr.Textbox("Hypothesis"), gr.Textbox("WER")],
76
- title="Laronix Voice Quality Checking Demo",
 
 
 
 
 
77
  description=description,
78
  allow_flagging="auto",
79
  )
 
7
  import lightning_module
8
  import pdb
9
  import jiwer
10
+
11
  # ASR part
12
  from transformers import pipeline
13
  p = pipeline("automatic-speech-recognition")
 
20
  jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
21
  ])
22
 
23
+ # WPM part
24
+ from transformers import Wav2Vec2PhonemeCTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC
25
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
26
+ phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
27
+
28
  class ChangeSampleRate(nn.Module):
29
  def __init__(self, input_rate: int, output_rate: int):
30
  super().__init__()
 
41
  output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
42
  return output
43
 
44
+ model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
45
+
46
  def calc_mos(audio_path, ref):
47
  wav, sr = torchaudio.load(audio_path)
48
  osr = 16_000
 
53
  trans = p(audio_path)["text"]
54
  # WER
55
  wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
56
+ # MOS
57
  batch = {
58
  'wav': out_wavs,
59
  'domains': torch.tensor([0]),
 
61
  }
62
  with torch.no_grad():
63
  output = model(batch)
 
64
  predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
65
+ # Phonemes per minute (PPM)
66
+ with torch.no_grad():
67
+ logits = phoneme_model(out_wavs).logits
68
+ phone_predicted_ids = torch.argmax(logits, dim=-1)
69
+ phone_transcription = processor.batch_decode(phone_predicted_ids)
70
+ lst_phonemes = phone_transcription[0].split(" ")
71
+ wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
72
+ ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
73
 
74
+ return predic_mos, trans, wer, phone_transcription, ppm
75
 
76
  description ="""
77
  MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
 
85
 
86
  iface = gr.Interface(
87
  fn=calc_mos,
88
+ inputs=[gr.Audio(type='filepath', label="Audio to evaluate"),
89
+ gr.Textbox(placeholder="Input referance here", label="Referance")],
90
+ outputs=[gr.Textbox(placeholder="Predicted MOS", label="Predicted MOS"),
91
+ gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
92
+ gr.Textbox(placeholder="Word Error Rate", label = "WER"),
93
+ gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"),
94
+ gr.Textbox(placeholder="Phonemes per minutes", label="PPM")],
95
+ title="Laronix's Voice Quality Checking System Demo",
96
  description=description,
97
  allow_flagging="auto",
98
  )