{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "982e76f5", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:43.297106Z", "iopub.status.busy": "2024-02-29T17:50:43.296820Z", "iopub.status.idle": "2024-02-29T17:50:43.328516Z", "shell.execute_reply": "2024-02-29T17:50:43.327846Z" }, "papermill": { "duration": 0.045827, "end_time": "2024-02-29T17:50:43.330441", "exception": false, "start_time": "2024-02-29T17:50:43.284614", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import joblib\n", "\n", "#joblib.parallel_backend(\"threading\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "675f0b41", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:43.355733Z", "iopub.status.busy": "2024-02-29T17:50:43.354986Z", "iopub.status.idle": "2024-02-29T17:50:43.361827Z", "shell.execute_reply": "2024-02-29T17:50:43.360970Z" }, "papermill": { "duration": 0.021501, "end_time": "2024-02-29T17:50:43.363729", "exception": false, "start_time": "2024-02-29T17:50:43.342228", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'\\n%cd /kaggle/working\\n#!git clone https://github.com/R-N/ml-utility-loss\\n%cd ml-utility-loss\\n!git pull\\n#!pip install .\\n!pip install . --no-deps --force-reinstall --upgrade\\n#'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", "%cd /kaggle/working\n", "#!git clone https://github.com/R-N/ml-utility-loss\n", "%cd ml-utility-loss\n", "!git pull\n", "#!pip install .\n", "!pip install . --no-deps --force-reinstall --upgrade\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "5ae30f5c", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:43.386989Z", "iopub.status.busy": "2024-02-29T17:50:43.386508Z", "iopub.status.idle": "2024-02-29T17:50:43.390341Z", "shell.execute_reply": "2024-02-29T17:50:43.389543Z" }, "papermill": { "duration": 0.017484, "end_time": "2024-02-29T17:50:43.392283", "exception": false, "start_time": "2024-02-29T17:50:43.374799", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "plt.rcParams['figure.figsize'] = [3,3]" ] }, { "cell_type": "code", "execution_count": 4, "id": "9f42c810", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:43.415937Z", "iopub.status.busy": "2024-02-29T17:50:43.415284Z", "iopub.status.idle": "2024-02-29T17:50:43.419146Z", "shell.execute_reply": "2024-02-29T17:50:43.418407Z" }, "executionInfo": { "elapsed": 678, "status": "ok", "timestamp": 1696841022168, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "ns5hFcVL2yvs", "papermill": { "duration": 0.017482, "end_time": "2024-02-29T17:50:43.421010", "exception": false, "start_time": "2024-02-29T17:50:43.403528", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "datasets = [\n", " \"insurance\",\n", " \"treatment\",\n", " \"contraceptive\"\n", "]\n", "\n", "study_dir = \"./\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "85d0c8ce", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:43.444561Z", "iopub.status.busy": "2024-02-29T17:50:43.444080Z", "iopub.status.idle": "2024-02-29T17:50:43.449576Z", "shell.execute_reply": "2024-02-29T17:50:43.448694Z" }, "papermill": { "duration": 0.019709, "end_time": "2024-02-29T17:50:43.451711", "exception": false, "start_time": "2024-02-29T17:50:43.432002", "status": "completed" }, "tags": [ "parameters" ] }, "outputs": [], "source": [ "#Parameters\n", "import os\n", "\n", "path_prefix = \"../../../../\"\n", "\n", "dataset_dir = os.path.join(path_prefix, \"ml-utility-loss/datasets\")\n", "dataset_name = \"treatment\"\n", "model_name=\"ml_utility_2\"\n", "models = [\"tvae\", \"realtabformer\", \"lct_gan\", \"tab_ddpm_concat\"]\n", "single_model = \"lct_gan\"\n", "random_seed = 42\n", "gp = True\n", "gp_multiply = True\n", "folder = \"eval\"\n", "debug = False\n", "path = None\n", "param_index = 0\n", "allow_same_prediction = False" ] }, { "cell_type": "code", "execution_count": 6, "id": "eb7d978f", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:43.479388Z", "iopub.status.busy": "2024-02-29T17:50:43.479092Z", "iopub.status.idle": "2024-02-29T17:50:43.484500Z", "shell.execute_reply": "2024-02-29T17:50:43.483600Z" }, "papermill": { "duration": 0.021864, "end_time": "2024-02-29T17:50:43.486766", "exception": false, "start_time": "2024-02-29T17:50:43.464902", "status": "completed" }, "tags": [ "injected-parameters" ] }, "outputs": [], "source": [ "# Parameters\n", "dataset = \"insurance\"\n", "dataset_name = \"insurance\"\n", "single_model = \"tvae\"\n", "gp = False\n", "gp_multiply = False\n", "random_seed = 4\n", "debug = False\n", "folder = \"eval\"\n", "path_prefix = \"../../../../\"\n", "path = \"eval/insurance/tvae/4\"\n", "param_index = 2\n", "allow_same_prediction = True\n" ] }, { "cell_type": "code", "execution_count": null, "id": "bd7c02d6", "metadata": { "papermill": { "duration": 0.011535, "end_time": "2024-02-29T17:50:43.510177", "exception": false, "start_time": "2024-02-29T17:50:43.498642", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 7, "id": "5f45b1d0", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:43.535341Z", "iopub.status.busy": "2024-02-29T17:50:43.534859Z", "iopub.status.idle": "2024-02-29T17:50:43.544013Z", "shell.execute_reply": "2024-02-29T17:50:43.543203Z" }, "executionInfo": { "elapsed": 7, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "UdvXYv3c3LXy", "papermill": { "duration": 0.023834, "end_time": "2024-02-29T17:50:43.545888", "exception": false, "start_time": "2024-02-29T17:50:43.522054", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working\n", "/kaggle/working/eval/insurance/tvae/4\n" ] } ], "source": [ "from pathlib import Path\n", "import os\n", "\n", "%cd /kaggle/working/\n", "\n", "if path is None:\n", " path = os.path.join(folder, dataset_name, single_model, random_seed)\n", "Path(path).mkdir(parents=True, exist_ok=True)\n", "\n", "%cd {path}" ] }, { "cell_type": "code", "execution_count": 8, "id": "f85bf540", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:43.570861Z", "iopub.status.busy": "2024-02-29T17:50:43.570608Z", "iopub.status.idle": "2024-02-29T17:50:45.715341Z", "shell.execute_reply": "2024-02-29T17:50:45.714481Z" }, "papermill": { "duration": 2.159597, "end_time": "2024-02-29T17:50:45.717481", "exception": false, "start_time": "2024-02-29T17:50:43.557884", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Set seed to \n" ] } ], "source": [ "from ml_utility_loss.util import seed\n", "if single_model:\n", " model_name=f\"{model_name}_{single_model}\"\n", "if random_seed is not None:\n", " seed(random_seed)\n", " print(\"Set seed to\", seed)" ] }, { "cell_type": "code", "execution_count": 9, "id": "8489feae", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:45.744285Z", "iopub.status.busy": "2024-02-29T17:50:45.743770Z", "iopub.status.idle": "2024-02-29T17:50:45.755553Z", "shell.execute_reply": "2024-02-29T17:50:45.754865Z" }, "papermill": { "duration": 0.027328, "end_time": "2024-02-29T17:50:45.757462", "exception": false, "start_time": "2024-02-29T17:50:45.730134", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import json\n", "import os\n", "\n", "df = pd.read_csv(os.path.join(dataset_dir, f\"{dataset_name}.csv\"))\n", "with open(os.path.join(dataset_dir, f\"{dataset_name}.json\")) as f:\n", " info = json.load(f)" ] }, { "cell_type": "code", "execution_count": 10, "id": "debcc684", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:45.780934Z", "iopub.status.busy": "2024-02-29T17:50:45.780680Z", "iopub.status.idle": "2024-02-29T17:50:45.787778Z", "shell.execute_reply": "2024-02-29T17:50:45.787084Z" }, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "Vrl2QkoV3o_8", "papermill": { "duration": 0.020988, "end_time": "2024-02-29T17:50:45.789578", "exception": false, "start_time": "2024-02-29T17:50:45.768590", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "task = info[\"task\"]\n", "target = info[\"target\"]\n", "cat_features = info[\"cat_features\"]\n", "mixed_features = info[\"mixed_features\"]\n", "longtail_features = info[\"longtail_features\"]\n", "integer_features = info[\"integer_features\"]\n", "\n", "test = df.sample(frac=0.2, random_state=42)\n", "train = df[~df.index.isin(test.index)]" ] }, { "cell_type": "code", "execution_count": 11, "id": "7538184a", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:45.813449Z", "iopub.status.busy": "2024-02-29T17:50:45.812848Z", "iopub.status.idle": "2024-02-29T17:50:45.911458Z", "shell.execute_reply": "2024-02-29T17:50:45.910597Z" }, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "TilUuFk9vqMb", "papermill": { "duration": 0.112764, "end_time": "2024-02-29T17:50:45.913541", "exception": false, "start_time": "2024-02-29T17:50:45.800777", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import ml_utility_loss.synthesizers.tab_ddpm.params as TAB_DDPM_PARAMS\n", "import ml_utility_loss.synthesizers.lct_gan.params as LCT_GAN_PARAMS\n", "import ml_utility_loss.synthesizers.realtabformer.params as RTF_PARAMS\n", "from ml_utility_loss.synthesizers.realtabformer.params.default import GPT2_PARAMS, REALTABFORMER_PARAMS\n", "from ml_utility_loss.util import filter_dict_2, filter_dict\n", "\n", "tab_ddpm_params = getattr(TAB_DDPM_PARAMS, dataset_name).BEST\n", "lct_gan_params = getattr(LCT_GAN_PARAMS, dataset_name).BEST\n", "lct_ae_params = filter_dict_2(lct_gan_params, LCT_GAN_PARAMS.default.AE_PARAMS)\n", "rtf_params = getattr(RTF_PARAMS, dataset_name).BEST\n", "rtf_params = filter_dict(rtf_params, REALTABFORMER_PARAMS)\n", "\n", "lct_ae_embedding_size=lct_gan_params[\"embedding_size\"]\n", "tab_ddpm_normalization=\"quantile\"\n", "tab_ddpm_cat_encoding=tab_ddpm_params[\"cat_encoding\"]\n", "#tab_ddpm_cat_encoding=\"one-hot\"\n", "tab_ddpm_y_policy=\"default\"\n", "tab_ddpm_is_y_cond=True" ] }, { "cell_type": "code", "execution_count": 12, "id": "cca61838", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:45.939348Z", "iopub.status.busy": "2024-02-29T17:50:45.939072Z", "iopub.status.idle": "2024-02-29T17:50:50.480136Z", "shell.execute_reply": "2024-02-29T17:50:50.479366Z" }, "executionInfo": { "elapsed": 3113, "status": "ok", "timestamp": 1696841025277, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "7Abt8nStvr9Z", "papermill": { "duration": 4.556435, "end_time": "2024-02-29T17:50:50.482677", "exception": false, "start_time": "2024-02-29T17:50:45.926242", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-02-29 17:50:48.176006: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-02-29 17:50:48.176066: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-02-29 17:50:48.177642: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_lct_ae\n", "\n", "lct_ae = load_lct_ae(\n", " dataset_name=dataset_name,\n", " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", " model_name=\"lct_ae\",\n", " df_name=\"df\",\n", ")\n", "lct_ae = None" ] }, { "cell_type": "code", "execution_count": 13, "id": "6f83b7b6", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:50.508054Z", "iopub.status.busy": "2024-02-29T17:50:50.507215Z", "iopub.status.idle": "2024-02-29T17:50:50.513947Z", "shell.execute_reply": "2024-02-29T17:50:50.513275Z" }, "papermill": { "duration": 0.021175, "end_time": "2024-02-29T17:50:50.515900", "exception": false, "start_time": "2024-02-29T17:50:50.494725", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_rtf_embed\n", "\n", "rtf_embed = load_rtf_embed(\n", " dataset_name=dataset_name,\n", " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", " model_name=\"realtabformer\",\n", " df_name=\"df\",\n", " ckpt_type=\"best-disc-model\"\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "0026de74", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:50.541991Z", "iopub.status.busy": "2024-02-29T17:50:50.541717Z", "iopub.status.idle": "2024-02-29T17:50:58.595638Z", "shell.execute_reply": "2024-02-29T17:50:58.594710Z" }, "executionInfo": { "elapsed": 20137, "status": "ok", "timestamp": 1696841045408, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "tbaguWxAvtPi", "papermill": { "duration": 8.069751, "end_time": "2024-02-29T17:50:58.598157", "exception": false, "start_time": "2024-02-29T17:50:50.528406", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n", "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:119: ConvergenceWarning: Number of distinct clusters (6) found smaller than n_clusters (10). Possibly due to duplicate points in X.\n", " .fit(X)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/1 [00:00 torch.Tensor>,\n", " 'single_model': True,\n", " 'bias': True,\n", " 'bias_final': True,\n", " 'pma_ffn_mode': 'shared',\n", " 'patience': 10,\n", " 'inds_init_mode': 'fixnorm',\n", " 'grad_clip': 0.77,\n", " 'head_final_mul': 'identity',\n", " 'gradient_penalty_mode': {'gradient_penalty': False,\n", " 'calc_grad_m': False,\n", " 'avg_non_role_model_m': False,\n", " 'inverse_avg_non_role_model_m': False},\n", " 'synth_data': 2,\n", " 'dataset_size': 2048,\n", " 'batch_size': 8,\n", " 'epochs': 100,\n", " 'n_warmup_steps': 100,\n", " 'Optim': torch_optimizer.diffgrad.DiffGrad,\n", " 'loss_balancer_beta': 0.75,\n", " 'loss_balancer_r': 0.95,\n", " 'fixed_role_model': 'tvae',\n", " 'd_model': 256,\n", " 'attn_activation': torch.nn.modules.activation.LeakyReLU,\n", " 'tf_d_inner': 512,\n", " 'tf_n_layers_enc': 4,\n", " 'tf_n_head': 64,\n", " 'tf_activation': torch.nn.modules.activation.ReLU6,\n", " 'tf_activation_final': ml_utility_loss.activations.LeakyHardsigmoid,\n", " 'ada_d_hid': 1024,\n", " 'ada_n_layers': 7,\n", " 'ada_activation': torch.nn.modules.activation.ReLU,\n", " 'ada_activation_final': torch.nn.modules.activation.Softsign,\n", " 'head_d_hid': 128,\n", " 'head_n_layers': 9,\n", " 'head_n_head': 64,\n", " 'head_activation': torch.nn.modules.activation.RReLU,\n", " 'head_activation_final': torch.nn.modules.activation.Softsign,\n", " 'models': ['tvae'],\n", " 'max_seconds': 3600,\n", " 'tf_lora': False,\n", " 'tf_num_inds': 32,\n", " 'ada_n_seeds': 0,\n", " 'gradient_penalty_kwargs': {'mag_loss': True,\n", " 'mse_mag': False,\n", " 'mag_corr': False,\n", " 'seq_mag': False,\n", " 'cos_loss': False,\n", " 'mag_corr_kwargs': {'only_sign': False},\n", " 'cos_loss_kwargs': {'only_sign': True, 'cos_matrix': False},\n", " 'mse_mag_kwargs': {'target': 0.1, 'multiply': False}}}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import ml_utility_loss.loss_learning.estimator.params2 as PARAMS\n", "from ml_utility_loss.tuning import map_parameters\n", "from ml_utility_loss.loss_learning.estimator.params.default import update_param_space, update_param_space_2\n", "import wandb\n", "\n", "#\"\"\"\n", "param_space = {\n", " **getattr(PARAMS, dataset_name).PARAM_SPACE,\n", "}\n", "params = {\n", " **getattr(PARAMS, dataset_name).BESTS[param_index],\n", "}\n", "if gp:\n", " params[\"gradient_penalty_mode\"] = \"ALL\"\n", " params[\"mse_mag\"] = True\n", " if gp_multiply:\n", " params[\"mse_mag_multiply\"] = True\n", " params[\"mse_mag_target\"] = 1.0\n", " else:\n", " params[\"mse_mag_multiply\"] = False\n", " params[\"mse_mag_target\"] = 0.1\n", "else:\n", " params[\"gradient_penalty_mode\"] = \"NONE\"\n", " params[\"mse_mag\"] = False\n", "params[\"single_model\"] = False\n", "if models:\n", " params[\"models\"] = models\n", "if single_model:\n", " params[\"fixed_role_model\"] = single_model\n", " params[\"single_model\"] = True\n", " params[\"models\"] = [single_model]\n", "if params[\"fixed_role_model\"] == \"realtabformer\" and dataset_name == \"treatment\":\n", " params[\"batch_size\"] = 2\n", "params[\"max_seconds\"] = 3600\n", "params[\"patience\"] = 10\n", "params[\"epochs\"] = 100\n", "if debug:\n", " params[\"epochs\"] = 2\n", "with open(\"params.json\", \"w\") as f:\n", " json.dump(params, f)\n", "params = map_parameters(params, param_space=param_space)\n", "params" ] }, { "cell_type": "code", "execution_count": 19, "id": "a48bd9e9", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:59.090944Z", "iopub.status.busy": "2024-02-29T17:50:59.090399Z", "iopub.status.idle": "2024-02-29T17:50:59.156842Z", "shell.execute_reply": "2024-02-29T17:50:59.156038Z" }, "papermill": { "duration": 0.081787, "end_time": "2024-02-29T17:50:59.158740", "exception": false, "start_time": "2024-02-29T17:50:59.076953", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "load_dataset_3_factory 2\n", "Caching in ../../../../insurance/_cache/tvae/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_2/insurance [80, 20]\n", "Caching in ../../../../insurance/_cache4/tvae/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_4/insurance [80, 20]\n", "Caching in ../../../../insurance/_cache5/tvae/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_5/insurance [160, 40]\n", "[320, 80]\n", "[320, 80]\n" ] } ], "source": [ "train_set, val_set = datasetsn(model=params[\"fixed_role_model\"], synth_data=params[\"synth_data\"])" ] }, { "cell_type": "code", "execution_count": 20, "id": "2fcb1418", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "execution": { "iopub.execute_input": "2024-02-29T17:50:59.187387Z", "iopub.status.busy": "2024-02-29T17:50:59.187119Z", "iopub.status.idle": "2024-02-29T17:50:59.608009Z", "shell.execute_reply": "2024-02-29T17:50:59.607099Z" }, "executionInfo": { "elapsed": 396850, "status": "error", "timestamp": 1696841446059, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "_bt1MQc5kpSk", "outputId": "01c1d3e5-ac64-461d-835a-b76f4a66e6d6", "papermill": { "duration": 0.438163, "end_time": "2024-02-29T17:50:59.610083", "exception": false, "start_time": "2024-02-29T17:50:59.171920", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating model of type \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[*] Embedding False True\n", "['tvae'] 1\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.model.pipeline import remove_non_model_params\n", "from ml_utility_loss.loss_learning.estimator.pipeline import create_model\n", "from ml_utility_loss.util import filter_dict, clear_memory\n", "\n", "clear_memory()\n", "\n", "params2 = remove_non_model_params(params)\n", "adapters = filter_dict(preprocessor.adapter_sizes, params[\"models\"])\n", "\n", "model = create_model(\n", " adapters=adapters,\n", " #Body=\"twin_encoder\",\n", " **params2,\n", ")\n", "#cf.apply_weight_standardization(model, n_last_layers_ignore=0)\n", "print(model.models, len(model.adapters))" ] }, { "cell_type": "code", "execution_count": 21, "id": "938f94fc", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:59.638580Z", "iopub.status.busy": "2024-02-29T17:50:59.638286Z", "iopub.status.idle": "2024-02-29T17:50:59.642446Z", "shell.execute_reply": "2024-02-29T17:50:59.641637Z" }, "papermill": { "duration": 0.020458, "end_time": "2024-02-29T17:50:59.644327", "exception": false, "start_time": "2024-02-29T17:50:59.623869", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "study_name=f\"{model_name}_{dataset_name}\"" ] }, { "cell_type": "code", "execution_count": 22, "id": "12fb613e", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:59.670684Z", "iopub.status.busy": "2024-02-29T17:50:59.670408Z", "iopub.status.idle": "2024-02-29T17:50:59.676722Z", "shell.execute_reply": "2024-02-29T17:50:59.675923Z" }, "papermill": { "duration": 0.021573, "end_time": "2024-02-29T17:50:59.678594", "exception": false, "start_time": "2024-02-29T17:50:59.657021", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "9638529" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "count_parameters(model)" ] }, { "cell_type": "code", "execution_count": 23, "id": "bd386e57", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:59.705226Z", "iopub.status.busy": "2024-02-29T17:50:59.704962Z", "iopub.status.idle": "2024-02-29T17:50:59.791397Z", "shell.execute_reply": "2024-02-29T17:50:59.790570Z" }, "papermill": { "duration": 0.101763, "end_time": "2024-02-29T17:50:59.793213", "exception": false, "start_time": "2024-02-29T17:50:59.691450", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "========================================================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "========================================================================================================================\n", "MLUtilitySingle [2, 1071, 36] --\n", "├─Adapter: 1-1 [2, 1071, 36] --\n", "│ └─Sequential: 2-1 [2, 1071, 256] --\n", "│ │ └─FeedForward: 3-1 [2, 1071, 1024] --\n", "│ │ │ └─Linear: 4-1 [2, 1071, 1024] 37,888\n", "│ │ │ └─ReLU: 4-2 [2, 1071, 1024] --\n", "│ │ └─FeedForward: 3-2 [2, 1071, 1024] --\n", "│ │ │ └─Linear: 4-3 [2, 1071, 1024] 1,049,600\n", "│ │ │ └─ReLU: 4-4 [2, 1071, 1024] --\n", "│ │ └─FeedForward: 3-3 [2, 1071, 1024] --\n", "│ │ │ └─Linear: 4-5 [2, 1071, 1024] 1,049,600\n", "│ │ │ └─ReLU: 4-6 [2, 1071, 1024] --\n", "│ │ └─FeedForward: 3-4 [2, 1071, 1024] --\n", "│ │ │ └─Linear: 4-7 [2, 1071, 1024] 1,049,600\n", "│ │ │ └─ReLU: 4-8 [2, 1071, 1024] --\n", "│ │ └─FeedForward: 3-5 [2, 1071, 1024] --\n", "│ │ │ └─Linear: 4-9 [2, 1071, 1024] 1,049,600\n", "│ │ │ └─ReLU: 4-10 [2, 1071, 1024] --\n", "│ │ └─FeedForward: 3-6 [2, 1071, 1024] --\n", "│ │ │ └─Linear: 4-11 [2, 1071, 1024] 1,049,600\n", "│ │ │ └─ReLU: 4-12 [2, 1071, 1024] --\n", "│ │ └─FeedForward: 3-7 [2, 1071, 256] --\n", "│ │ │ └─Linear: 4-13 [2, 1071, 256] 262,400\n", "│ │ │ └─Softsign: 4-14 [2, 1071, 256] --\n", "├─Adapter: 1-2 [2, 267, 36] (recursive)\n", "│ └─Sequential: 2-2 [2, 267, 256] (recursive)\n", "│ │ └─FeedForward: 3-8 [2, 267, 1024] (recursive)\n", "│ │ │ └─Linear: 4-15 [2, 267, 1024] (recursive)\n", "│ │ │ └─ReLU: 4-16 [2, 267, 1024] --\n", "│ │ └─FeedForward: 3-9 [2, 267, 1024] (recursive)\n", "│ │ │ └─Linear: 4-17 [2, 267, 1024] (recursive)\n", "│ │ │ └─ReLU: 4-18 [2, 267, 1024] --\n", "│ │ └─FeedForward: 3-10 [2, 267, 1024] (recursive)\n", "│ │ │ └─Linear: 4-19 [2, 267, 1024] (recursive)\n", "│ │ │ └─ReLU: 4-20 [2, 267, 1024] --\n", "│ │ └─FeedForward: 3-11 [2, 267, 1024] (recursive)\n", "│ │ │ └─Linear: 4-21 [2, 267, 1024] (recursive)\n", "│ │ │ └─ReLU: 4-22 [2, 267, 1024] --\n", "│ │ └─FeedForward: 3-12 [2, 267, 1024] (recursive)\n", "│ │ │ └─Linear: 4-23 [2, 267, 1024] (recursive)\n", "│ │ │ └─ReLU: 4-24 [2, 267, 1024] --\n", "│ │ └─FeedForward: 3-13 [2, 267, 1024] (recursive)\n", "│ │ │ └─Linear: 4-25 [2, 267, 1024] (recursive)\n", "│ │ │ └─ReLU: 4-26 [2, 267, 1024] --\n", "│ │ └─FeedForward: 3-14 [2, 267, 256] (recursive)\n", "│ │ │ └─Linear: 4-27 [2, 267, 256] (recursive)\n", "│ │ │ └─Softsign: 4-28 [2, 267, 256] --\n", "├─TwinEncoder: 1-3 [2, 4096] --\n", "│ └─Encoder: 2-3 [2, 16, 256] --\n", "│ │ └─ModuleList: 3-16 -- (recursive)\n", "│ │ │ └─EncoderLayer: 4-29 [2, 1071, 256] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-1 [2, 1071, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-1 [2, 32, 256] 8,192\n", "│ │ │ │ │ └─MultiHeadAttention: 6-2 [2, 32, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-1 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-2 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-3 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-4 [2, 64, 32, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-1 [2, 64, 32, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-5 [2, 32, 256] 65,792\n", "│ │ │ │ │ │ └─LeakyReLU: 7-6 [2, 32, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-3 [2, 1071, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-7 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-8 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-9 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-10 [2, 64, 1071, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-2 [2, 64, 1071, 32] --\n", "│ │ │ │ │ │ └─Linear: 7-11 [2, 1071, 256] 65,792\n", "│ │ │ │ │ │ └─LeakyReLU: 7-12 [2, 1071, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-2 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-4 [2, 1071, 512] 131,584\n", "│ │ │ │ │ └─ReLU6: 6-5 [2, 1071, 512] --\n", "│ │ │ │ │ └─Linear: 6-6 [2, 1071, 256] 131,328\n", "│ │ │ └─EncoderLayer: 4-30 [2, 1071, 256] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-3 [2, 1071, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-7 [2, 32, 256] 8,192\n", "│ │ │ │ │ └─MultiHeadAttention: 6-8 [2, 32, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-13 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-14 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-15 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-16 [2, 64, 32, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-3 [2, 64, 32, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-17 [2, 32, 256] 65,792\n", "│ │ │ │ │ │ └─LeakyReLU: 7-18 [2, 32, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-9 [2, 1071, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-19 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-20 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-21 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-22 [2, 64, 1071, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-4 [2, 64, 1071, 32] --\n", "│ │ │ │ │ │ └─Linear: 7-23 [2, 1071, 256] 65,792\n", "│ │ │ │ │ │ └─LeakyReLU: 7-24 [2, 1071, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-4 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-10 [2, 1071, 512] 131,584\n", "│ │ │ │ │ └─ReLU6: 6-11 [2, 1071, 512] --\n", "│ │ │ │ │ └─Linear: 6-12 [2, 1071, 256] 131,328\n", "│ │ │ └─EncoderLayer: 4-31 [2, 1071, 256] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-5 [2, 1071, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-13 [2, 32, 256] 8,192\n", "│ │ │ │ │ └─MultiHeadAttention: 6-14 [2, 32, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-25 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-26 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-27 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-28 [2, 64, 32, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-5 [2, 64, 32, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-29 [2, 32, 256] 65,792\n", "│ │ │ │ │ │ └─LeakyReLU: 7-30 [2, 32, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-15 [2, 1071, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-31 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-32 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-33 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-34 [2, 64, 1071, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-6 [2, 64, 1071, 32] --\n", "│ │ │ │ │ │ └─Linear: 7-35 [2, 1071, 256] 65,792\n", "│ │ │ │ │ │ └─LeakyReLU: 7-36 [2, 1071, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-6 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-16 [2, 1071, 512] 131,584\n", "│ │ │ │ │ └─ReLU6: 6-17 [2, 1071, 512] --\n", "│ │ │ │ │ └─Linear: 6-18 [2, 1071, 256] 131,328\n", "│ │ │ └─EncoderLayer: 4-32 [2, 16, 256] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-7 [2, 1071, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-19 [2, 32, 256] 8,192\n", "│ │ │ │ │ └─MultiHeadAttention: 6-20 [2, 32, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-37 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-38 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-39 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-40 [2, 64, 32, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-7 [2, 64, 32, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-41 [2, 32, 256] 65,792\n", "│ │ │ │ │ │ └─LeakyReLU: 7-42 [2, 32, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-21 [2, 1071, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-43 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-44 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-45 [2, 32, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-46 [2, 64, 1071, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-8 [2, 64, 1071, 32] --\n", "│ │ │ │ │ │ └─Linear: 7-47 [2, 1071, 256] 65,792\n", "│ │ │ │ │ │ └─LeakyReLU: 7-48 [2, 1071, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-8 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-22 [2, 1071, 512] 131,584\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-23 [2, 1071, 512] --\n", "│ │ │ │ │ └─Linear: 6-24 [2, 1071, 256] 131,328\n", "│ │ │ │ └─PoolingByMultiheadAttention: 5-9 [2, 16, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-25 [2, 16, 256] 4,096\n", "│ │ │ │ │ └─SimpleMultiHeadAttention: 6-26 [2, 16, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-49 [2, 16, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-50 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-51 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-52 [2, 64, 16, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-9 [2, 64, 16, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-53 [2, 16, 256] 65,792\n", "│ │ │ │ │ │ └─LeakyReLU: 7-54 [2, 16, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-10 [2, 16, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-27 [2, 16, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-28 [2, 16, 512] --\n", "│ │ │ │ │ └─Linear: 6-29 [2, 16, 256] (recursive)\n", "│ └─Encoder: 2-4 [2, 16, 256] (recursive)\n", "│ │ └─ModuleList: 3-16 -- (recursive)\n", "│ │ │ └─EncoderLayer: 4-33 [2, 267, 256] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-11 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-30 [2, 32, 256] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-31 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-55 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-56 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-57 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-58 [2, 64, 32, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-10 [2, 64, 32, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-59 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─LeakyReLU: 7-60 [2, 32, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-32 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-61 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-62 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-63 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-64 [2, 64, 267, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-11 [2, 64, 267, 32] --\n", "│ │ │ │ │ │ └─Linear: 7-65 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─LeakyReLU: 7-66 [2, 267, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-12 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-33 [2, 267, 512] (recursive)\n", "│ │ │ │ │ └─ReLU6: 6-34 [2, 267, 512] --\n", "│ │ │ │ │ └─Linear: 6-35 [2, 267, 256] (recursive)\n", "│ │ │ └─EncoderLayer: 4-34 [2, 267, 256] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-13 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-36 [2, 32, 256] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-37 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-67 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-68 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-69 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-70 [2, 64, 32, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-12 [2, 64, 32, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-71 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─LeakyReLU: 7-72 [2, 32, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-38 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-73 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-74 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-75 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-76 [2, 64, 267, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-13 [2, 64, 267, 32] --\n", "│ │ │ │ │ │ └─Linear: 7-77 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─LeakyReLU: 7-78 [2, 267, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-14 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-39 [2, 267, 512] (recursive)\n", "│ │ │ │ │ └─ReLU6: 6-40 [2, 267, 512] --\n", "│ │ │ │ │ └─Linear: 6-41 [2, 267, 256] (recursive)\n", "│ │ │ └─EncoderLayer: 4-35 [2, 267, 256] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-15 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-42 [2, 32, 256] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-43 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-79 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-80 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-81 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-82 [2, 64, 32, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-14 [2, 64, 32, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-83 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─LeakyReLU: 7-84 [2, 32, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-44 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-85 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-86 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-87 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-88 [2, 64, 267, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-15 [2, 64, 267, 32] --\n", "│ │ │ │ │ │ └─Linear: 7-89 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─LeakyReLU: 7-90 [2, 267, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-16 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-45 [2, 267, 512] (recursive)\n", "│ │ │ │ │ └─ReLU6: 6-46 [2, 267, 512] --\n", "│ │ │ │ │ └─Linear: 6-47 [2, 267, 256] (recursive)\n", "│ │ │ └─EncoderLayer: 4-36 [2, 16, 256] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-17 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-48 [2, 32, 256] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-49 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-91 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-92 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-93 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-94 [2, 64, 32, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-16 [2, 64, 32, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-95 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─LeakyReLU: 7-96 [2, 32, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-50 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-97 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-98 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-99 [2, 32, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-100 [2, 64, 267, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-17 [2, 64, 267, 32] --\n", "│ │ │ │ │ │ └─Linear: 7-101 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─LeakyReLU: 7-102 [2, 267, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-18 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-51 [2, 267, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-52 [2, 267, 512] --\n", "│ │ │ │ │ └─Linear: 6-53 [2, 267, 256] (recursive)\n", "│ │ │ │ └─PoolingByMultiheadAttention: 5-19 [2, 16, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-54 [2, 16, 256] (recursive)\n", "│ │ │ │ │ └─SimpleMultiHeadAttention: 6-55 [2, 16, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-103 [2, 16, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-104 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-105 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-106 [2, 64, 16, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-18 [2, 64, 16, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-107 [2, 16, 256] (recursive)\n", "│ │ │ │ │ │ └─LeakyReLU: 7-108 [2, 16, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-20 [2, 16, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-56 [2, 16, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-57 [2, 16, 512] --\n", "│ │ │ │ │ └─Linear: 6-58 [2, 16, 256] (recursive)\n", "├─Head: 1-4 [2] --\n", "│ └─Sequential: 2-5 [2, 1] --\n", "│ │ └─FeedForward: 3-17 [2, 128] --\n", "│ │ │ └─Linear: 4-37 [2, 128] 524,416\n", "│ │ │ └─RReLU: 4-38 [2, 128] --\n", "│ │ └─FeedForward: 3-18 [2, 128] --\n", "│ │ │ └─Linear: 4-39 [2, 128] 16,512\n", "│ │ │ └─RReLU: 4-40 [2, 128] --\n", "│ │ └─FeedForward: 3-19 [2, 128] --\n", "│ │ │ └─Linear: 4-41 [2, 128] 16,512\n", "│ │ │ └─RReLU: 4-42 [2, 128] --\n", "│ │ └─FeedForward: 3-20 [2, 128] --\n", "│ │ │ └─Linear: 4-43 [2, 128] 16,512\n", "│ │ │ └─RReLU: 4-44 [2, 128] --\n", "│ │ └─FeedForward: 3-21 [2, 128] --\n", "│ │ │ └─Linear: 4-45 [2, 128] 16,512\n", "│ │ │ └─RReLU: 4-46 [2, 128] --\n", "│ │ └─FeedForward: 3-22 [2, 128] --\n", "│ │ │ └─Linear: 4-47 [2, 128] 16,512\n", "│ │ │ └─RReLU: 4-48 [2, 128] --\n", "│ │ └─FeedForward: 3-23 [2, 128] --\n", "│ │ │ └─Linear: 4-49 [2, 128] 16,512\n", "│ │ │ └─RReLU: 4-50 [2, 128] --\n", "│ │ └─FeedForward: 3-24 [2, 128] --\n", "│ │ │ └─Linear: 4-51 [2, 128] 16,512\n", "│ │ │ └─RReLU: 4-52 [2, 128] --\n", "│ │ └─FeedForward: 3-25 [2, 1] --\n", "│ │ │ └─Linear: 4-53 [2, 1] 129\n", "│ │ │ └─Softsign: 4-54 [2, 1] --\n", "========================================================================================================================\n", "Total params: 9,638,529\n", "Trainable params: 9,638,529\n", "Non-trainable params: 0\n", "Total mult-adds (M): 38.18\n", "========================================================================================================================\n", "Input size (MB): 0.39\n", "Forward/backward pass size (MB): 307.47\n", "Params size (MB): 38.55\n", "Estimated Total Size (MB): 346.41\n", "========================================================================================================================" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torchinfo import summary\n", "\n", "role_model = params[\"fixed_role_model\"]\n", "s = train_set[0][role_model]\n", "summary(model[role_model], input_size=((2, *s[0].shape), (2, *s[1].shape)), depth=9) # 8 max" ] }, { "cell_type": "code", "execution_count": 24, "id": "0f42c4d1", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T17:50:59.822160Z", "iopub.status.busy": "2024-02-29T17:50:59.821866Z", "iopub.status.idle": "2024-02-29T18:10:05.205760Z", "shell.execute_reply": "2024-02-29T18:10:05.204793Z" }, "papermill": { "duration": 1145.418287, "end_time": "2024-02-29T18:10:05.225558", "exception": false, "start_time": "2024-02-29T17:50:59.807271", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.16.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: W&B syncing is set to \u001b[1m`offline`\u001b[0m in this directory. \n", "\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb online`\u001b[0m or set \u001b[1mWANDB_MODE=online\u001b[0m to enable cloud syncing.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "g_loss_mul 0.1\n", "Epoch 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.05458167113538366, 'avg_role_model_std_loss': 4.561985811768864, 'avg_role_model_mean_pred_loss': 0.023550471702759836, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.05458167113538366, 'n_size': 320, 'n_batch': 40, 'duration': 39.33485436439514, 'duration_batch': 0.9833713591098785, 'duration_size': 0.12292141988873481, 'avg_pred_std': 0.12334999229060487}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.01106511988909915, 'avg_role_model_std_loss': 7.3775769050087545, 'avg_role_model_mean_pred_loss': 0.00039846873109325995, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01106511988909915, 'n_size': 80, 'n_batch': 10, 'duration': 8.30074167251587, 'duration_batch': 0.8300741672515869, 'duration_size': 0.10375927090644836, 'avg_pred_std': 0.04176213040482253}\n", "Epoch 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010921533098735382, 'avg_role_model_std_loss': 3.7708608118317897, 'avg_role_model_mean_pred_loss': 0.0006090835865870864, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010921533098735382, 'n_size': 320, 'n_batch': 40, 'duration': 38.92951965332031, 'duration_batch': 0.9732379913330078, 'duration_size': 0.12165474891662598, 'avg_pred_std': 0.07502402040408924}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.002461973318713717, 'avg_role_model_std_loss': 0.2656627141033823, 'avg_role_model_mean_pred_loss': 8.382165523856955e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002461973318713717, 'n_size': 80, 'n_batch': 10, 'duration': 8.352109670639038, 'duration_batch': 0.8352109670639039, 'duration_size': 0.10440137088298798, 'avg_pred_std': 0.07963283583521844}\n", "Epoch 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.004752650485897902, 'avg_role_model_std_loss': 4.5005246672456, 'avg_role_model_mean_pred_loss': 7.41932757047259e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.004752650485897902, 'n_size': 320, 'n_batch': 40, 'duration': 39.09458088874817, 'duration_batch': 0.9773645222187042, 'duration_size': 0.12217056527733802, 'avg_pred_std': 0.0816895533236675}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0009612412060960196, 'avg_role_model_std_loss': 0.23112409779214432, 'avg_role_model_mean_pred_loss': 2.9920761611550857e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0009612412060960196, 'n_size': 80, 'n_batch': 10, 'duration': 8.301867723464966, 'duration_batch': 0.8301867723464966, 'duration_size': 0.10377334654331208, 'avg_pred_std': 0.08093988439068198}\n", "Epoch 3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0029934452861198222, 'avg_role_model_std_loss': 1.4091149369219238, 'avg_role_model_mean_pred_loss': 3.706777377407988e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0029934452861198222, 'n_size': 320, 'n_batch': 40, 'duration': 39.057528257369995, 'duration_batch': 0.9764382064342498, 'duration_size': 0.12205477580428123, 'avg_pred_std': 0.08644149880856275}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0017080451536457986, 'avg_role_model_std_loss': 0.5054739748910834, 'avg_role_model_mean_pred_loss': 2.0698145459556274e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0017080451536457986, 'n_size': 80, 'n_batch': 10, 'duration': 8.39680528640747, 'duration_batch': 0.839680528640747, 'duration_size': 0.10496006608009338, 'avg_pred_std': 0.0637943553738296}\n", "Epoch 4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0022114409464847997, 'avg_role_model_std_loss': 1.4571088086362807, 'avg_role_model_mean_pred_loss': 9.073904502000795e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0022114409464847997, 'n_size': 320, 'n_batch': 40, 'duration': 38.94040822982788, 'duration_batch': 0.973510205745697, 'duration_size': 0.12168877571821213, 'avg_pred_std': 0.08093992052599788}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0034676186623983085, 'avg_role_model_std_loss': 0.354912094264597, 'avg_role_model_mean_pred_loss': 1.1135635656955855e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0034676186623983085, 'n_size': 80, 'n_batch': 10, 'duration': 8.30032467842102, 'duration_batch': 0.830032467842102, 'duration_size': 0.10375405848026276, 'avg_pred_std': 0.10819654231891036}\n", "Epoch 5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0016322427756676916, 'avg_role_model_std_loss': 0.8344269889868698, 'avg_role_model_mean_pred_loss': 2.8054938205387956e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0016322427756676916, 'n_size': 320, 'n_batch': 40, 'duration': 39.14137244224548, 'duration_batch': 0.9785343110561371, 'duration_size': 0.12231678888201714, 'avg_pred_std': 0.09135764897800983}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0034494245337555185, 'avg_role_model_std_loss': 2.7931900787574704, 'avg_role_model_mean_pred_loss': 6.050434956339501e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0034494245337555185, 'n_size': 80, 'n_batch': 10, 'duration': 8.366892337799072, 'duration_batch': 0.8366892337799072, 'duration_size': 0.1045861542224884, 'avg_pred_std': 0.055338869569823146}\n", "Epoch 6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.002849590677578817, 'avg_role_model_std_loss': 0.8129531741204119, 'avg_role_model_mean_pred_loss': 4.8211906484207924e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002849590677578817, 'n_size': 320, 'n_batch': 40, 'duration': 38.968292236328125, 'duration_batch': 0.9742073059082031, 'duration_size': 0.1217759132385254, 'avg_pred_std': 0.0901852805633098}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0025212633569026365, 'avg_role_model_std_loss': 0.6178526908131061, 'avg_role_model_mean_pred_loss': 1.4414640320481454e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0025212633569026365, 'n_size': 80, 'n_batch': 10, 'duration': 8.266654014587402, 'duration_batch': 0.8266654014587402, 'duration_size': 0.10333317518234253, 'avg_pred_std': 0.05773084256798029}\n", "Epoch 7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0034268524424987843, 'avg_role_model_std_loss': 1.5629836895840525, 'avg_role_model_mean_pred_loss': 1.5161425290398687e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0034268524424987843, 'n_size': 320, 'n_batch': 40, 'duration': 39.01256036758423, 'duration_batch': 0.9753140091896058, 'duration_size': 0.12191425114870072, 'avg_pred_std': 0.0829970414401032}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0014418774226214737, 'avg_role_model_std_loss': 0.05386366389284376, 'avg_role_model_mean_pred_loss': 2.4331560492640845e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0014418774226214737, 'n_size': 80, 'n_batch': 10, 'duration': 8.32570481300354, 'duration_batch': 0.832570481300354, 'duration_size': 0.10407131016254426, 'avg_pred_std': 0.08880755109712482}\n", "Epoch 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0016761758448410546, 'avg_role_model_std_loss': 0.571136603817564, 'avg_role_model_mean_pred_loss': 7.011435811053887e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0016761758448410546, 'n_size': 320, 'n_batch': 40, 'duration': 38.852670669555664, 'duration_batch': 0.9713167667388916, 'duration_size': 0.12141459584236144, 'avg_pred_std': 0.09045831263065338}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0006263804327318212, 'avg_role_model_std_loss': 0.24181758030463243, 'avg_role_model_mean_pred_loss': 6.295153740953907e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0006263804327318212, 'n_size': 80, 'n_batch': 10, 'duration': 8.345160722732544, 'duration_batch': 0.8345160722732544, 'duration_size': 0.1043145090341568, 'avg_pred_std': 0.08191414531320333}\n", "Epoch 9\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0008744017197386711, 'avg_role_model_std_loss': 0.17949836104246067, 'avg_role_model_mean_pred_loss': 4.63043962907906e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0008744017197386711, 'n_size': 320, 'n_batch': 40, 'duration': 39.11712980270386, 'duration_batch': 0.9779282450675965, 'duration_size': 0.12224103063344956, 'avg_pred_std': 0.09466907754540443}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0011390350133297033, 'avg_role_model_std_loss': 0.004834387120854444, 'avg_role_model_mean_pred_loss': 3.0137808032293377e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0011390350133297033, 'n_size': 80, 'n_batch': 10, 'duration': 8.316069841384888, 'duration_batch': 0.8316069841384888, 'duration_size': 0.1039508730173111, 'avg_pred_std': 0.0990539627149701}\n", "Epoch 10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0004748740824652486, 'avg_role_model_std_loss': 0.1777749692730623, 'avg_role_model_mean_pred_loss': 2.3089836414527056e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0004748740824652486, 'n_size': 320, 'n_batch': 40, 'duration': 39.06293201446533, 'duration_batch': 0.9765733003616333, 'duration_size': 0.12207166254520416, 'avg_pred_std': 0.09201494687004015}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.00032443252712255344, 'avg_role_model_std_loss': 0.0010629200933180982, 'avg_role_model_mean_pred_loss': 3.426863805611191e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00032443252712255344, 'n_size': 80, 'n_batch': 10, 'duration': 8.351998329162598, 'duration_batch': 0.8351998329162598, 'duration_size': 0.10439997911453247, 'avg_pred_std': 0.0884638118557632}\n", "Epoch 11\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.00030916042924218347, 'avg_role_model_std_loss': 0.04881817966124018, 'avg_role_model_mean_pred_loss': 2.0088672352989394e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00030916042924218347, 'n_size': 320, 'n_batch': 40, 'duration': 38.86811137199402, 'duration_batch': 0.9717027842998505, 'duration_size': 0.12146284803748131, 'avg_pred_std': 0.10129309091717005}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.00028257269877940416, 'avg_role_model_std_loss': 1.0737754437432159, 'avg_role_model_mean_pred_loss': 2.8357685949442768e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00028257269877940416, 'n_size': 80, 'n_batch': 10, 'duration': 8.252684354782104, 'duration_batch': 0.8252684354782105, 'duration_size': 0.10315855443477631, 'avg_pred_std': 0.08038602282758803}\n", "Epoch 12\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0013487103491570452, 'avg_role_model_std_loss': 0.43372808683234754, 'avg_role_model_mean_pred_loss': 1.0047366970687786e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0013487103491570452, 'n_size': 320, 'n_batch': 40, 'duration': 39.09407997131348, 'duration_batch': 0.9773519992828369, 'duration_size': 0.12216899991035461, 'avg_pred_std': 0.0899976636399515}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.003439919964876026, 'avg_role_model_std_loss': 0.015614798056776635, 'avg_role_model_mean_pred_loss': 2.0251521429592856e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.003439919964876026, 'n_size': 80, 'n_batch': 10, 'duration': 8.342942476272583, 'duration_batch': 0.8342942476272583, 'duration_size': 0.1042867809534073, 'avg_pred_std': 0.11240037991665304}\n", "Epoch 13\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0008618889094577753, 'avg_role_model_std_loss': 0.1384840221481113, 'avg_role_model_mean_pred_loss': 4.446653539750059e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0008618889094577753, 'n_size': 320, 'n_batch': 40, 'duration': 38.886531829833984, 'duration_batch': 0.9721632957458496, 'duration_size': 0.1215204119682312, 'avg_pred_std': 0.09283134532161057}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.000532695987567422, 'avg_role_model_std_loss': 0.6308531300281175, 'avg_role_model_mean_pred_loss': 1.360833180625437e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.000532695987567422, 'n_size': 80, 'n_batch': 10, 'duration': 8.28925633430481, 'duration_batch': 0.828925633430481, 'duration_size': 0.10361570417881012, 'avg_pred_std': 0.0891546759288758}\n", "Epoch 14\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.00030356911156559363, 'avg_role_model_std_loss': 0.3619111133062688, 'avg_role_model_mean_pred_loss': 5.198837278813596e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00030356911156559363, 'n_size': 320, 'n_batch': 40, 'duration': 38.988784074783325, 'duration_batch': 0.9747196018695832, 'duration_size': 0.1218399502336979, 'avg_pred_std': 0.09746413570828735}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0005432431978988461, 'avg_role_model_std_loss': 0.001004549844947178, 'avg_role_model_mean_pred_loss': 2.781989758560144e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0005432431978988461, 'n_size': 80, 'n_batch': 10, 'duration': 8.42578673362732, 'duration_batch': 0.842578673362732, 'duration_size': 0.1053223341703415, 'avg_pred_std': 0.09348368076607586}\n", "Epoch 15\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.00029625174347529536, 'avg_role_model_std_loss': 0.0572095896306493, 'avg_role_model_mean_pred_loss': 6.03426183574306e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00029625174347529536, 'n_size': 320, 'n_batch': 40, 'duration': 39.27268958091736, 'duration_batch': 0.981817239522934, 'duration_size': 0.12272715494036675, 'avg_pred_std': 0.09890737304231152}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.00036838351952610536, 'avg_role_model_std_loss': 0.7212186768025276, 'avg_role_model_mean_pred_loss': 2.6941624464704718e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00036838351952610536, 'n_size': 80, 'n_batch': 10, 'duration': 8.407346963882446, 'duration_batch': 0.8407346963882446, 'duration_size': 0.10509183704853058, 'avg_pred_std': 0.08277125156018883}\n", "Epoch 16\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0005824315209792986, 'avg_role_model_std_loss': 0.32841089839253074, 'avg_role_model_mean_pred_loss': 9.46431564320671e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0005824315209792986, 'n_size': 320, 'n_batch': 40, 'duration': 39.24979209899902, 'duration_batch': 0.9812448024749756, 'duration_size': 0.12265560030937195, 'avg_pred_std': 0.09515800991794095}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0007761759276036173, 'avg_role_model_std_loss': 1.2551941490234082, 'avg_role_model_mean_pred_loss': 2.9720144345546373e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0007761759276036173, 'n_size': 80, 'n_batch': 10, 'duration': 8.318324089050293, 'duration_batch': 0.8318324089050293, 'duration_size': 0.10397905111312866, 'avg_pred_std': 0.09093907248461619}\n", "Epoch 17\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0012332158104982228, 'avg_role_model_std_loss': 0.6590279597393532, 'avg_role_model_mean_pred_loss': 8.36102873709775e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0012332158104982228, 'n_size': 320, 'n_batch': 40, 'duration': 39.119892835617065, 'duration_batch': 0.9779973208904267, 'duration_size': 0.12224966511130334, 'avg_pred_std': 0.0890660552540794}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.001825959722918924, 'avg_role_model_std_loss': 1.9564546512207017, 'avg_role_model_mean_pred_loss': 1.1242688799484313e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.001825959722918924, 'n_size': 80, 'n_batch': 10, 'duration': 8.359159708023071, 'duration_batch': 0.8359159708023072, 'duration_size': 0.1044894963502884, 'avg_pred_std': 0.06533113070763648}\n", "Epoch 18\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.001502491006613127, 'avg_role_model_std_loss': 0.4666562590015076, 'avg_role_model_mean_pred_loss': 3.179587947280127e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.001502491006613127, 'n_size': 320, 'n_batch': 40, 'duration': 39.09075927734375, 'duration_batch': 0.9772689819335938, 'duration_size': 0.12215862274169922, 'avg_pred_std': 0.09002331190858967}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0008729565364774316, 'avg_role_model_std_loss': 0.20973070683976403, 'avg_role_model_mean_pred_loss': 2.998759428507469e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0008729565364774316, 'n_size': 80, 'n_batch': 10, 'duration': 8.317886352539062, 'duration_batch': 0.8317886352539062, 'duration_size': 0.10397357940673828, 'avg_pred_std': 0.08415974881500006}\n", "Epoch 19\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.001246437881127349, 'avg_role_model_std_loss': 0.6120949116166994, 'avg_role_model_mean_pred_loss': 2.0787087329172948e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.001246437881127349, 'n_size': 320, 'n_batch': 40, 'duration': 39.05946326255798, 'duration_batch': 0.9764865815639496, 'duration_size': 0.1220608226954937, 'avg_pred_std': 0.09171894917380996}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.002248370127927046, 'avg_role_model_std_loss': 4.686978222953622, 'avg_role_model_mean_pred_loss': 2.895269359082242e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002248370127927046, 'n_size': 80, 'n_batch': 10, 'duration': 8.319932222366333, 'duration_batch': 0.8319932222366333, 'duration_size': 0.10399915277957916, 'avg_pred_std': 0.0715100662317127}\n", "Epoch 20\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0029011325517785736, 'avg_role_model_std_loss': 0.9372176351432528, 'avg_role_model_mean_pred_loss': 1.1204305509332328e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0029011325517785736, 'n_size': 320, 'n_batch': 40, 'duration': 38.93023109436035, 'duration_batch': 0.9732557773590088, 'duration_size': 0.1216569721698761, 'avg_pred_std': 0.08712862803367898}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.0013807336257741555, 'avg_role_model_std_loss': 2.265311992234274, 'avg_role_model_mean_pred_loss': 1.8232192309453056e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0013807336257741555, 'n_size': 80, 'n_batch': 10, 'duration': 8.43259859085083, 'duration_batch': 0.843259859085083, 'duration_size': 0.10540748238563538, 'avg_pred_std': 0.07126395150553436}\n", "Epoch 21\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0005831055974340416, 'avg_role_model_std_loss': 0.7094658932399625, 'avg_role_model_mean_pred_loss': 5.484142581780628e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0005831055974340416, 'n_size': 320, 'n_batch': 40, 'duration': 38.932344913482666, 'duration_batch': 0.9733086228370667, 'duration_size': 0.12166357785463333, 'avg_pred_std': 0.09576541467686184}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: - 0.000 MB of 0.000 MB uploaded\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.00040745751502981877, 'avg_role_model_std_loss': 0.01960964320030456, 'avg_role_model_mean_pred_loss': 8.632079813164495e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00040745751502981877, 'n_size': 80, 'n_batch': 10, 'duration': 8.343457698822021, 'duration_batch': 0.8343457698822021, 'duration_size': 0.10429322123527526, 'avg_pred_std': 0.0816122055053711}\n", "Stopped False\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: \n", "\u001b[34m\u001b[1mwandb\u001b[0m: Run history:\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_test █▂▁▂▃▃▂▂▁▂▁▁▃▁▁▁▁▂▁▂▂▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_train █▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_test ▁▅▅▃█▂▃▆▅▇▆▅█▆▆▅▆▃▅▄▄▅\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_train █▁▂▃▂▃▃▂▃▄▃▅▃▄▄▄▄▃▃▃▃▄\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_test █▂▁▂▃▃▂▂▁▂▁▁▃▁▁▁▁▂▁▂▂▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_train █▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_test █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_train █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_test █▁▁▁▁▄▂▁▁▁▁▂▁▂▁▂▂▃▁▅▃▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_train █▇█▃▃▂▂▃▂▁▁▁▂▁▁▁▁▂▂▂▂▂\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_test ▃▅▃▇▃▅▂▄▅▃▅▁▅▂█▇▄▅▄▄█▅\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_train █▂▅▄▂▅▃▃▁▅▄▁▅▁▃▇▇▅▄▄▂▂\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_test ▃▅▃▇▃▅▂▄▅▃▅▁▅▂█▇▄▅▄▄█▅\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_train █▂▅▄▂▅▃▃▁▅▄▁▅▁▃▇▇▅▄▄▂▂\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_test ▃▅▃▇▃▅▂▄▅▃▅▁▅▂█▇▄▅▄▄█▅\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_train █▂▅▄▂▅▃▃▁▅▄▁▅▁▃▇▇▅▄▄▂▂\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_size_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_size_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \n", "\u001b[34m\u001b[1mwandb\u001b[0m: Run summary:\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_test 0.00041\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_train 0.00058\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_test 0.08161\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_train 0.09577\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_test 0.00041\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_train 0.00058\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_test 0.01961\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_train 0.70947\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_test 0.83435\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_train 0.97331\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_test 0.10429\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_train 0.12166\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_test 8.34346\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_train 38.93234\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_test 10\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_train 40\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_size_test 80\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_size_train 320\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: You can sync this run to the cloud by running:\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[1mwandb sync /kaggle/working/eval/insurance/tvae/4/wandb/offline-run-20240229_175101-y38e2cwk\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Find logs at: \u001b[35m\u001b[1m./wandb/offline-run-20240229_175101-y38e2cwk/logs\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Eval loss {'role_model': 'tvae', 'n_size': 399, 'n_batch': 50, 'role_model_metrics': {'avg_loss': 0.00027497335946021655, 'avg_g_mag_loss': 0.044538616604244054, 'avg_g_cos_loss': 0.14042933068628, 'pred_duration': 0.8693411350250244, 'grad_duration': 0.5576837062835693, 'total_duration': 1.4270248413085938, 'pred_std': 0.15040387213230133, 'std_loss': 0.0008385280380025506, 'mean_pred_loss': 1.3921320984877639e-08, 'pred_rmse': 0.016582321375608444, 'pred_mae': 0.012896367348730564, 'pred_mape': 0.13851648569107056, 'grad_rmse': 0.034425459802150726, 'grad_mae': 0.018821561709046364, 'grad_mape': 0.6824872493743896}, 'non_role_model_metrics': {'avg_loss': 0, 'avg_g_mag_loss': 0, 'avg_g_cos_loss': 0, 'avg_pred_duration': 0, 'avg_grad_duration': 0, 'avg_total_duration': 0, 'avg_pred_std': 0, 'avg_std_loss': 0, 'avg_mean_pred_loss': 0}, 'avg_metrics': {'avg_loss': 0.00027497335946021655, 'avg_g_mag_loss': 0.044538616604244054, 'avg_g_cos_loss': 0.14042933068628, 'avg_pred_duration': 0.8693411350250244, 'avg_grad_duration': 0.5576837062835693, 'avg_total_duration': 1.4270248413085938, 'avg_pred_std': 0.15040387213230133, 'avg_std_loss': 0.0008385280380025506, 'avg_mean_pred_loss': 1.3921320984877639e-08}, 'min_metrics': {'avg_loss': 0.00027497335946021655, 'avg_g_mag_loss': 0.044538616604244054, 'avg_g_cos_loss': 0.14042933068628, 'pred_duration': 0.8693411350250244, 'grad_duration': 0.5576837062835693, 'total_duration': 1.4270248413085938, 'pred_std': 0.15040387213230133, 'std_loss': 0.0008385280380025506, 'mean_pred_loss': 1.3921320984877639e-08, 'pred_rmse': 0.016582321375608444, 'pred_mae': 0.012896367348730564, 'pred_mape': 0.13851648569107056, 'grad_rmse': 0.034425459802150726, 'grad_mae': 0.018821561709046364, 'grad_mape': 0.6824872493743896}, 'model_metrics': {'tvae': {'avg_loss': 0.00027497335946021655, 'avg_g_mag_loss': 0.044538616604244054, 'avg_g_cos_loss': 0.14042933068628, 'pred_duration': 0.8693411350250244, 'grad_duration': 0.5576837062835693, 'total_duration': 1.4270248413085938, 'pred_std': 0.15040387213230133, 'std_loss': 0.0008385280380025506, 'mean_pred_loss': 1.3921320984877639e-08, 'pred_rmse': 0.016582321375608444, 'pred_mae': 0.012896367348730564, 'pred_mape': 0.13851648569107056, 'grad_rmse': 0.034425459802150726, 'grad_mae': 0.018821561709046364, 'grad_mape': 0.6824872493743896}}}\n" ] } ], "source": [ "import torch\n", "from ml_utility_loss.loss_learning.estimator.pipeline import train, train_2\n", "from ml_utility_loss.loss_learning.estimator.process_simple import train_epoch, eval as _eval\n", "from ml_utility_loss.params import GradientPenaltyMode\n", "from ml_utility_loss.util import clear_memory\n", "import time\n", "#torch.autograd.set_detect_anomaly(True)\n", "\n", "clear_memory()\n", "\n", "opt = params[\"Optim\"](model.parameters())\n", "loss = train_2(\n", " [train_set, val_set, test_set],\n", " preprocessor=preprocessor,\n", " whole_model=model,\n", " optim=opt,\n", " log_dir=\"logs\",\n", " checkpoint_dir=\"checkpoints\",\n", " verbose=True,\n", " allow_same_prediction=allow_same_prediction,\n", " wandb=wandb,\n", " study_name=study_name,\n", " **params\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "id": "9b514a07", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:10:05.264464Z", "iopub.status.busy": "2024-02-29T18:10:05.264165Z", "iopub.status.idle": "2024-02-29T18:10:05.268178Z", "shell.execute_reply": "2024-02-29T18:10:05.267324Z" }, "papermill": { "duration": 0.026093, "end_time": "2024-02-29T18:10:05.270073", "exception": false, "start_time": "2024-02-29T18:10:05.243980", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "model = loss[\"whole_model\"]\n", "opt = loss[\"optim\"]" ] }, { "cell_type": "code", "execution_count": 26, "id": "331a49e1", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:10:05.307164Z", "iopub.status.busy": "2024-02-29T18:10:05.306497Z", "iopub.status.idle": "2024-02-29T18:10:05.387248Z", "shell.execute_reply": "2024-02-29T18:10:05.386472Z" }, "papermill": { "duration": 0.101501, "end_time": "2024-02-29T18:10:05.389452", "exception": false, "start_time": "2024-02-29T18:10:05.287951", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import torch\n", "from copy import deepcopy\n", "\n", "torch.save(deepcopy(model.state_dict()), \"model.pt\")\n", "#torch.save(deepcopy(opt.state_dict()), \"optim.pt\")" ] }, { "cell_type": "code", "execution_count": 27, "id": "123b4b17", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:10:05.429312Z", "iopub.status.busy": "2024-02-29T18:10:05.428645Z", "iopub.status.idle": "2024-02-29T18:10:05.699888Z", "shell.execute_reply": "2024-02-29T18:10:05.699094Z" }, "papermill": { "duration": 0.29332, "end_time": "2024-02-29T18:10:05.701797", "exception": false, "start_time": "2024-02-29T18:10:05.408477", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASQAAAESCAYAAABU2qhcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA5v0lEQVR4nO3deXxTVf7/8Ve6JN0XWrphoYAtZS2brcUF0UpBFIoIyPAF5Ie4DDhqFRUHqDPOdyqOOi7w1a+Ogs5YQWbEcZAvCgVcoKwtAgIVsFC2tpQl3bfk/P5IGxrolrY0gXyej8d9QG9Obs5tknfPPffcczVKKYUQQtgBJ1tXQAgh6kggCSHshgSSEMJuSCAJIeyGBJIQwm5IIAkh7IYEkhDCbrjYugLtwWg0cvr0aby9vdFoNLaujhDiMkopiouLCQsLw8mp8XbQdRFIp0+fJjw83NbVEEI048SJE9xwww2NPn5dBJK3tzdg2lkfHx8b10YIcbmioiLCw8PN39XGXBeBVHeY5uPjI4EkhB1rrktFOrWFEHZDAkkIYTckkIQQduO66EMSbWMwGKiurrZ1NcQ1zNXVFWdn5zZvRwLJgSmlyMvL4+LFi7auirgO+Pn5ERIS0qaxgBJIDqwujIKCgvDw8JBBpaJVlFKUlZVRUFAAQGhoaKu35VCBZDAqMnMvUFRezR29gnB2ctwvoMFgMIdRQECArasjrnHu7u4AFBQUEBQU1OrDN4fq1DYYFRPfy2DWx7soqaixdXVsqq7PyMPDw8Y1EdeLus9SW/ojHSqQtC5O6FxMu1xUIZ240PxANSFaqj0+Sw4VSADebq6ABJIQ9sjhAsnH3dRtVuzgh2xC2COHC6S6FpIEkmgtjUbDl19+aetqtKuXXnqJgQMH2roajhdIPm6mFlJRuRyyiWvX8uXL8fPza7ftPfvss6Snp7fb9lrLoU77A3i71R2ySSCJ619VVRVarbbZcl5eXnh5eXVAjZrmgC0kOWRrjFKKsqqaDl+svXnyunXruPXWW/Hz8yMgIIB7772Xo0ePAjBs2DCef/55i/Jnz57F1dWV77//HoAzZ84wZswY3N3d6d69O2lpaURERPDmm2+26ve2b98+7rzzTtzd3QkICOCRRx6hpKTE/PjmzZuJjY3F09MTPz8/brnlFo4fPw7ATz/9xIgRI/D29sbHx4chQ4awa9euJl9v8+bNzJw5E71ej0ajQaPR8NJLLwEQERHByy+/zPTp0/Hx8eGRRx4B4PnnnycqKgoPDw969OjBwoULLU7PX37I9tBDD5GUlMRrr71GaGgoAQEBzJkz56pfYuS4LaRKCaTLlVcb6LPomw5/3QN/TMRD2/KPYmlpKcnJyQwYMICSkhIWLVrE+PHj2bNnD1OnTuXVV1/llVdeMZ+GXrlyJWFhYdx2220ATJ8+ncLCQjZv3oyrqyvJycnmUcbWKi0tJTExkfj4eHbu3ElBQQEPP/wwc+fOZfny5dTU1JCUlMTs2bP57LPPqKqqYseOHea6TZ06lUGDBvHuu+/i7OzMnj17cHV1bfI1hw0bxptvvsmiRYvIzs4GsGjdvPbaayxatIiUlBTzOm9vb5YvX05YWBj79u1j9uzZeHt789xzzzX6Ops2bSI0NJRNmzZx5MgRJk+ezMCBA5k9e3arflct4YCBVHvaX/qQrlkTJkyw+Pmjjz6ic+fOHDhwgEmTJvHUU0/x448/mgMoLS2NKVOmoNFoOHToEBs2bGDnzp0MHToUgL/97W9ERka2qi5paWlUVFTwySef4OnpCcCSJUu47777WLx4Ma6uruj1eu6991569uwJQO/evc3Pz83NZd68eURHRwO0qB5arRZfX180Gg0hISFXPH7nnXfyzDPPWKxbsGCB+f8RERE8++yzrFixoslA8vf3Z8mSJTg7OxMdHc2YMWNIT0+XQGpPl/qQpIV0OXdXZw78MdEmr2uNw4cPs2jRIrZv305hYSFGoxEwfbn79evHyJEj+fTTT7ntttvIyckhIyOD//3f/wUgOzsbFxcXBg8ebN7ejTfeiL+/f6vqfvDgQWJiYsxhBHDLLbdgNBrJzs7m9ttv56GHHiIxMZG7776bhIQEJk2aZL7eKzk5mYcffpi///3vJCQkMHHiRHNwtVZd0Na3cuVK3n77bY4ePUpJSQk1NTXNzq7at29fi0tAQkND2bdvX5vq1hyH7UOSgZFX0mg0eGhdOnyxdoTvfffdx/nz5/nggw/Yvn0727dvB0wduGA6DPrnP/9JdXU1aWlp9O/fn/79+7f776ulli1bRkZGBsOGDWPlypVERUWxbds2wNR38/PPPzNmzBg2btxInz59WL16dZter344AmRkZDB16lTuuece1qxZQ1ZWFr///e/Nv6/GXH7oqNFozOF/tThcINW1kIqkhXRNOnfuHNnZ2SxYsIC77rqL3r17c+HCBYsy48aNo6KignXr1pGWlsbUqVPNj/Xq1YuamhqysrLM644cOXLFNlqqd+/e/PTTT5SWlprXbdmyBScnJ3r16mVeN2jQIObPn8/WrVvp168faWlp5seioqJ4+umn+fbbb7n//vtZtmxZs6+r1WoxGAwtquPWrVvp1q0bv//97xk6dCiRkZHmTnV744CBVHeWTVpI1yJ/f38CAgJ4//33OXLkCBs3biQ5OdmijKenJ0lJSSxcuJCDBw8yZcoU82PR0dEkJCTwyCOPsGPHDrKysnjkkUdwd3dv1bVYU6dOxc3NjRkzZrB//342bdrEE088wbRp0wgODiYnJ4f58+eTkZHB8ePH+fbbbzl8+DC9e/emvLycuXPnsnnzZo4fP86WLVvYuXOnRR9TYyIiIigpKSE9PZ3CwkLKysoaLRsZGUlubi4rVqzg6NGjvP32221uhV0tDhhI0od0LXNycmLFihXs3r2bfv368fTTT/OXv/zlinJTp07lp59+4rbbbqNr164Wj33yyScEBwdz++23M378ePMZJzc3N6vr4+HhwTfffMP58+e56aabeOCBB7jrrrtYsmSJ+fFDhw4xYcIEoqKieOSRR5gzZw6PPvoozs7OnDt3junTpxMVFcWkSZMYPXo0f/jDH5p93WHDhvHYY48xefJkOnfuzKuvvtpo2bFjx/L0008zd+5cBg4cyNatW1m4cKHV+9ohVCssWbJEdevWTel0OhUbG6u2b9/eZPnPP/9c9erVS+l0OtWvXz/19ddfWzw+Y8YMBVgsiYmJLa6PXq9XgNLr9c2WzT1Xqro9v0b1WrC2xdu/HpWXl6sDBw6o8vJyW1fF5k6cOKEAtWHDBltX5ZrW1Geqpd9Rq1tIK1euJDk5mZSUFDIzM4mJiSExMbHRcRxbt25lypQpzJo1i6ysLJKSkkhKSmL//v0W5UaNGsWZM2fMy2effWZt1VqkroVUUW2kqubqdtAJ+7Rx40a++uorcnJy2Lp1Kw8++CARERHcfvvttq6aw7M6kN544w1mz57NzJkz6dOnD++99x4eHh589NFHDZZ/6623GDVqFPPmzaN37968/PLLDB482NykraPT6QgJCTEvTZ2GrayspKioyGJpKS/dpZEO0o/kmKqrq3nxxRfp27cv48ePp3PnzuZBkp9++qn5MorLl759+3ZYHUePHt1oPf785z93WD06mlXjkKqqqti9ezfz5883r3NyciIhIYGMjIwGn5ORkXFFp2NiYuIVV0tv3ryZoKAg/P39ufPOO/nTn/7U6NSqqampLTrOboiLsxOeWmdKqwwUV9QQ4KVr1XbEtSsxMZHExIbHW40dO5a4uLgGH2tuBHV7+tvf/kZ5eXmDj3Xq1KnD6tHRrAqkwsJCDAYDwcHBFuuDg4M5dOhQg8/Jy8trsHxeXp7551GjRnH//ffTvXt3jh49yosvvsjo0aPJyMhocG7e+fPnW4Rc3X3DW8rbzdUcSELU5+3t3ez95ztCly5dbF0Fm7CLkdoPPvig+f/9+/dnwIAB9OzZk82bN3PXXXddUV6n06HTtb5l4+3mQl6RDI4Uwt5Y1YcUGBiIs7Mz+fn5Fuvz8/MbvKYGICQkxKryAD169CAwMJAjR45YU70WkylIhLBPVgWSVqtlyJAhFhM5GY1G0tPTiY+Pb/A58fHxV0z8tH79+kbLA5w8eZJz58616f5OTfFxr7t8RA7ZhLAnVp9lS05O5oMPPuDjjz/m4MGDPP7445SWljJz5kzANLVD/U7vJ598knXr1vH6669z6NAhXnrpJXbt2sXcuXMBKCkpYd68eWzbto1jx46Rnp7OuHHjuPHGGxvteGwrmcZWCPtkdR/S5MmTOXv2LIsWLSIvL4+BAweybt06c8d1bm4uTk6Xcm7YsGGkpaWxYMECXnzxRSIjI/nyyy/p168fAM7Ozuzdu5ePP/6YixcvEhYWxsiRI3n55Zfb1E/UFG+ZxlYIu9SqTu25c+eaWziX27x58xXrJk6cyMSJExss7+7uzjffdOykYHL5iGgLjUbD6tWrSUpKsnVVrjsOdy0b1J/GVlpI4trU3pP8g6kxodFouHjxYrtu1xoOGkjSQhLCHjlkIMndaxuhFFSVdvwik/y32yT/lZWVPPvss3Tp0gVPT0/i4uIsulGOHz/Offfdh7+/P56envTt25e1a9dy7NgxRowYAZimeNFoNDz00EOt+n20hV0MjOxo0ofUiOoy+HNYx7/ui6dB69l8uVoyyX/jk/zPnTuXAwcOsGLFCsLCwli9ejWjRo1i3759REZGMmfOHKqqqvj+++/x9PTkwIEDeHl5ER4ezr/+9S8mTJhAdnY2Pj4+uLu7t+p30hYOGUh145CkD+naJJP8NzzJf25uLsuWLSM3N5ewMNMflmeffZZ169axbNky/vznP5Obm8uECRPMU/r26NHD/Py6a+SCgoLavX+qpRwykKSF1AhXD1NrxRavawWZ5L9h+/btw2AwEBUVZbG+srLSfKH67373Ox5//HG+/fZbEhISmDBhAgMGDGjV610NDt+HpKzsv7iuaTSmQ6eOXmSS/3aZ5L+kpARnZ2d2797Nnj17zMvBgwd56623AHj44Yf59ddfmTZtGvv27WPo0KG888477bavbeWggWRqIVUbFJUySds1RSb5N2lokv9BgwZhMBgoKCjgxhtvtFjqH9qFh4fz2GOP8cUXX/DMM8/wwQcfmLcJtPjmAVeDQwaSl9bF/EdZzrRdW2SSf5OGJvmPiopi6tSpTJ8+nS+++IKcnBx27NhBamoqX3/9NQBPPfUU33zzDTk5OWRmZrJp0ybz63Xr1g2NRsOaNWs4e/asxZnCDnOVptftUNbMqV2nX8o61e35NepIQfFVrJn9upbn1F6/fr3q3bu30ul0asCAAWrz5s0KUKtXrzaXWbt2rQLU7bfffsXzT58+rUaPHq10Op3q1q2bSktLU0FBQeq9995r0etf/lp79+5VI0aMUG5ubqpTp05q9uzZqrjY9LnKy8tTSUlJKjQ0VGm1WtWtWze1aNEiZTAYVGVlpXrwwQdVeHi40mq1KiwsTM2dO7fF78ljjz2mAgICFKBSUlKUUkpVVVWpRYsWqYiICOXq6qpCQ0PV+PHj1d69e5VSSs2dO1f17NlT6XQ61blzZzVt2jRVWFho3uYf//hHFRISojQajZoxY0aL6lGnPebU1ih17XeiFBUV4evri16vb/ZunHVueWUjpy6Ws/q3wxjUtXUdmteyiooKcnJy6N69e6vutnE9OXnyJOHh4WzYsKHB+bdEyzT1mWrpd9Qhz7KBnGlzZBs3bqSkpIT+/ftz5swZnnvuOZnk3044ZB8S1L+eTQLJ0cgk//ZLWkjSqe1wZJJ/++XwgSRn2UR9Msm/bTnuIZu7HLIBMjBUtJv2+Cw5bCA5eqd23eFHWVmZjWsirhd1n6W2HNo68CGbY09B4uzsjJ+fn/kqdw8Pj1YNDBRCKUVZWRkFBQX4+fk1eC/FlnLgQKqbV9sxW0iA+XKC1k69IUR9fn5+Td7erCUcNpBkGlvT3NChoaEEBQVRXe24vwfRdq6urm1qGdVx2EBy9D6k+pydndvlwyREWzlwp3ZtC6lSWgZC2AuHDSQf6UMSwu44biDVjkMqqayRsThC2AmHDaS6PiSDUVFWZbsJqYQQlzhsILm7OuPsZBp3Ix3bQtgHhw0kjUYj17MJYWdaFUhLly4lIiICNzc34uLi2LFjR5PlV61aRXR0NG5ubvTv35+1a9c2Wvaxxx5Do9G0+qZ91pCxSELYF6sDaeXKlSQnJ5OSkkJmZiYxMTEkJiY2Otp369atTJkyhVmzZpGVlUVSUhJJSUns37//irKrV69m27Zt5ntKXW2XWkhyyCaEPbA6kN544w1mz57NzJkz6dOnD++99x4eHh589NFHDZZ/6623GDVqFPPmzaN37968/PLLDB48mCVLlliUO3XqFE888QSffvpph807I4MjhbAvVgVSVVUVu3fvJiEh4dIGnJxISEggIyOjwedkZGRYlAfTBFn1yxuNRqZNm8a8efNaNCtfZWUlRUVFFktrmC+wLZdDNiHsgVWBVFhYiMFgIDg42GJ9cHAweXl5DT4nLy+v2fKLFy/GxcWF3/3udy2qR2pqKr6+vuYlPDzcmt0wk2lshbAvNj/Ltnv3bt566y2WL1/e4ukv5s+fj16vNy8nTpxo1WvLNLZC2BerAikwMBBnZ2fy8/Mt1ufn5zc67UBISEiT5X/44QcKCgro2rUrLi4uuLi4cPz4cZ555hkiIiIa3KZOp8PHx8diaQ0f6UMSwq5YFUharZYhQ4aQnp5uXmc0GklPTyc+Pr7B58THx1uUB1i/fr25/LRp09i7d6/FvcjDwsKYN28e33zzjbX7Y5W6y0dkHJIQ9sHq6UeSk5OZMWMGQ4cOJTY2ljfffJPS0lJmzpwJwPTp0+nSpQupqakAPPnkkwwfPpzXX3+dMWPGsGLFCnbt2sX7778PQEBAAAEBARav4erqSkhIiMW90a8GOcsmhH2xOpAmT57M2bNnWbRoEXl5eQwcOJB169aZO65zc3NxcrrU8Bo2bBhpaWksWLCAF198kcjISL788kv69evXfnvRSt4yMFIIu+Kwt9IG+OHwWaZ9uINewd5887TctVSIq6Wl31Gbn2WzJbl0RAj74tCBJH1IQtgXBw+k2knaqmowGq/5I1chrnkOHkimFpJSUFwprSQhbM2hA8nN1Rmti+lXIP1IQtieQwcSyGhtIeyJwweSt1xgK4TdkEAy3w5JDtmEsDWHDyQfuWGkEHbD4QNJxiIJYT8kkCSQhLAbEkgyja0QdsPhA6muD0nuPCKE7Tl8IMk0tkLYDwkk6UMSwm44fCDJNLZC2A+HDyRpIQlhPxw+kGSSNiHsh8MHkrSQhLAfDh9IdS2ksioD1QajjWsjhGNz+EDycrt045USaSUJYVMOH0iuzk64uzoDctgmhK05fCBBvSlIpGNbCJuSQELGIglhLySQkDNtQtgLCSRkGlsh7IUEEnKBrRD2QgKJelOQlEsLSQhbalUgLV26lIiICNzc3IiLi2PHjh1Nll+1ahXR0dG4ubnRv39/1q5da/H4Sy+9RHR0NJ6envj7+5OQkMD27dtbU7VW8ZEWkhB2wepAWrlyJcnJyaSkpJCZmUlMTAyJiYkUFBQ0WH7r1q1MmTKFWbNmkZWVRVJSEklJSezfv99cJioqiiVLlrBv3z5+/PFHIiIiGDlyJGfPnm39nllBOrWFsBPKSrGxsWrOnDnmnw0GgwoLC1OpqakNlp80aZIaM2aMxbq4uDj16KOPNvoaer1eAWrDhg0tqlNdeb1e36Lyl1u+JUd1e36Nevwfu1r1fCFE01r6HbWqhVRVVcXu3btJSEgwr3NyciIhIYGMjIwGn5ORkWFRHiAxMbHR8lVVVbz//vv4+voSExPTYJnKykqKiooslrbwca+7N5u0kISwJasCqbCwEIPBQHBwsMX64OBg8vLyGnxOXl5ei8qvWbMGLy8v3Nzc+Otf/8r69esJDAxscJupqan4+vqal/DwcGt24wreOpmCRAh7YDdn2UaMGMGePXvYunUro0aNYtKkSY32S82fPx+9Xm9eTpw40abXlj4kIeyDVYEUGBiIs7Mz+fn5Fuvz8/MJCQlp8DkhISEtKu/p6cmNN97IzTffzIcffoiLiwsffvhhg9vU6XT4+PhYLG1x6dIRCSQhbMmqQNJqtQwZMoT09HTzOqPRSHp6OvHx8Q0+Jz4+3qI8wPr16xstX3+7lZWV1lSv1eTiWiHsg0vzRSwlJyczY8YMhg4dSmxsLG+++SalpaXMnDkTgOnTp9OlSxdSU1MBePLJJxk+fDivv/46Y8aMYcWKFezatYv3338fgNLSUv77v/+bsWPHEhoaSmFhIUuXLuXUqVNMnDixHXe1cXWXjlTVGKmsMaBzce6Q1xVCWLI6kCZPnszZs2dZtGgReXl5DBw4kHXr1pk7rnNzc3FyutTwGjZsGGlpaSxYsIAXX3yRyMhIvvzyS/r16weAs7Mzhw4d4uOPP6awsJCAgABuuukmfvjhB/r27dtOu9k0L92lX0NxRQ06LwkkIWxBo5RStq5EWxUVFeHr64ter291f1L/lG8orqxh07N30D3Qs51rKIRja+l31G7OstmauR+pXPqRhLAVCaRaMgWJELYngVRLpiARwvYkkGrVjUWSFpIQtiOBVEvGIglhexJItS4FkrSQhLAVCaRalzq1pYUkhK1IINXykbNsQticBFItGYckhO1JINWSKUiEsD0JpFrmQ7ZKaSEJYSsSSLVkGlshbE8CqZacZRPC9iSQatXvQ7oOJkAQ4pokgVSrrg+pxqioqDbauDZCOCYJpFoeWmecnTSAXD4ihK1IINXSaDTmmSOlH0kI25BAqkeuZxPCtiSQ6pHLR4SwLQmkeuTyESFsSwKpHpnGVgjbkkCqx0emsRXCpiSQ6pFpbIWwLQmkemQaWyFsSwKpHpmCRAjbkkCqRy6wFcK2JJDqqRuHJAMjhbANCaR6ZBySELbVqkBaunQpERERuLm5ERcXx44dO5osv2rVKqKjo3Fzc6N///6sXbvW/Fh1dTXPP/88/fv3x9PTk7CwMKZPn87p06dbU7U2kT4kIWzL6kBauXIlycnJpKSkkJmZSUxMDImJiRQUFDRYfuvWrUyZMoVZs2aRlZVFUlISSUlJ7N+/H4CysjIyMzNZuHAhmZmZfPHFF2RnZzN27Ni27VkrSB+SEDamrBQbG6vmzJlj/tlgMKiwsDCVmpraYPlJkyapMWPGWKyLi4tTjz76aKOvsWPHDgWo48ePt6hOer1eAUqv17eofGPyi8pVt+fXqO4vrFEGg7FN2xJCXNLS76hVLaSqqip2795NQkKCeZ2TkxMJCQlkZGQ0+JyMjAyL8gCJiYmNlgfQ6/VoNBr8/PwafLyyspKioiKLpT3UdWobFZRWyWGbEB3NqkAqLCzEYDAQHBxssT44OJi8vLwGn5OXl2dV+YqKCp5//nmmTJmCj49Pg2VSU1Px9fU1L+Hh4dbsRqN0Lk64OpsmaZN+JCE6nl2dZauurmbSpEkopXj33XcbLTd//nz0er15OXHiRLu8vkajkQtshbAhF2sKBwYG4uzsTH5+vsX6/Px8QkJCGnxOSEhIi8rXhdHx48fZuHFjo60jAJ1Oh06ns6bqLebj5sL50irp2BbCBqxqIWm1WoYMGUJ6erp5ndFoJD09nfj4+AafEx8fb1EeYP369Rbl68Lo8OHDbNiwgYCAAGuq1a68zYMjJZCE6GhWtZAAkpOTmTFjBkOHDiU2NpY333yT0tJSZs6cCcD06dPp0qULqampADz55JMMHz6c119/nTFjxrBixQp27drF+++/D5jC6IEHHiAzM5M1a9ZgMBjM/UudOnVCq9W21762iIxFEsJ2rA6kyZMnc/bsWRYtWkReXh4DBw5k3bp15o7r3NxcnJwuNbyGDRtGWloaCxYs4MUXXyQyMpIvv/ySfv36AXDq1Cm++uorAAYOHGjxWps2beKOO+5o5a61jlw+IoTtaJS69u+KWFRUhK+vL3q9vsm+p5aYt+onVu0+yXOjevHbO25spxoK4dha+h21q7Ns9sDch1QuLSQhOpoE0mW8ZRpbIWxGAukyMo2tELYjgXQZaSEJYTsSSJfxkbvXCmEzEkiXkSlIhLAdCaTLyO20hbAdCaTLyEhtIWxHAukydYFUUlmDwXjNjxkV4poigXSZuj4kgBJpJQnRoSSQLqN1ccLN1fRrkSv+hehYEkgNkEnahLANCaQGmO/PJi0kITqUBFIDpIUkhG1IIDXARy4fEcImJJAaIIMjhbANCaQGmPuQyqWFJERHkkBqgHkKkkppIQnRkSSQGuCtkz4kIWxBAqkBlw7ZpIUkREeSQGqA3JtNCNuQQGqATGMrhG1IIDVAprEVwjYkkBrgLdPYCmETEkgN8JFpbIWwCccMpJKCJh+uC6SKaiPVBmNH1EgIgaMFkv4kvD0Y3hkKhsZbP161h2wgHdtCdKRWBdLSpUuJiIjAzc2NuLg4duzY0WT5VatWER0djZubG/3792ft2rUWj3/xxReMHDmSgIAANBoNe/bsaU21mucdChV6qNRDbkajxZydNObBkSfOl12dugghrmB1IK1cuZLk5GRSUlLIzMwkJiaGxMRECgoaPgzaunUrU6ZMYdasWWRlZZGUlERSUhL79+83lyktLeXWW29l8eLFrd+TlnByhqhE0/+z/6/Jorf36gzAip0nrm6dhBCXKCvFxsaqOXPmmH82GAwqLCxMpaamNlh+0qRJasyYMRbr4uLi1KOPPnpF2ZycHAWorKwsq+qk1+sVoPR6ffOFD3ylVIqPUm/GKGU0Nlps+6/nVLfn16joBf+nLpZWWVUfIYSlln5HrWohVVVVsXv3bhISEszrnJycSEhIICOj4UOgjIwMi/IAiYmJjZZvicrKSoqKiiyWFusxApy1cCEHzmY3WuymCH+iQ7wprzaware0koToCFYFUmFhIQaDgeDgYIv1wcHB5OXlNficvLw8q8q3RGpqKr6+vuYlPDy85U/WeUH34ab//9L4YZtGo2HGsAgA/r7tOEa5JZIQV901eZZt/vz56PV683LihJUtmF6jTP820480bmAYPm4uHD9XxneHz7aytkKIlrIqkAIDA3F2diY/P99ifX5+PiEhIQ0+JyQkxKryLaHT6fDx8bFYrBJVG0gndkBpYaPFPLQuTBpqan19vPVYK2srhGgpqwJJq9UyZMgQ0tPTzeuMRiPp6enEx8c3+Jz4+HiL8gDr169vtHyH8L0BQgYACg5/22TR/7q5GxoNbM4+y7HC0o6pnxAOyupDtuTkZD744AM+/vhjDh48yOOPP05paSkzZ84EYPr06cyfP99c/sknn2TdunW8/vrrHDp0iJdeeoldu3Yxd+5cc5nz58+zZ88eDhw4AEB2djZ79uxpUz9Ts3qNNv2bvbbJYhGBntwRZRoC8Pdtx69efYQQ1p/2V0qpd955R3Xt2lVptVoVGxurtm3bZn5s+PDhasaMGRblP//8cxUVFaW0Wq3q27ev+vrrry0eX7ZsmQKuWFJSUlpUH6tO+9c5lWk6/f+nUKWqypssuvFQvur2/BrVL2WdKq2sbvlrCCGUUi3/jmqUUtf86aOioiJ8fX3R6/Ut709SCt7oDcVnYOq/IDKh0aJGo2LE65s5fq6MP4/vz2/iurZTzYVwDC39jl6TZ9nahUZzqXO7idP/AE5OGqbd3A2ATzKOcR1kuBB2yXEDCer1I60ztZiaMHFIOO6uzhzKK2ZHzvkOqJwQjsexA6n77eDqAUUnIW9fk0V9PVxJGtQFgE8ypHNbiKvBsQPJ1d10KQk0O0gSYHq86bBt3c955OkrrmbNhHBIjh1IcGnUdjP9SAC9Q32I7d4Jg1GRtl1aSUK0NwmkqFGABk5nQdGZZovPiI8AIG1HLpU1hqtbNyEcjASSVxB0GWL6/y/rmi0+sm8wwT46CkuqWLf/Kg7cFMIBSSBBvbNtzR+2uTo7MTXO1Jck17cJ0b4kkOBSIOV8B1XNT1n7YGw4rs4aMnMvsu+k/ipXTgjHIYEEENQH/LpCTQX8urn54t5u3NM/FDANlBRCtA8JJKgdtd2yi23rTK/t3P73T6e5UFp1lSomhGORQKpTd9j2yzdgbP5ebIO7+tGviw9VNUZW7pIpboVoDxJIdbrdAjofKC2A05nNFtdoNOZW0t8zjlNRLUMAhGgrCaQ6Llroeafp/y042wYwNiYMfw9XTl0sZ8zbP7DnxMWrVz8hHIAEUn297jH928JAcnN1ZunUwXT21nH0bCn3/88W/vLNIRkwKUQrSSDVF3k3aJyg4Ge40LJLQ4b1DOTbp25n3MAwjAqWbjrKuCVb+Pm0DAcQwloSSPV5dIKutXN9t2DUdh1/Ty1vPTiId6cOppOnlkN5xYxbsoW3Nhym2tB8B7kQwsTF1hWwO1Gj4PgW02Fb3KNWPXV0/1Bu6t6JBav3s+7nPP664Rc2HMzn9UkxRAV7t2wjOT/Awf9gmsVXY2qxaZxMQxM09X92Aq2nabhCcB+rd7O+imoD+0/pOa2vYHhkZ3w9XNu0PZtSCg59DQE3QlC0rWsjrOS4U9g2pvAILBkCTq7w3K/gZv32lFJ89dNpFv37Z/Tl1WidnXj67igeub0Hzk6ahp9UWQLrF8GuD62vc3A/GDAJ+j0Avl2ardvJC+VknbhI5vELZOVe4MCZIqoNpo9BkLeOVyb0587o4Ca3Y7e+fw02vmya5+qhNZeuUxQ21dLvqARSQ94ZAueOwMTl0Hd8qzeTX1TB/C/2sfFQAQD9u/gycegN3BkdxA3+HpcKHtsC//4tXDhm+nnAg6aR48p4aUHV/r/evxePw+H1YKyu3ZAGIm6F/hOhzzhw98NgVGTmXiDz+AXTv7kXOVtceUVdA7106FycOHWxHICJQ25g4X198HG7hlpLh76GFb+59LNHIMz6FgJ62q5O16CCogr2ndIzPKozLs7t06sjgdQW3y6Are/AgMlw//tt2pRSilW7T/Lyfw5QXFljXh8d4s2oKF9+U7qczj8vQ4MC33AY+w70HNHyFyg7Dwf+DftWmQ41617XWcuJwNt4/+JQPtf3oYpLweLipKFPmA+Du/ozqKsfg7v6c4O/O5U1Rl77JpsPt+SgFIT6uvHKhAEMr70NlF3L2w8fjoTqUhg8A87sgTM/gX8EzFpvmtVBNMloVPxj+3FeXZdNSWUNd0UH8c5vBuGhbXvPjgRSWxzbAsvvAXd/eDwDfELbvMk8fQWrs06RfjCfzNwLxHCY11zfo6eTaQ6m7X5jKBr+R4b1icBT18oPwMVcKrI+p2xnGp3KjppXF+HJTt+RnI2eRs/eg+jfxRc3V+dGN7Pz2HnmrfqJY+dMFxpPiQ3n92P64NXael1tpYXw/gjQ50L34fBfX0DZOfjwblMrMnQgPPQ16LxsXVO79Ut+MS/8ay+ZuRct1g8M9+PDGUMJ8NK1afsSSG1hqIG/9oWSPHD1hNuSIX4uuLq1fds1lZSv/xNu25egwUiB8uf56ofZZBwEgNbFibjunYiN6MTgbv7EhPu1KAgKiitYtuUY/8g4TnFlNb01uUz12M54l614VhZcKth9ONz0sGnMlXPj2y2rquHVddksr51ipYufO68+MIBbbgxs0+63u5oq+GQc5G6FTj3g4XTT2VKAc0dNoVR2zjTodcpK0wBYR1R02tS/lrfPNJd83yQI7kdFjZH/2XSEd787SrVB4aVz4blRvYgO8WH2J7vQl1fTPdCTj2fG0jXAo9mXafTlJZDaKG8/rHkKTu40/ezXFUb+N/S+z3S2qzVOZ8Hqx+HsQdPPAyZTdXcqO/IUGw7mk34onxPnyy2e4qSBqGBvBnfzZ3BXfwZ39aN7oCea2jocKyzl/R9+5Z+7T1JVYxpiEBnkxWPDezJ2YBiuGgVHN5k6y39ZV9sfBXiHwZCHYMgM8A5ptMoZR88x758/cfKCqV7Tbu7GC6OjW9+Ka09KwX+ehMyPTZf9PLwBOveyLHNyN3x8L1SXmfrmxr/X+vevo5Wdh6oS02evLdv48Q3Y8YFpNot6yr0j+FflTXxaPJiDqisJvUN4Oakvob7uABwpKGbGRzs5dbGcQC8tyx6Kpf8Nvq2qhgRSe1DK1DezPgWKT5vWRdwGo16BkH4t20aF3jSlyS/fwN6VYKwxdbbe96Yp3CxeTnG4oIQtRwrJzDWdBavrZK7P38OVQV39cXXWsP5APsbad3BwVz8ev+NG7ooOwqmhs3kXc2HXMsj8BMoKTeucXCD6XlOrKeLWBr+spZU1pP7fQf6xLReA8E7uPHJbDxL6BJs/vDax/X/h/54DNPCbzyFqZMPlDq+HtMmgDHDLU3D3HzqyltZRCo79CLs+Mg3/MFZDyACIedB0sqKlfWGVJbDtXdj6NlQWmdZ1jYf+D1D9Szoc2YCrujRLRalXNzwGTURT23Kq+xzkF1Xw0LKdHDxThIfWmXf/a0ir+hQlkNpTVSn8+Kbpza2pMI0BGvIQjPg9eF52CKMUFByAw9/C4Q1wYpsphOr0GQdj3rjyeY3IL6ogq/bsWObxC+w9pTe3hOrcGR3EY8N7clOEv7nl1KSaSjjwFez8m6l+dTpHQ78JpoAK6n1FOP14uJDn/7XXIiT7dfHh7t4h3N0nmN6h3i17/fZwdBP8Y4IpZO5+GW75HQCVNQa2HCnESWPquA/yrj3MzvrUdCYTYPSrVo8xu+rKzsOeNNi9HM4dNq9WGic0ta1ao8aZM4HDyPIfxTbXOE6WKPKLKskvqqC8ykBnbx1hXhruV+u558KneNVcAKDErzcX4l/As+9oMn49T8pXP1NecpG7nLJ4NPAn+pTuQGOod+a1U0/T2eUBk6FzFMUV1Tz2j91sOXIOFycNr0wYwANDbrBq9ySQroaLuaaxQj+vNv2s84U7XjD99Tq+xfSX+MgGKDpl+byASNNlKb3uabQV0lJVNUYOnCki8/gFzpZUMjYmjN6hbdjnvH2w80PY+7npDFUd/+4QPcbUirvhJnAydYIXV1Tz2Y5cvv05n925Fyzur3mDvzsjowMYG1ZMP+djuJw7bPprG3k3uPu1vo6XO3cUPhhhan3GTIGkd/m1sJQVO0/wz90nOV9vfqpALx19wnzoE+rDuOI0eh94C4UGzcRlbRrS0S6UgtxtsHsZ/Pwl1IZCjYsHe/xGsrxyBNvPu5NIBhOcf2CQ0xHzU4uUO2sNcXxhuI2dqhcaYLzTjzzt+k9u0Jhav8eMwbxRM5H/GG9GXXZRRo/OnqSO709cjwCoLDa14H9ebfoM1w+nG26Cgb+hKno8z605xpd7TEcK8xJ78ds7erb4D5AE0tV0bAuse77xm0u6uEP32yByJNyYAJ26X/06tVWF3jR84NDXptZH/Q+lZ2fTfFHR99XeXNPU6ii8WETmrgzOHNqGtmAvvckhWpOLm6baYtNGjQtVNwxD228sTr3HgE9Y2+r5twQo/AVjl6GsHfI3/rErj22/XrqbcLCPDi+dC78Wll52Q2LFH12WM91lPdW4sKzHX/HqdQcDw/2ICvZqtzE3TVIKivPg4Femw+e6/kTgiHNPPqq4g38bhlHKpUNhjcYUrIM8C7mP77m1NB3/6ks3mKjyugFc3dBeMAVWqa4zP4T9P9LdEskrqaGgqJKC4goulFWjdXHiseE9mTOiJzqXBs601oXTvlWmcFK1F4o761DRY/i85jbm/9QZI05Mu7kbL43t2/hg33quaiAtXbqUv/zlL+Tl5RETE8M777xDbGxso+VXrVrFwoULOXbsGJGRkSxevJh77rnH/LhSipSUFD744AMuXrzILbfcwrvvvktkZGSL6tPhgQRgNEDW3yH9ZVN/TKcetQF0N0TcYroJ5bWqssTU0jv0tenDWVnvQmGtl2nuqJI8yD9Qb1DmJcV4sN8QQY4KZojTYXo5nbR4PEcXzcngO6m8cRSdIwbQI8gL75YMwDQaIG0SHNlAkWsQ99f8iSPlplP5ThoY0SuIKbFduaOXaUBfWVUN2XnFHDhTxIHTRRw4U8QvZ/S8xhuMdt5JkfIgufpxipQHbi5O9OjsQWSQFzd29qRnZ08CPF1N48OUAhc3UyvP3R/c/Ezvb1OtA0M1nM+Bwl+g8BdqCn6hpuAQLueP4FJdbC5WpnT8xxBPmuFOflI9AQ2RQV7c3COAm3sEEBPuS7CPG671w9JoNJ1V/GmFqWVVVbs9Nz+49WmIfQS0V54Rq6oxYlSqySEfForzYd/npkPJggPm1aW6znxSejP/NNxOz96DeXvKoGa3edUCaeXKlUyfPp333nuPuLg43nzzTVatWkV2djZBQVd2uG3dupXbb7+d1NRU7r33XtLS0li8eDGZmZn062fqGF68eDGpqal8/PHHdO/enYULF7Jv3z4OHDiAm1vzp9ptEkh1qiug/EK7jFWyS4ZqUyfroa9NS13nfh03PwgbCKExpvE+oTEY/CLIPKHnu+yzHC4opjLvML2LfuAuzU4Gaw7jpLn0kTtqDGW9cQj5TkG4urqiddWi1WrR1S06LW5aHe46LRHnfyD65CrKlZaJVYvYr3oQ4uPG5JvCmXRTOF38mv8jYDAqcvML8fnnZALO7W79r8XJlWpXX2q0PlRpfalx9aFa6wuVJbgXHcW3/CTONDwNjUFpOKS6ssIwgi8NtxISFGQOoNjunejsbcWYn6oy001Oyy+a+v/a89C4jlKmQaZ70kwBVX7B/NAeY09+6jaDGbOebHITVy2Q4uLiuOmmm1iyZAkARqOR8PBwnnjiCV544YUryk+ePJnS0lLWrFljXnfzzTczcOBA3nvvPZRShIWF8cwzz/Dss88CoNfrCQ4OZvny5Tz44IPN1smmgeRIlDINXcjNMI0qD40xnZJuQT9CjcHIqYvlnMjNQWWvo/OpDfQs3oUrV7awmvNE9ROUR41jSmzX1l/eUH4BvvodFBxAoaHaoCivMVJRbaS82khljcLUlWzaNx1V+GpK8aUUF03LZnAoVTqOqjDTYgzjlPMNFHn3xOgfwQ2B/sT16ERc9wDrAsjWaipNreY9aajD36JRBi7e/kf87myfQLJqMElVVRW7d+9m/vz55nVOTk4kJCSQkZHR4HMyMjJITk62WJeYmMiXX34JQE5ODnl5eSQkJJgf9/X1JS4ujoyMjAYDqbKyksrKS30cRUVF1uyGaC2NBroMNi1WcnF2oluAJ90C+sGgfsCzpv6KIxuoyl5PTZmemupqamqqMdRUYaypwWioxmgw/YvRgNFo5Ej4A8y/70nCWtAaapK7P0z+u2m3AG3tUjfKprzKwL5Ten46cZE9Jy9yobTKdBmhMuKmyvE0FOOhSvAyluBpLMbTWIKnKgZnHZV+kRAYiU9QV0L8POjr68Zdvm7X1nWBjXHRQZ+x0GcsmpICjHs/x2/A5PbbvDWFCwsLMRgMBAdbXgkeHBzMoUOHGnxOXl5eg+Xz8vLMj9eta6zM5VJTU/nDH+x4LIloGZ039B2Ptu94Wjp+uttVrdAl7lpnYrt3IrZ7pw56xWuQVxBOw+a26yavyQna5s+fj16vNy8nTshdP4S4HlgVSIGBgTg7O5Ofn2+xPj8/n5CQhi8/CAkJabJ83b/WbFOn0+Hj42OxCCGufVYFklarZciQIaSnp5vXGY1G0tPTiY+Pb/A58fHxFuUB1q9fby7fvXt3QkJCLMoUFRWxffv2RrcphLhOKSutWLFC6XQ6tXz5cnXgwAH1yCOPKD8/P5WXl6eUUmratGnqhRdeMJffsmWLcnFxUa+99po6ePCgSklJUa6urmrfvn3mMq+88ory8/NT//73v9XevXvVuHHjVPfu3VV5eXmL6qTX6xWg9Hq9tbsjhOgALf2OWh1ISin1zjvvqK5duyqtVqtiY2PVtm3bzI8NHz5czZgxw6L8559/rqKiopRWq1V9+/ZVX3/9tcXjRqNRLVy4UAUHByudTqfuuusulZ2d3eL6SCAJYd9a+h2VS0eEEFddS7+j1+RZNiHE9ckOZtlqu7pGngyQFMI+1X03mzsguy4CqbjYdHFheHi4jWsihGhKcXExvr6Nzzp5XfQhGY1GTp8+jbd38xOEFRUVER4ezokTJ67p/qbrZT/g+tkX2Y/GKaUoLi4mLCwMJ6fGe4quixaSk5MTN9xg3Qx218uAyutlP+D62RfZj4Y11TKqI53aQgi7IYEkhLAbDhdIOp2OlJQUdLpraA6aBlwv+wHXz77IfrTdddGpLYS4PjhcC0kIYb8kkIQQdkMCSQhhNySQhBB2QwJJCGE3HC6Qli5dSkREBG5ubsTFxbFjxw5bV8kqL730EhqNxmKJjo62dbWa9f3333PfffcRFhaGRqMx33WmjlKKRYsWERoairu7OwkJCRw+fLjhjdlYc/vy0EMPXfEejRo1yjaVbUJqaio33XQT3t7eBAUFkZSURHZ2tkWZiooK5syZQ0BAAF5eXkyYMOGK6abbk0MF0sqVK0lOTiYlJYXMzExiYmJITEykoKDA1lWzSt++fTlz5ox5+fHHH21dpWaVlpYSExPD0qVLG3z81Vdf5e233+a9995j+/bteHp6kpiYSEVFRQfXtHnN7QvAqFGjLN6jzz77rANr2DLfffcdc+bMYdu2baxfv57q6mpGjhxJaWmpuczTTz/Nf/7zH1atWsV3333H6dOnuf/++69epa7qNHF2JjY2Vs2ZM8f8s8FgUGFhYSo1NdWGtbJOSkqKiomJsXU12gRQq1evNv9sNBpVSEiI+stf/mJed/HiRaXT6dRnn31mgxq23OX7opRSM2bMUOPGjbNJfdqioKBAAeq7775TSpneA1dXV7Vq1SpzmYMHDypAZWRkXJU6OEwLqe4ml/VvSNncTS7t1eHDhwkLC6NHjx5MnTqV3NxcW1epTZq7Wei1aPPmzQQFBdGrVy8ef/xxzp07Z+sqNUuv1wPQqZPpXnS7d++murra4n2Jjo6ma9euV+19cZhAauoml43dkNIexcXFsXz5ctatW8e7775LTk4Ot912m3lOqGtRa24Was9GjRrFJ598Qnp6OosXL+a7775j9OjRGAwGW1etUUajkaeeeopbbrmFfv36Aab3RavV4ufnZ1H2ar4v18X0I45k9OjR5v8PGDCAuLg4unXrxueff86sWbNsWDNRp/7t3/v378+AAQPo2bMnmzdv5q677rJhzRo3Z84c9u/fb/P+SIdpIbXmJpfXAj8/P6Kiojhy5Iitq9JqrblZ6LWkR48eBAYG2u17NHfuXNasWcOmTZss5hULCQmhqqqKixcvWpS/mu+LwwRSa25yeS0oKSnh6NGjhIaG2roqrXa93yz05MmTnDt3zu7eI6UUc+fOZfXq1WzcuJHu3btbPD5kyBBcXV0t3pfs7Gxyc3Ov3vtyVbrK7VRzN7m8FjzzzDNq8+bNKicnR23ZskUlJCSowMBAVVBQYOuqNam4uFhlZWWprKwsBag33nhDZWVlqePHjyul2n6z0I7U1L4UFxerZ599VmVkZKicnBy1YcMGNXjwYBUZGakqKipsXXULjz/+uPL19VWbN29WZ86cMS9lZWXmMo899pjq2rWr2rhxo9q1a5eKj49X8fHxV61ODhVISjV9k8trweTJk1VoaKjSarWqS5cuavLkyerIkSO2rlazNm3apIArlrqbirb1ZqEdqal9KSsrUyNHjlSdO3dWrq6uqlu3bmr27Nl2+UevoX0A1LJly8xlysvL1W9/+1vl7++vPDw81Pjx49WZM2euWp1kPiQhhN1wmD4kIYT9k0ASQtgNCSQhhN2QQBJC2A0JJCGE3ZBAEkLYDQkkIYTdkEASQtgNCSQhhN2QQBJC2A0JJCGE3fj/d6S3UxKs9ToAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "history = loss[\"history\"]\n", "history.to_csv(\"history.csv\")\n", "history[[\"avg_loss_train\", \"avg_loss_test\"]].plot()" ] }, { "cell_type": "code", "execution_count": 28, "id": "2586ba0a", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:10:05.739775Z", "iopub.status.busy": "2024-02-29T18:10:05.739483Z", "iopub.status.idle": "2024-02-29T18:10:52.699221Z", "shell.execute_reply": "2024-02-29T18:10:52.698208Z" }, "papermill": { "duration": 46.981509, "end_time": "2024-02-29T18:10:52.701858", "exception": false, "start_time": "2024-02-29T18:10:05.720349", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "\n", "from ml_utility_loss.loss_learning.estimator.pipeline import eval\n", "#eval_loss = loss[\"eval_loss\"]\n", "\n", "batch_size = params[\"batch_size_low\"] if \"batch_size_low\" in params else params[\"batch_size\"]\n", "\n", "eval_loss = eval(\n", " test_set, model,\n", " batch_size=batch_size,\n", ")" ] }, { "cell_type": "code", "execution_count": 29, "id": "187137f6", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:10:52.742539Z", "iopub.status.busy": "2024-02-29T18:10:52.742224Z", "iopub.status.idle": "2024-02-29T18:10:52.761904Z", "shell.execute_reply": "2024-02-29T18:10:52.761097Z" }, "papermill": { "duration": 0.041848, "end_time": "2024-02-29T18:10:52.763762", "exception": false, "start_time": "2024-02-29T18:10:52.721914", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
avg_g_cos_lossavg_g_mag_lossavg_lossgrad_durationgrad_maegrad_mapegrad_rmsemean_pred_losspred_durationpred_maepred_mapepred_rmsepred_stdstd_losstotal_duration
tvae0.1347840.0363680.0002750.5683860.0188240.6825660.0344341.392132e-080.8837040.0128970.1385170.0165830.1504040.0008391.45209
\n", "
" ], "text/plain": [ " avg_g_cos_loss avg_g_mag_loss avg_loss grad_duration grad_mae \\\n", "tvae 0.134784 0.036368 0.000275 0.568386 0.018824 \n", "\n", " grad_mape grad_rmse mean_pred_loss pred_duration pred_mae \\\n", "tvae 0.682566 0.034434 1.392132e-08 0.883704 0.012897 \n", "\n", " pred_mape pred_rmse pred_std std_loss total_duration \n", "tvae 0.138517 0.016583 0.150404 0.000839 1.45209 " ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "metrics = pd.DataFrame(eval_loss[\"model_metrics\"]).T\n", "metrics.to_csv(\"eval.csv\")\n", "metrics" ] }, { "cell_type": "code", "execution_count": 30, "id": "123d305b", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:10:52.802551Z", "iopub.status.busy": "2024-02-29T18:10:52.801916Z", "iopub.status.idle": "2024-02-29T18:10:53.233160Z", "shell.execute_reply": "2024-02-29T18:10:53.232302Z" }, "papermill": { "duration": 0.452737, "end_time": "2024-02-29T18:10:53.235205", "exception": false, "start_time": "2024-02-29T18:10:52.782468", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.util import clear_memory\n", "clear_memory()" ] }, { "cell_type": "code", "execution_count": 31, "id": "a3eecc2a", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:10:53.275648Z", "iopub.status.busy": "2024-02-29T18:10:53.275331Z", "iopub.status.idle": "2024-02-29T18:11:41.971985Z", "shell.execute_reply": "2024-02-29T18:11:41.971200Z" }, "papermill": { "duration": 48.719326, "end_time": "2024-02-29T18:11:41.974270", "exception": false, "start_time": "2024-02-29T18:10:53.254944", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caching in ../../../../insurance/_cache_test/tvae/all inf False\n" ] } ], "source": [ "#\"\"\"\n", "from ml_utility_loss.loss_learning.estimator.process import pred, pred_2\n", "from ml_utility_loss.util import stack_samples\n", "\n", "#samples = test_set[list(range(len(test_set)))]\n", "#y = {m: pred(model[m], s) for m, s in samples.items()}\n", "y = pred_2(model, test_set, batch_size=batch_size)\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": 32, "id": "6ab51db8", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:11:42.014824Z", "iopub.status.busy": "2024-02-29T18:11:42.014495Z", "iopub.status.idle": "2024-02-29T18:11:42.030851Z", "shell.execute_reply": "2024-02-29T18:11:42.030176Z" }, "papermill": { "duration": 0.038526, "end_time": "2024-02-29T18:11:42.032661", "exception": false, "start_time": "2024-02-29T18:11:41.994135", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "from ml_utility_loss.util import transpose_dict\n", "\n", "os.makedirs(\"pred\", exist_ok=True)\n", "y2 = transpose_dict(y)\n", "for k, v in y2.items():\n", " df = pd.DataFrame(v)\n", " df.to_csv(f\"pred/{k}.csv\")" ] }, { "cell_type": "code", "execution_count": 33, "id": "d81a30f1", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:11:42.070440Z", "iopub.status.busy": "2024-02-29T18:11:42.070179Z", "iopub.status.idle": "2024-02-29T18:11:42.075287Z", "shell.execute_reply": "2024-02-29T18:11:42.074351Z" }, "papermill": { "duration": 0.026305, "end_time": "2024-02-29T18:11:42.077266", "exception": false, "start_time": "2024-02-29T18:11:42.050961", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'tvae': 0.05389225258396234}\n" ] } ], "source": [ "print({k: sum(v[\"pred\"])/len(v[\"pred\"]) for k, v in y.items()})" ] }, { "cell_type": "code", "execution_count": 34, "id": "3b3ff322", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:11:42.116430Z", "iopub.status.busy": "2024-02-29T18:11:42.116168Z", "iopub.status.idle": "2024-02-29T18:11:42.432730Z", "shell.execute_reply": "2024-02-29T18:11:42.431890Z" }, "papermill": { "duration": 0.338569, "end_time": "2024-02-29T18:11:42.434727", "exception": false, "start_time": "2024-02-29T18:11:42.096158", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_pred_density_2\n", "\n", "_ = plot_pred_density_2(y)" ] }, { "cell_type": "code", "execution_count": 35, "id": "e79e4b0f", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:11:42.474328Z", "iopub.status.busy": "2024-02-29T18:11:42.474060Z", "iopub.status.idle": "2024-02-29T18:11:42.757409Z", "shell.execute_reply": "2024-02-29T18:11:42.756574Z" }, "papermill": { "duration": 0.305409, "end_time": "2024-02-29T18:11:42.759353", "exception": false, "start_time": "2024-02-29T18:11:42.453944", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_density_3\n", "\n", "_ = plot_density_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))" ] }, { "cell_type": "code", "execution_count": 36, "id": "745adde1", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:11:42.801100Z", "iopub.status.busy": "2024-02-29T18:11:42.800581Z", "iopub.status.idle": "2024-02-29T18:11:43.019886Z", "shell.execute_reply": "2024-02-29T18:11:43.019064Z" }, "papermill": { "duration": 0.242308, "end_time": "2024-02-29T18:11:43.021663", "exception": false, "start_time": "2024-02-29T18:11:42.779355", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_box_3\n", "\n", "_ = plot_box_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))" ] }, { "cell_type": "code", "execution_count": 37, "id": "eabe1bab", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T18:11:43.063972Z", "iopub.status.busy": "2024-02-29T18:11:43.063705Z", "iopub.status.idle": "2024-02-29T18:11:43.256163Z", "shell.execute_reply": "2024-02-29T18:11:43.255366Z" }, "papermill": { "duration": 0.215855, "end_time": "2024-02-29T18:11:43.258087", "exception": false, "start_time": "2024-02-29T18:11:43.042232", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#\"\"\"\n", "from ml_utility_loss.loss_learning.visualization import plot_grad, plot_grad_2, plot_grad_3\n", "import matplotlib.pyplot as plt\n", "\n", "#plot_grad_2(y, model.models)\n", "for m in model.models:\n", " ym = y[m]\n", " fig, ax = plt.subplots()\n", " plot_grad_3(ym[\"error\"], ym[\"grad\"], name=f\"{m}_grad\", fig=fig, ax=ax)\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "54c0e9f3", "metadata": { "papermill": { "duration": 0.021173, "end_time": "2024-02-29T18:11:43.300089", "exception": false, "start_time": "2024-02-29T18:11:43.278916", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "celltoolbar": "Tags", "colab": { "authorship_tag": "ABX9TyOOVfelovKP9fLGU7SvvRie", "gpuType": "T4", "mount_file_id": "17POSGAvge8y9DW9WGs2jLkibaRjToayg", "provenance": [] }, "kaggle": { "accelerator": "gpu", "dataSources": [], "dockerImageVersionId": 30648, "isGpuEnabled": true, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "papermill": { "default_parameters": {}, "duration": 1264.119122, "end_time": "2024-02-29T18:11:46.042253", "environment_variables": {}, "exception": null, "input_path": "eval/insurance/tvae/4/mlu-eval.ipynb", "output_path": "eval/insurance/tvae/4/mlu-eval.ipynb", "parameters": { "allow_same_prediction": true, "dataset": "insurance", "dataset_name": "insurance", "debug": false, "folder": "eval", "gp": false, "gp_multiply": false, "param_index": 2, "path": "eval/insurance/tvae/4", "path_prefix": "../../../../", "random_seed": 4, "single_model": "tvae" }, "start_time": "2024-02-29T17:50:41.923131", "version": "2.5.0" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }