import pandas as pd import numpy as np import torch import os from typing import List, Union from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel from datasets import load_dataset, Dataset, DatasetDict import shap import wandb import evaluate import logging os.environ["TOKENIZERS_PARALLELISM"] = "false" device: str = 'cuda' if torch.cuda.is_available() else 'cpu' SEED: int = 42 BATCH_SIZE: int = 16 EPOCHS: int = 3 SUBSAMPLING: float = 0.1 # WandB configuration os.environ["WANDB_PROJECT"] = "DAEDRA multiclass model training" os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints os.environ["WANDB_NOTEBOOK_NAME"] = "DAEDRA.ipynb" dataset = load_dataset("chrisvoncsefalvay/vaers-outcomes") if SUBSAMPLING < 1: _ = DatasetDict() for each in dataset.keys(): _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING))) dataset = _ accuracy = evaluate.load("accuracy") precision, recall = evaluate.load("precision"), evaluate.load("recall") f1 = evaluate.load("f1") def compute_metrics(eval_pred): predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) return { 'accuracy': accuracy.compute(predictions=predictions, references=labels)["accuracy"], 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')["precision"], 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')["precision"], 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')["recall"], 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')["recall"], 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')["f1"] } label_map = {i: label for i, label in enumerate(dataset["test"].features["label"].names)} def train_from_model(model_ckpt: str, push: bool = False): print(f"Initialising training based on {model_ckpt}...") print("Tokenising...") tokenizer = AutoTokenizer.from_pretrained(model_ckpt) cols = dataset["train"].column_names cols.remove("label") ds_enc = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=512), batched=True, remove_columns=cols) print("Loading model...") try: model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(dataset["test"].features["label"].names), id2label=label_map, label2id={v:k for k,v in label_map.items()}) except OSError: model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(dataset["test"].features["label"].names), id2label=label_map, label2id={v:k for k,v in label_map.items()}, from_tf=True) args = TrainingArguments( output_dir="vaers", evaluation_strategy="steps", eval_steps=100, save_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, num_train_epochs=EPOCHS, weight_decay=.01, logging_steps=1, run_name=f"daedra-minisample-comparison-{SUBSAMPLING}", report_to=["wandb"]) trainer = Trainer( model=model, args=args, train_dataset=ds_enc["train"], eval_dataset=ds_enc["test"], tokenizer=tokenizer, compute_metrics=compute_metrics) if SUBSAMPLING != 1.0: wandb_tag: List[str] = [f"subsample-{SUBSAMPLING}"] else: wandb_tag: List[str] = [f"full_sample"] wandb_tag.append(f"batch_size-{BATCH_SIZE}") wandb_tag.append(f"base:{model_ckpt}") if "/" in model_ckpt: sanitised_model_name = model_ckpt.split("/")[1] else: sanitised_model_name = model_ckpt wandb.init(name=f"daedra_{SUBSAMPLING}-{sanitised_model_name}", tags=wandb_tag, magic=True) print("Starting training...") trainer.train() print("Training finished.") wandb.finish() if __name__ == "__main__": wandb.finish() for mname in ( #"dmis-lab/biobert-base-cased-v1.2", "emilyalsentzer/Bio_ClinicalBERT", "bert-base-uncased", "distilbert-base-uncased" ): print(f"Now training on subsample with {mname}...") train_from_model(mname)