""" TODO: + [x] Load Configuration + [ ] Multi ASR Engine + [ ] Batch / Real Time support """ import numpy as np from pathlib import Path import jiwer import pdb import torch.nn as nn import torch import torchaudio import gradio as gr from logging import PlaceHolder from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2CTCTokenizer from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC from datasets import load_dataset import datasets import yaml from transformers import pipeline import librosa import librosa.display import matplotlib.pyplot as plt import soundfile as sf # local import import sys from local.vis import token_plot from local.wer import get_WER_highlight sys.path.append("src") # Load automos config_yaml = "config/samples.yaml" with open(config_yaml, "r") as f: # pdb.set_trace() try: config = yaml.safe_load(f) except FileExistsError: print("Config file Loading Error") exit() # Auto load examples refs = np.loadtxt(config["ref_txt"], delimiter="\n", dtype="str") refs_ids = [x.split()[0] for x in refs] refs_txt = [" ".join(x.split()[1:]) for x in refs] ref_wavs = [str(x) for x in sorted(Path(config["ref_wavs"]).glob("**/*.wav"))] # with open("src/description.html", "r", encoding="utf-8") as f: # description = f.read() description = "" reference_id = gr.Textbox( value="ID", placeholder="Utter ID", label="Reference_ID" ) reference_textbox = gr.Textbox( value="Input reference here", placeholder="Input reference here", label="Reference", ) reference_PPM = gr.Textbox( placeholder="Pneumatic Voice's PPM", label="Ref PPM" ) examples = [ [x, y] for x, y in zip(ref_wavs, refs_txt) ] # def map_to_array(batch): # speech, _ = sf.read(batch["file"]) # batch["speech"] = speech # return batch # ASR part p = pipeline("automatic-speech-recognition") import pdb # Tokenlizer part # import model, feature extractor, tokenizer def TOKENLIZER(audio_path, activate_plot=False): token_model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h") tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h") feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") # # load first sample of English common_voice # dataset = load_dataset("common_voice", "en", split="train", streaming=True) # dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000)) # dataset_iter = iter(dataset) # sample = next(dataset_iter) # # forward sample through model to get greedily predicted transcription ids # input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values # pdb.set_trace() input_values, sr = torchaudio.load(audio_path) if sr != feature_extractor.sampling_rate: input_values = torchaudio.functional.resample(input_values, sr, feature_extractor.sampling_rate) logits = token_model(input_values).logits[0] pred_ids = torch.argmax(logits, axis=-1) # retrieve word stamps (analogous commands for `output_char_offsets`) outputs = tokenizer.decode(pred_ids, output_word_offsets=True) # pdb.set_trace() # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate time_offset = token_model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate word_offsets = [ { "word": d["word"], "start_time": round(d["start_offset"] * time_offset, 2), "end_time": round(d["end_offset"] * time_offset, 2), } for d in outputs.word_offsets ] if activate_plot == True: token_fig = token_plot(input_values, feature_extractor.sampling_rate, word_offsets) return word_offsets, token_fig return word_offsets # TOKENLIZER("data/samples/p326_020.wav") # pdb.set_trace() # Load dataset # pdb.set_trace() # dataset = load_dataset("common_voice", "en", split="train", streaming=True) # dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000)) # dataset_iter = iter(dataset) # sample = next(dataset_iter) # pdb.set_trace() # input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values # pdb.set_trace() # WER part transformation = jiwer.Compose( [ jiwer.RemovePunctuation(), jiwer.ToUpperCase(), jiwer.RemoveWhiteSpace(replace_by_space=True), jiwer.RemoveMultipleSpaces(), jiwer.ReduceToListOfListOfWords(word_delimiter=" "), ] ) () class ChangeSampleRate(nn.Module): def __init__(self, input_rate: int, output_rate: int): super().__init__() self.output_rate = output_rate self.input_rate = input_rate def forward(self, wav: torch.tensor) -> torch.tensor: # Only accepts 1-channel waveform input wav = wav.view(wav.size(0), -1) new_length = wav.size(-1) * self.output_rate // self.input_rate indices = torch.arange(new_length) * ( self.input_rate / self.output_rate ) round_down = wav[:, indices.long()] round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)] output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(0) + ( round_up * indices.fmod(1.0).unsqueeze(0) ) return output # Flagging setup def calc_wer(audio_path, ref): wav, sr = torchaudio.load(audio_path) if wav.shape[0] != 1: wav = wav[0, :].unsqueeze(0) print(wav.shape) osr = 16000 batch = wav.unsqueeze(0).repeat(10, 1, 1) csr = ChangeSampleRate(sr, osr) out_wavs = csr(wav) # ASR # trans = jiwer.ToUpperCase()(p(audio_path)["text"]) # Tokenlizer tokens, token_wav_plot = TOKENLIZER(audio_path, activate_plot=True) # ASR part trans_cnt = [] for i in tokens: word, start_time, end_time = i.values() trans_cnt.append(word) trans = " ".join(x for x in trans_cnt) trans = jiwer.ToUpperCase()(trans) # WER ref = jiwer.ToUpperCase()(ref) highlight_hyp = get_WER_highlight(ref.split(" "), trans.split(" ")) wer = jiwer.wer( ref, trans, truth_transform=transformation, hypothesis_transform=transformation, ) # pdb.set_trace() word_acc = "%0.2f%%" %((1.0 - float(wer))*100) return [highlight_hyp, word_acc, token_wav_plot] # calc_wer(examples[1][0], examples[1][1]) # # calc_wer() # pdb.set_trace() iface = gr.Interface( fn=calc_wer, inputs=[ gr.Audio( source="upload", type="filepath", label="Audio_to_evaluate", show_label=False ), reference_textbox, ], #gr.Textbox(placeholder="Hypothesis", label="Recognition by AI"), outputs=[ gr.HighlightedText(placeholder="Hypothesis", label="Diff", combine_adjacent=True, adjacent_separator=" ", show_label=False).style(color_map={"1": "#78bd91", "0": "#ddbabf"}), gr.Textbox(placeholder="Word Accuracy", label="Word Accuracy (The Higher the better)"), gr.Plot(label="waveform", show_label=False) ], description=description, examples=examples, examples_per_page=20, css=".body {background-color: green}", ) print("Launch examples") iface.launch( share=False, )