File size: 5,267 Bytes
9d85ee2 8a02493 9d85ee2 8a02493 e676bd8 8a02493 e676bd8 8a02493 e676bd8 8a02493 3155f54 74e4942 8a02493 3155f54 8a02493 3155f54 8a02493 3155f54 8a02493 e676bd8 3155f54 e676bd8 3155f54 e676bd8 3155f54 e676bd8 8a02493 ea82efc 3155f54 e676bd8 3155f54 e676bd8 3155f54 8a02493 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import numpy as np
import unicodedata
import diff_match_patch as dmp_module
from enum import Enum
import gradio as gr
from datasets import load_dataset
import pandas as pd
from jiwer import process_words, wer_default
class Action(Enum):
INSERTION = 1
DELETION = -1
EQUAL = 0
def compare_string(text1: str, text2: str) -> list:
text1_normalized = unicodedata.normalize("NFKC", text1)
text2_normalized = unicodedata.normalize("NFKC", text2)
dmp = dmp_module.diff_match_patch()
diff = dmp.diff_main(text1_normalized, text2_normalized)
dmp.diff_cleanupSemantic(diff)
return diff
def style_text(diff):
fullText = ""
for action, text in diff:
if action == Action.INSERTION.value:
fullText += f"<span style='background-color:Lightgreen'>{text}</span>"
elif action == Action.DELETION.value:
fullText += f"<span style='background-color:#FFCCCB'><s>{text}</s></span>"
elif action == Action.EQUAL.value:
fullText += f"{text}"
else:
raise Exception("Not Implemented")
fullText = fullText.replace("](", "]\(").replace("~", "\~")
return fullText
dataset = load_dataset(
"distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count()
)
csv_v2 = pd.read_csv("assets/large-v2.csv")
norm_target = csv_v2["Norm Target"]
norm_pred_v2 = csv_v2["Norm Pred"]
norm_target = [norm_target[i] for i in range(len(norm_target))]
norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))]
csv_v2 = pd.read_csv("assets/large-32-2.csv")
norm_pred_32_2 = csv_v2["Norm Pred"]
norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))]
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
def get_visualisation(idx, model="large-v2", round_dp=2):
idx -= 1
audio = dataset[idx]["audio"]
array = (audio["array"] * max_range).astype(np.int16)
sampling_rate = audio["sampling_rate"]
text1 = norm_target[idx]
if model == "large-v2":
text2 = norm_pred_v2[idx]
elif model == "large-32-2":
text2 = norm_pred_32_2[idx]
else:
raise ValueError(f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`.")
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, round_dp)
ier_percentage = round(
100 * wer_output.insertions / len(wer_output.references[0]), round_dp
)
rel_length = round(len(text2.split()) / len(text1.split()), round_dp)
diff = compare_string(text1, text2)
full_text = style_text(diff)
return (sampling_rate, array), wer_percentage, ier_percentage, rel_length, full_text
def get_side_by_side_visualisation(idx):
large_v2 = get_visualisation(idx, model="large-v2")
large_32_2 = get_visualisation(idx, model="large-32-2")
# format the rows
table = [large_v2[1:-1], large_32_2[1:-1]]
# format the model names
table[0] = ["large-v2", *table[0]]
table[1] = ["large-32-2", *table[1]]
return large_v2[0], table, large_v2[-1], large_32_2[-1]
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
Whisper Transcription Analysis
</h1>
</div>
</div>
"""
)
gr.Markdown(
"Analyse the transcriptions generated by the Whisper large-v2 and large-32-2 models on the TEDLIUM dev set."
"The transcriptions for both models are shown at the bottom of the demo. The text diff for each is computed "
"relative to the target transcriptions. Insertions are displayed in <span style='background-color:Lightgreen'>green</span>, and "
"deletions in <span style='background-color:#FFCCCB'><s>red</s></span>."
)
slider = gr.Slider(
minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
)
btn = gr.Button("Analyse")
audio_out = gr.Audio(label="Audio input")
with gr.Column():
table = gr.Dataframe(
headers=[
"Model",
"Word Error Rate (WER)",
"Insertion Error Rate (IER)",
"Rel length (ref length / tgt length)",
],
height=1000,
)
with gr.Row():
gr.Markdown("**large-v2 text diff**")
gr.Markdown("**large-32-2 text diff**")
with gr.Row():
text_out_v2 = gr.Markdown(label="Text difference")
text_out_32_2 = gr.Markdown(label="Text difference")
btn.click(
fn=get_side_by_side_visualisation,
inputs=slider,
outputs=[audio_out, table, text_out_v2, text_out_32_2],
)
demo.launch()
|