KevinGeng's picture
support pitch contour and db plotting
5407cce
raw
history blame
5.51 kB
from random import sample
import gradio as gr
import torchaudio
import torch
import torch.nn as nn
import lightning_module
import pdb
import jiwer
from local.convert_metrics import nat2avaMOS, WER2INTELI
from local.indicator_plot import Intelligibility_Plot, Naturalness_Plot
from local.pitch_contour import draw_spec_db_pitch
# ASR part
from transformers import pipeline
p = pipeline("automatic-speech-recognition")
# WER part
transformation = jiwer.Compose([
jiwer.ToLowerCase(),
jiwer.RemoveWhiteSpace(replace_by_space=True),
jiwer.RemoveMultipleSpaces(),
jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
])
# WPM part
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
# phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
class ChangeSampleRate(nn.Module):
def __init__(self, input_rate: int, output_rate: int):
super().__init__()
self.output_rate = output_rate
self.input_rate = input_rate
def forward(self, wav: torch.tensor) -> torch.tensor:
# Only accepts 1-channel waveform input
wav = wav.view(wav.size(0), -1)
new_length = wav.size(-1) * self.output_rate // self.input_rate
indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
round_down = wav[:, indices.long()]
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
return output
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
def calc_mos(audio_path, ref):
wav, sr = torchaudio.load(audio_path, channels_first=True)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True) # Mono channel
# get decibel
osr = 16_000
batch = wav.unsqueeze(0).repeat(10, 1, 1)
csr = ChangeSampleRate(sr, osr)
out_wavs = csr(wav)
db = torchaudio.transforms.AmplitudeToDB(stype="amplitude", top_db=80)(wav)
# ASR
trans = p(audio_path)["text"]
# WER
wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
# WER convert to Intellibility score
INTELI_score = WER2INTELI(wer*100)
INT_fig = Intelligibility_Plot(INTELI_score)
# MOS
batch = {
'wav': out_wavs,
'domains': torch.tensor([0]),
'judge_id': torch.tensor([288])
}
with torch.no_grad():
output = model(batch)
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
# MOS to AVA MOS
AVA_MOS = nat2avaMOS(predic_mos)
MOS_fig = Naturalness_Plot(AVA_MOS)
# Phonemes per minute (PPM)
with torch.no_grad():
logits = phoneme_model(out_wavs).logits
phone_predicted_ids = torch.argmax(logits, dim=-1)
phone_transcription = processor.batch_decode(phone_predicted_ids)
lst_phonemes = phone_transcription[0].split(" ")
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
# draw f0 and db analysis plot
f0_db_fig = draw_spec_db_pitch(audio_path, save_fig_path=None)
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
# pdb.set_trace()
return AVA_MOS, MOS_fig, INTELI_score, INT_fig, trans, phone_transcription, ppm , f0_db_fig
with open("local/description.md") as f:
description = f.read()
# x = calc_mos("JOHN1.wav", "he would answer in a soft voice, 'I don't know.'")
# pdb.set_trace()
examples = [
["local/Julianna_Set1_Author_01.wav", "Once upon a time, there was a young rat named Arthur who couldn't make up his mind."],
["local/Patient_Arthur_set1_002_noisy.wav", "Whenever the other rats asked him if he would like to go hunting with them, he would answer in a soft voice, 'I don't know.'"],
]
iface = gr.Interface(
fn=calc_mos,
inputs=[gr.Audio(type='filepath', label="Audio to evaluate"),
gr.Textbox(placeholder="Input reference here (Don't keep this empty)", label="Reference")],
outputs=[gr.Textbox(placeholder="Naturalness Score, ranged from 1 to 5, the higher the better.", label="Naturalness Score, ranged from 0 to 5, the higher the better.", visible=False),
gr.Plot(label="Naturalness Score, ranged from 1 to 5, the higher the better.", show_label=True, container=True),
gr.Textbox(placeholder="Intelligibility Score", label = "Intelligibility Score, range from 0 to 100, the higher the better", visible=False),
gr.Plot(label="Intelligibility Score, range from 0 to 100, the higher the better", show_label=True, container=True),
gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes", visible=False),
gr.Textbox(placeholder="Speaking Rate, Phonemes per minutes", label="Speaking Rate, Phonemes per minutes", visible=False),
gr.Plot(label="Pitch Contour and dB Analysis", show_label=True, container=True)],
title="Speech Analysis by Laronix AI",
description=description,
allow_flagging="auto",
examples=examples,
)
# add password to protect the interface
iface.launch(share=False, auth=['Laronix', 'LaronixSLP'], auth_message="Authentication Required, ask kevin@laronix.com for password.\n Thanks for your cooperation!")