push / train.py
darshanmakwana's picture
Upload 2 files
479fcf6 verified
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,5,7"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoConfig, GPT2LMHeadModel, AutoModel, AutoModelForCausalLM
from transformers import Trainer, TrainingArguments
from datasets import Dataset, DatasetDict, concatenate_datasets, Sequence, Value
from torch.nn import functional as F
from tqdm import tqdm
import time
import torch
import wandb
import random
import string
from eval_model import evaluate_model
def process(text):
# Lower case every letter
text = text.lower()
# Remove punctuation
punctuation_to_remove = string.punctuation.replace("'", "")
translation_table = str.maketrans('', '', punctuation_to_remove)
text = text.translate(translation_table)
# Remove whitespaces from front and behind
while text[0] == ' ' or text[-1] == ' ':
if text[0] == ' ':
text = text[1:]
if text[-1] == ' ':
text = text[:-1]
return text
dataset_name = "entity_tokenized"
tokenizer_path = "./../tokenizer"
max_length = 2048
# n_layer = 16
# n_head = 16
# n_emb = 1024
n_bwords = 25
dataset = Dataset.load_from_disk(dataset_name)
dataset = dataset.remove_columns(["audio_tokens", "raw_text", "transcript", "entities", "prompt"])
feat = dataset.features.copy()
feat["input_ids"] = Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None)
feat["attention_mask"] = Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)
dataset = dataset.cast(feat)
dataset = dataset.train_test_split(test_size=0.025)
asr_dataset = DatasetDict.load_from_disk("/root/.cache/huggingface/hub/models--darshanmakwana--storage/snapshots/b6e4caa73046e02ad19b48b39c097ba7b9980210/ASR/tokenized_librispeech/").remove_columns(["token_type_ids"])
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token_id = 0
tokenizer.pad_token = "<|padding|>"
tokenizer.padding_side = "right"
# new tokens for prompting
num_new_tokens = tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"])
# new tokens for entities
tokenizer.add_tokens(["<|entity:PER|>", "<|entity:LOC|>", "<|entity:ORG|>", "<|entity|>", "<|detectentities|>"])
# new tokens for images
# tokenizer.add_tokens(["<|startofimage|>", "<|endofimage|>"])
# tokenizer.add_tokens([ f"<|image:{tkn}|>" for tkn in range(16000)])
with open("./../prompting/blist/all_rare_words.txt") as fin:
rarewords = [process(word.strip()) for word in fin]
def tokenize(element):
# Add audio
audio_tkns = element["audio_tokens"]
data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns]) + "<|startofprompt|>"
# sample context words and mix with the biasing list
b_words = element["b_words"]
if n_bwords > len(b_words):
context = b_words + random.sample(rarewords, n_bwords - len(b_words))
else:
context = random.sample(b_words, n_bwords)
random.shuffle(context)
# add the context words
data += "<|sepofprompt|>".join(context)
# Add text
data += "<|endofprompt|><|startoftranscript|>" + element["text"] + "<|endoftranscript|>"
outputs = tokenizer(data, truncation=True, max_length=max_length, padding="max_length")
return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
p_dataset = DatasetDict.load_from_disk("./../libripseech_tokenized")
prompt_dataset = p_dataset.map(
tokenize, batched=False, remove_columns = p_dataset["train.clean.100"].column_names
)
print("Total Vocab Size:", len(tokenizer))
model = GPT2LMHeadModel.from_pretrained("./../models/checkpoint-prompting")
model.resize_token_embeddings(len(tokenizer))
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
config = {
"output_dir": "./out",
"max_steps": 20000,
"per_device_train_batch_size": 5,
"per_device_eval_batch_size": 5,
"gradient_accumulation_steps": 1,
"eval_strategy": "steps",
"save_strategy": "steps",
"eval_steps": 500,
"logging_steps": 1,
"logging_first_step": True,
"save_total_limit": 5,
"load_best_model_at_end": True,
"save_steps": 1000,
"lr_scheduler_type": "cosine",
"learning_rate": 1e-4,
"warmup_steps": 10,
"weight_decay": 0.01,
"report_to": "wandb",
"fp16": True
}
from argparse import Namespace
args = Namespace(**config)
train_args = TrainingArguments(**config)
wandb.init(project="multi_modal_exps", name="entity")
class GPTTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
labels = labels[:, 1:]
logits = logits[:, :-1, :]
print(logits.shape, labels.shape, torch.max(logits).item(), torch.max(labels).item(), torch.min(logits).item(), torch.min(labels).item())
loss = F.cross_entropy(torch.reshape(logits, (-1, logits.size(-1))), torch.reshape(labels, (-1, )), ignore_index=-100)
return (loss, outputs) if return_outputs else loss
@torch.no_grad()
def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"):
eval_output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
wer, cer, b_wer, u_wer = evaluate_model(model)
wandb.log({
"Word Error Rate": wer,
"Char Error Rate": cer,
"Biased Word Error Rate": b_wer,
"Unbiased Word Error Rate": u_wer
})
return eval_output
trainer = GPTTrainer(
model = model,
tokenizer = tokenizer,
args = train_args,
data_collator = data_collator,
train_dataset = concatenate_datasets([dataset["train"], asr_dataset["train.clean.100"], prompt_dataset["train.clean.100"]]),
eval_dataset = concatenate_datasets([dataset["test"], asr_dataset["validation.clean"], prompt_dataset["validation.clean"]]),
)
trainer.train()