{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n", "\n", "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "nteract": { "transient": { "deleting": false } }, "tags": [] }, "outputs": [], "source": [ "# %pip install accelerate -U" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "nteract": { "transient": { "deleting": false } } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n", "Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n", "Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n", "Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n", "Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n", "Requirement already satisfied: scikit-multilearn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.2.0)\n", "Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n", "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n", "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n", "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n", "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n", "Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n", "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n", "Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n", "Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n", "Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n", "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n", "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n", "Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n", "Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n", "Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n", "Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n", "Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n", "Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n", "Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n", "Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n", "Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n", "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n", "Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n", "Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n", "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n", "Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n", "Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n", "Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n", "Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n", "Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n", "Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n", "Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n", "Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n", "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n", "Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n", "Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n", "Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n", "Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n", "Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n", "Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n", "Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n", "Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n", "Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n", "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n", "Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n", "Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n", "Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n", "Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n", "Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n", "Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n", "Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install transformers datasets shap watermark wandb scikit-multilearn evaluate codecarbon" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "datalore": { "hide_input_from_viewers": false, "hide_output_from_viewers": false, "node_id": "caZjjFP0OyQNMVgZDiwswE", "report_properties": { "rowId": "un8W7ez7ZwoGb5Co6nydEV" }, "type": "CODE" }, "gather": { "logged": 1706454586481 }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "2024-01-28 19:47:15.508449: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2024-01-28 19:47:16.502791: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2024-01-28 19:47:16.502915: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2024-01-28 19:47:16.502928: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import os\n", "from typing import List, Union\n", "from sklearn.metrics import f1_score, accuracy_score, classification_report\n", "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n", "from datasets import load_dataset, Dataset, DatasetDict\n", "from pyarrow import Table\n", "import shap\n", "import wandb\n", "import evaluate\n", "from codecarbon import EmissionsTracker\n", "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "tracker = EmissionsTracker()\n", "\n", "%load_ext watermark" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "gather": { "logged": 1706454586654 }, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "SEED: int = 42\n", "\n", "BATCH_SIZE: int = 16\n", "EPOCHS: int = 3\n", "model_ckpt: str = \"distilbert-base-uncased\"\n", "\n", "# WandB configuration\n", "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n", "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n", "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\"" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy : 1.23.5\n", "re : 2.2.1\n", "evaluate: 0.4.1\n", "pandas : 2.0.2\n", "wandb : 0.16.2\n", "shap : 0.44.1\n", "torch : 1.12.0\n", "logging : 0.5.1.2\n", "\n" ] } ], "source": [ "%watermark --iversion" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "datalore": { "hide_input_from_viewers": true, "hide_output_from_viewers": true, "node_id": "UU2oOJhwbIualogG1YyCMd", "type": "CODE" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sun Jan 28 19:47:19 2024 \n", "+---------------------------------------------------------------------------------------+\n", "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n", "|-----------------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|=========================================+======================+======================|\n", "| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n", "| N/A 29C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n", "| | | N/A |\n", "+-----------------------------------------+----------------------+----------------------+\n", "| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n", "| N/A 29C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n", "| | | N/A |\n", "+-----------------------------------------+----------------------+----------------------+\n", " \n", "+---------------------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=======================================================================================|\n", "| No running processes found |\n", "+---------------------------------------------------------------------------------------+\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": { "datalore": { "hide_input_from_viewers": false, "hide_output_from_viewers": false, "node_id": "t45KHugmcPVaO0nuk8tGJ9", "report_properties": { "rowId": "40nN9Hvgi1clHNV5RAemI5" }, "type": "MD" } }, "source": [ "## Loading the data set" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "gather": { "logged": 1706449040507 }, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "gather": { "logged": 1706449044205 }, "jupyter": { "outputs_hidden": false, "source_hidden": false }, "nteract": { "transient": { "deleting": false } } }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'text', 'label'],\n", " num_rows: 1270444\n", " })\n", " test: Dataset({\n", " features: ['id', 'text', 'label'],\n", " num_rows: 272238\n", " })\n", " val: Dataset({\n", " features: ['id', 'text', 'label'],\n", " num_rows: 272238\n", " })\n", "})" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "SUBSAMPLING = 0.1\n", "\n", "if SUBSAMPLING < 1:\n", " _ = DatasetDict()\n", " for each in dataset.keys():\n", " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n", "\n", " dataset = _" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tokenisation and encoding" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n", " return ds_enc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation metrics" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "accuracy = evaluate.load(\"accuracy\")\n", "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n", "f1 = evaluate.load(\"f1\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " predictions = np.argmax(predictions, axis=1)\n", " return {\n", " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n", " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n", " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n", " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n", " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n", " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n", " }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :(" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 127044/127044 [00:53<00:00, 2384.54 examples/s]\n", "Map: 100%|██████████| 27223/27223 [00:11<00:00, 2396.71 examples/s]\n", "Map: 100%|██████████| 27223/27223 [00:11<00:00, 2375.38 examples/s]\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n", "\n", "cols = dataset[\"train\"].column_names\n", "cols.remove(\"label\")\n", "ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n", " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n", " id2label=label_map, \n", " label2id={v:k for k,v in label_map.items()})\n", "\n", "args = TrainingArguments(\n", " output_dir=\"vaers\",\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " per_device_eval_batch_size=BATCH_SIZE,\n", " num_train_epochs=EPOCHS,\n", " weight_decay=.01,\n", " logging_steps=1,\n", " load_best_model_at_end=True,\n", " run_name=f\"daedra-training\",\n", " report_to=[\"wandb\"])\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=args,\n", " train_dataset=ds_enc[\"train\"],\n", " eval_dataset=ds_enc[\"test\"],\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.2" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_194842-yvxddyg6" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run daedra_training_run to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Finishing last run (ID:yvxddyg6) before initializing another..." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run daedra_training_run at: https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: ./wandb/run-20240128_194842-yvxddyg6/logs" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Successfully finished last run (ID:yvxddyg6). Initializing new run:
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.2" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_194845-9g8te2gf" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run daedra_training_run to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/9g8te2gf" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "if SUBSAMPLING != 1.0:\n", " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n", "else:\n", " wandb_tag: List[str] = [f\"full_sample\"]\n", "\n", "wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n", "wandb_tag.append(f\"base:{model_ckpt}\")\n", " \n", "wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [ 7943/11913 43:43 < 21:51, 3.03 it/s, Epoch 2/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyPrecision MacroaveragePrecision MicroaverageRecall MacroaverageRecall MicroaverageF1 Microaverage
10.2513000.3629170.8657750.7010810.8657750.5565700.8657750.865775
20.0360000.3521180.8705510.7280510.8705510.6097870.8705510.870551

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3971)... Done. 18.2s\n", "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-7942)... " ] } ], "source": [ "tracker.start()\n", "trainer.train()\n", "tracker.stop()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "

Run history:


eval/accuracy▁▇█
eval/f1_microaverage▁▇█
eval/loss█▃▁
eval/precision_macroaverage▁▇█
eval/precision_microaverage▁▇█
eval/recall_macroaverage▁▇█
eval/recall_microaverage▁▇█
eval/runtime▁▃█
eval/samples_per_second█▆▁
eval/steps_per_second█▆▁
train/epoch▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss█▅▆▆▅▄▄▃▆▅▃▃▅▄▆▄▄▄▂▄▄▅▄▃▄▄▁▄▂▂▃▃▃▂▂▃▂▃▃▂
train/total_flos
train/train_loss
train/train_runtime
train/train_samples_per_second
train/train_steps_per_second

Run summary:


eval/accuracy0.84019
eval/f1_microaverage0.84019
eval/loss0.44011
eval/precision_macroaverage0.415
eval/precision_microaverage0.84019
eval/recall_macroaverage0.40704
eval/recall_microaverage0.84019
eval/runtime10.0118
eval/samples_per_second271.878
eval/steps_per_second8.59
train/epoch3.0
train/global_step1191
train/learning_rate0.0
train/loss0.1782
train/total_flos4885522962505728.0
train/train_loss0.4724
train/train_runtime483.5027
train/train_samples_per_second78.825
train/train_steps_per_second2.463

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run daedra_training_run at: https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/3xvt3c2y
Synced 5 W&B file(s), 0 media file(s), 40 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: ./wandb/run-20240128_192000-3xvt3c2y/logs" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "wandb.finish()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/chrisvoncsefalvay/daedra/commit/c482ca6c8520142a3e67df4be25a408e6b557053', commit_message='DAEDRA model trained on 1.0% of the full sample of the VAERS dataset (training set size: 12,704)', commit_description='', oid='c482ca6c8520142a3e67df4be25a408e6b557053', pr_url=None, pr_revision=None, pr_num=None)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n", "tokenizer._tokenizer.save(\"tokenizer.json\")\n", "tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n", "sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n", "\n", "model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n", " variant=variant,\n", " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from collections import Counter\n", "\n", "def get_most_frequent_unknown_tokens(tokenizer, dataset):\n", " unknown_tokens = []\n", " \n", " # Tokenize each text in the dataset\n", " for example in dataset:\n", " tokens = tokenizer.tokenize(example['text'])\n", " \n", " # Check if each token is the 'unknown' special token\n", " for token in tokens:\n", " if token == tokenizer.unk_token:\n", " unknown_tokens.append(token)\n", " \n", " # Count the frequency of each unique unknown token\n", " token_counts = Counter(unknown_tokens)\n", " \n", " # Sort the tokens based on their frequency in descending order\n", " most_frequent_tokens = token_counts.most_common()\n", " \n", " return most_frequent_tokens\n", "\n", "# Example usage\n", "tokenizer = YourTokenizer() # Replace with your tokenizer\n", "dataset = YourDataset() # Replace with your dataset\n", "\n", "most_frequent_unknown_tokens = get_most_frequent_unknown_tokens(tokenizer, dataset)\n", "print(most_frequent_unknown_tokens)\n" ] } ], "metadata": { "datalore": { "base_environment": "default", "computation_mode": "JUPYTER", "package_manager": "pip", "packages": [ { "name": "datasets", "source": "PIP", "version": "2.16.1" }, { "name": "torch", "source": "PIP", "version": "2.1.2" }, { "name": "accelerate", "source": "PIP", "version": "0.26.1" } ], "report_row_ids": [ "un8W7ez7ZwoGb5Co6nydEV", "40nN9Hvgi1clHNV5RAemI5", "TgRD90H5NSPpKS41OeXI1w", "ZOm5BfUs3h1EGLaUkBGeEB", "kOP0CZWNSk6vqE3wkPp7Vc", "W4PWcOu2O2pRaZyoE2W80h", "RolbOnQLIftk0vy9mIcz5M", "8OPhUgbaNJmOdiq5D3a6vK", "5Qrt3jSvSrpK6Ne1hS6shL", "hTq7nFUrovN5Ao4u6dIYWZ", "I8WNZLpJ1DVP2wiCW7YBIB", "SawhU3I9BewSE1XBPstpNJ", "80EtLEl2FIE4FqbWnUD3nT" ], "version": 3 }, "kernelspec": { "display_name": "Python 3.8 - Pytorch and Tensorflow", "language": "python", "name": "python38-azureml-pt-tf" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "microsoft": { "host": { "AzureML": { "notebookHasBeenCompleted": true } }, "ms_spell_check": { "ms_spell_check_language": "en" } }, "nteract": { "version": "nteract-front-end@1.0.0" } }, "nbformat": 4, "nbformat_minor": 4 }