File size: 4,841 Bytes
67beffe 0514036 2090f04 67beffe dc94ed9 67beffe 4ea1ce4 67beffe 0514036 2090f04 67beffe 0514036 2090f04 67beffe 2090f04 0514036 67beffe 2090f04 4ea1ce4 2090f04 67beffe 2090f04 4ea1ce4 2090f04 67beffe 2090f04 67beffe 2090f04 67beffe 2090f04 67beffe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from random import sample
import gradio as gr
import torchaudio
import torch
import torch.nn as nn
import lightning_module
import pdb
import jiwer
from local.convert_metrics import nat2avaMOS, WER2INTELI
from local.indicator_plot import Intelligibility_Plot, Naturalness_Plot
# ASR part
from transformers import pipeline
p = pipeline("automatic-speech-recognition")
# WER part
transformation = jiwer.Compose([
jiwer.ToLowerCase(),
jiwer.RemoveWhiteSpace(replace_by_space=True),
jiwer.RemoveMultipleSpaces(),
jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
])
# WPM part
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
# phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
class ChangeSampleRate(nn.Module):
def __init__(self, input_rate: int, output_rate: int):
super().__init__()
self.output_rate = output_rate
self.input_rate = input_rate
def forward(self, wav: torch.tensor) -> torch.tensor:
# Only accepts 1-channel waveform input
wav = wav.view(wav.size(0), -1)
new_length = wav.size(-1) * self.output_rate // self.input_rate
indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
round_down = wav[:, indices.long()]
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
return output
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
def calc_mos(audio_path, ref):
wav, sr = torchaudio.load(audio_path, channels_first=True)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True) # Mono channel
osr = 16_000
batch = wav.unsqueeze(0).repeat(10, 1, 1)
csr = ChangeSampleRate(sr, osr)
out_wavs = csr(wav)
# ASR
trans = p(audio_path)["text"]
# WER
wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
# WER convert to Intellibility score
INTELI_score = WER2INTELI(wer*100)
INT_fig = Intelligibility_Plot(INTELI_score)
# MOS
batch = {
'wav': out_wavs,
'domains': torch.tensor([0]),
'judge_id': torch.tensor([288])
}
with torch.no_grad():
output = model(batch)
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
# MOS to AVA MOS
AVA_MOS = nat2avaMOS(predic_mos)
MOS_fig = Naturalness_Plot(AVA_MOS)
# Phonemes per minute (PPM)
with torch.no_grad():
logits = phoneme_model(out_wavs).logits
phone_predicted_ids = torch.argmax(logits, dim=-1)
phone_transcription = processor.batch_decode(phone_predicted_ids)
lst_phonemes = phone_transcription[0].split(" ")
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
# pdb.set_trace()
return AVA_MOS, MOS_fig, INTELI_score, INT_fig, trans, phone_transcription, ppm
with open("local/description.md") as f:
description = f.read()
# calc_mos("audio_2023-11-01_15-57-39.wav", "hello world")
# pdb.set_trace()
examples = [
[None, "Once upon a time, there was a young rat named Arthur who couldn't make up his mind."],
[None, "Whenever the other rats asked Arthur if he wanted to go to the park, he would say, 'I don't know.'"],
]
iface = gr.Interface(
fn=calc_mos,
inputs=[gr.Audio(type='filepath', label="Audio to evaluate"),
gr.Textbox(placeholder="Input reference here (Don't keep this empty)", label="Reference")],
outputs=[gr.Textbox(placeholder="Naturalness Score, ranged from 0 to 5, the higher the better.", label="Naturalness Score, ranged from 0 to 5, the higher the better.", visible=False),
gr.Plot(label="Naturalness Score, ranged from 0 to 5, the higher the better.", show_label=True, container=True),
gr.Textbox(placeholder="Intelligibility Score", label = "Intelligibility Score, range from 0 to 100, the higher the better", visible=False),
gr.Plot(label="Intelligibility Score, range from 0 to 100, the higher the better", show_label=True, container=True),
gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes", visible=False),
gr.Textbox(placeholder="Speaking Rate, Phonemes per minutes", label="Speaking Rate, Phonemes per minutes", visible=False)],
title="Speech Analysis by Laronix AI",
description=description,
allow_flagging="auto",
examples=examples,
)
iface.launch() |