{ "cells": [ { "cell_type": "markdown", "source": [ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n", "\n", "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States." ], "metadata": {} }, { "cell_type": "code", "source": [ "# %pip install accelerate -U" ], "outputs": [], "execution_count": 1, "metadata": { "nteract": { "transient": { "deleting": false } }, "tags": [] } }, { "cell_type": "code", "source": [ "%pip install transformers datasets shap watermark wandb scikit-multilearn" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nCollecting scikit-multilearn\n Downloading scikit_multilearn-0.2.0-py3-none-any.whl (89 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m89.4/89.4 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nInstalling collected packages: scikit-multilearn\nSuccessfully installed scikit-multilearn-0.2.0\nNote: you may need to restart the kernel to use updated packages.\n" } ], "execution_count": 1, "metadata": { "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import os\n", "from typing import List\n", "from sklearn.metrics import f1_score, accuracy_score, classification_report\n", "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n", "from datasets import load_dataset, Dataset, DatasetDict\n", "from pyarrow import Table\n", "import shap\n", "import wandb\n", "from skmultilearn.problem_transform import LabelPowerset\n", "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "\n", "%load_ext watermark" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-28 15:09:42.856486: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-28 15:09:43.818179: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-28 15:09:43.818307: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-28 15:09:43.818321: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" } ], "execution_count": 2, "metadata": { "datalore": { "hide_input_from_viewers": false, "hide_output_from_viewers": false, "node_id": "caZjjFP0OyQNMVgZDiwswE", "report_properties": { "rowId": "un8W7ez7ZwoGb5Co6nydEV" }, "type": "CODE" }, "gather": { "logged": 1706454586481 }, "tags": [] } }, { "cell_type": "code", "source": [ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "SEED: int = 42\n", "\n", "BATCH_SIZE: int = 16\n", "EPOCHS: int = 3\n", "model_ckpt: str = \"distilbert-base-uncased\"\n", "\n", "CLASS_NAMES: List[str] = [\"DIED\",\n", " \"ER_VISIT\",\n", " \"HOSPITAL\",\n", " \"OFC_VISIT\",\n", " #\"X_STAY\", # pruned\n", " #\"DISABLE\", # pruned\n", " #\"D_PRESENTED\" # pruned\n", " ]\n", "\n", "\n", "\n", "\n", "# WandB configuration\n", "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n", "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n", "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\"" ], "outputs": [], "execution_count": 3, "metadata": { "collapsed": false, "gather": { "logged": 1706454586654 }, "jupyter": { "outputs_hidden": false } } }, { "cell_type": "code", "source": [ "%watermark --iversion" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "shap : 0.44.1\nlogging: 0.5.1.2\npandas : 2.0.2\nnumpy : 1.23.5\ntorch : 1.12.0\nwandb : 0.16.2\nre : 2.2.1\n\n" } ], "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } } }, { "cell_type": "code", "source": [ "!nvidia-smi" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "Sun Jan 28 15:09:47 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 30C P0 38W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 29C P0 38W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n" } ], "execution_count": 5, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "UU2oOJhwbIualogG1YyCMd", "type": "CODE" } } }, { "cell_type": "markdown", "source": [ "## Loading the data set" ], "metadata": { "datalore": { "hide_input_from_viewers": false, "hide_output_from_viewers": false, "node_id": "t45KHugmcPVaO0nuk8tGJ9", "report_properties": { "rowId": "40nN9Hvgi1clHNV5RAemI5" }, "type": "MD" } } }, { "cell_type": "code", "source": [ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")" ], "outputs": [], "execution_count": 7, "metadata": { "collapsed": false, "gather": { "logged": 1706449040507 }, "jupyter": { "outputs_hidden": false } } }, { "cell_type": "code", "source": [ "dataset" ], "outputs": [ { "output_type": "execute_result", "execution_count": 8, "data": { "text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 272238\n })\n})" }, "metadata": {} } ], "execution_count": 8, "metadata": { "collapsed": false, "gather": { "logged": 1706449044205 }, "jupyter": { "outputs_hidden": false, "source_hidden": false }, "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "SUBSAMPLING: float = 0.1" ], "outputs": [], "execution_count": 9, "metadata": {} }, { "cell_type": "code", "source": [ "def minisample(ds: DatasetDict, fraction: float) -> DatasetDict:\n", " res = DatasetDict()\n", "\n", " res[\"train\"] = Dataset.from_dict(ds[\"train\"].shuffle()[:round(len(ds[\"train\"]) * fraction)])\n", " res[\"test\"] = Dataset.from_dict(ds[\"test\"].shuffle()[:round(len(ds[\"test\"]) * fraction)])\n", " res[\"val\"] = Dataset.from_dict(ds[\"val\"].shuffle()[:round(len(ds[\"val\"]) * fraction)])\n", " \n", " return res" ], "outputs": [], "execution_count": 10, "metadata": { "collapsed": false, "gather": { "logged": 1706449378281 }, "jupyter": { "outputs_hidden": false, "source_hidden": false }, "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "dataset = minisample(dataset, SUBSAMPLING)" ], "outputs": [], "execution_count": 11, "metadata": { "collapsed": false, "gather": { "logged": 1706449384162 }, "jupyter": { "outputs_hidden": false, "source_hidden": false }, "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "dataset" ], "outputs": [ { "output_type": "execute_result", "execution_count": 12, "data": { "text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 127044\n })\n test: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 27224\n })\n val: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 27224\n })\n})" }, "metadata": {} } ], "execution_count": 12, "metadata": { "collapsed": false, "gather": { "logged": 1706449387981 }, "jupyter": { "outputs_hidden": false, "source_hidden": false }, "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "markdown", "source": [ "We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`." ], "metadata": { "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "ds = DatasetDict()\n", "\n", "for i in [\"test\", \"train\", \"val\"]:\n", " tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n", " ds[i] = Dataset(tab)\n", "\n", "dataset = ds" ], "outputs": [], "execution_count": 13, "metadata": { "collapsed": false, "gather": { "logged": 1706449443055 }, "jupyter": { "outputs_hidden": false, "source_hidden": false }, "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "markdown", "source": [ "### Tokenisation and encoding" ], "metadata": {} }, { "cell_type": "code", "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)" ], "outputs": [], "execution_count": 14, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "I7n646PIscsUZRoHu6m7zm", "type": "CODE" }, "gather": { "logged": 1706449638377 } } }, { "cell_type": "code", "source": [ "def tokenize_and_encode(examples):\n", " return tokenizer(examples[\"text\"], truncation=True)" ], "outputs": [], "execution_count": 15, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "QBLOSI0yVIslV7v7qX9ZC3", "type": "CODE" }, "gather": { "logged": 1706449642580 } } }, { "cell_type": "code", "source": [ "cols = dataset[\"train\"].column_names\n", "cols.remove(\"labels\")\n", "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "Map: 100%|██████████| 27224/27224 [00:10<00:00, 2638.52 examples/s]\nMap: 100%|██████████| 127044/127044 [00:48<00:00, 2633.40 examples/s]\nMap: 100%|██████████| 27224/27224 [00:10<00:00, 2613.19 examples/s]\n" } ], "execution_count": 16, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "slHeNysZOX9uWS9PB7jFDb", "type": "CODE" }, "gather": { "logged": 1706449721161 } } }, { "cell_type": "markdown", "source": [ "### Training" ], "metadata": {} }, { "cell_type": "code", "source": [ "class MultiLabelTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False):\n", " labels = inputs.pop(\"labels\")\n", " outputs = model(**inputs)\n", " logits = outputs.logits\n", " loss_fct = torch.nn.BCEWithLogitsLoss()\n", " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n", " labels.float().view(-1, self.model.config.num_labels))\n", " return (loss, outputs) if return_outputs else loss" ], "outputs": [], "execution_count": 17, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "itXWkbDw9sqbkMuDP84QoT", "type": "CODE" }, "gather": { "logged": 1706449743072 } } }, { "cell_type": "code", "source": [ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" } ], "execution_count": 18, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "ZQU7aW6TV45VmhHOQRzcnF", "type": "CODE" }, "gather": { "logged": 1706449761205 } } }, { "cell_type": "code", "source": [ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n", " y_pred = torch.from_numpy(y_pred)\n", " y_true = torch.from_numpy(y_true)\n", "\n", " if sigmoid:\n", " y_pred = y_pred.sigmoid()\n", "\n", " return ((y_pred > threshold) == y_true.bool()).float().mean().item()" ], "outputs": [], "execution_count": 19, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "swhgyyyxoGL8HjnXJtMuSW", "type": "CODE" }, "gather": { "logged": 1706449761541 } } }, { "cell_type": "code", "source": [ "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}" ], "outputs": [], "execution_count": 20, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "1Uq3HtkaBxtHNAnSwit5cI", "type": "CODE" }, "gather": { "logged": 1706449761720 } } }, { "cell_type": "code", "source": [ "args = TrainingArguments(\n", " output_dir=\"vaers\",\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " per_device_eval_batch_size=BATCH_SIZE,\n", " num_train_epochs=EPOCHS,\n", " weight_decay=.01,\n", " logging_steps=1,\n", " run_name=f\"daedra-training\",\n", " report_to=[\"wandb\"]\n", ")" ], "outputs": [], "execution_count": 21, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "1iPZOTKPwSkTgX5dORqT89", "type": "CODE" }, "gather": { "logged": 1706449761893 } } }, { "cell_type": "code", "source": [ "multi_label_trainer = MultiLabelTrainer(\n", " model, \n", " args, \n", " train_dataset=ds_enc[\"train\"], \n", " eval_dataset=ds_enc[\"test\"], \n", " compute_metrics=compute_metrics, \n", " tokenizer=tokenizer\n", ")" ], "outputs": [], "execution_count": 22, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "bnRkNvRYltLun6gCEgL7v0", "type": "CODE" }, "gather": { "logged": 1706449769103 } } }, { "cell_type": "code", "source": [ "if SUBSAMPLING != 1.0:\n", " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n", "else:\n", " wandb_tag: List[str] = [f\"full_sample\"]\n", " \n", "wandb.init(name=\"init_evaluation_run\", tags=wandb_tag, magic=True)\n", "\n", "multi_label_trainer.evaluate()\n", "wandb.finish()" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n" }, { "output_type": "display_data", "data": { "text/html": "Tracking run with wandb version 0.16.2", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Run data is saved locally in /mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141956-9lniqjvz", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Syncing run init_evaluation_run to Weights & Biases (docs)
", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": " View project at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": " View run at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Finishing last run (ID:9lniqjvz) before initializing another...", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": " View run init_evaluation_run at: https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Find logs at: ./wandb/run-20240128_141956-9lniqjvz/logs", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Successfully finished last run (ID:9lniqjvz). Initializing new run:
", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Tracking run with wandb version 0.16.2", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Run data is saved locally in /mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141958-5idmkcie", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Syncing run init_evaluation_run to Weights & Biases (docs)
", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": " View project at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": " View run at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "\n
\n \n \n [851/851 26:26]\n
\n ", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "\n

Run history:


eval/accuracy_thresh
eval/loss
eval/runtime
eval/samples_per_second
eval/steps_per_second
train/global_step

Run summary:


eval/accuracy_thresh0.55198
eval/loss0.68442
eval/runtime105.0436
eval/samples_per_second259.168
eval/steps_per_second8.101
train/global_step0

", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": " View run init_evaluation_run at: https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Find logs at: ./wandb/run-20240128_141958-5idmkcie/logs", "text/plain": "" }, "metadata": {} } ], "execution_count": 23, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "LO54PlDkWQdFrzV25FvduB", "type": "CODE" }, "gather": { "logged": 1706449880674 } } }, { "cell_type": "code", "source": [ "if SUBSAMPLING != 1.0:\n", " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n", "else:\n", " wandb_tag: List[str] = [f\"full_sample\"]\n", " \n", "wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)\n", "\n", "multi_label_trainer.train()\n", "wandb.finish()" ], "outputs": [ { "output_type": "display_data", "data": { "text/html": "Tracking run with wandb version 0.16.2", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Run data is saved locally in /mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_142151-2mcc0ibc", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "Syncing run daedra_training_run to Weights & Biases (docs)
", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": " View project at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": " View run at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc", "text/plain": "" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": "\n
\n \n \n [ 3972/11913 24:20 < 48:40, 2.72 it/s, Epoch 1/3]\n
\n \n \n \n \n \n \n \n \n \n \n
EpochTraining LossValidation Loss

", "text/plain": "" }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.6s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 22.7s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 14.0s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 15.2s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 14.0s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... Done. 12.4s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3500)... Done. 13.4s\n" } ], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw", "type": "CODE" }, "gather": { "logged": 1706449934637 } } }, { "cell_type": "markdown", "source": [ "### Evaluation" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "We instantiate a classifier `pipeline` and push it to CUDA." ], "metadata": {} }, { "cell_type": "code", "source": [ "classifier = pipeline(\"text-classification\", \n", " model, \n", " tokenizer=tokenizer, \n", " device=\"cuda:0\")" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "kHoUdBeqcyVXDSGv54C4aE", "type": "CODE" }, "gather": { "logged": 1706411459928 } } }, { "cell_type": "markdown", "source": [ "We use the same tokenizer used for training to tokenize/encode the validation set." ], "metadata": {} }, { "cell_type": "code", "source": [ "test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n", " max_length=None, \n", " padding='max_length', \n", " return_token_type_ids=True, \n", " truncation=True)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "Dr5WCWA6jL51NR1fSrQu6Z", "type": "CODE" }, "gather": { "logged": 1706411523285 } } }, { "cell_type": "markdown", "source": [ "Once we've made the data loadable by putting it into a `DataLoader`, we " ], "metadata": {} }, { "cell_type": "code", "source": [ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n", " torch.tensor(test_encodings['attention_mask']), \n", " torch.tensor(ds_enc[\"val\"][\"labels\"]), \n", " torch.tensor(test_encodings['token_type_ids']))\n", "test_dataloader = torch.utils.data.DataLoader(test_data, \n", " sampler=torch.utils.data.SequentialSampler(test_data), \n", " batch_size=BATCH_SIZE)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "MWfGq2tTkJNzFiDoUPq2X7", "type": "CODE" }, "gather": { "logged": 1706411543379 } } }, { "cell_type": "code", "source": [ "model.eval()\n", "\n", "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n", "\n", "for i, batch in enumerate(test_dataloader):\n", " batch = tuple(t.to(device) for t in batch)\n", " \n", " # Unpack the inputs from our dataloader\n", " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n", " \n", " with torch.no_grad():\n", " outs = model(b_input_ids, attention_mask=b_input_mask)\n", " b_logit_pred = outs[0]\n", " pred_label = torch.sigmoid(b_logit_pred)\n", "\n", " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n", " pred_label = pred_label.to('cpu').numpy()\n", " b_labels = b_labels.to('cpu').numpy()\n", "\n", " tokenized_texts.append(b_input_ids)\n", " logit_preds.append(b_logit_pred)\n", " true_labels.append(b_labels)\n", " pred_labels.append(pred_label)\n", "\n", "# Flatten outputs\n", "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n", "pred_labels = [item for sublist in pred_labels for item in sublist]\n", "true_labels = [item for sublist in true_labels for item in sublist]\n", "\n", "# Converting flattened binary values to boolean values\n", "true_bools = [tl == 1 for tl in true_labels]\n", "pred_bools = [pl > 0.50 for pl in pred_labels] " ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "1SJCSrQTRCexFCNCIyRrzL", "type": "CODE" }, "gather": { "logged": 1706411587843 } } }, { "cell_type": "markdown", "source": [ "We create a classification report:" ], "metadata": {} }, { "cell_type": "code", "source": [ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n", "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n", "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n", "print(clf_report)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "eBprrgF086mznPbPVBpOLS", "type": "CODE" }, "gather": { "logged": 1706411588249 } } }, { "cell_type": "markdown", "source": [ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels." ], "metadata": {} }, { "cell_type": "code", "source": [ "# Creating a map of class names from class numbers\n", "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "yELHY0IEwMlMw3x6e7hoD1", "type": "CODE" }, "gather": { "logged": 1706411588638 } } }, { "cell_type": "code", "source": [ "true_label_idxs, pred_label_idxs = [], []\n", "\n", "for vals in true_bools:\n", " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n", "for vals in pred_bools:\n", " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "jH0S35dDteUch01sa6me6e", "type": "CODE" }, "gather": { "logged": 1706411589004 } } }, { "cell_type": "code", "source": [ "true_label_texts, pred_label_texts = [], []\n", "\n", "for vals in true_label_idxs:\n", " if vals:\n", " true_label_texts.append([idx2label[val] for val in vals])\n", " else:\n", " true_label_texts.append(vals)\n", "\n", "for vals in pred_label_idxs:\n", " if vals:\n", " pred_label_texts.append([idx2label[val] for val in vals])\n", " else:\n", " pred_label_texts.append(vals)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "h4vHL8XdGpayZ6xLGJUF6F", "type": "CODE" }, "gather": { "logged": 1706411589301 } } }, { "cell_type": "code", "source": [ "symptom_texts = [tokenizer.decode(text,\n", " skip_special_tokens=True,\n", " clean_up_tokenization_spaces=False) for text in tokenized_texts]" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "SxUmVHfQISEeptg1SawOmB", "type": "CODE" }, "gather": { "logged": 1706411591952 } } }, { "cell_type": "code", "source": [ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n", " 'true_labels': true_label_texts, \n", " 'pred_labels':pred_label_texts})\n", "comparisons_df.to_csv('comparisons.csv')\n", "comparisons_df" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "BxFNigNGRLTOqraI55BPSH", "type": "CODE" }, "gather": { "logged": 1706411592512 } } } ], "metadata": { "datalore": { "base_environment": "default", "computation_mode": "JUPYTER", "package_manager": "pip", "packages": [ { "name": "datasets", "source": "PIP", "version": "2.16.1" }, { "name": "torch", "source": "PIP", "version": "2.1.2" }, { "name": "accelerate", "source": "PIP", "version": "0.26.1" } ], "report_row_ids": [ "un8W7ez7ZwoGb5Co6nydEV", "40nN9Hvgi1clHNV5RAemI5", "TgRD90H5NSPpKS41OeXI1w", "ZOm5BfUs3h1EGLaUkBGeEB", "kOP0CZWNSk6vqE3wkPp7Vc", "W4PWcOu2O2pRaZyoE2W80h", "RolbOnQLIftk0vy9mIcz5M", "8OPhUgbaNJmOdiq5D3a6vK", "5Qrt3jSvSrpK6Ne1hS6shL", "hTq7nFUrovN5Ao4u6dIYWZ", "I8WNZLpJ1DVP2wiCW7YBIB", "SawhU3I9BewSE1XBPstpNJ", "80EtLEl2FIE4FqbWnUD3nT" ], "version": 3 }, "kernelspec": { "display_name": "Python 3.8 - Pytorch and Tensorflow", "language": "python", "name": "python38-azureml-pt-tf" }, "language_info": { "name": "python", "version": "3.8.5", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "microsoft": { "host": { "AzureML": { "notebookHasBeenCompleted": true } }, "ms_spell_check": { "ms_spell_check_language": "en" } }, "nteract": { "version": "nteract-front-end@1.0.0" } }, "nbformat": 4, "nbformat_minor": 4 }