import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,5,7" # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoConfig, GPT2LMHeadModel, AutoModel, AutoModelForCausalLM from transformers import Trainer, TrainingArguments from datasets import Dataset, DatasetDict, concatenate_datasets, Sequence, Value from torch.nn import functional as F from tqdm import tqdm import time import torch import wandb import random import string from eval_model import evaluate_model def process(text): # Lower case every letter text = text.lower() # Remove punctuation punctuation_to_remove = string.punctuation.replace("'", "") translation_table = str.maketrans('', '', punctuation_to_remove) text = text.translate(translation_table) # Remove whitespaces from front and behind while text[0] == ' ' or text[-1] == ' ': if text[0] == ' ': text = text[1:] if text[-1] == ' ': text = text[:-1] return text dataset_name = "entity_tokenized" tokenizer_path = "./../tokenizer" max_length = 2048 # n_layer = 16 # n_head = 16 # n_emb = 1024 n_bwords = 25 dataset = Dataset.load_from_disk(dataset_name) dataset = dataset.remove_columns(["audio_tokens", "raw_text", "transcript", "entities", "prompt"]) feat = dataset.features.copy() feat["input_ids"] = Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None) feat["attention_mask"] = Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None) dataset = dataset.cast(feat) dataset = dataset.train_test_split(test_size=0.025) asr_dataset = DatasetDict.load_from_disk("/root/.cache/huggingface/hub/models--darshanmakwana--storage/snapshots/b6e4caa73046e02ad19b48b39c097ba7b9980210/ASR/tokenized_librispeech/").remove_columns(["token_type_ids"]) tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) tokenizer.pad_token_id = 0 tokenizer.pad_token = "<|padding|>" tokenizer.padding_side = "right" # new tokens for prompting num_new_tokens = tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"]) # new tokens for entities tokenizer.add_tokens(["<|entity:PER|>", "<|entity:LOC|>", "<|entity:ORG|>", "<|entity|>", "<|detectentities|>"]) # new tokens for images # tokenizer.add_tokens(["<|startofimage|>", "<|endofimage|>"]) # tokenizer.add_tokens([ f"<|image:{tkn}|>" for tkn in range(16000)]) with open("./../prompting/blist/all_rare_words.txt") as fin: rarewords = [process(word.strip()) for word in fin] def tokenize(element): # Add audio audio_tkns = element["audio_tokens"] data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns]) + "<|startofprompt|>" # sample context words and mix with the biasing list b_words = element["b_words"] if n_bwords > len(b_words): context = b_words + random.sample(rarewords, n_bwords - len(b_words)) else: context = random.sample(b_words, n_bwords) random.shuffle(context) # add the context words data += "<|sepofprompt|>".join(context) # Add text data += "<|endofprompt|><|startoftranscript|>" + element["text"] + "<|endoftranscript|>" outputs = tokenizer(data, truncation=True, max_length=max_length, padding="max_length") return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} p_dataset = DatasetDict.load_from_disk("./../libripseech_tokenized") prompt_dataset = p_dataset.map( tokenize, batched=False, remove_columns = p_dataset["train.clean.100"].column_names ) print("Total Vocab Size:", len(tokenizer)) model = GPT2LMHeadModel.from_pretrained("./../models/checkpoint-prompting") model.resize_token_embeddings(len(tokenizer)) from transformers import DataCollatorForLanguageModeling data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False) config = { "output_dir": "./out", "max_steps": 20000, "per_device_train_batch_size": 5, "per_device_eval_batch_size": 5, "gradient_accumulation_steps": 1, "eval_strategy": "steps", "save_strategy": "steps", "eval_steps": 500, "logging_steps": 1, "logging_first_step": True, "save_total_limit": 5, "load_best_model_at_end": True, "save_steps": 1000, "lr_scheduler_type": "cosine", "learning_rate": 1e-4, "warmup_steps": 10, "weight_decay": 0.01, "report_to": "wandb", "fp16": True } from argparse import Namespace args = Namespace(**config) train_args = TrainingArguments(**config) wandb.init(project="multi_modal_exps", name="entity") class GPTTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.get("labels") outputs = model(**inputs) logits = outputs.get("logits") labels = labels[:, 1:] logits = logits[:, :-1, :] print(logits.shape, labels.shape, torch.max(logits).item(), torch.max(labels).item(), torch.min(logits).item(), torch.min(labels).item()) loss = F.cross_entropy(torch.reshape(logits, (-1, logits.size(-1))), torch.reshape(labels, (-1, )), ignore_index=-100) return (loss, outputs) if return_outputs else loss @torch.no_grad() def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"): eval_output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix) wer, cer, b_wer, u_wer = evaluate_model(model) wandb.log({ "Word Error Rate": wer, "Char Error Rate": cer, "Biased Word Error Rate": b_wer, "Unbiased Word Error Rate": u_wer }) return eval_output trainer = GPTTrainer( model = model, tokenizer = tokenizer, args = train_args, data_collator = data_collator, train_dataset = concatenate_datasets([dataset["train"], asr_dataset["train.clean.100"], prompt_dataset["train.clean.100"]]), eval_dataset = concatenate_datasets([dataset["test"], asr_dataset["validation.clean"], prompt_dataset["validation.clean"]]), ) trainer.train()