import os import numpy as np import unicodedata import diff_match_patch as dmp_module from enum import Enum import gradio as gr from datasets import load_dataset import pandas as pd from jiwer import process_words, wer_default class Action(Enum): INSERTION = 1 DELETION = -1 EQUAL = 0 def compare_string(text1: str, text2: str) -> list: text1_normalized = unicodedata.normalize("NFKC", text1) text2_normalized = unicodedata.normalize("NFKC", text2) dmp = dmp_module.diff_match_patch() diff = dmp.diff_main(text1_normalized, text2_normalized) dmp.diff_cleanupSemantic(diff) return diff def style_text(diff): fullText = "" for action, text in diff: if action == Action.INSERTION.value: fullText += f"{text}" elif action == Action.DELETION.value: fullText += f"{text}" elif action == Action.EQUAL.value: fullText += f"{text}" else: raise Exception("Not Implemented") fullText = fullText.replace("](", "]\(").replace("~", "\~") return fullText dataset = load_dataset( "distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count() ) csv_v2 = pd.read_csv("assets/large-v2.csv") norm_target = csv_v2["Norm Target"] norm_pred_v2 = csv_v2["Norm Pred"] norm_target = [norm_target[i] for i in range(len(norm_target))] norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))] csv_v2 = pd.read_csv("assets/large-32-2.csv") norm_pred_32_2 = csv_v2["Norm Pred"] norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))] target_dtype = np.int16 max_range = np.iinfo(target_dtype).max def get_visualisation(idx, model="large-v2", round_dp=2): idx -= 1 audio = dataset[idx]["audio"] array = (audio["array"] * max_range).astype(np.int16) sampling_rate = audio["sampling_rate"] text1 = norm_target[idx] if model == "large-v2": text2 = norm_pred_v2[idx] elif model == "large-32-2": text2 = norm_pred_32_2[idx] else: raise ValueError(f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`.") wer_output = process_words(text1, text2, wer_default, wer_default) wer_percentage = round(100 * wer_output.wer, round_dp) ier_percentage = round( 100 * wer_output.insertions / len(wer_output.references[0]), round_dp ) rel_length = round(len(text2.split()) / len(text1.split()), round_dp) diff = compare_string(text1, text2) full_text = style_text(diff) return (sampling_rate, array), wer_percentage, ier_percentage, rel_length, full_text def get_side_by_side_visualisation(idx): large_v2 = get_visualisation(idx, model="large-v2") large_32_2 = get_visualisation(idx, model="large-32-2") # format the rows table = [large_v2[1:-1], large_32_2[1:-1]] # format the model names table[0] = ["large-v2", *table[0]] table[1] = ["large-32-2", *table[1]] return large_v2[0], table, large_v2[-1], large_32_2[-1] if __name__ == "__main__": with gr.Blocks() as demo: gr.Markdown( "Analyse the transcriptions generated by the Whisper large-v2 and large-32-2 models on the TEDLIUM dev set." "The transcriptions for both models are shown at the bottom of the demo. The text diff for each is computed " "relative to the target transcriptions. Insertions are displayed in green, and " "deletions in red." ) slider = gr.Slider( minimum=1, maximum=len(norm_target), step=1, label="Dataset sample" ) btn = gr.Button("Analyse") audio_out = gr.Audio(label="Audio input") with gr.Column(): table = gr.Dataframe( headers=[ "Model", "Word Error Rate (WER)", "Insertion Error Rate (IER)", "Rel length (ref length / tgt length)", ], height=1000, ) with gr.Row(): gr.Markdown("**large-v2 text diff**") gr.Markdown("**large-32-2 text diff**") with gr.Row(): text_out_v2 = gr.Markdown(label="Text difference") text_out_32_2 = gr.Markdown(label="Text difference") btn.click( fn=get_side_by_side_visualisation, inputs=slider, outputs=[audio_out, table, text_out_v2, text_out_32_2], ) demo.launch()