push / eval_model.py
darshanmakwana's picture
Upload eval_model.py
2675a94 verified
raw
history blame
4.76 kB
from transformers import AutoTokenizer, GPT2LMHeadModel
from datasets import load_dataset, Dataset, DatasetDict
import random
import string
import torch
from torchmetrics.text import WordErrorRate, CharErrorRate
wer = WordErrorRate()
cer = CharErrorRate()
def process(text):
# Lower case every letter
text = text.lower()
# Remove punctuation
punctuation_to_remove = string.punctuation.replace("'", "")
translation_table = str.maketrans('', '', punctuation_to_remove)
text = text.translate(translation_table)
# Remove whitespaces from front and behind
while text[0] == ' ' or text[-1] == ' ':
if text[0] == ' ':
text = text[1:]
if text[-1] == ' ':
text = text[:-1]
return text
import jiwer
from edit_distance import SequenceMatcher
def correct_text(text):
transforms = jiwer.Compose(
[
jiwer.ExpandCommonEnglishContractions(),
jiwer.ToLowerCase(),
jiwer.RemoveMultipleSpaces(),
jiwer.Strip(),
jiwer.RemovePunctuation(),
jiwer.ReduceToListOfListOfWords(),
]
)
return transforms(text)
def align_gt_asr(gt, asr):
sm = SequenceMatcher(a=gt, b=asr)
best_path = []
opcodes = sm.get_opcodes()
for tag, i1, i2, j1, j2 in opcodes:
if tag == "delete":
for i in range(i1, i2):
best_path.append([gt[i], ""])
if tag == "replace" or tag == "equal":
for i, j in zip(range(i1, i2), range(j1, j2)):
best_path.append([gt[i], asr[j]])
if tag == "insert":
for j in range(j1, j2):
best_path.append(["", asr[j]])
return best_path
dtype = torch.float16
dataset_name = "./../libripseech_tokenized"
dataset = DatasetDict.load_from_disk(dataset_name)
with open("./../prompting/blist/all_rare_words.txt") as fin:
rarewords = [process(word.strip()) for word in fin]
tokenizer = AutoTokenizer.from_pretrained("./../tokenizer")
tokenizer.pad_token_id = 0
tokenizer.pad_token = "<|padding|>"
tokenizer.padding_side = "left"
# Adding new tokens for introducing prompts
tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"])
sot_token = tokenizer.encode("<|startoftranscript|>")[0]
eot_token = tokenizer.encode("<|endoftranscript|>")[0]
from math import ceil
from tqdm import tqdm
val_bs = 32
n_bwords = 25
context_length = 2048
def prepare(element):
# Add audio
audio_tkns = element["audio_tokens"]
data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns])
# sample context words and mix with the biasing list
b_words = element["b_words"]
if n_bwords > len(b_words):
context = b_words + random.sample(rarewords, n_bwords - len(b_words))
else:
context = random.sample(b_words, n_bwords)
random.shuffle(context)
# add the context words
data += "<|startofprompt|>" + "<|sepofprompt|>".join(context) + "<|endofprompt|>"
# Add text
data += "<|startoftranscript|>"
return {"data": data, "context": context}
@torch.no_grad()
def evaluate_model(model):
transcripts = []
processed_data = dataset["test.clean"].map(prepare)
data = processed_data["data"]
for idx in tqdm(range(ceil(len(data)/val_bs))):
outputs = tokenizer(data[idx * val_bs: (idx + 1) * val_bs], truncation=False, max_length=None, padding=True, return_tensors="pt").to(model.device)
input_ids = outputs["input_ids"]
par = input_ids.shape[-1]
generations = model.generate(
input_ids,
max_new_tokens=context_length - par - 1,
eos_token_id = eot_token
)
transcripts += tokenizer.batch_decode(generations[:, par:], skip_special_tokens=True)
bias_word_cnt = 0
normal_word_cnt = 0
u_wer = 0.0
b_wer = 0.0
pred_list = correct_text(transcripts)
text_list = correct_text(processed_data["text"])
prompt_list = processed_data["context"]
for a, b, c in zip(pred_list, text_list, prompt_list):
aligned_pair = align_gt_asr(b, a)
for gt_word, asr_word in aligned_pair:
if gt_word in c or asr_word in c:
if gt_word != asr_word:
b_wer += 1.0
if gt_word in c:
bias_word_cnt += 1
else:
if gt_word != asr_word:
u_wer += 1.0
if gt_word != "":
normal_word_cnt += 1
u_wer = u_wer / normal_word_cnt * 100
b_wer = b_wer / bias_word_cnt * 100
return wer(transcripts, processed_data["text"]).item() * 100, cer(transcripts, processed_data["text"]).item() * 100, b_wer, u_wer