|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
import os |
|
from typing import List, Union |
|
from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel |
|
from datasets import load_dataset, Dataset, DatasetDict |
|
import shap |
|
import wandb |
|
import evaluate |
|
import logging |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
device: str = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
SEED: int = 42 |
|
|
|
BATCH_SIZE: int = 16 |
|
EPOCHS: int = 3 |
|
SUBSAMPLING: float = 0.1 |
|
|
|
|
|
os.environ["WANDB_PROJECT"] = "DAEDRA multiclass model training" |
|
os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints |
|
os.environ["WANDB_NOTEBOOK_NAME"] = "DAEDRA.ipynb" |
|
|
|
dataset = load_dataset("chrisvoncsefalvay/vaers-outcomes") |
|
|
|
if SUBSAMPLING < 1: |
|
_ = DatasetDict() |
|
for each in dataset.keys(): |
|
_[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING))) |
|
|
|
dataset = _ |
|
|
|
accuracy = evaluate.load("accuracy") |
|
precision, recall = evaluate.load("precision"), evaluate.load("recall") |
|
f1 = evaluate.load("f1") |
|
|
|
def compute_metrics(eval_pred): |
|
predictions, labels = eval_pred |
|
predictions = np.argmax(predictions, axis=1) |
|
return { |
|
'accuracy': accuracy.compute(predictions=predictions, references=labels)["accuracy"], |
|
'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')["precision"], |
|
'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')["precision"], |
|
'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')["recall"], |
|
'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')["recall"], |
|
'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')["f1"] |
|
} |
|
|
|
label_map = {i: label for i, label in enumerate(dataset["test"].features["label"].names)} |
|
|
|
def train_from_model(model_ckpt: str, push: bool = False): |
|
print(f"Initialising training based on {model_ckpt}...") |
|
|
|
print("Tokenising...") |
|
tokenizer = AutoTokenizer.from_pretrained(model_ckpt) |
|
|
|
cols = dataset["train"].column_names |
|
cols.remove("label") |
|
ds_enc = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=512), batched=True, remove_columns=cols) |
|
|
|
print("Loading model...") |
|
try: |
|
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, |
|
num_labels=len(dataset["test"].features["label"].names), |
|
id2label=label_map, |
|
label2id={v:k for k,v in label_map.items()}) |
|
except OSError: |
|
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, |
|
num_labels=len(dataset["test"].features["label"].names), |
|
id2label=label_map, |
|
label2id={v:k for k,v in label_map.items()}, |
|
from_tf=True) |
|
|
|
|
|
args = TrainingArguments( |
|
output_dir="vaers", |
|
evaluation_strategy="steps", |
|
eval_steps=100, |
|
save_strategy="epoch", |
|
learning_rate=2e-5, |
|
per_device_train_batch_size=BATCH_SIZE, |
|
per_device_eval_batch_size=BATCH_SIZE, |
|
num_train_epochs=EPOCHS, |
|
weight_decay=.01, |
|
logging_steps=1, |
|
run_name=f"daedra-minisample-comparison-{SUBSAMPLING}", |
|
report_to=["wandb"]) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=args, |
|
train_dataset=ds_enc["train"], |
|
eval_dataset=ds_enc["test"], |
|
tokenizer=tokenizer, |
|
compute_metrics=compute_metrics) |
|
|
|
if SUBSAMPLING != 1.0: |
|
wandb_tag: List[str] = [f"subsample-{SUBSAMPLING}"] |
|
else: |
|
wandb_tag: List[str] = [f"full_sample"] |
|
|
|
wandb_tag.append(f"batch_size-{BATCH_SIZE}") |
|
wandb_tag.append(f"base:{model_ckpt}") |
|
|
|
if "/" in model_ckpt: |
|
sanitised_model_name = model_ckpt.split("/")[1] |
|
else: |
|
sanitised_model_name = model_ckpt |
|
|
|
wandb.init(name=f"daedra_{SUBSAMPLING}-{sanitised_model_name}", tags=wandb_tag, magic=True) |
|
|
|
print("Starting training...") |
|
|
|
trainer.train() |
|
|
|
print("Training finished.") |
|
|
|
wandb.finish() |
|
|
|
if __name__ == "__main__": |
|
wandb.finish() |
|
|
|
for mname in ( |
|
#"dmis-lab/biobert-base-cased-v1.2", |
|
"emilyalsentzer/Bio_ClinicalBERT", |
|
"bert-base-uncased", |
|
"distilbert-base-uncased" |
|
): |
|
print(f"Now training on subsample with {mname}...") |
|
train_from_model(mname) |