|
import os |
|
|
|
import numpy as np |
|
import unicodedata |
|
import diff_match_patch as dmp_module |
|
from enum import Enum |
|
import gradio as gr |
|
from datasets import load_dataset |
|
import pandas as pd |
|
from jiwer import process_words, wer_default |
|
|
|
|
|
class Action(Enum): |
|
INSERTION = 1 |
|
DELETION = -1 |
|
EQUAL = 0 |
|
|
|
|
|
def compare_string(text1: str, text2: str) -> list: |
|
text1_normalized = unicodedata.normalize("NFKC", text1) |
|
text2_normalized = unicodedata.normalize("NFKC", text2) |
|
|
|
dmp = dmp_module.diff_match_patch() |
|
diff = dmp.diff_main(text1_normalized, text2_normalized) |
|
dmp.diff_cleanupSemantic(diff) |
|
|
|
return diff |
|
|
|
|
|
def style_text(diff): |
|
fullText = "" |
|
for action, text in diff: |
|
if action == Action.INSERTION.value: |
|
fullText += f"<span style='background-color:Lightgreen'>{text}</span>" |
|
elif action == Action.DELETION.value: |
|
fullText += f"<span style='background-color:#FFCCCB'><s>{text}</s></span>" |
|
elif action == Action.EQUAL.value: |
|
fullText += f"{text}" |
|
else: |
|
raise Exception("Not Implemented") |
|
fullText = fullText.replace("](", "]\(").replace("~", "\~") |
|
return fullText |
|
|
|
|
|
dataset = load_dataset( |
|
"distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count() |
|
) |
|
|
|
csv_v2 = pd.read_csv("assets/large-v2.csv") |
|
|
|
norm_target = csv_v2["Norm Target"] |
|
norm_pred_v2 = csv_v2["Norm Pred"] |
|
|
|
norm_target = [norm_target[i] for i in range(len(norm_target))] |
|
norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))] |
|
|
|
csv_v2 = pd.read_csv("assets/large-32-2.csv") |
|
|
|
norm_pred_32_2 = csv_v2["Norm Pred"] |
|
norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))] |
|
|
|
target_dtype = np.int16 |
|
max_range = np.iinfo(target_dtype).max |
|
|
|
|
|
def get_visualisation(idx, model="large-v2", round_dp=2): |
|
idx -= 1 |
|
audio = dataset[idx]["audio"] |
|
array = (audio["array"] * max_range).astype(np.int16) |
|
sampling_rate = audio["sampling_rate"] |
|
|
|
text1 = norm_target[idx] |
|
if model == "large-v2": |
|
text2 = norm_pred_v2[idx] |
|
elif model == "large-32-2": |
|
text2 = norm_pred_32_2[idx] |
|
else: |
|
raise ValueError(f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`.") |
|
|
|
wer_output = process_words(text1, text2, wer_default, wer_default) |
|
wer_percentage = round(100 * wer_output.wer, round_dp) |
|
ier_percentage = round( |
|
100 * wer_output.insertions / len(wer_output.references[0]), round_dp |
|
) |
|
|
|
rel_length = round(len(text2.split()) / len(text1.split()), round_dp) |
|
|
|
diff = compare_string(text1, text2) |
|
full_text = style_text(diff) |
|
|
|
return (sampling_rate, array), wer_percentage, ier_percentage, rel_length, full_text |
|
|
|
|
|
def get_side_by_side_visualisation(idx): |
|
large_v2 = get_visualisation(idx, model="large-v2") |
|
large_32_2 = get_visualisation(idx, model="large-32-2") |
|
|
|
table = [large_v2[1:-1], large_32_2[1:-1]] |
|
|
|
table[0] = ["large-v2", *table[0]] |
|
table[1] = ["large-32-2", *table[1]] |
|
return large_v2[0], table, large_v2[-1], large_32_2[-1] |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks() as demo: |
|
gr.HTML( |
|
""" |
|
<div style="text-align: center; max-width: 700px; margin: 0 auto;"> |
|
<div |
|
style=" |
|
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem; |
|
" |
|
> |
|
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;"> |
|
Whisper Transcription Analysis |
|
</h1> |
|
</div> |
|
</div> |
|
""" |
|
) |
|
gr.Markdown( |
|
"Analyse the transcriptions generated by the Whisper large-v2 and large-32-2 models on the TEDLIUM dev set." |
|
"The transcriptions for both models are shown at the bottom of the demo. The text diff for each is computed " |
|
"relative to the target transcriptions. Insertions are displayed in <span style='background-color:Lightgreen'>green</span>, and " |
|
"deletions in <span style='background-color:#FFCCCB'><s>red</s></span>." |
|
) |
|
slider = gr.Slider( |
|
minimum=1, maximum=len(norm_target), step=1, label="Dataset sample" |
|
) |
|
btn = gr.Button("Analyse") |
|
audio_out = gr.Audio(label="Audio input") |
|
with gr.Column(): |
|
table = gr.Dataframe( |
|
headers=[ |
|
"Model", |
|
"Word Error Rate (WER)", |
|
"Insertion Error Rate (IER)", |
|
"Rel length (ref length / tgt length)", |
|
], |
|
height=1000, |
|
) |
|
with gr.Row(): |
|
gr.Markdown("**large-v2 text diff**") |
|
gr.Markdown("**large-32-2 text diff**") |
|
with gr.Row(): |
|
text_out_v2 = gr.Markdown(label="Text difference") |
|
text_out_32_2 = gr.Markdown(label="Text difference") |
|
|
|
btn.click( |
|
fn=get_side_by_side_visualisation, |
|
inputs=slider, |
|
outputs=[audio_out, table, text_out_v2, text_out_32_2], |
|
) |
|
demo.launch() |
|
|