File size: 4,087 Bytes
b33c328 5918e9e b33c328 211fff4 b33c328 5918e9e 211fff4 5918e9e 211fff4 b33c328 5918e9e b33c328 211fff4 b33c328 5918e9e b33c328 5918e9e b33c328 5918e9e b33c328 211fff4 b33c328 eb2441e b33c328 211fff4 b33c328 5918e9e 211fff4 5918e9e 211fff4 5918e9e 211fff4 5918e9e b33c328 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from random import sample
import gradio as gr
import torchaudio
import torch
import torch.nn as nn
import lightning_module
import pdb
import jiwer
# ASR part
from transformers import pipeline
# p = pipeline("automatic-speech-recognition")
p = pipeline(
"automatic-speech-recognition",
model="KevinGeng/whipser_medium_en_PAL300_step25",
device=0,
)
# WER part
transformation = jiwer.Compose([
jiwer.ToLowerCase(),
jiwer.RemoveWhiteSpace(replace_by_space=True),
jiwer.RemoveMultipleSpaces(),
jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
])
# WPM part
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
# phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
class ChangeSampleRate(nn.Module):
def __init__(self, input_rate: int, output_rate: int):
super().__init__()
self.output_rate = output_rate
self.input_rate = input_rate
def forward(self, wav: torch.tensor) -> torch.tensor:
# Only accepts 1-channel waveform input
wav = wav.view(wav.size(0), -1)
new_length = wav.size(-1) * self.output_rate // self.input_rate
indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
round_down = wav[:, indices.long()]
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
return output
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
def calc_mos(audio_path, ref):
wav, sr = torchaudio.load(audio_path, channels_first=True)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True) # Mono channel
osr = 16_000
batch = wav.unsqueeze(0).repeat(10, 1, 1)
csr = ChangeSampleRate(sr, osr)
out_wavs = csr(wav)
# ASR
trans = p(audio_path)["text"]
# WER
wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
# MOS
batch = {
'wav': out_wavs,
'domains': torch.tensor([0]),
'judge_id': torch.tensor([288])
}
with torch.no_grad():
output = model(batch)
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
# Phonemes per minute (PPM)
with torch.no_grad():
logits = phoneme_model(out_wavs).logits
phone_predicted_ids = torch.argmax(logits, dim=-1)
phone_transcription = processor.batch_decode(phone_predicted_ids)
lst_phonemes = phone_transcription[0].split(" ")
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
return predic_mos, trans, wer, phone_transcription, ppm
description ="""
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
This demo only accepts .wav format. Best at 16 kHz sampling rate.
Paper is available [here](https://arxiv.org/abs/2204.02152)
Add ASR based on wav2vec-960, currently only English available.
Add WER interface.
"""
iface = gr.Interface(
fn=calc_mos,
inputs=[gr.Audio(type='filepath', label="Audio to evaluate"),
gr.Textbox(placeholder="Input reference here (Don't keep this empty)", label="Reference")],
outputs=[gr.Textbox(placeholder="Naturalness evaluation, ranged 1 to 5, the higher the better.", label="Predicted MOS"),
gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
gr.Textbox(placeholder="Word Error Rate: Only valid when Reference is given", label = "WER"),
gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"),
gr.Textbox(placeholder="Speaking Rate, Phonemes per minutes", label="PPM")],
title="Laronix's Voice Quality Checking System Demo",
description=description,
allow_flagging="auto",
)
iface.launch() |