|
import os |
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,5,7" |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers import AutoConfig, GPT2LMHeadModel, AutoModel, AutoModelForCausalLM |
|
from transformers import Trainer, TrainingArguments |
|
from datasets import Dataset, DatasetDict, concatenate_datasets, Sequence, Value |
|
from torch.nn import functional as F |
|
from tqdm import tqdm |
|
import time |
|
import torch |
|
import wandb |
|
import random |
|
import string |
|
from eval_model import evaluate_model |
|
|
|
def process(text): |
|
|
|
|
|
text = text.lower() |
|
|
|
|
|
punctuation_to_remove = string.punctuation.replace("'", "") |
|
translation_table = str.maketrans('', '', punctuation_to_remove) |
|
text = text.translate(translation_table) |
|
|
|
|
|
while text[0] == ' ' or text[-1] == ' ': |
|
if text[0] == ' ': |
|
text = text[1:] |
|
if text[-1] == ' ': |
|
text = text[:-1] |
|
|
|
return text |
|
|
|
dataset_name = "entity_tokenized" |
|
tokenizer_path = "./../tokenizer" |
|
max_length = 2048 |
|
|
|
|
|
|
|
n_bwords = 25 |
|
|
|
dataset = Dataset.load_from_disk(dataset_name) |
|
dataset = dataset.remove_columns(["audio_tokens", "raw_text", "transcript", "entities", "prompt"]) |
|
feat = dataset.features.copy() |
|
feat["input_ids"] = Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None) |
|
feat["attention_mask"] = Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None) |
|
dataset = dataset.cast(feat) |
|
dataset = dataset.train_test_split(test_size=0.025) |
|
|
|
asr_dataset = DatasetDict.load_from_disk("/root/.cache/huggingface/hub/models--darshanmakwana--storage/snapshots/b6e4caa73046e02ad19b48b39c097ba7b9980210/ASR/tokenized_librispeech/").remove_columns(["token_type_ids"]) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
tokenizer.pad_token_id = 0 |
|
tokenizer.pad_token = "<|padding|>" |
|
tokenizer.padding_side = "right" |
|
|
|
|
|
num_new_tokens = tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"]) |
|
|
|
tokenizer.add_tokens(["<|entity:PER|>", "<|entity:LOC|>", "<|entity:ORG|>", "<|entity|>", "<|detectentities|>"]) |
|
|
|
|
|
|
|
|
|
with open("./../prompting/blist/all_rare_words.txt") as fin: |
|
rarewords = [process(word.strip()) for word in fin] |
|
|
|
def tokenize(element): |
|
|
|
|
|
audio_tkns = element["audio_tokens"] |
|
data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns]) + "<|startofprompt|>" |
|
|
|
|
|
b_words = element["b_words"] |
|
if n_bwords > len(b_words): |
|
context = b_words + random.sample(rarewords, n_bwords - len(b_words)) |
|
else: |
|
context = random.sample(b_words, n_bwords) |
|
random.shuffle(context) |
|
|
|
|
|
data += "<|sepofprompt|>".join(context) |
|
|
|
|
|
data += "<|endofprompt|><|startoftranscript|>" + element["text"] + "<|endoftranscript|>" |
|
|
|
outputs = tokenizer(data, truncation=True, max_length=max_length, padding="max_length") |
|
return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} |
|
|
|
p_dataset = DatasetDict.load_from_disk("./../libripseech_tokenized") |
|
prompt_dataset = p_dataset.map( |
|
tokenize, batched=False, remove_columns = p_dataset["train.clean.100"].column_names |
|
) |
|
|
|
print("Total Vocab Size:", len(tokenizer)) |
|
|
|
model = GPT2LMHeadModel.from_pretrained("./../models/checkpoint-prompting") |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
from transformers import DataCollatorForLanguageModeling |
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False) |
|
|
|
config = { |
|
"output_dir": "./out", |
|
"max_steps": 20000, |
|
"per_device_train_batch_size": 5, |
|
"per_device_eval_batch_size": 5, |
|
"gradient_accumulation_steps": 1, |
|
"eval_strategy": "steps", |
|
"save_strategy": "steps", |
|
"eval_steps": 500, |
|
"logging_steps": 1, |
|
"logging_first_step": True, |
|
"save_total_limit": 5, |
|
"load_best_model_at_end": True, |
|
"save_steps": 1000, |
|
"lr_scheduler_type": "cosine", |
|
"learning_rate": 1e-4, |
|
"warmup_steps": 10, |
|
"weight_decay": 0.01, |
|
"report_to": "wandb", |
|
"fp16": True |
|
} |
|
|
|
from argparse import Namespace |
|
|
|
args = Namespace(**config) |
|
train_args = TrainingArguments(**config) |
|
|
|
wandb.init(project="multi_modal_exps", name="entity") |
|
|
|
class GPTTrainer(Trainer): |
|
def compute_loss(self, model, inputs, return_outputs=False): |
|
|
|
labels = inputs.get("labels") |
|
outputs = model(**inputs) |
|
logits = outputs.get("logits") |
|
|
|
labels = labels[:, 1:] |
|
logits = logits[:, :-1, :] |
|
|
|
print(logits.shape, labels.shape, torch.max(logits).item(), torch.max(labels).item(), torch.min(logits).item(), torch.min(labels).item()) |
|
|
|
loss = F.cross_entropy(torch.reshape(logits, (-1, logits.size(-1))), torch.reshape(labels, (-1, )), ignore_index=-100) |
|
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|
@torch.no_grad() |
|
def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"): |
|
|
|
eval_output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix) |
|
|
|
wer, cer, b_wer, u_wer = evaluate_model(model) |
|
|
|
wandb.log({ |
|
"Word Error Rate": wer, |
|
"Char Error Rate": cer, |
|
"Biased Word Error Rate": b_wer, |
|
"Unbiased Word Error Rate": u_wer |
|
}) |
|
|
|
return eval_output |
|
|
|
trainer = GPTTrainer( |
|
model = model, |
|
tokenizer = tokenizer, |
|
args = train_args, |
|
data_collator = data_collator, |
|
train_dataset = concatenate_datasets([dataset["train"], asr_dataset["train.clean.100"], prompt_dataset["train.clean.100"]]), |
|
eval_dataset = concatenate_datasets([dataset["test"], asr_dataset["validation.clean"], prompt_dataset["validation.clean"]]), |
|
) |
|
|
|
trainer.train() |