{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
"\n",
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"gather": {
"logged": 1706475754655
},
"nteract": {
"transient": {
"deleting": false
}
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
"Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
"Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
"Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
"Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
"Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
"Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
"Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install accelerate -U"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
"Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
"Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
"Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
"Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
"Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
"Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
"Requirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\n",
"Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
"Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
"Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
"Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
"Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
"Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
"Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
"Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
"Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
"Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
"Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
"Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
"Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
"Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
"Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
"Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
"Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
"Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
"Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
"Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
"Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
"Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
"Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
"Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
"Requirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\n",
"Requirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\n",
"Requirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\n",
"Requirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\n",
"Requirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\n",
"Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
"Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
"Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
"Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
"Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
"Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
"Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
"Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
"Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
"Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
"Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
"Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
"Requirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\n",
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\n",
"Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
"Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
"Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
"Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
"Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
"Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
"Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
"Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
"Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
"Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install transformers datasets shap watermark wandb evaluate codecarbon"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"datalore": {
"hide_input_from_viewers": false,
"hide_output_from_viewers": false,
"node_id": "caZjjFP0OyQNMVgZDiwswE",
"report_properties": {
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
},
"type": "CODE"
},
"gather": {
"logged": 1706503443742
},
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
" View run daedra_0.05-distilbert-base-uncased at: https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/cwkdl3x7
View job at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v3
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Find logs at: ./wandb/run-20240129_152136-cwkdl3x7/logs
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The watermark extension is already loaded. To reload it, use:\n",
" %reload_ext watermark\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import torch\n",
"import os\n",
"from typing import List, Union\n",
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
"from datasets import load_dataset, Dataset, DatasetDict\n",
"import shap\n",
"import wandb\n",
"import evaluate\n",
"import logging\n",
"\n",
"wandb.finish()\n",
"\n",
"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n",
"%load_ext watermark"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706503443899
},
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"SEED: int = 42\n",
"\n",
"BATCH_SIZE: int = 32\n",
"EPOCHS: int = 5\n",
"model_ckpt: str = \"distilbert-base-uncased\"\n",
"\n",
"# WandB configuration\n",
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"re : 2.2.1\n",
"torch : 1.12.0\n",
"wandb : 0.16.2\n",
"logging : 0.5.1.2\n",
"numpy : 1.23.5\n",
"pandas : 2.0.2\n",
"evaluate: 0.4.1\n",
"shap : 0.44.1\n",
"\n"
]
}
],
"source": [
"%watermark --iversion"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"datalore": {
"hide_input_from_viewers": true,
"hide_output_from_viewers": true,
"node_id": "UU2oOJhwbIualogG1YyCMd",
"type": "CODE"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
"Mon Jan 29 15:20:22 2024 \n",
"+---------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
"|-----------------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+======================+======================|\n",
"| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
"| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
"| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
"| N/A 26C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
"| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\n",
"| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
"| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\n",
"| N/A 28C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
" \n",
"+---------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=======================================================================================|\n",
"| No running processes found |\n",
"+---------------------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "markdown",
"metadata": {
"datalore": {
"hide_input_from_viewers": false,
"hide_output_from_viewers": false,
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
"report_properties": {
"rowId": "40nN9Hvgi1clHNV5RAemI5"
},
"type": "MD"
}
},
"source": [
"## Loading the data set"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706503446033
},
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"gather": {
"logged": 1706503446252
},
"jupyter": {
"outputs_hidden": false,
"source_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['id', 'text', 'label'],\n",
" num_rows: 1270444\n",
" })\n",
" test: Dataset({\n",
" features: ['id', 'text', 'label'],\n",
" num_rows: 272238\n",
" })\n",
" val: Dataset({\n",
" features: ['id', 'text', 'label'],\n",
" num_rows: 272238\n",
" })\n",
"})"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"gather": {
"logged": 1706503446498
}
},
"outputs": [],
"source": [
"SUBSAMPLING = 0.01\n",
"\n",
"if SUBSAMPLING < 1:\n",
" _ = DatasetDict()\n",
" for each in dataset.keys():\n",
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
"\n",
" dataset = _"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenisation and encoding"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"gather": {
"logged": 1706503446633
}
},
"outputs": [],
"source": [
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
" return ds_enc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation metrics"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"gather": {
"logged": 1706503446863
}
},
"outputs": [],
"source": [
"accuracy = evaluate.load(\"accuracy\")\n",
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
"f1 = evaluate.load(\"f1\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"gather": {
"logged": 1706503447004
}
},
"outputs": [],
"source": [
"def compute_metrics(eval_pred):\n",
" predictions, labels = eval_pred\n",
" predictions = np.argmax(predictions, axis=1)\n",
" return {\n",
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"gather": {
"logged": 1706503447186
}
},
"outputs": [],
"source": [
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"jupyter": {
"outputs_hidden": false,
"source_hidden": false
},
"nteract": {
"transient": {
"deleting": false
}
}
},
"outputs": [],
"source": [
"def train_from_model(model_ckpt: str, push: bool = False):\n",
" print(f\"Initialising training based on {model_ckpt}...\")\n",
"\n",
" print(\"Tokenising...\")\n",
" tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
"\n",
" cols = dataset[\"train\"].column_names\n",
" cols.remove(\"label\")\n",
" ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
"\n",
" print(\"Loading model...\")\n",
" try:\n",
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
" id2label=label_map, \n",
" label2id={v:k for k,v in label_map.items()})\n",
" except OSError:\n",
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
" id2label=label_map, \n",
" label2id={v:k for k,v in label_map.items()},\n",
" from_tf=True)\n",
"\n",
"\n",
" args = TrainingArguments(\n",
" output_dir=\"vaers\",\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" num_train_epochs=EPOCHS,\n",
" weight_decay=.01,\n",
" logging_steps=1,\n",
" load_best_model_at_end=True,\n",
" run_name=f\"daedra-training\",\n",
" report_to=[\"wandb\"])\n",
"\n",
" trainer = Trainer(\n",
" model=model,\n",
" args=args,\n",
" train_dataset=ds_enc[\"train\"],\n",
" eval_dataset=ds_enc[\"test\"],\n",
" tokenizer=tokenizer,\n",
" compute_metrics=compute_metrics)\n",
" \n",
" if SUBSAMPLING != 1.0:\n",
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
" else:\n",
" wandb_tag: List[str] = [f\"full_sample\"]\n",
"\n",
" wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
" wandb_tag.append(f\"base:{model_ckpt}\")\n",
" \n",
" wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
"\n",
" print(\"Starting training...\")\n",
"\n",
" trainer.train()\n",
"\n",
" print(\"Training finished.\")\n",
"\n",
" if push:\n",
" variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
" tokenizer._tokenizer.save(\"tokenizer.json\")\n",
" tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
" sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
"\n",
" model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
" variant=variant,\n",
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"gather": {
"logged": 1706503552083
}
},
"outputs": [],
"source": [
"\n",
"base_models = [\n",
" \"bert-base-uncased\",\n",
" \"distilbert-base-uncased\",\n",
"]"
]
}
],
"metadata": {
"datalore": {
"base_environment": "default",
"computation_mode": "JUPYTER",
"package_manager": "pip",
"packages": [
{
"name": "datasets",
"source": "PIP",
"version": "2.16.1"
},
{
"name": "torch",
"source": "PIP",
"version": "2.1.2"
},
{
"name": "accelerate",
"source": "PIP",
"version": "0.26.1"
}
],
"report_row_ids": [
"un8W7ez7ZwoGb5Co6nydEV",
"40nN9Hvgi1clHNV5RAemI5",
"TgRD90H5NSPpKS41OeXI1w",
"ZOm5BfUs3h1EGLaUkBGeEB",
"kOP0CZWNSk6vqE3wkPp7Vc",
"W4PWcOu2O2pRaZyoE2W80h",
"RolbOnQLIftk0vy9mIcz5M",
"8OPhUgbaNJmOdiq5D3a6vK",
"5Qrt3jSvSrpK6Ne1hS6shL",
"hTq7nFUrovN5Ao4u6dIYWZ",
"I8WNZLpJ1DVP2wiCW7YBIB",
"SawhU3I9BewSE1XBPstpNJ",
"80EtLEl2FIE4FqbWnUD3nT"
],
"version": 3
},
"kernel_info": {
"name": "python38-azureml-pt-tf"
},
"kernelspec": {
"display_name": "azureml_py38_PT_TF",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"microsoft": {
"host": {
"AzureML": {
"notebookHasBeenCompleted": true
}
},
"ms_spell_check": {
"ms_spell_check_language": "en"
}
},
"nteract": {
"version": "nteract-front-end@1.0.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}