{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
"\n",
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"nteract": {
"transient": {
"deleting": false
}
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
"Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
"Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
"Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
"Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
"Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
"Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"# %pip install accelerate -U"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true,
"source_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
"Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
"Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
"Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
"Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
"Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
"Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
"Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
"Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
"Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
"Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
"Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
"Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
"Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
"Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
"Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
"Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
"Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
"Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
"Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
"Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
"Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
"Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
"Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
"Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
"Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
"Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
"Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
"Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
"Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
"Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
"Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
"Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
"Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
"Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
"Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
"Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
"Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
"Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
"Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\n",
"Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
"Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
"Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
"Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
"Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
"Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
"Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
"Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
"Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"# %pip install transformers datasets shap watermark wandb"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"datalore": {
"hide_input_from_viewers": false,
"hide_output_from_viewers": false,
"node_id": "caZjjFP0OyQNMVgZDiwswE",
"report_properties": {
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
},
"type": "CODE"
},
"gather": {
"logged": 1706449625034
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The watermark extension is already loaded. To reload it, use:\n",
" %reload_ext watermark\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import torch\n",
"import os\n",
"from typing import List\n",
"from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
"from datasets import load_dataset, Dataset, DatasetDict\n",
"from pyarrow import Table\n",
"import shap\n",
"import wandb\n",
"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n",
"%load_ext watermark"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706449721319
},
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"SEED: int = 42\n",
"\n",
"BATCH_SIZE: int = 32\n",
"EPOCHS: int = 3\n",
"model_ckpt: str = \"distilbert-base-uncased\"\n",
"\n",
"CLASS_NAMES: List[str] = [\"DIED\",\n",
" \"ER_VISIT\",\n",
" \"HOSPITAL\",\n",
" \"OFC_VISIT\",\n",
" #\"X_STAY\", # pruned\n",
" #\"DISABLE\", # pruned\n",
" #\"D_PRESENTED\" # pruned\n",
" ]\n",
"\n",
"\n",
"\n",
"\n",
"# WandB configuration\n",
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shap : 0.44.1\n",
"torch : 1.12.0\n",
"logging: 0.5.1.2\n",
"numpy : 1.23.5\n",
"pandas : 2.0.2\n",
"re : 2.2.1\n",
"\n"
]
}
],
"source": [
"%watermark --iversion"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "UU2oOJhwbIualogG1YyCMd",
"type": "CODE"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sun Jan 28 13:54:22 2024 \n",
"+---------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
"|-----------------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+======================+======================|\n",
"| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
"| N/A 30C P0 38W / 250W | 12830MiB / 16384MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
"| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
"| N/A 30C P0 38W / 250W | 11960MiB / 16384MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
" \n",
"+---------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=======================================================================================|\n",
"| 0 N/A N/A 11781 C .../envs/azureml_py38_PT_TF/bin/python 12826MiB |\n",
"| 1 N/A N/A 11781 C .../envs/azureml_py38_PT_TF/bin/python 11956MiB |\n",
"+---------------------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "markdown",
"metadata": {
"datalore": {
"hide_input_from_viewers": false,
"hide_output_from_viewers": false,
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
"report_properties": {
"rowId": "40nN9Hvgi1clHNV5RAemI5"
},
"type": "MD"
}
},
"source": [
"## Loading the data set"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706449040507
},
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706449044205
},
"jupyter": {
"outputs_hidden": false,
"source_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['id', 'text', 'labels'],\n",
" num_rows: 1270444\n",
" })\n",
" test: Dataset({\n",
" features: ['id', 'text', 'labels'],\n",
" num_rows: 272238\n",
" })\n",
" val: Dataset({\n",
" features: ['id', 'text', 'labels'],\n",
" num_rows: 272238\n",
" })\n",
"})"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"SUBSAMPLING: float = 0.1"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706449378281
},
"jupyter": {
"outputs_hidden": false,
"source_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [],
"source": [
"def minisample(ds: DatasetDict, fraction: float) -> DatasetDict:\n",
" res = DatasetDict()\n",
"\n",
" res[\"train\"] = Dataset.from_dict(ds[\"train\"].shuffle()[:round(len(ds[\"train\"]) * fraction)])\n",
" res[\"test\"] = Dataset.from_dict(ds[\"test\"].shuffle()[:round(len(ds[\"test\"]) * fraction)])\n",
" res[\"val\"] = Dataset.from_dict(ds[\"val\"].shuffle()[:round(len(ds[\"val\"]) * fraction)])\n",
" \n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706449384162
},
"jupyter": {
"outputs_hidden": false,
"source_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [],
"source": [
"dataset = minisample(dataset, SUBSAMPLING)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706449387981
},
"jupyter": {
"outputs_hidden": false,
"source_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['id', 'text', 'labels'],\n",
" num_rows: 127044\n",
" })\n",
" test: Dataset({\n",
" features: ['id', 'text', 'labels'],\n",
" num_rows: 27224\n",
" })\n",
" val: Dataset({\n",
" features: ['id', 'text', 'labels'],\n",
" num_rows: 27224\n",
" })\n",
"})"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"nteract": {
"transient": {
"deleting": false
}
}
},
"source": [
"We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`."
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706449443055
},
"jupyter": {
"outputs_hidden": false,
"source_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [],
"source": [
"ds = DatasetDict()\n",
"\n",
"for i in [\"test\", \"train\", \"val\"]:\n",
" tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n",
" ds[i] = Dataset(tab)\n",
"\n",
"dataset = ds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tokenisation and encoding"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "I7n646PIscsUZRoHu6m7zm",
"type": "CODE"
},
"gather": {
"logged": 1706449638377
}
},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "QBLOSI0yVIslV7v7qX9ZC3",
"type": "CODE"
},
"gather": {
"logged": 1706449642580
}
},
"outputs": [],
"source": [
"def tokenize_and_encode(examples):\n",
" return tokenizer(examples[\"text\"], truncation=True)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "slHeNysZOX9uWS9PB7jFDb",
"type": "CODE"
},
"gather": {
"logged": 1706449721161
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 100%|██████████| 27224/27224 [00:11<00:00, 2347.91 examples/s]\n",
"Map: 100%|██████████| 127044/127044 [00:52<00:00, 2417.41 examples/s]\n",
"Map: 100%|██████████| 27224/27224 [00:11<00:00, 2376.02 examples/s]\n"
]
}
],
"source": [
"cols = dataset[\"train\"].column_names\n",
"cols.remove(\"labels\")\n",
"ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "itXWkbDw9sqbkMuDP84QoT",
"type": "CODE"
},
"gather": {
"logged": 1706449743072
}
},
"outputs": [],
"source": [
"class MultiLabelTrainer(Trainer):\n",
" def compute_loss(self, model, inputs, return_outputs=False):\n",
" labels = inputs.pop(\"labels\")\n",
" outputs = model(**inputs)\n",
" logits = outputs.logits\n",
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
" loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
" labels.float().view(-1, self.model.config.num_labels))\n",
" return (loss, outputs) if return_outputs else loss"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "ZQU7aW6TV45VmhHOQRzcnF",
"type": "CODE"
},
"gather": {
"logged": 1706449761205
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "swhgyyyxoGL8HjnXJtMuSW",
"type": "CODE"
},
"gather": {
"logged": 1706449761541
}
},
"outputs": [],
"source": [
"def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
" y_pred = torch.from_numpy(y_pred)\n",
" y_true = torch.from_numpy(y_true)\n",
"\n",
" if sigmoid:\n",
" y_pred = y_pred.sigmoid()\n",
"\n",
" return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "1Uq3HtkaBxtHNAnSwit5cI",
"type": "CODE"
},
"gather": {
"logged": 1706449761720
}
},
"outputs": [],
"source": [
"def compute_metrics(eval_pred):\n",
" predictions, labels = eval_pred\n",
" return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "1iPZOTKPwSkTgX5dORqT89",
"type": "CODE"
},
"gather": {
"logged": 1706449761893
}
},
"outputs": [],
"source": [
"args = TrainingArguments(\n",
" output_dir=\"vaers\",\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" num_train_epochs=EPOCHS,\n",
" weight_decay=.01,\n",
" logging_steps=1,\n",
" run_name=f\"daedra-training\",\n",
" report_to=[\"wandb\"]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "bnRkNvRYltLun6gCEgL7v0",
"type": "CODE"
},
"gather": {
"logged": 1706449769103
}
},
"outputs": [],
"source": [
"multi_label_trainer = MultiLabelTrainer(\n",
" model, \n",
" args, \n",
" train_dataset=ds_enc[\"train\"], \n",
" eval_dataset=ds_enc[\"test\"], \n",
" compute_metrics=compute_metrics, \n",
" tokenizer=tokenizer\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "LO54PlDkWQdFrzV25FvduB",
"type": "CODE"
},
"gather": {
"logged": 1706449880674
}
},
"outputs": [
{
"data": {
"text/html": [
"Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to the W&B docs."
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to the W&B docs."
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
]
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.16.2"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in /mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141352-spfdhiij
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run init_evaluation_run to Weights & Biases (docs)
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Finishing last run (ID:spfdhiij) before initializing another..."
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run init_evaluation_run at: https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Find logs at: ./wandb/run-20240128_141352-spfdhiij/logs
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Successfully finished last run (ID:spfdhiij). Initializing new run:
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.16.2"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in /mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141354-mpe6cpuz
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run init_evaluation_run to Weights & Biases (docs)
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"Run history:
eval/accuracy_thresh | ▁ |
eval/loss | ▁ |
eval/runtime | ▁ |
eval/samples_per_second | ▁ |
eval/steps_per_second | ▁ |
train/global_step | ▁ |
Run summary:
eval/accuracy_thresh | 0.42136 |
eval/loss | 0.69069 |
eval/runtime | 79.1475 |
eval/samples_per_second | 343.965 |
eval/steps_per_second | 2.691 |
train/global_step | 0 |
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run init_evaluation_run at: https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Find logs at: ./wandb/run-20240128_141354-mpe6cpuz/logs
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if SUBSAMPLING != 1.0:\n",
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
"else:\n",
" wandb_tag: List[str] = [f\"full_sample\"]\n",
" \n",
"wandb.init(name=\"init_evaluation_run\", tags=wandb_tag, magic=True)\n",
"\n",
"multi_label_trainer.evaluate()\n",
"wandb.finish()"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
"type": "CODE"
},
"gather": {
"logged": 1706449934637
}
},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Caught RuntimeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py\", line 61, in _worker\n output = module(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 1002, in forward\n distilbert_output = self.distilbert(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 822, in forward\n return self.transformer(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 587, in forward\n layer_outputs = layer_module(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 513, in forward\n sa_output = self.attention(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 243, in forward\n scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)\nRuntimeError: CUDA out of memory. Tried to allocate 96.00 MiB (GPU 0; 15.77 GiB total capacity; 14.69 GiB already allocated; 5.12 MiB free; 14.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[62], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmulti_label_trainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/trainer.py:1539\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1537\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1539\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1542\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1544\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/trainer.py:1869\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1866\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 1868\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 1869\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 1872\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 1873\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m 1874\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 1875\u001b[0m ):\n\u001b[1;32m 1876\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 1877\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/trainer.py:2768\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2765\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 2767\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 2768\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2770\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 2771\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean() \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
"Cell \u001b[0;32mIn[55], line 4\u001b[0m, in \u001b[0;36mMultiLabelTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_loss\u001b[39m(\u001b[38;5;28mself\u001b[39m, model, inputs, return_outputs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 3\u001b[0m labels \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m logits \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits\n\u001b[1;32m 6\u001b[0m loss_fct \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mBCEWithLogitsLoss()\n",
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:168\u001b[0m, in \u001b[0;36mDataParallel.forward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule(\u001b[38;5;241m*\u001b[39minputs[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 167\u001b[0m replicas \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreplicate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(inputs)])\n\u001b[0;32m--> 168\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgather(outputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_device)\n",
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:178\u001b[0m, in \u001b[0;36mDataParallel.parallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_apply\u001b[39m(\u001b[38;5;28mself\u001b[39m, replicas, inputs, kwargs):\n\u001b[0;32m--> 178\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py:86\u001b[0m, in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 84\u001b[0m output \u001b[38;5;241m=\u001b[39m results[i]\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, ExceptionWrapper):\n\u001b[0;32m---> 86\u001b[0m \u001b[43moutput\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreraise\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 87\u001b[0m outputs\u001b[38;5;241m.\u001b[39mappend(output)\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/_utils.py:461\u001b[0m, in \u001b[0;36mExceptionWrapper.reraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 457\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 458\u001b[0m \u001b[38;5;66;03m# If the exception takes multiple arguments, don't try to\u001b[39;00m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;66;03m# instantiate since we don't know how to\u001b[39;00m\n\u001b[1;32m 460\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 461\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exception\n",
"\u001b[0;31mRuntimeError\u001b[0m: Caught RuntimeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py\", line 61, in _worker\n output = module(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 1002, in forward\n distilbert_output = self.distilbert(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 822, in forward\n return self.transformer(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 587, in forward\n layer_outputs = layer_module(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 513, in forward\n sa_output = self.attention(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 243, in forward\n scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)\nRuntimeError: CUDA out of memory. Tried to allocate 96.00 MiB (GPU 0; 15.77 GiB total capacity; 14.69 GiB already allocated; 5.12 MiB free; 14.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n"
]
}
],
"source": [
"if SUBSAMPLING != 1.0:\n",
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
"else:\n",
" wandb_tag: List[str] = [f\"full_sample\"]\n",
" \n",
"wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)\n",
"\n",
"multi_label_trainer.train()\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We instantiate a classifier `pipeline` and push it to CUDA."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "kHoUdBeqcyVXDSGv54C4aE",
"type": "CODE"
},
"gather": {
"logged": 1706411459928
}
},
"outputs": [],
"source": [
"classifier = pipeline(\"text-classification\", \n",
" model, \n",
" tokenizer=tokenizer, \n",
" device=\"cuda:0\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use the same tokenizer used for training to tokenize/encode the validation set."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
"type": "CODE"
},
"gather": {
"logged": 1706411523285
}
},
"outputs": [],
"source": [
"test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
" max_length=None, \n",
" padding='max_length', \n",
" return_token_type_ids=True, \n",
" truncation=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once we've made the data loadable by putting it into a `DataLoader`, we "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "MWfGq2tTkJNzFiDoUPq2X7",
"type": "CODE"
},
"gather": {
"logged": 1706411543379
}
},
"outputs": [],
"source": [
"test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
" torch.tensor(test_encodings['attention_mask']), \n",
" torch.tensor(ds_enc[\"val\"][\"labels\"]), \n",
" torch.tensor(test_encodings['token_type_ids']))\n",
"test_dataloader = torch.utils.data.DataLoader(test_data, \n",
" sampler=torch.utils.data.SequentialSampler(test_data), \n",
" batch_size=BATCH_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "1SJCSrQTRCexFCNCIyRrzL",
"type": "CODE"
},
"gather": {
"logged": 1706411587843
}
},
"outputs": [],
"source": [
"model.eval()\n",
"\n",
"logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
"\n",
"for i, batch in enumerate(test_dataloader):\n",
" batch = tuple(t.to(device) for t in batch)\n",
" \n",
" # Unpack the inputs from our dataloader\n",
" b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
" \n",
" with torch.no_grad():\n",
" outs = model(b_input_ids, attention_mask=b_input_mask)\n",
" b_logit_pred = outs[0]\n",
" pred_label = torch.sigmoid(b_logit_pred)\n",
"\n",
" b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
" pred_label = pred_label.to('cpu').numpy()\n",
" b_labels = b_labels.to('cpu').numpy()\n",
"\n",
" tokenized_texts.append(b_input_ids)\n",
" logit_preds.append(b_logit_pred)\n",
" true_labels.append(b_labels)\n",
" pred_labels.append(pred_label)\n",
"\n",
"# Flatten outputs\n",
"tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
"pred_labels = [item for sublist in pred_labels for item in sublist]\n",
"true_labels = [item for sublist in true_labels for item in sublist]\n",
"\n",
"# Converting flattened binary values to boolean values\n",
"true_bools = [tl == 1 for tl in true_labels]\n",
"pred_bools = [pl > 0.50 for pl in pred_labels] "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We create a classification report:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "eBprrgF086mznPbPVBpOLS",
"type": "CODE"
},
"gather": {
"logged": 1706411588249
}
},
"outputs": [],
"source": [
"print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
"print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
"clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
"print(clf_report)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "yELHY0IEwMlMw3x6e7hoD1",
"type": "CODE"
},
"gather": {
"logged": 1706411588638
}
},
"outputs": [],
"source": [
"# Creating a map of class names from class numbers\n",
"idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "jH0S35dDteUch01sa6me6e",
"type": "CODE"
},
"gather": {
"logged": 1706411589004
}
},
"outputs": [],
"source": [
"true_label_idxs, pred_label_idxs = [], []\n",
"\n",
"for vals in true_bools:\n",
" true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
"for vals in pred_bools:\n",
" pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "h4vHL8XdGpayZ6xLGJUF6F",
"type": "CODE"
},
"gather": {
"logged": 1706411589301
}
},
"outputs": [],
"source": [
"true_label_texts, pred_label_texts = [], []\n",
"\n",
"for vals in true_label_idxs:\n",
" if vals:\n",
" true_label_texts.append([idx2label[val] for val in vals])\n",
" else:\n",
" true_label_texts.append(vals)\n",
"\n",
"for vals in pred_label_idxs:\n",
" if vals:\n",
" pred_label_texts.append([idx2label[val] for val in vals])\n",
" else:\n",
" pred_label_texts.append(vals)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "SxUmVHfQISEeptg1SawOmB",
"type": "CODE"
},
"gather": {
"logged": 1706411591952
}
},
"outputs": [],
"source": [
"symptom_texts = [tokenizer.decode(text,\n",
" skip_special_tokens=True,\n",
" clean_up_tokenization_spaces=False) for text in tokenized_texts]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "BxFNigNGRLTOqraI55BPSH",
"type": "CODE"
},
"gather": {
"logged": 1706411592512
}
},
"outputs": [],
"source": [
"comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
" 'true_labels': true_label_texts, \n",
" 'pred_labels':pred_label_texts})\n",
"comparisons_df.to_csv('comparisons.csv')\n",
"comparisons_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Shapley analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "OpdZcoenX2HwzLdai7K5UA",
"type": "CODE"
},
"gather": {
"logged": 1706415109071
}
},
"outputs": [],
"source": [
"explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
]
},
{
"cell_type": "markdown",
"metadata": {
"nteract": {
"transient": {
"deleting": false
}
}
},
"source": [
"#### Sampling correct predictions\n",
"\n",
"First, let's look at some correct predictions of deaths:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706414973990
},
"jupyter": {
"outputs_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [],
"source": [
"correct_death_predictions = comparisons_df[comparisons_df['true_labels'].astype(str) == \"['DIED']\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706415114683
},
"jupyter": {
"outputs_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [],
"source": [
"texts = [i[:512] for i in correct_death_predictions.sample(n=6).symptom_text]\n",
"idxs = [i for i in range(len(texts))]\n",
"\n",
"d_s = Dataset(Table.from_arrays([idxs, texts], names=[\"idx\", \"texts\"]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706415129229
},
"jupyter": {
"outputs_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [],
"source": [
"shap_values = explainer(d_s[\"texts\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706415151494
},
"jupyter": {
"outputs_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [],
"source": [
"shap.plots.text(shap_values)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"datalore": {
"base_environment": "default",
"computation_mode": "JUPYTER",
"package_manager": "pip",
"packages": [
{
"name": "datasets",
"source": "PIP",
"version": "2.16.1"
},
{
"name": "torch",
"source": "PIP",
"version": "2.1.2"
},
{
"name": "accelerate",
"source": "PIP",
"version": "0.26.1"
}
],
"report_row_ids": [
"un8W7ez7ZwoGb5Co6nydEV",
"40nN9Hvgi1clHNV5RAemI5",
"TgRD90H5NSPpKS41OeXI1w",
"ZOm5BfUs3h1EGLaUkBGeEB",
"kOP0CZWNSk6vqE3wkPp7Vc",
"W4PWcOu2O2pRaZyoE2W80h",
"RolbOnQLIftk0vy9mIcz5M",
"8OPhUgbaNJmOdiq5D3a6vK",
"5Qrt3jSvSrpK6Ne1hS6shL",
"hTq7nFUrovN5Ao4u6dIYWZ",
"I8WNZLpJ1DVP2wiCW7YBIB",
"SawhU3I9BewSE1XBPstpNJ",
"80EtLEl2FIE4FqbWnUD3nT"
],
"version": 3
},
"kernelspec": {
"display_name": "Python 3.8 - Pytorch and Tensorflow",
"language": "python",
"name": "python38-azureml-pt-tf"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"microsoft": {
"host": {
"AzureML": {
"notebookHasBeenCompleted": true
}
},
"ms_spell_check": {
"ms_spell_check_language": "en"
}
},
"nteract": {
"version": "nteract-front-end@1.0.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}