{ "cells": [ { "cell_type": "markdown", "source": [ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n", "\n", "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "%pip install accelerate -U" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nNote: you may need to restart the kernel to use updated packages.\n" } ], "execution_count": 1, "metadata": { "jupyter": { "source_hidden": false, "outputs_hidden": false }, "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "%pip install transformers datasets shap watermark wandb" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" }, { "output_type": "stream", "name": "stdout", "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nCollecting wandb\n Using cached wandb-0.16.2-py3-none-any.whl (2.2 MB)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nCollecting sentry-sdk>=1.0.0\n Using cached sentry_sdk-1.39.2-py2.py3-none-any.whl (254 kB)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nCollecting docker-pycreds>=0.4.0\n Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nCollecting setproctitle\n Using cached setproctitle-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)\nCollecting appdirs>=1.4.3\n Using cached appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nInstalling collected packages: appdirs, setproctitle, sentry-sdk, docker-pycreds, wandb\nSuccessfully installed appdirs-1.4.4 docker-pycreds-0.4.0 sentry-sdk-1.39.2 setproctitle-1.3.3 wandb-0.16.2\nNote: you may need to restart the kernel to use updated packages.\n" } ], "execution_count": 17, "metadata": { "jupyter": { "source_hidden": false, "outputs_hidden": false }, "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import os\n", "from typing import List\n", "from sklearn.metrics import f1_score, accuracy_score, classification_report\n", "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n", "from datasets import load_dataset, Dataset, DatasetDict\n", "from pyarrow import Table\n", "import shap\n", "\n", "%load_ext watermark" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "The watermark extension is already loaded. To reload it, use:\n %reload_ext watermark\n" } ], "execution_count": 99, "metadata": { "datalore": { "node_id": "caZjjFP0OyQNMVgZDiwswE", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "un8W7ez7ZwoGb5Co6nydEV" } }, "gather": { "logged": 1706413522853 } } }, { "cell_type": "code", "source": [ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "SEED: int = 42\n", "\n", "BATCH_SIZE: int = 8\n", "EPOCHS: int = 1\n", "model_ckpt: str = \"distilbert-base-uncased\"\n", "\n", "CLASS_NAMES: List[str] = [\"DIED\",\n", " \"ER_VISIT\",\n", " \"HOSPITAL\",\n", " \"OFC_VISIT\",\n", " #\"X_STAY\", # pruned\n", " #\"DISABLE\", # pruned\n", " #\"D_PRESENTED\" # pruned\n", " ]\n", "\n", "\n", "\n", "\n", "# WandB configuration\n", "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n", "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints" ], "outputs": [], "execution_count": 87, "metadata": { "collapsed": false, "gather": { "logged": 1706413214901 } } }, { "cell_type": "code", "source": [ "%watermark --iversion" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "re : 2.2.1\nnumpy : 1.23.5\nlogging: 0.5.1.2\npandas : 2.0.2\ntorch : 1.12.0\nshap : 0.44.1\n\n" } ], "execution_count": 5, "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "!nvidia-smi" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "Sun Jan 28 02:27:31 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 28C P0 37W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 27C P0 36W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n" } ], "execution_count": 6, "metadata": { "datalore": { "node_id": "UU2oOJhwbIualogG1YyCMd", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "attachments": {}, "cell_type": "markdown", "source": [ "## Loading the data set" ], "metadata": { "datalore": { "node_id": "t45KHugmcPVaO0nuk8tGJ9", "type": "MD", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "40nN9Hvgi1clHNV5RAemI5" } } } }, { "cell_type": "code", "source": [ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")" ], "outputs": [], "execution_count": 105, "metadata": { "collapsed": false, "gather": { "logged": 1706413798729 } } }, { "cell_type": "markdown", "source": [ "We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`." ], "metadata": { "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "ds = DatasetDict()\n", "\n", "for i in [\"test\", \"train\", \"val\"]:\n", " tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n", " ds[i] = Dataset(tab)\n", "\n", "dataset = ds" ], "outputs": [], "execution_count": 106, "metadata": { "jupyter": { "source_hidden": false, "outputs_hidden": false }, "nteract": { "transient": { "deleting": false } }, "gather": { "logged": 1706413801396 } } }, { "cell_type": "markdown", "source": [ "### Tokenisation and encoding" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)" ], "outputs": [], "execution_count": 8, "metadata": { "datalore": { "node_id": "I7n646PIscsUZRoHu6m7zm", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408853475 } } }, { "cell_type": "code", "source": [ "def tokenize_and_encode(examples):\n", " return tokenizer(examples[\"text\"], truncation=True)" ], "outputs": [], "execution_count": 9, "metadata": { "datalore": { "node_id": "QBLOSI0yVIslV7v7qX9ZC3", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408853684 } } }, { "cell_type": "code", "source": [ "cols = dataset[\"train\"].column_names\n", "cols.remove(\"labels\")\n", "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "Map: 100%|██████████| 15786/15786 [00:01<00:00, 10990.82 examples/s]\n" } ], "execution_count": 10, "metadata": { "datalore": { "node_id": "slHeNysZOX9uWS9PB7jFDb", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408854738 } } }, { "cell_type": "markdown", "source": [ "### Training" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "class MultiLabelTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False):\n", " labels = inputs.pop(\"labels\")\n", " outputs = model(**inputs)\n", " logits = outputs.logits\n", " loss_fct = torch.nn.BCEWithLogitsLoss()\n", " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n", " labels.float().view(-1, self.model.config.num_labels))\n", " return (loss, outputs) if return_outputs else loss" ], "outputs": [], "execution_count": 11, "metadata": { "datalore": { "node_id": "itXWkbDw9sqbkMuDP84QoT", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408854925 } } }, { "cell_type": "code", "source": [ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" } ], "execution_count": 12, "metadata": { "datalore": { "node_id": "ZQU7aW6TV45VmhHOQRzcnF", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408857008 } } }, { "cell_type": "code", "source": [ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n", " y_pred = torch.from_numpy(y_pred)\n", " y_true = torch.from_numpy(y_true)\n", "\n", " if sigmoid:\n", " y_pred = y_pred.sigmoid()\n", "\n", " return ((y_pred > threshold) == y_true.bool()).float().mean().item()" ], "outputs": [], "execution_count": 13, "metadata": { "datalore": { "node_id": "swhgyyyxoGL8HjnXJtMuSW", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408857297 } } }, { "cell_type": "code", "source": [ "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}" ], "outputs": [], "execution_count": 14, "metadata": { "datalore": { "node_id": "1Uq3HtkaBxtHNAnSwit5cI", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408857499 } } }, { "cell_type": "code", "source": [ "args = TrainingArguments(\n", " output_dir=\"vaers\",\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " per_device_eval_batch_size=BATCH_SIZE,\n", " num_train_epochs=EPOCHS,\n", " weight_decay=.01,\n", " report_to=[\"wandb\"]\n", ")" ], "outputs": [], "execution_count": 15, "metadata": { "datalore": { "node_id": "1iPZOTKPwSkTgX5dORqT89", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408857680 } } }, { "cell_type": "code", "source": [ "multi_label_trainer = MultiLabelTrainer(\n", " model, \n", " args, \n", " train_dataset=ds_enc[\"train\"], \n", " eval_dataset=ds_enc[\"test\"], \n", " compute_metrics=compute_metrics, \n", " tokenizer=tokenizer\n", ")" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" } ], "execution_count": 18, "metadata": { "datalore": { "node_id": "bnRkNvRYltLun6gCEgL7v0", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408895305 } } }, { "cell_type": "code", "source": [ "multi_label_trainer.evaluate()" ], "outputs": [ { "output_type": "display_data", "data": { "text/plain": "", "text/html": "\n
\n \n \n [987/987 21:41]\n
\n " }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" }, { "output_type": "display_data", "data": { "text/plain": "", "text/html": "Tracking run with wandb version 0.16.2" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": "", "text/html": "Run data is saved locally in /mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_022947-hh1sxw9i" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": "", "text/html": "Syncing run icy-firebrand-1 to Weights & Biases (docs)
" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": "", "text/html": " View project at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": "", "text/html": " View run at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/hh1sxw9i" }, "metadata": {} }, { "output_type": "execute_result", "execution_count": 19, "data": { "text/plain": "{'eval_loss': 0.7153111100196838,\n 'eval_accuracy_thresh': 0.2938227355480194,\n 'eval_runtime': 82.3613,\n 'eval_samples_per_second': 191.668,\n 'eval_steps_per_second': 11.984}" }, "metadata": {} } ], "execution_count": 19, "metadata": { "datalore": { "node_id": "LO54PlDkWQdFrzV25FvduB", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408991752 } } }, { "cell_type": "code", "source": [ "multi_label_trainer.train()" ], "outputs": [ { "output_type": "display_data", "data": { "text/plain": "", "text/html": "\n
\n \n \n [4605/4605 20:25, Epoch 1/1]\n
\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
EpochTraining LossValidation LossAccuracy Thresh
10.0867000.0933880.962897

" }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": "Checkpoint destination directory vaers/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.9s\nCheckpoint destination directory vaers/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 12.5s\nCheckpoint destination directory vaers/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 21.9s\nCheckpoint destination directory vaers/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 13.8s\nCheckpoint destination directory vaers/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 15.7s\nCheckpoint destination directory vaers/checkpoint-3000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... Done. 21.7s\nCheckpoint destination directory vaers/checkpoint-3500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3500)... Done. 10.6s\nCheckpoint destination directory vaers/checkpoint-4000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-4000)... Done. 15.0s\nCheckpoint destination directory vaers/checkpoint-4500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-4500)... Done. 16.7s\n" }, { "output_type": "execute_result", "execution_count": 21, "data": { "text/plain": "TrainOutput(global_step=4605, training_loss=0.09062977189220382, metrics={'train_runtime': 1223.2444, 'train_samples_per_second': 60.223, 'train_steps_per_second': 3.765, 'total_flos': 9346797199425174.0, 'train_loss': 0.09062977189220382, 'epoch': 1.0})" }, "metadata": {} } ], "execution_count": 21, "metadata": { "datalore": { "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411445752 } } }, { "cell_type": "markdown", "source": [ "### Evaluation" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "We instantiate a classifier `pipeline` and push it to CUDA." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "classifier = pipeline(\"text-classification\", \n", " model, \n", " tokenizer=tokenizer, \n", " device=\"cuda:0\")" ], "outputs": [], "execution_count": 24, "metadata": { "datalore": { "node_id": "kHoUdBeqcyVXDSGv54C4aE", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411459928 } } }, { "cell_type": "markdown", "source": [ "We use the same tokenizer used for training to tokenize/encode the validation set." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n", " max_length=None, \n", " padding='max_length', \n", " return_token_type_ids=True, \n", " truncation=True)" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n" } ], "execution_count": 26, "metadata": { "datalore": { "node_id": "Dr5WCWA6jL51NR1fSrQu6Z", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411523285 } } }, { "cell_type": "markdown", "source": [ "Once we've made the data loadable by putting it into a `DataLoader`, we " ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n", " torch.tensor(test_encodings['attention_mask']), \n", " torch.tensor(ds_enc[\"val\"][\"labels\"]), \n", " torch.tensor(test_encodings['token_type_ids']))\n", "test_dataloader = torch.utils.data.DataLoader(test_data, \n", " sampler=torch.utils.data.SequentialSampler(test_data), \n", " batch_size=BATCH_SIZE)" ], "outputs": [], "execution_count": 29, "metadata": { "datalore": { "node_id": "MWfGq2tTkJNzFiDoUPq2X7", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411543379 } } }, { "cell_type": "code", "source": [ "model.eval()\n", "\n", "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n", "\n", "for i, batch in enumerate(test_dataloader):\n", " batch = tuple(t.to(device) for t in batch)\n", " \n", " # Unpack the inputs from our dataloader\n", " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n", " \n", " with torch.no_grad():\n", " outs = model(b_input_ids, attention_mask=b_input_mask)\n", " b_logit_pred = outs[0]\n", " pred_label = torch.sigmoid(b_logit_pred)\n", "\n", " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n", " pred_label = pred_label.to('cpu').numpy()\n", " b_labels = b_labels.to('cpu').numpy()\n", "\n", " tokenized_texts.append(b_input_ids)\n", " logit_preds.append(b_logit_pred)\n", " true_labels.append(b_labels)\n", " pred_labels.append(pred_label)\n", "\n", "# Flatten outputs\n", "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n", "pred_labels = [item for sublist in pred_labels for item in sublist]\n", "true_labels = [item for sublist in true_labels for item in sublist]\n", "\n", "# Converting flattened binary values to boolean values\n", "true_bools = [tl == 1 for tl in true_labels]\n", "pred_bools = [pl > 0.50 for pl in pred_labels] " ], "outputs": [], "execution_count": 30, "metadata": { "datalore": { "node_id": "1SJCSrQTRCexFCNCIyRrzL", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411587843 } } }, { "cell_type": "markdown", "source": [ "We create a classification report:" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n", "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n", "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n", "print(clf_report)" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "Test F1 Accuracy: 0.8148841961852862\nTest Flat Accuracy: 0.8456129236617042 \n\n precision recall f1-score support\n\n DIED 0.98 0.83 0.90 312\n ER_VISIT 0.75 0.57 0.65 1143\n HOSPITAL 0.94 0.90 0.92 2361\n OFC_VISIT 0.77 0.66 0.71 2835\n X_STAY 0.00 0.00 0.00 9\n DISABLE 0.62 0.28 0.39 313\n D_PRESENTED 0.89 0.85 0.87 5392\n\n micro avg 0.86 0.77 0.81 12365\n macro avg 0.71 0.59 0.63 12365\nweighted avg 0.85 0.77 0.81 12365\n samples avg 0.29 0.28 0.28 12365\n\n" }, { "output_type": "stream", "name": "stderr", "text": "Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\nPrecision and F-score are ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\nRecall and F-score are ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.\n" } ], "execution_count": 31, "metadata": { "datalore": { "node_id": "eBprrgF086mznPbPVBpOLS", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411588249 } } }, { "cell_type": "markdown", "source": [ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "# Creating a map of class names from class numbers\n", "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))" ], "outputs": [], "execution_count": 32, "metadata": { "datalore": { "node_id": "yELHY0IEwMlMw3x6e7hoD1", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411588638 } } }, { "cell_type": "code", "source": [ "true_label_idxs, pred_label_idxs = [], []\n", "\n", "for vals in true_bools:\n", " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n", "for vals in pred_bools:\n", " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())" ], "outputs": [], "execution_count": 33, "metadata": { "datalore": { "node_id": "jH0S35dDteUch01sa6me6e", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411589004 } } }, { "cell_type": "code", "source": [ "true_label_texts, pred_label_texts = [], []\n", "\n", "for vals in true_label_idxs:\n", " if vals:\n", " true_label_texts.append([idx2label[val] for val in vals])\n", " else:\n", " true_label_texts.append(vals)\n", "\n", "for vals in pred_label_idxs:\n", " if vals:\n", " pred_label_texts.append([idx2label[val] for val in vals])\n", " else:\n", " pred_label_texts.append(vals)" ], "outputs": [], "execution_count": 34, "metadata": { "datalore": { "node_id": "h4vHL8XdGpayZ6xLGJUF6F", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411589301 } } }, { "cell_type": "code", "source": [ "symptom_texts = [tokenizer.decode(text,\n", " skip_special_tokens=True,\n", " clean_up_tokenization_spaces=False) for text in tokenized_texts]" ], "outputs": [], "execution_count": 35, "metadata": { "datalore": { "node_id": "SxUmVHfQISEeptg1SawOmB", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411591952 } } }, { "cell_type": "code", "source": [ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n", " 'true_labels': true_label_texts, \n", " 'pred_labels':pred_label_texts})\n", "comparisons_df.to_csv('comparisons.csv')\n", "comparisons_df" ], "outputs": [ { "output_type": "execute_result", "execution_count": 36, "data": { "text/plain": " symptom_text true_labels \\\n0 pt was due for hepb, hib, ipv. i gave pentacel... [] \n1 cold ; covid - 19 twice, he tested positive ; ... [] \n2 patient described pain in both shoulders and r... [] \n3 error : improper storage ( ex. temp. / locatio... [] \n4 vaccine was stored in as unapproved storage unit [] \n... ... ... \n15780 allergic reaction ; this is a spontaneous repo... [] \n15781 immediate side effects were in line with expec... [] \n15782 anaphylaxis immediately after administration o... [] \n15783 no additional ae ' s were reported ; the hcp r... [] \n15784 vaccine was frozen rather than refrigerated. p... [] \n\n pred_labels \n0 [] \n1 [] \n2 [] \n3 [] \n4 [] \n... ... \n15780 [] \n15781 [] \n15782 [] \n15783 [] \n15784 [] \n\n[15785 rows x 3 columns]", "text/html": "

\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
symptom_texttrue_labelspred_labels
0pt was due for hepb, hib, ipv. i gave pentacel...[][]
1cold ; covid - 19 twice, he tested positive ; ...[][]
2patient described pain in both shoulders and r...[][]
3error : improper storage ( ex. temp. / locatio...[][]
4vaccine was stored in as unapproved storage unit[][]
............
15780allergic reaction ; this is a spontaneous repo...[][]
15781immediate side effects were in line with expec...[][]
15782anaphylaxis immediately after administration o...[][]
15783no additional ae ' s were reported ; the hcp r...[][]
15784vaccine was frozen rather than refrigerated. p...[][]
\n

15785 rows × 3 columns

\n
" }, "metadata": {} } ], "execution_count": 36, "metadata": { "datalore": { "node_id": "BxFNigNGRLTOqraI55BPSH", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411592512 } } }, { "cell_type": "markdown", "source": [ "### Shapley analysis" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)" ], "outputs": [], "execution_count": 160, "metadata": { "datalore": { "node_id": "OpdZcoenX2HwzLdai7K5UA", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706415109071 } } }, { "cell_type": "markdown", "source": [ "#### Sampling correct predictions\n", "\n", "First, let's look at some correct predictions of deaths:" ], "metadata": { "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "correct_death_predictions = comparisons_df[comparisons_df['true_labels'].astype(str) == \"['DIED']\"]" ], "outputs": [], "execution_count": 153, "metadata": { "jupyter": { "source_hidden": false, "outputs_hidden": false }, "nteract": { "transient": { "deleting": false } }, "gather": { "logged": 1706414973990 } } }, { "cell_type": "code", "source": [ "texts = [i[:512] for i in correct_death_predictions.sample(n=6).symptom_text]\n", "idxs = [i for i in range(len(texts))]\n", "\n", "d_s = Dataset(Table.from_arrays([idxs, texts], names=[\"idx\", \"texts\"]))" ], "outputs": [], "execution_count": 161, "metadata": { "jupyter": { "source_hidden": false, "outputs_hidden": false }, "nteract": { "transient": { "deleting": false } }, "gather": { "logged": 1706415114683 } } }, { "cell_type": "code", "source": [ "shap_values = explainer(d_s[\"texts\"])" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\nPartitionExplainer explainer: 7it [00:14, 3.70s/it] \n" } ], "execution_count": 162, "metadata": { "jupyter": { "source_hidden": false, "outputs_hidden": false }, "nteract": { "transient": { "deleting": false } }, "gather": { "logged": 1706415129229 } } }, { "cell_type": "code", "source": [ "shap.plots.text(shap_values)" ], "outputs": [ { "output_type": "display_data", "data": { "text/plain": "", "text/html": "\n
\n
\n
[0]
\n
\n
\n\n
outputs
\n
DIED
\n
ER_VISIT
\n
HOSPITAL
\n
OFC_VISIT
\n
X
\n
xx
\n
XY


0.50.30.1-0.10.70.900base value0.9993840.999384fDIED(inputs)0.104 ( serious criteria death 0.09 deceased 0.038 death 0.035 shingr 0.034 reported 0.031 ; 0.028 cause 0.027 significant ) 0.026 death was fatal 0.025 pes 0.024 her 0.023 3rd august 0.018 and 0.018 ( 0.017 gsk medically 0.017 . 0.017 on an 0.017 unknown date 0.017 hyl 0.016 on 3rd 0.015 prop 0.015 ( intravenous ). 0.014 unknown c 0.013 the 0.013 of the 0.013 er 0.012 of 0.012 death 0.011 ix 0.011 august 202 0.011 the patient received 0.01 2022 0.01 4 days 0.01 / 0.01 was 0.01 . 0.01 the 0.01 death 0.01 experienced 0.01 2 0.009 the outcome 0.009 received 0.009 who 0.009 died on 0.009 0.009 in 0.008 shingrix 0.007 axis 0.007 was reported 0.007 the patient 0.006 ost 0.006 z 0.006 on 30th july 2022 0.006 , 0.006 shingrix 0.006 , 0.006 death 0.005 after receiving 0.004 described the occurrence of 0.004 wrong 0.004 ful 0.003 a patient 0.003 , 0.003 , 0.003 . 0.001 and 0.001 patient 0.001 ) 0.0 for 0.0 -0.006 by a lawyer -0.005 this case -0.002 .
inputs
0.0
0.09
deceased
0.01
/
0.004
wrong
0.004
ful
0.038
death
0.031
;
-0.005 / 2
this case
0.007 / 2
was reported
-0.006 / 3
by a lawyer
0.001
and
0.004 / 4
described the occurrence of
0.006
death
0.009
in
0.003 / 2
a patient
0.009
who
0.009
received
0.024
her
0.025
pes
0.006
z
0.006
ost
0.013
er
0.018
(
0.035 / 2
shingr
0.011
ix
0.001
)
0.0
for
0.015
prop
0.017
hyl
0.007 / 2
axis
-0.002
.
0.006 / 5
on 30th july 2022
0.003
,
0.011 / 3
the patient received
0.008 / 3
shingrix
0.015 / 6
( intravenous ).
0.016 / 2
on 3rd
0.011 / 2
august 202
0.01
2
0.003
,
0.01 / 2
4 days
0.005 / 2
after receiving
0.006 / 3
shingrix
0.006
,
0.007 / 2
the patient
0.01
experienced
0.01
death
0.104 / 4
( serious criteria death
0.018
and
0.017 / 4
gsk medically
0.027 / 2
significant )
0.003
.
0.017 / 2
on an
0.017 / 2
unknown date
0.006
,
0.009 / 2
the outcome
0.013 / 2
of the
0.026 / 3
death was fatal
0.017
.
0.013
the
0.001
patient
0.009 / 2
died on
0.023 / 2
3rd august
0.01 / 2
2022
0.01
.
0.01
the
0.034
reported
0.028
cause
0.012
of
0.012
death
0.01
was
0.014 / 2
unknown c
0.009