{ "cells": [ { "cell_type": "markdown", "source": [ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n", "\n", "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import os\n", "from typing import List\n", "from datasets import load_dataset\n", "import shap\n", "from sklearn.metrics import f1_score, accuracy_score, classification_report\n", "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n", "\n", "%load_ext watermark" ], "outputs": [ { "output_type": "error", "ename": "ModuleNotFoundError", "evalue": "No module named 'torch'", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m List\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'" ] } ], "execution_count": 2, "metadata": { "datalore": { "node_id": "caZjjFP0OyQNMVgZDiwswE", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "un8W7ez7ZwoGb5Co6nydEV" } }, "gather": { "logged": 1706406690290 } } }, { "cell_type": "code", "source": [ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "SEED: int = 42\n", "\n", "BATCH_SIZE: int = 8\n", "EPOCHS: int = 1\n", "model_ckpt: str = \"distilbert-base-uncased\"\n", "\n", "CLASS_NAMES: List[str] = [\"DIED\",\n", " \"ER_VISIT\",\n", " \"HOSPITAL\",\n", " \"OFC_VISIT\",\n", " \"X_STAY\",\n", " \"DISABLE\",\n", " \"D_PRESENTED\"]\n", "\n", "# WandB configuration\n", "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n", "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints" ], "outputs": [], "execution_count": null, "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "%watermark --iversion" ], "outputs": [], "execution_count": null, "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "!nvidia-smi" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "Sun Jan 28 01:31:42 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 28C P0 37W / 250W | 0MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 29C P0 36W / 250W | 0MiB / 16384MiB | 1% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n" } ], "execution_count": 4, "metadata": { "datalore": { "node_id": "UU2oOJhwbIualogG1YyCMd", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "attachments": {}, "cell_type": "markdown", "source": [ "## Loading the data set" ], "metadata": { "datalore": { "node_id": "t45KHugmcPVaO0nuk8tGJ9", "type": "MD", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "40nN9Hvgi1clHNV5RAemI5" } } } }, { "cell_type": "code", "source": [ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")" ], "outputs": [], "execution_count": null, "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "### Tokenisation and encoding" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "I7n646PIscsUZRoHu6m7zm", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "def tokenize_and_encode(examples):\n", " return tokenizer(examples[\"text\"], truncation=True)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "QBLOSI0yVIslV7v7qX9ZC3", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "cols = dataset[\"train\"].column_names\n", "cols.remove(\"labels\")\n", "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "slHeNysZOX9uWS9PB7jFDb", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "markdown", "source": [ "### Training" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "class MultiLabelTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False):\n", " labels = inputs.pop(\"labels\")\n", " outputs = model(**inputs)\n", " logits = outputs.logits\n", " loss_fct = torch.nn.BCEWithLogitsLoss()\n", " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n", " labels.float().view(-1, self.model.config.num_labels))\n", " return (loss, outputs) if return_outputs else loss" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "itXWkbDw9sqbkMuDP84QoT", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to(\"cuda\")" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "ZQU7aW6TV45VmhHOQRzcnF", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n", " y_pred = torch.from_numpy(y_pred)\n", " y_true = torch.from_numpy(y_true)\n", "\n", " if sigmoid:\n", " y_pred = y_pred.sigmoid()\n", "\n", " return ((y_pred > threshold) == y_true.bool()).float().mean().item()" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "swhgyyyxoGL8HjnXJtMuSW", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "1Uq3HtkaBxtHNAnSwit5cI", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "args = TrainingArguments(\n", " output_dir=\"vaers\",\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " per_device_eval_batch_size=BATCH_SIZE,\n", " num_train_epochs=EPOCHS,\n", " weight_decay=.01,\n", " report_to=[\"wandb\"]\n", ")" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "1iPZOTKPwSkTgX5dORqT89", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "multi_label_trainer = MultiLabelTrainer(\n", " model, \n", " args, \n", " train_dataset=ds_enc[\"train\"], \n", " eval_dataset=ds_enc[\"test\"], \n", " compute_metrics=compute_metrics, \n", " tokenizer=tokenizer\n", ")" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "bnRkNvRYltLun6gCEgL7v0", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "multi_label_trainer.evaluate()" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "LO54PlDkWQdFrzV25FvduB", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "multi_label_trainer.train()" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "markdown", "source": [ "### Evaluation" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "We instantiate a classifier `pipeline` and push it to CUDA." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "classifier = pipeline(\"text-classification\", \n", " model, \n", " tokenizer=tokenizer, \n", " device=\"cuda:0\")" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "kHoUdBeqcyVXDSGv54C4aE", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "markdown", "source": [ "We use the same tokenizer used for training to tokenize/encode the validation set." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "test_encodings = tokenizer.batch_encode_plus(dataset[\"validate\"][\"text\"], \n", " max_length=255, \n", " pad_to_max_length=True, \n", " return_token_type_ids=True, \n", " truncation=True)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "Dr5WCWA6jL51NR1fSrQu6Z", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "markdown", "source": [ "Once we've made the data loadable by putting it into a `DataLoader`, we " ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n", " torch.tensor(test_encodings['attention_mask']), \n", " torch.tensor(ds_enc[\"validate\"][\"labels\"]), \n", " torch.tensor(test_encodings['token_type_ids']))\n", "test_dataloader = torch.utils.data.DataLoader(test_data, \n", " sampler=torch.utils.data.SequentialSampler(test_data), \n", " batch_size=BATCH_SIZE)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "MWfGq2tTkJNzFiDoUPq2X7", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "model.eval()\n", "\n", "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n", "\n", "for i, batch in enumerate(test_dataloader):\n", " batch = tuple(t.to(device) for t in batch)\n", " # Unpack the inputs from our dataloader\n", " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n", " \n", " with torch.no_grad():\n", " outs = model(b_input_ids, attention_mask=b_input_mask)\n", " b_logit_pred = outs[0]\n", " pred_label = torch.sigmoid(b_logit_pred)\n", "\n", " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n", " pred_label = pred_label.to('cpu').numpy()\n", " b_labels = b_labels.to('cpu').numpy()\n", "\n", " tokenized_texts.append(b_input_ids)\n", " logit_preds.append(b_logit_pred)\n", " true_labels.append(b_labels)\n", " pred_labels.append(pred_label)\n", "\n", "# Flatten outputs\n", "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n", "pred_labels = [item for sublist in pred_labels for item in sublist]\n", "true_labels = [item for sublist in true_labels for item in sublist]\n", "\n", "# Converting flattened binary values to boolean values\n", "true_bools = [tl == 1 for tl in true_labels]\n", "pred_bools = [pl > 0.50 for pl in pred_labels] " ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "1SJCSrQTRCexFCNCIyRrzL", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "markdown", "source": [ "We create a classification report:" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n", "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n", "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n", "print(clf_report)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "eBprrgF086mznPbPVBpOLS", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "markdown", "source": [ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "# Creating a map of class names from class numbers\n", "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "yELHY0IEwMlMw3x6e7hoD1", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "true_label_idxs, pred_label_idxs = [], []\n", "\n", "for vals in true_bools:\n", " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n", "for vals in pred_bools:\n", " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "jH0S35dDteUch01sa6me6e", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "true_label_texts, pred_label_texts = [], []\n", "\n", "for vals in true_label_idxs:\n", " if vals:\n", " true_label_texts.append([idx2label[val] for val in vals])\n", " else:\n", " true_label_texts.append(vals)\n", "\n", "for vals in pred_label_idxs:\n", " if vals:\n", " pred_label_texts.append([idx2label[val] for val in vals])\n", " else:\n", " pred_label_texts.append(vals)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "h4vHL8XdGpayZ6xLGJUF6F", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "symptom_texts = [tokenizer.decode(text,\n", " skip_special_tokens=True,\n", " clean_up_tokenization_spaces=False) for text in tokenized_texts]" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "SxUmVHfQISEeptg1SawOmB", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n", " 'true_labels': true_label_texts, \n", " 'pred_labels':pred_label_texts})\n", "comparisons_df.to_csv('comparisons.csv')\n", "comparisons_df" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "BxFNigNGRLTOqraI55BPSH", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "markdown", "source": [ "### Shapley analysis" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "OpdZcoenX2HwzLdai7K5UA", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "shap_values = explainer(dataset[\"validate\"][\"text\"][1:2])" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "FvbCMfIDlcf16YSvb8wNQv", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "shap.plots.text(shap_values)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "TSxvakWLPCpjVMWi9ZdEbd", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } } ], "metadata": { "kernelspec": { "name": "python3", "language": "python", "display_name": "Python 3 (ipykernel)" }, "datalore": { "computation_mode": "JUPYTER", "package_manager": "pip", "base_environment": "default", "packages": [ { "name": "datasets", "version": "2.16.1", "source": "PIP" }, { "name": "torch", "version": "2.1.2", "source": "PIP" }, { "name": "accelerate", "version": "0.26.1", "source": "PIP" } ], "report_row_ids": [ "un8W7ez7ZwoGb5Co6nydEV", "40nN9Hvgi1clHNV5RAemI5", "TgRD90H5NSPpKS41OeXI1w", "ZOm5BfUs3h1EGLaUkBGeEB", "kOP0CZWNSk6vqE3wkPp7Vc", "W4PWcOu2O2pRaZyoE2W80h", "RolbOnQLIftk0vy9mIcz5M", "8OPhUgbaNJmOdiq5D3a6vK", "5Qrt3jSvSrpK6Ne1hS6shL", "hTq7nFUrovN5Ao4u6dIYWZ", "I8WNZLpJ1DVP2wiCW7YBIB", "SawhU3I9BewSE1XBPstpNJ", "80EtLEl2FIE4FqbWnUD3nT" ], "version": 3 }, "microsoft": { "ms_spell_check": { "ms_spell_check_language": "en" } }, "language_info": { "name": "python", "version": "3.8.5", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "nteract": { "version": "nteract-front-end@1.0.0" } }, "nbformat": 4, "nbformat_minor": 4 }