|
|
|
from random import sample |
|
import gradio as gr |
|
import torchaudio |
|
import torch |
|
import torch.nn as nn |
|
import lightning_module |
|
import pdb |
|
import jiwer |
|
|
|
|
|
from transformers import pipeline |
|
p = pipeline("automatic-speech-recognition") |
|
|
|
|
|
transformation = jiwer.Compose([ |
|
jiwer.ToLowerCase(), |
|
jiwer.RemoveWhiteSpace(replace_by_space=True), |
|
jiwer.RemoveMultipleSpaces(), |
|
jiwer.ReduceToListOfListOfWords(word_delimiter=" ") |
|
]) |
|
|
|
|
|
from transformers import Wav2Vec2PhonemeCTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC |
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft") |
|
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft") |
|
|
|
class ChangeSampleRate(nn.Module): |
|
def __init__(self, input_rate: int, output_rate: int): |
|
super().__init__() |
|
self.output_rate = output_rate |
|
self.input_rate = input_rate |
|
|
|
def forward(self, wav: torch.tensor) -> torch.tensor: |
|
|
|
wav = wav.view(wav.size(0), -1) |
|
new_length = wav.size(-1) * self.output_rate // self.input_rate |
|
indices = (torch.arange(new_length) * (self.input_rate / self.output_rate)) |
|
round_down = wav[:, indices.long()] |
|
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)] |
|
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0) |
|
return output |
|
|
|
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval() |
|
|
|
def calc_mos(audio_path, ref): |
|
wav, sr = torchaudio.load(audio_path) |
|
osr = 16_000 |
|
batch = wav.unsqueeze(0).repeat(10, 1, 1) |
|
csr = ChangeSampleRate(sr, osr) |
|
out_wavs = csr(wav) |
|
|
|
trans = p(audio_path)["text"] |
|
|
|
wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation) |
|
|
|
batch = { |
|
'wav': out_wavs, |
|
'domains': torch.tensor([0]), |
|
'judge_id': torch.tensor([288]) |
|
} |
|
with torch.no_grad(): |
|
output = model(batch) |
|
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3 |
|
|
|
with torch.no_grad(): |
|
logits = phoneme_model(out_wavs).logits |
|
phone_predicted_ids = torch.argmax(logits, dim=-1) |
|
phone_transcription = processor.batch_decode(phone_predicted_ids) |
|
lst_phonemes = phone_transcription[0].split(" ") |
|
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr) |
|
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60 |
|
|
|
return predic_mos, trans, wer, phone_transcription, ppm |
|
|
|
description =""" |
|
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset. |
|
This demo only accepts .wav format. Best at 16 kHz sampling rate. |
|
|
|
Paper is available [here](https://arxiv.org/abs/2204.02152) |
|
|
|
Add ASR based on wav2vec-960, currently only English available. |
|
Add WER interface. |
|
""" |
|
|
|
iface = gr.Interface( |
|
fn=calc_mos, |
|
inputs=[gr.Audio(type='filepath', label="Audio to evaluate"), |
|
gr.Textbox(placeholder="Input referance here", label="Referance")], |
|
outputs=[gr.Textbox(placeholder="Predicted MOS", label="Predicted MOS"), |
|
gr.Textbox(placeholder="Hypothesis", label="Hypothesis"), |
|
gr.Textbox(placeholder="Word Error Rate", label = "WER"), |
|
gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"), |
|
gr.Textbox(placeholder="Phonemes per minutes", label="PPM")], |
|
title="Laronix's Voice Quality Checking System Demo", |
|
description=description, |
|
allow_flagging="auto", |
|
) |
|
iface.launch() |