|
from transformers import AutoTokenizer, GPT2LMHeadModel |
|
from datasets import load_dataset, Dataset, DatasetDict |
|
import random |
|
import string |
|
import torch |
|
|
|
from torchmetrics.text import WordErrorRate, CharErrorRate |
|
|
|
wer = WordErrorRate() |
|
cer = CharErrorRate() |
|
|
|
def process(text): |
|
|
|
|
|
text = text.lower() |
|
|
|
|
|
punctuation_to_remove = string.punctuation.replace("'", "") |
|
translation_table = str.maketrans('', '', punctuation_to_remove) |
|
text = text.translate(translation_table) |
|
|
|
|
|
while text[0] == ' ' or text[-1] == ' ': |
|
if text[0] == ' ': |
|
text = text[1:] |
|
if text[-1] == ' ': |
|
text = text[:-1] |
|
|
|
return text |
|
|
|
import jiwer |
|
from edit_distance import SequenceMatcher |
|
def correct_text(text): |
|
transforms = jiwer.Compose( |
|
[ |
|
jiwer.ExpandCommonEnglishContractions(), |
|
jiwer.ToLowerCase(), |
|
jiwer.RemoveMultipleSpaces(), |
|
jiwer.Strip(), |
|
jiwer.RemovePunctuation(), |
|
jiwer.ReduceToListOfListOfWords(), |
|
] |
|
) |
|
return transforms(text) |
|
|
|
def align_gt_asr(gt, asr): |
|
sm = SequenceMatcher(a=gt, b=asr) |
|
best_path = [] |
|
opcodes = sm.get_opcodes() |
|
for tag, i1, i2, j1, j2 in opcodes: |
|
if tag == "delete": |
|
for i in range(i1, i2): |
|
best_path.append([gt[i], ""]) |
|
if tag == "replace" or tag == "equal": |
|
for i, j in zip(range(i1, i2), range(j1, j2)): |
|
best_path.append([gt[i], asr[j]]) |
|
if tag == "insert": |
|
for j in range(j1, j2): |
|
best_path.append(["", asr[j]]) |
|
return best_path |
|
|
|
dtype = torch.float16 |
|
|
|
dataset_name = "./../libripseech_tokenized" |
|
dataset = DatasetDict.load_from_disk(dataset_name) |
|
|
|
with open("./../prompting/blist/all_rare_words.txt") as fin: |
|
rarewords = [process(word.strip()) for word in fin] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./../tokenizer") |
|
tokenizer.pad_token_id = 0 |
|
tokenizer.pad_token = "<|padding|>" |
|
tokenizer.padding_side = "left" |
|
|
|
|
|
tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"]) |
|
sot_token = tokenizer.encode("<|startoftranscript|>")[0] |
|
eot_token = tokenizer.encode("<|endoftranscript|>")[0] |
|
|
|
from math import ceil |
|
from tqdm import tqdm |
|
|
|
val_bs = 32 |
|
n_bwords = 25 |
|
context_length = 2048 |
|
|
|
def prepare(element): |
|
|
|
|
|
audio_tkns = element["audio_tokens"] |
|
data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns]) |
|
|
|
|
|
b_words = element["b_words"] |
|
if n_bwords > len(b_words): |
|
context = b_words + random.sample(rarewords, n_bwords - len(b_words)) |
|
else: |
|
context = random.sample(b_words, n_bwords) |
|
random.shuffle(context) |
|
|
|
|
|
data += "<|startofprompt|>" + "<|sepofprompt|>".join(context) + "<|endofprompt|>" |
|
|
|
|
|
data += "<|startoftranscript|>" |
|
|
|
return {"data": data, "context": context} |
|
|
|
@torch.no_grad() |
|
def evaluate_model(model): |
|
|
|
transcripts = [] |
|
|
|
processed_data = dataset["test.clean"].map(prepare) |
|
data = processed_data["data"] |
|
|
|
for idx in tqdm(range(ceil(len(data)/val_bs))): |
|
|
|
outputs = tokenizer(data[idx * val_bs: (idx + 1) * val_bs], truncation=False, max_length=None, padding=True, return_tensors="pt").to(model.device) |
|
input_ids = outputs["input_ids"] |
|
par = input_ids.shape[-1] |
|
|
|
generations = model.generate( |
|
input_ids, |
|
max_new_tokens=context_length - par - 1, |
|
eos_token_id = eot_token |
|
) |
|
transcripts += tokenizer.batch_decode(generations[:, par:], skip_special_tokens=True) |
|
|
|
bias_word_cnt = 0 |
|
normal_word_cnt = 0 |
|
u_wer = 0.0 |
|
b_wer = 0.0 |
|
pred_list = correct_text(transcripts) |
|
text_list = correct_text(processed_data["text"]) |
|
prompt_list = processed_data["context"] |
|
for a, b, c in zip(pred_list, text_list, prompt_list): |
|
aligned_pair = align_gt_asr(b, a) |
|
for gt_word, asr_word in aligned_pair: |
|
if gt_word in c or asr_word in c: |
|
if gt_word != asr_word: |
|
b_wer += 1.0 |
|
if gt_word in c: |
|
bias_word_cnt += 1 |
|
else: |
|
if gt_word != asr_word: |
|
u_wer += 1.0 |
|
if gt_word != "": |
|
normal_word_cnt += 1 |
|
u_wer = u_wer / normal_word_cnt * 100 |
|
b_wer = b_wer / bias_word_cnt * 100 |
|
|
|
return wer(transcripts, processed_data["text"]).item() * 100, cer(transcripts, processed_data["text"]).item() * 100, b_wer, u_wer |