from transformers import AutoTokenizer, GPT2LMHeadModel from datasets import load_dataset, Dataset, DatasetDict import random import string import torch from torchmetrics.text import WordErrorRate, CharErrorRate wer = WordErrorRate() cer = CharErrorRate() def process(text): # Lower case every letter text = text.lower() # Remove punctuation punctuation_to_remove = string.punctuation.replace("'", "") translation_table = str.maketrans('', '', punctuation_to_remove) text = text.translate(translation_table) # Remove whitespaces from front and behind while text[0] == ' ' or text[-1] == ' ': if text[0] == ' ': text = text[1:] if text[-1] == ' ': text = text[:-1] return text import jiwer from edit_distance import SequenceMatcher def correct_text(text): transforms = jiwer.Compose( [ jiwer.ExpandCommonEnglishContractions(), jiwer.ToLowerCase(), jiwer.RemoveMultipleSpaces(), jiwer.Strip(), jiwer.RemovePunctuation(), jiwer.ReduceToListOfListOfWords(), ] ) return transforms(text) def align_gt_asr(gt, asr): sm = SequenceMatcher(a=gt, b=asr) best_path = [] opcodes = sm.get_opcodes() for tag, i1, i2, j1, j2 in opcodes: if tag == "delete": for i in range(i1, i2): best_path.append([gt[i], ""]) if tag == "replace" or tag == "equal": for i, j in zip(range(i1, i2), range(j1, j2)): best_path.append([gt[i], asr[j]]) if tag == "insert": for j in range(j1, j2): best_path.append(["", asr[j]]) return best_path dtype = torch.float16 dataset_name = "./../libripseech_tokenized" dataset = DatasetDict.load_from_disk(dataset_name) with open("./../prompting/blist/all_rare_words.txt") as fin: rarewords = [process(word.strip()) for word in fin] tokenizer = AutoTokenizer.from_pretrained("./../tokenizer") tokenizer.pad_token_id = 0 tokenizer.pad_token = "<|padding|>" tokenizer.padding_side = "left" # Adding new tokens for introducing prompts tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"]) sot_token = tokenizer.encode("<|startoftranscript|>")[0] eot_token = tokenizer.encode("<|endoftranscript|>")[0] from math import ceil from tqdm import tqdm val_bs = 32 n_bwords = 25 context_length = 2048 def prepare(element): # Add audio audio_tkns = element["audio_tokens"] data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns]) # sample context words and mix with the biasing list b_words = element["b_words"] if n_bwords > len(b_words): context = b_words + random.sample(rarewords, n_bwords - len(b_words)) else: context = random.sample(b_words, n_bwords) random.shuffle(context) # add the context words data += "<|startofprompt|>" + "<|sepofprompt|>".join(context) + "<|endofprompt|>" # Add text data += "<|startoftranscript|>" return {"data": data, "context": context} @torch.no_grad() def evaluate_model(model): transcripts = [] processed_data = dataset["test.clean"].map(prepare) data = processed_data["data"] for idx in tqdm(range(ceil(len(data)/val_bs))): outputs = tokenizer(data[idx * val_bs: (idx + 1) * val_bs], truncation=False, max_length=None, padding=True, return_tensors="pt").to(model.device) input_ids = outputs["input_ids"] par = input_ids.shape[-1] generations = model.generate( input_ids, max_new_tokens=context_length - par - 1, eos_token_id = eot_token ) transcripts += tokenizer.batch_decode(generations[:, par:], skip_special_tokens=True) bias_word_cnt = 0 normal_word_cnt = 0 u_wer = 0.0 b_wer = 0.0 pred_list = correct_text(transcripts) text_list = correct_text(processed_data["text"]) prompt_list = processed_data["context"] for a, b, c in zip(pred_list, text_list, prompt_list): aligned_pair = align_gt_asr(b, a) for gt_word, asr_word in aligned_pair: if gt_word in c or asr_word in c: if gt_word != asr_word: b_wer += 1.0 if gt_word in c: bias_word_cnt += 1 else: if gt_word != asr_word: u_wer += 1.0 if gt_word != "": normal_word_cnt += 1 u_wer = u_wer / normal_word_cnt * 100 b_wer = b_wer / bias_word_cnt * 100 return wer(transcripts, processed_data["text"]).item() * 100, cer(transcripts, processed_data["text"]).item() * 100, b_wer, u_wer