File size: 4,087 Bytes
b33c328
 
 
 
 
 
 
 
5918e9e
b33c328
 
211fff4
 
 
 
 
 
b33c328
 
 
 
 
 
 
 
5918e9e
211fff4
5918e9e
 
211fff4
b33c328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5918e9e
 
b33c328
211fff4
 
 
b33c328
 
 
 
 
 
 
 
5918e9e
b33c328
 
 
 
 
 
 
 
5918e9e
 
 
 
 
 
 
 
b33c328
5918e9e
b33c328
211fff4
b33c328
 
 
 
 
eb2441e
 
 
b33c328
 
211fff4
b33c328
 
5918e9e
211fff4
 
5918e9e
211fff4
5918e9e
211fff4
5918e9e
b33c328
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from random import sample
import gradio as gr
import torchaudio
import torch
import torch.nn as nn
import lightning_module
import pdb
import jiwer

# ASR part
from transformers import pipeline
# p = pipeline("automatic-speech-recognition")
p = pipeline(
    "automatic-speech-recognition",
    model="KevinGeng/whipser_medium_en_PAL300_step25",
    device=0,
)
# WER part
transformation = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.RemoveWhiteSpace(replace_by_space=True),
    jiwer.RemoveMultipleSpaces(),
    jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
])

# WPM part
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
# phoneme_model =  pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft") 
class ChangeSampleRate(nn.Module):
    def __init__(self, input_rate: int, output_rate: int):
        super().__init__()
        self.output_rate = output_rate
        self.input_rate = input_rate

    def forward(self, wav: torch.tensor) -> torch.tensor:
        # Only accepts 1-channel waveform input
        wav = wav.view(wav.size(0), -1)
        new_length = wav.size(-1) * self.output_rate // self.input_rate
        indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
        round_down = wav[:, indices.long()]
        round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
        output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
        return output

model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()

def calc_mos(audio_path, ref):
    wav, sr = torchaudio.load(audio_path, channels_first=True)  
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True) # Mono channel
    osr = 16_000
    batch = wav.unsqueeze(0).repeat(10, 1, 1)
    csr = ChangeSampleRate(sr, osr)
    out_wavs = csr(wav)
    # ASR
    trans = p(audio_path)["text"]
    # WER
    wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
    # MOS
    batch = {
        'wav': out_wavs,
        'domains': torch.tensor([0]),
        'judge_id': torch.tensor([288])
    }
    with torch.no_grad():
        output = model(batch)
    predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
    # Phonemes per minute (PPM)
    with torch.no_grad():
        logits = phoneme_model(out_wavs).logits
    phone_predicted_ids = torch.argmax(logits, dim=-1)
    phone_transcription = processor.batch_decode(phone_predicted_ids)
    lst_phonemes = phone_transcription[0].split(" ")
    wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
    ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
    
    return predic_mos, trans, wer, phone_transcription, ppm


description ="""
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
This demo only accepts .wav format. Best at 16 kHz sampling rate.

Paper is available [here](https://arxiv.org/abs/2204.02152)

Add ASR based on wav2vec-960, currently only English available.
Add WER interface.
""" 


iface = gr.Interface(
  fn=calc_mos,
  inputs=[gr.Audio(type='filepath', label="Audio to evaluate"), 
          gr.Textbox(placeholder="Input reference here (Don't keep this empty)", label="Reference")],
  outputs=[gr.Textbox(placeholder="Naturalness evaluation, ranged 1 to 5, the higher the better.", label="Predicted MOS"),
           gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
           gr.Textbox(placeholder="Word Error Rate: Only valid when Reference is given", label = "WER"),
           gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"),
           gr.Textbox(placeholder="Speaking Rate, Phonemes per minutes", label="PPM")],
  title="Laronix's Voice Quality Checking System Demo",
  description=description,
  allow_flagging="auto",
)
iface.launch()