File size: 8,614 Bytes
9d85ee2 8a02493 eed20cf 8a02493 9d85ee2 8a02493 e676bd8 8a02493 e676bd8 8a02493 e676bd8 8a02493 3e3e17d 941081c 3e3e17d eed20cf 74e4942 8a02493 3155f54 65f6dc4 3155f54 65f6dc4 8a02493 3155f54 8a02493 eed20cf 8a02493 65f6dc4 e676bd8 3155f54 e676bd8 3155f54 e676bd8 3155f54 e67d080 e676bd8 8a02493 ea82efc 3155f54 d515eda 4447566 58bf923 4447566 58bf923 d515eda 58bf923 d515eda 3155f54 3e3e17d 3155f54 eed20cf 3155f54 3843f4e e676bd8 e67d080 e676bd8 3155f54 8a02493 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
import os
import numpy as np
import unicodedata
import diff_match_patch as dmp_module
from enum import Enum
import gradio as gr
from datasets import load_dataset
import pandas as pd
from jiwer import process_words, wer_default
from nltk import ngrams
class Action(Enum):
INSERTION = 1
DELETION = -1
EQUAL = 0
def compare_string(text1: str, text2: str) -> list:
text1_normalized = unicodedata.normalize("NFKC", text1)
text2_normalized = unicodedata.normalize("NFKC", text2)
dmp = dmp_module.diff_match_patch()
diff = dmp.diff_main(text1_normalized, text2_normalized)
dmp.diff_cleanupSemantic(diff)
return diff
def style_text(diff):
fullText = ""
for action, text in diff:
if action == Action.INSERTION.value:
fullText += f"<span style='background-color:Lightgreen'>{text}</span>"
elif action == Action.DELETION.value:
fullText += f"<span style='background-color:#FFCCCB'><s>{text}</s></span>"
elif action == Action.EQUAL.value:
fullText += f"{text}"
else:
raise Exception("Not Implemented")
fullText = fullText.replace("](", "]\(").replace("~", "\~")
return fullText
dataset = load_dataset(
"distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count()
)
csv_v2 = pd.read_csv("assets/large-v2.csv")
norm_target = csv_v2["Norm Target"]
norm_pred_v2 = csv_v2["Norm Pred"]
norm_target = [norm_target[i] for i in range(len(norm_target))]
norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))]
csv_v2 = pd.read_csv("assets/large-32-2.csv")
norm_pred_32_2 = csv_v2["Norm Pred"]
norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))]
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
def get_statistics(model="large-v2", round_dp=2, ngram_degree=5):
text1 = norm_target
if model == "large-v2":
text2 = norm_pred_v2
elif model == "large-32-2":
text2 = norm_pred_32_2
else:
raise ValueError(
f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`."
)
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, round_dp)
ier_percentage = round(
100 * wer_output.insertions / sum([len(ref) for ref in wer_output.references]), round_dp
)
all_ngrams = list(ngrams(" ".join(text2).split(), ngram_degree))
unique_ngrams = []
for ngram in all_ngrams:
if ngram not in unique_ngrams:
unique_ngrams.append(ngram)
repeated_ngrams = len(all_ngrams) - len(unique_ngrams)
return wer_percentage, ier_percentage, repeated_ngrams
def get_overall_table():
large_v2 = get_statistics(model="large-v2")
large_32_2 = get_statistics(model="large-32-2")
# format the rows
table = [large_v2, large_32_2]
# format the model names
table[0] = ["Whisper", *table[0]]
table[1] = ["Distil-Whisper", *table[1]]
return table
def get_visualisation(idx, model="large-v2", round_dp=2, ngram_degree=5):
idx -= 1
audio = dataset[idx]["audio"]
array = (audio["array"] * max_range).astype(np.int16)
sampling_rate = audio["sampling_rate"]
text1 = norm_target[idx]
if model == "large-v2":
text2 = norm_pred_v2[idx]
elif model == "large-32-2":
text2 = norm_pred_32_2[idx]
else:
raise ValueError(
f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`."
)
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, round_dp)
ier_percentage = round(
100 * wer_output.insertions / len(wer_output.references[0]), round_dp
)
all_ngrams = list(ngrams(text2.split(), ngram_degree))
unique_ngrams = []
for ngram in all_ngrams:
if ngram not in unique_ngrams:
unique_ngrams.append(ngram)
repeated_ngrams = len(all_ngrams) - len(unique_ngrams)
diff = compare_string(text1, text2)
full_text = style_text(diff)
return (
(sampling_rate, array),
wer_percentage,
ier_percentage,
repeated_ngrams,
full_text,
)
def get_side_by_side_visualisation(idx):
large_v2 = get_visualisation(idx, model="large-v2")
large_32_2 = get_visualisation(idx, model="large-32-2")
# format the rows
table = [large_v2[1:-1], large_32_2[1:-1]]
# format the model names
table[0] = ["Whisper", *table[0]]
table[1] = ["Distil-Whisper", *table[1]]
return large_v2[0], table, large_v2[-1], large_32_2[-1]
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
Whisper Transcription Analysis
</h1>
</div>
</div>
"""
)
gr.Markdown(
"""
One of the major claims of the <a href="https://arxiv.org/abs/2311.00430"> Distil-Whisper paper</a> is that
that Distil-Whisper hallucinates less than Whisper on long-form audio. To demonstrate this, we'll analyse the
transcriptions generated by <a href="https://huggingface.co/openai/whisper-large-v2"> Whisper</a>
and <a href="https://huggingface.co/distil-whisper/distil-large-v2"> Distil-Whisper</a> on the
<a href="https://huggingface.co/datasets/distil-whisper/tedlium-long-form"> TED-LIUM</a> validation set.
To quantify the amount of repetition and hallucination in the predicted transcriptions, we measure the number
of repeated 5-gram word duplicates (5-Dup.) and the insertion error rate (IER). Analysis is performed on the
overall level, where statistics are computed over the entire dataset, and also a per-sample level (i.e. an
on an individual example basis).
The transcriptions for both models are shown at the bottom of the demo. We compute a text difference for each
relative to the ground truth transcriptions. Insertions are displayed in <span style='background-color:Lightgreen'>green</span>,
and deletions in <span style='background-color:#FFCCCB'><s>red</s></span>. Multiple words in <span style='background-color:Lightgreen'>green</span>
indicates that a model has hallucinated, since it has inserted words not present in the ground truth transcription.
Overall, Distil-Whisper has roughly half the number of 5-Dup. and IER. This indicates that it has a lower
propensity to hallucinate compared to the Whisper model. Try both models with some of the TED-LIUM examples
and view the reduction in hallucinations for yourself!
"""
)
gr.Markdown("**Overall statistics:**")
table = gr.Dataframe(
value=get_overall_table(),
headers=[
"Model",
"Word Error Rate (WER)",
"Insertion Error Rate (IER)",
"Repeated 5-grams",
],
row_count=2,
)
gr.Markdown("**Per-sample statistics:**")
slider = gr.Slider(
minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
)
btn = gr.Button("Analyse")
audio_out = gr.Audio(label="Audio input")
with gr.Column():
table = gr.Dataframe(
headers=[
"Model",
"Word Error Rate (WER)",
"Insertion Error Rate (IER)",
"Repeated 5-grams",
],
row_count=2,
)
with gr.Row():
gr.Markdown("**Whisper text diff**")
gr.Markdown("**Distil-Whisper text diff**")
with gr.Row():
text_out_v2 = gr.Markdown(label="Text difference")
text_out_32_2 = gr.Markdown(label="Text difference")
btn.click(
fn=get_side_by_side_visualisation,
inputs=slider,
outputs=[audio_out, table, text_out_v2, text_out_32_2],
)
demo.launch()
|