{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "982e76f5", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:44.570944Z", "iopub.status.busy": "2024-03-03T11:01:44.569998Z", "iopub.status.idle": "2024-03-03T11:01:44.611904Z", "shell.execute_reply": "2024-03-03T11:01:44.611054Z" }, "papermill": { "duration": 0.057756, "end_time": "2024-03-03T11:01:44.614024", "exception": false, "start_time": "2024-03-03T11:01:44.556268", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import joblib\n", "\n", "#joblib.parallel_backend(\"threading\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "675f0b41", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:44.642117Z", "iopub.status.busy": "2024-03-03T11:01:44.641699Z", "iopub.status.idle": "2024-03-03T11:01:44.648593Z", "shell.execute_reply": "2024-03-03T11:01:44.647728Z" }, "papermill": { "duration": 0.024578, "end_time": "2024-03-03T11:01:44.650963", "exception": false, "start_time": "2024-03-03T11:01:44.626385", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'\\n%cd /kaggle/working\\n#!git clone https://github.com/R-N/ml-utility-loss\\n%cd ml-utility-loss\\n!git pull\\n#!pip install .\\n!pip install . --no-deps --force-reinstall --upgrade\\n#'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", "%cd /kaggle/working\n", "#!git clone https://github.com/R-N/ml-utility-loss\n", "%cd ml-utility-loss\n", "!git pull\n", "#!pip install .\n", "!pip install . --no-deps --force-reinstall --upgrade\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "5ae30f5c", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:44.683632Z", "iopub.status.busy": "2024-03-03T11:01:44.683035Z", "iopub.status.idle": "2024-03-03T11:01:44.687788Z", "shell.execute_reply": "2024-03-03T11:01:44.686964Z" }, "papermill": { "duration": 0.02277, "end_time": "2024-03-03T11:01:44.689713", "exception": false, "start_time": "2024-03-03T11:01:44.666943", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "plt.rcParams['figure.figsize'] = [3,3]" ] }, { "cell_type": "code", "execution_count": 4, "id": "9f42c810", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:44.713693Z", "iopub.status.busy": "2024-03-03T11:01:44.713372Z", "iopub.status.idle": "2024-03-03T11:01:44.717894Z", "shell.execute_reply": "2024-03-03T11:01:44.717086Z" }, "executionInfo": { "elapsed": 678, "status": "ok", "timestamp": 1696841022168, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "ns5hFcVL2yvs", "papermill": { "duration": 0.019153, "end_time": "2024-03-03T11:01:44.719881", "exception": false, "start_time": "2024-03-03T11:01:44.700728", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "datasets = [\n", " \"insurance\",\n", " \"treatment\",\n", " \"contraceptive\"\n", "]\n", "\n", "study_dir = \"./\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "85d0c8ce", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:44.745548Z", "iopub.status.busy": "2024-03-03T11:01:44.745224Z", "iopub.status.idle": "2024-03-03T11:01:44.751772Z", "shell.execute_reply": "2024-03-03T11:01:44.750911Z" }, "papermill": { "duration": 0.021918, "end_time": "2024-03-03T11:01:44.753823", "exception": false, "start_time": "2024-03-03T11:01:44.731905", "status": "completed" }, "tags": [ "parameters" ] }, "outputs": [], "source": [ "#Parameters\n", "import os\n", "\n", "path_prefix = \"../../../../\"\n", "\n", "dataset_dir = os.path.join(path_prefix, \"ml-utility-loss/datasets\")\n", "dataset_name = \"treatment\"\n", "model_name=\"ml_utility_2\"\n", "models = [\"tvae\", \"realtabformer\", \"lct_gan\", \"tab_ddpm_concat\"]\n", "single_model = \"lct_gan\"\n", "random_seed = 42\n", "gp = True\n", "gp_multiply = True\n", "folder = \"eval\"\n", "debug = False\n", "path = None\n", "param_index = 0\n", "allow_same_prediction = True\n", "log_wandb = False" ] }, { "cell_type": "code", "execution_count": 6, "id": "4e90a1af", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:44.780315Z", "iopub.status.busy": "2024-03-03T11:01:44.780032Z", "iopub.status.idle": "2024-03-03T11:01:44.785324Z", "shell.execute_reply": "2024-03-03T11:01:44.784384Z" }, "papermill": { "duration": 0.022708, "end_time": "2024-03-03T11:01:44.787932", "exception": false, "start_time": "2024-03-03T11:01:44.765224", "status": "completed" }, "tags": [ "injected-parameters" ] }, "outputs": [], "source": [ "# Parameters\n", "dataset = \"insurance\"\n", "dataset_name = \"insurance\"\n", "single_model = \"tab_ddpm_concat\"\n", "gp = False\n", "gp_multiply = False\n", "random_seed = 4\n", "debug = False\n", "folder = \"eval\"\n", "path_prefix = \"../../../../\"\n", "path = \"eval/insurance/tab_ddpm_concat/4\"\n", "param_index = 3\n", "allow_same_prediction = True\n", "log_wandb = False\n" ] }, { "cell_type": "code", "execution_count": null, "id": "bd7c02d6", "metadata": { "papermill": { "duration": 0.012007, "end_time": "2024-03-03T11:01:44.816037", "exception": false, "start_time": "2024-03-03T11:01:44.804030", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 7, "id": "5f45b1d0", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:44.840912Z", "iopub.status.busy": "2024-03-03T11:01:44.840199Z", "iopub.status.idle": "2024-03-03T11:01:44.851377Z", "shell.execute_reply": "2024-03-03T11:01:44.850578Z" }, "executionInfo": { "elapsed": 7, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "UdvXYv3c3LXy", "papermill": { "duration": 0.02564, "end_time": "2024-03-03T11:01:44.853397", "exception": false, "start_time": "2024-03-03T11:01:44.827757", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working\n", "/kaggle/working/eval/insurance/tab_ddpm_concat/4\n" ] } ], "source": [ "from pathlib import Path\n", "import os\n", "\n", "%cd /kaggle/working/\n", "\n", "if path is None:\n", " path = os.path.join(folder, dataset_name, single_model, random_seed)\n", "Path(path).mkdir(parents=True, exist_ok=True)\n", "\n", "%cd {path}" ] }, { "cell_type": "code", "execution_count": 8, "id": "f85bf540", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:44.879869Z", "iopub.status.busy": "2024-03-03T11:01:44.879593Z", "iopub.status.idle": "2024-03-03T11:01:47.059565Z", "shell.execute_reply": "2024-03-03T11:01:47.058615Z" }, "papermill": { "duration": 2.195792, "end_time": "2024-03-03T11:01:47.061639", "exception": false, "start_time": "2024-03-03T11:01:44.865847", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Set seed to \n" ] } ], "source": [ "from ml_utility_loss.util import seed\n", "if single_model:\n", " model_name=f\"{model_name}_{single_model}\"\n", "if random_seed is not None:\n", " seed(random_seed)\n", " print(\"Set seed to\", seed)" ] }, { "cell_type": "code", "execution_count": 9, "id": "8489feae", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:47.086553Z", "iopub.status.busy": "2024-03-03T11:01:47.086168Z", "iopub.status.idle": "2024-03-03T11:01:47.097582Z", "shell.execute_reply": "2024-03-03T11:01:47.096887Z" }, "papermill": { "duration": 0.02616, "end_time": "2024-03-03T11:01:47.099577", "exception": false, "start_time": "2024-03-03T11:01:47.073417", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import json\n", "import os\n", "\n", "df = pd.read_csv(os.path.join(dataset_dir, f\"{dataset_name}.csv\"))\n", "with open(os.path.join(dataset_dir, f\"{dataset_name}.json\")) as f:\n", " info = json.load(f)" ] }, { "cell_type": "code", "execution_count": 10, "id": "debcc684", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:47.125152Z", "iopub.status.busy": "2024-03-03T11:01:47.124166Z", "iopub.status.idle": "2024-03-03T11:01:47.132075Z", "shell.execute_reply": "2024-03-03T11:01:47.131245Z" }, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "Vrl2QkoV3o_8", "papermill": { "duration": 0.02292, "end_time": "2024-03-03T11:01:47.134095", "exception": false, "start_time": "2024-03-03T11:01:47.111175", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "task = info[\"task\"]\n", "target = info[\"target\"]\n", "cat_features = info[\"cat_features\"]\n", "mixed_features = info[\"mixed_features\"]\n", "longtail_features = info[\"longtail_features\"]\n", "integer_features = info[\"integer_features\"]\n", "\n", "test = df.sample(frac=0.2, random_state=42)\n", "train = df[~df.index.isin(test.index)]" ] }, { "cell_type": "code", "execution_count": 11, "id": "7538184a", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:47.161414Z", "iopub.status.busy": "2024-03-03T11:01:47.161162Z", "iopub.status.idle": "2024-03-03T11:01:47.266060Z", "shell.execute_reply": "2024-03-03T11:01:47.265250Z" }, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "TilUuFk9vqMb", "papermill": { "duration": 0.121905, "end_time": "2024-03-03T11:01:47.268153", "exception": false, "start_time": "2024-03-03T11:01:47.146248", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import ml_utility_loss.synthesizers.tab_ddpm.params as TAB_DDPM_PARAMS\n", "import ml_utility_loss.synthesizers.lct_gan.params as LCT_GAN_PARAMS\n", "import ml_utility_loss.synthesizers.realtabformer.params as RTF_PARAMS\n", "from ml_utility_loss.synthesizers.realtabformer.params.default import GPT2_PARAMS, REALTABFORMER_PARAMS\n", "from ml_utility_loss.util import filter_dict_2, filter_dict\n", "\n", "tab_ddpm_params = getattr(TAB_DDPM_PARAMS, dataset_name).BEST\n", "lct_gan_params = getattr(LCT_GAN_PARAMS, dataset_name).BEST\n", "lct_ae_params = filter_dict_2(lct_gan_params, LCT_GAN_PARAMS.default.AE_PARAMS)\n", "rtf_params = getattr(RTF_PARAMS, dataset_name).BEST\n", "rtf_params = filter_dict(rtf_params, REALTABFORMER_PARAMS)\n", "\n", "lct_ae_embedding_size=lct_gan_params[\"embedding_size\"]\n", "tab_ddpm_normalization=\"quantile\"\n", "tab_ddpm_cat_encoding=tab_ddpm_params[\"cat_encoding\"]\n", "#tab_ddpm_cat_encoding=\"one-hot\"\n", "tab_ddpm_y_policy=\"default\"\n", "tab_ddpm_is_y_cond=True" ] }, { "cell_type": "code", "execution_count": 12, "id": "cca61838", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:47.292550Z", "iopub.status.busy": "2024-03-03T11:01:47.292267Z", "iopub.status.idle": "2024-03-03T11:01:51.912455Z", "shell.execute_reply": "2024-03-03T11:01:51.911426Z" }, "executionInfo": { "elapsed": 3113, "status": "ok", "timestamp": 1696841025277, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "7Abt8nStvr9Z", "papermill": { "duration": 4.635182, "end_time": "2024-03-03T11:01:51.914891", "exception": false, "start_time": "2024-03-03T11:01:47.279709", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-03-03 11:01:49.494405: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-03-03 11:01:49.494461: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-03-03 11:01:49.496116: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_lct_ae\n", "\n", "lct_ae = load_lct_ae(\n", " dataset_name=dataset_name,\n", " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", " model_name=\"lct_ae\",\n", " df_name=\"df\",\n", ")\n", "lct_ae = None" ] }, { "cell_type": "code", "execution_count": 13, "id": "6f83b7b6", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:51.941837Z", "iopub.status.busy": "2024-03-03T11:01:51.941275Z", "iopub.status.idle": "2024-03-03T11:01:51.948431Z", "shell.execute_reply": "2024-03-03T11:01:51.947703Z" }, "papermill": { "duration": 0.022719, "end_time": "2024-03-03T11:01:51.950358", "exception": false, "start_time": "2024-03-03T11:01:51.927639", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_rtf_embed\n", "\n", "rtf_embed = load_rtf_embed(\n", " dataset_name=dataset_name,\n", " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", " model_name=\"realtabformer\",\n", " df_name=\"df\",\n", " ckpt_type=\"best-disc-model\"\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "0026de74", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:01:51.975269Z", "iopub.status.busy": "2024-03-03T11:01:51.974924Z", "iopub.status.idle": "2024-03-03T11:02:00.346124Z", "shell.execute_reply": "2024-03-03T11:02:00.344770Z" }, "executionInfo": { "elapsed": 20137, "status": "ok", "timestamp": 1696841045408, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "tbaguWxAvtPi", "papermill": { "duration": 8.386665, "end_time": "2024-03-03T11:02:00.348756", "exception": false, "start_time": "2024-03-03T11:01:51.962091", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n", "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:119: ConvergenceWarning: Number of distinct clusters (6) found smaller than n_clusters (10). Possibly due to duplicate points in X.\n", " .fit(X)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/1 [00:00 torch.Tensor>,\n", " 'pma_ffn_mode': 'shared',\n", " 'patience': 10,\n", " 'inds_init_mode': 'fixnorm',\n", " 'grad_clip': 0.6896836352825375,\n", " 'head_final_mul': 'identity',\n", " 'gradient_penalty_mode': {'gradient_penalty': False,\n", " 'calc_grad_m': False,\n", " 'avg_non_role_model_m': False,\n", " 'inverse_avg_non_role_model_m': False},\n", " 'dataset_size': 2048,\n", " 'batch_size': 4,\n", " 'epochs': 100,\n", " 'lr_mul': 0.08030439779404704,\n", " 'n_warmup_steps': 85,\n", " 'Optim': torch_optimizer.diffgrad.DiffGrad,\n", " 'fixed_role_model': 'tab_ddpm_concat',\n", " 'd_model': 256,\n", " 'attn_activation': torch.nn.modules.activation.Sigmoid,\n", " 'tf_d_inner': 256,\n", " 'tf_n_layers_enc': 5,\n", " 'tf_n_head': 128,\n", " 'tf_activation': ml_utility_loss.activations.LeakyHardsigmoid,\n", " 'tf_activation_final': ml_utility_loss.activations.LeakyHardtanh,\n", " 'ada_d_hid': 256,\n", " 'ada_n_layers': 8,\n", " 'ada_activation': torch.nn.modules.activation.ReLU6,\n", " 'ada_activation_final': ml_utility_loss.activations.LeakyHardtanh,\n", " 'head_d_hid': 256,\n", " 'head_n_layers': 8,\n", " 'head_n_head': 32,\n", " 'head_activation': torch.nn.modules.activation.ReLU6,\n", " 'head_activation_final': torch.nn.modules.activation.Softsign,\n", " 'single_model': True,\n", " 'models': ['tab_ddpm_concat'],\n", " 'max_seconds': 3600,\n", " 'Body': 'twin_encoder',\n", " 'loss_balancer_log': False,\n", " 'loss_balancer_lbtw': False,\n", " 'pma_skip_small': False,\n", " 'isab_skip_small': False,\n", " 'layer_norm': False,\n", " 'pma_layer_norm': False,\n", " 'attn_residual': True,\n", " 'tf_n_layers_dec': False,\n", " 'tf_isab_rank': 0,\n", " 'tf_layer_norm': False,\n", " 'tf_pma_start': -1,\n", " 'head_n_seeds': 0,\n", " 'dropout': 0,\n", " 'combine_mode': 'diff_left',\n", " 'tf_isab_mode': 'separate',\n", " 'bias': True,\n", " 'bias_final': True,\n", " 'synth_data': 2,\n", " 'tf_lora': False,\n", " 'tf_num_inds': 64,\n", " 'ada_n_seeds': 0,\n", " 'gradient_penalty_kwargs': {'mag_loss': True,\n", " 'mse_mag': False,\n", " 'mag_corr': False,\n", " 'seq_mag': False,\n", " 'cos_loss': False,\n", " 'mag_corr_kwargs': {'only_sign': False},\n", " 'cos_loss_kwargs': {'only_sign': True, 'cos_matrix': False},\n", " 'mse_mag_kwargs': {'target': 0.13044551835398707, 'multiply': True}}}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import ml_utility_loss.loss_learning.estimator.params2 as PARAMS\n", "from ml_utility_loss.tuning import map_parameters\n", "from ml_utility_loss.loss_learning.estimator.params.default import update_param_space, update_param_space_2\n", "import wandb\n", "\n", "#\"\"\"\n", "param_space = {\n", " **getattr(PARAMS, dataset_name).PARAM_SPACE,\n", "}\n", "params = {\n", " **getattr(PARAMS, dataset_name).BESTS[param_index],\n", "}\n", "if gp:\n", " params[\"gradient_penalty_mode\"] = \"ALL\"\n", " params[\"mse_mag\"] = True\n", " if gp_multiply:\n", " params[\"mse_mag_multiply\"] = True\n", " params[\"mse_mag_target\"] = 1.0\n", " else:\n", " params[\"mse_mag_multiply\"] = False\n", " params[\"mse_mag_target\"] = 0.1\n", "else:\n", " params[\"gradient_penalty_mode\"] = \"NONE\"\n", " params[\"mse_mag\"] = False\n", "params[\"single_model\"] = False\n", "if models:\n", " params[\"models\"] = models\n", "if single_model:\n", " params[\"fixed_role_model\"] = single_model\n", " params[\"single_model\"] = True\n", " params[\"models\"] = [single_model]\n", "if params[\"fixed_role_model\"] == \"realtabformer\" and dataset_name == \"treatment\":\n", " params[\"batch_size\"] = 2\n", "params[\"max_seconds\"] = 3600\n", "params[\"patience\"] = 10\n", "params[\"epochs\"] = 100\n", "if debug:\n", " params[\"epochs\"] = 2\n", "with open(\"params.json\", \"w\") as f:\n", " json.dump(params, f)\n", "params = map_parameters(params, param_space=param_space)\n", "params" ] }, { "cell_type": "code", "execution_count": 19, "id": "a48bd9e9", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:02:00.882262Z", "iopub.status.busy": "2024-03-03T11:02:00.881551Z", "iopub.status.idle": "2024-03-03T11:02:00.955360Z", "shell.execute_reply": "2024-03-03T11:02:00.954296Z" }, "papermill": { "duration": 0.08959, "end_time": "2024-03-03T11:02:00.957400", "exception": false, "start_time": "2024-03-03T11:02:00.867810", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "load_dataset_3_factory 2\n", "Caching in ../../../../insurance/_cache/tab_ddpm_concat/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_2/insurance [80, 20]\n", "Caching in ../../../../insurance/_cache4/tab_ddpm_concat/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_4/insurance [80, 20]\n", "Caching in ../../../../insurance/_cache5/tab_ddpm_concat/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_5/insurance [160, 40]\n", "[320, 80]\n", "[320, 80]\n" ] } ], "source": [ "train_set, val_set = datasetsn(model=params[\"fixed_role_model\"], synth_data=params[\"synth_data\"])" ] }, { "cell_type": "code", "execution_count": 20, "id": "2fcb1418", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "execution": { "iopub.execute_input": "2024-03-03T11:02:00.985850Z", "iopub.status.busy": "2024-03-03T11:02:00.985495Z", "iopub.status.idle": "2024-03-03T11:02:01.418904Z", "shell.execute_reply": "2024-03-03T11:02:01.417808Z" }, "executionInfo": { "elapsed": 396850, "status": "error", "timestamp": 1696841446059, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "_bt1MQc5kpSk", "outputId": "01c1d3e5-ac64-461d-835a-b76f4a66e6d6", "papermill": { "duration": 0.4505, "end_time": "2024-03-03T11:02:01.421338", "exception": false, "start_time": "2024-03-03T11:02:00.970838", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating model of type \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[*] Embedding False True\n", "['tab_ddpm_concat'] 1\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.model.pipeline import remove_non_model_params\n", "from ml_utility_loss.loss_learning.estimator.pipeline import create_model\n", "from ml_utility_loss.util import filter_dict, clear_memory\n", "\n", "clear_memory()\n", "\n", "params2 = remove_non_model_params(params)\n", "adapters = filter_dict(preprocessor.adapter_sizes, params[\"models\"])\n", "\n", "model = create_model(\n", " adapters=adapters,\n", " #Body=\"twin_encoder\",\n", " **params2,\n", ")\n", "#cf.apply_weight_standardization(model, n_last_layers_ignore=0)\n", "print(model.models, len(model.adapters))" ] }, { "cell_type": "code", "execution_count": 21, "id": "938f94fc", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:02:01.451176Z", "iopub.status.busy": "2024-03-03T11:02:01.450374Z", "iopub.status.idle": "2024-03-03T11:02:01.454846Z", "shell.execute_reply": "2024-03-03T11:02:01.453995Z" }, "papermill": { "duration": 0.021568, "end_time": "2024-03-03T11:02:01.456847", "exception": false, "start_time": "2024-03-03T11:02:01.435279", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "study_name=f\"{model_name}_{dataset_name}\"" ] }, { "cell_type": "code", "execution_count": 22, "id": "12fb613e", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:02:01.484109Z", "iopub.status.busy": "2024-03-03T11:02:01.483802Z", "iopub.status.idle": "2024-03-03T11:02:01.491045Z", "shell.execute_reply": "2024-03-03T11:02:01.490241Z" }, "papermill": { "duration": 0.023247, "end_time": "2024-03-03T11:02:01.493037", "exception": false, "start_time": "2024-03-03T11:02:01.469790", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "8696065" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "count_parameters(model)" ] }, { "cell_type": "code", "execution_count": 23, "id": "bd386e57", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:02:01.520098Z", "iopub.status.busy": "2024-03-03T11:02:01.519804Z", "iopub.status.idle": "2024-03-03T11:02:01.622148Z", "shell.execute_reply": "2024-03-03T11:02:01.621289Z" }, "papermill": { "duration": 0.11822, "end_time": "2024-03-03T11:02:01.624199", "exception": false, "start_time": "2024-03-03T11:02:01.505979", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "========================================================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "========================================================================================================================\n", "MLUtilitySingle [2, 1071, 12] --\n", "├─Adapter: 1-1 [2, 1071, 12] --\n", "│ └─Sequential: 2-1 [2, 1071, 256] --\n", "│ │ └─FeedForward: 3-1 [2, 1071, 256] --\n", "│ │ │ └─Linear: 4-1 [2, 1071, 256] 3,328\n", "│ │ │ └─ReLU6: 4-2 [2, 1071, 256] --\n", "│ │ └─FeedForward: 3-2 [2, 1071, 256] --\n", "│ │ │ └─Linear: 4-3 [2, 1071, 256] 65,792\n", "│ │ │ └─ReLU6: 4-4 [2, 1071, 256] --\n", "│ │ └─FeedForward: 3-3 [2, 1071, 256] --\n", "│ │ │ └─Linear: 4-5 [2, 1071, 256] 65,792\n", "│ │ │ └─ReLU6: 4-6 [2, 1071, 256] --\n", "│ │ └─FeedForward: 3-4 [2, 1071, 256] --\n", "│ │ │ └─Linear: 4-7 [2, 1071, 256] 65,792\n", "│ │ │ └─ReLU6: 4-8 [2, 1071, 256] --\n", "│ │ └─FeedForward: 3-5 [2, 1071, 256] --\n", "│ │ │ └─Linear: 4-9 [2, 1071, 256] 65,792\n", "│ │ │ └─ReLU6: 4-10 [2, 1071, 256] --\n", "│ │ └─FeedForward: 3-6 [2, 1071, 256] --\n", "│ │ │ └─Linear: 4-11 [2, 1071, 256] 65,792\n", "│ │ │ └─ReLU6: 4-12 [2, 1071, 256] --\n", "│ │ └─FeedForward: 3-7 [2, 1071, 256] --\n", "│ │ │ └─Linear: 4-13 [2, 1071, 256] 65,792\n", "│ │ │ └─ReLU6: 4-14 [2, 1071, 256] --\n", "│ │ └─FeedForward: 3-8 [2, 1071, 256] --\n", "│ │ │ └─Linear: 4-15 [2, 1071, 256] 65,792\n", "│ │ │ └─LeakyHardtanh: 4-16 [2, 1071, 256] --\n", "├─Adapter: 1-2 [2, 267, 12] (recursive)\n", "│ └─Sequential: 2-2 [2, 267, 256] (recursive)\n", "│ │ └─FeedForward: 3-9 [2, 267, 256] (recursive)\n", "│ │ │ └─Linear: 4-17 [2, 267, 256] (recursive)\n", "│ │ │ └─ReLU6: 4-18 [2, 267, 256] --\n", "│ │ └─FeedForward: 3-10 [2, 267, 256] (recursive)\n", "│ │ │ └─Linear: 4-19 [2, 267, 256] (recursive)\n", "│ │ │ └─ReLU6: 4-20 [2, 267, 256] --\n", "│ │ └─FeedForward: 3-11 [2, 267, 256] (recursive)\n", "│ │ │ └─Linear: 4-21 [2, 267, 256] (recursive)\n", "│ │ │ └─ReLU6: 4-22 [2, 267, 256] --\n", "│ │ └─FeedForward: 3-12 [2, 267, 256] (recursive)\n", "│ │ │ └─Linear: 4-23 [2, 267, 256] (recursive)\n", "│ │ │ └─ReLU6: 4-24 [2, 267, 256] --\n", "│ │ └─FeedForward: 3-13 [2, 267, 256] (recursive)\n", "│ │ │ └─Linear: 4-25 [2, 267, 256] (recursive)\n", "│ │ │ └─ReLU6: 4-26 [2, 267, 256] --\n", "│ │ └─FeedForward: 3-14 [2, 267, 256] (recursive)\n", "│ │ │ └─Linear: 4-27 [2, 267, 256] (recursive)\n", "│ │ │ └─ReLU6: 4-28 [2, 267, 256] --\n", "│ │ └─FeedForward: 3-15 [2, 267, 256] (recursive)\n", "│ │ │ └─Linear: 4-29 [2, 267, 256] (recursive)\n", "│ │ │ └─ReLU6: 4-30 [2, 267, 256] --\n", "│ │ └─FeedForward: 3-16 [2, 267, 256] (recursive)\n", "│ │ │ └─Linear: 4-31 [2, 267, 256] (recursive)\n", "│ │ │ └─LeakyHardtanh: 4-32 [2, 267, 256] --\n", "├─TwinEncoder: 1-3 [2, 16384] --\n", "│ └─Encoder: 2-3 [2, 64, 256] --\n", "│ │ └─ModuleList: 3-18 -- (recursive)\n", "│ │ │ └─EncoderLayer: 4-33 [2, 1071, 256] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-1 [2, 1071, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-1 [2, 64, 256] 16,384\n", "│ │ │ │ │ └─MultiHeadAttention: 6-2 [2, 64, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-1 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-2 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-3 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-4 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-1 [2, 128, 64, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-5 [2, 64, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-6 [2, 64, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-3 [2, 1071, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-7 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-8 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-9 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-10 [2, 128, 1071, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-2 [2, 128, 1071, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-11 [2, 1071, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-12 [2, 1071, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-2 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-4 [2, 1071, 256] 65,792\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-5 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-6 [2, 1071, 256] 65,792\n", "│ │ │ └─EncoderLayer: 4-34 [2, 1071, 256] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-3 [2, 1071, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-7 [2, 64, 256] 16,384\n", "│ │ │ │ │ └─MultiHeadAttention: 6-8 [2, 64, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-13 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-14 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-15 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-16 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-3 [2, 128, 64, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-17 [2, 64, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-18 [2, 64, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-9 [2, 1071, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-19 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-20 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-21 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-22 [2, 128, 1071, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-4 [2, 128, 1071, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-23 [2, 1071, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-24 [2, 1071, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-4 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-10 [2, 1071, 256] 65,792\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-11 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-12 [2, 1071, 256] 65,792\n", "│ │ │ └─EncoderLayer: 4-35 [2, 1071, 256] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-5 [2, 1071, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-13 [2, 64, 256] 16,384\n", "│ │ │ │ │ └─MultiHeadAttention: 6-14 [2, 64, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-25 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-26 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-27 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-28 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-5 [2, 128, 64, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-29 [2, 64, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-30 [2, 64, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-15 [2, 1071, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-31 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-32 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-33 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-34 [2, 128, 1071, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-6 [2, 128, 1071, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-35 [2, 1071, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-36 [2, 1071, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-6 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-16 [2, 1071, 256] 65,792\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-17 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-18 [2, 1071, 256] 65,792\n", "│ │ │ └─EncoderLayer: 4-36 [2, 1071, 256] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-7 [2, 1071, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-19 [2, 64, 256] 16,384\n", "│ │ │ │ │ └─MultiHeadAttention: 6-20 [2, 64, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-37 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-38 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-39 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-40 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-7 [2, 128, 64, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-41 [2, 64, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-42 [2, 64, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-21 [2, 1071, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-43 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-44 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-45 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-46 [2, 128, 1071, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-8 [2, 128, 1071, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-47 [2, 1071, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-48 [2, 1071, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-8 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-22 [2, 1071, 256] 65,792\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-23 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-24 [2, 1071, 256] 65,792\n", "│ │ │ └─EncoderLayer: 4-37 [2, 64, 256] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-9 [2, 1071, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-25 [2, 64, 256] 16,384\n", "│ │ │ │ │ └─MultiHeadAttention: 6-26 [2, 64, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-49 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-50 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-51 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-52 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-9 [2, 128, 64, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-53 [2, 64, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-54 [2, 64, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-27 [2, 1071, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-55 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-56 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-57 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-58 [2, 128, 1071, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-10 [2, 128, 1071, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-59 [2, 1071, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-60 [2, 1071, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-10 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-28 [2, 1071, 256] 65,792\n", "│ │ │ │ │ └─LeakyHardtanh: 6-29 [2, 1071, 256] --\n", "│ │ │ │ │ └─Linear: 6-30 [2, 1071, 256] 65,792\n", "│ │ │ │ └─PoolingByMultiheadAttention: 5-11 [2, 64, 256] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-31 [2, 64, 256] 16,384\n", "│ │ │ │ │ └─SimpleMultiHeadAttention: 6-32 [2, 64, 256] --\n", "│ │ │ │ │ │ └─Linear: 7-61 [2, 64, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-62 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─Linear: 7-63 [2, 1071, 256] 65,536\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-64 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-11 [2, 128, 64, 1071] --\n", "│ │ │ │ │ │ └─Linear: 7-65 [2, 64, 256] 65,792\n", "│ │ │ │ │ │ └─Sigmoid: 7-66 [2, 64, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-12 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-33 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-34 [2, 64, 256] --\n", "│ │ │ │ │ └─Linear: 6-35 [2, 64, 256] (recursive)\n", "│ └─Encoder: 2-4 [2, 64, 256] (recursive)\n", "│ │ └─ModuleList: 3-18 -- (recursive)\n", "│ │ │ └─EncoderLayer: 4-38 [2, 267, 256] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-13 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-36 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-37 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-67 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-68 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-69 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-70 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-12 [2, 128, 64, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-71 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-72 [2, 64, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-38 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-73 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-74 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-75 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-76 [2, 128, 267, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-13 [2, 128, 267, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-77 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-78 [2, 267, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-14 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-39 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-40 [2, 267, 256] --\n", "│ │ │ │ │ └─Linear: 6-41 [2, 267, 256] (recursive)\n", "│ │ │ └─EncoderLayer: 4-39 [2, 267, 256] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-15 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-42 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-43 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-79 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-80 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-81 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-82 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-14 [2, 128, 64, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-83 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-84 [2, 64, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-44 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-85 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-86 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-87 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-88 [2, 128, 267, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-15 [2, 128, 267, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-89 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-90 [2, 267, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-16 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-45 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-46 [2, 267, 256] --\n", "│ │ │ │ │ └─Linear: 6-47 [2, 267, 256] (recursive)\n", "│ │ │ └─EncoderLayer: 4-40 [2, 267, 256] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-17 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-48 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-49 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-91 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-92 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-93 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-94 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-16 [2, 128, 64, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-95 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-96 [2, 64, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-50 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-97 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-98 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-99 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-100 [2, 128, 267, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-17 [2, 128, 267, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-101 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-102 [2, 267, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-18 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-51 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-52 [2, 267, 256] --\n", "│ │ │ │ │ └─Linear: 6-53 [2, 267, 256] (recursive)\n", "│ │ │ └─EncoderLayer: 4-41 [2, 267, 256] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-19 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-54 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-55 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-103 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-104 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-105 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-106 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-18 [2, 128, 64, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-107 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-108 [2, 64, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-56 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-109 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-110 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-111 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-112 [2, 128, 267, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-19 [2, 128, 267, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-113 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-114 [2, 267, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-20 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-57 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─LeakyHardsigmoid: 6-58 [2, 267, 256] --\n", "│ │ │ │ │ └─Linear: 6-59 [2, 267, 256] (recursive)\n", "│ │ │ └─EncoderLayer: 4-42 [2, 64, 256] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-21 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-60 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-61 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-115 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-116 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-117 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-118 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-20 [2, 128, 64, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-119 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-120 [2, 64, 256] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-62 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-121 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-122 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-123 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-124 [2, 128, 267, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-21 [2, 128, 267, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-125 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-126 [2, 267, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-22 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-63 [2, 267, 256] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-64 [2, 267, 256] --\n", "│ │ │ │ │ └─Linear: 6-65 [2, 267, 256] (recursive)\n", "│ │ │ │ └─PoolingByMultiheadAttention: 5-23 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-66 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─SimpleMultiHeadAttention: 6-67 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-127 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-128 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-129 [2, 267, 256] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-130 [2, 128, 64, 2] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-22 [2, 128, 64, 267] --\n", "│ │ │ │ │ │ └─Linear: 7-131 [2, 64, 256] (recursive)\n", "│ │ │ │ │ │ └─Sigmoid: 7-132 [2, 64, 256] --\n", "│ │ │ │ └─DoubleFeedForward: 5-24 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─Linear: 6-68 [2, 64, 256] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-69 [2, 64, 256] --\n", "│ │ │ │ │ └─Linear: 6-70 [2, 64, 256] (recursive)\n", "├─Head: 1-4 [2] --\n", "│ └─Sequential: 2-5 [2, 1] --\n", "│ │ └─FeedForward: 3-19 [2, 256] --\n", "│ │ │ └─Linear: 4-43 [2, 256] 4,194,560\n", "│ │ │ └─ReLU6: 4-44 [2, 256] --\n", "│ │ └─FeedForward: 3-20 [2, 256] --\n", "│ │ │ └─Linear: 4-45 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-46 [2, 256] --\n", "│ │ └─FeedForward: 3-21 [2, 256] --\n", "│ │ │ └─Linear: 4-47 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-48 [2, 256] --\n", "│ │ └─FeedForward: 3-22 [2, 256] --\n", "│ │ │ └─Linear: 4-49 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-50 [2, 256] --\n", "│ │ └─FeedForward: 3-23 [2, 256] --\n", "│ │ │ └─Linear: 4-51 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-52 [2, 256] --\n", "│ │ └─FeedForward: 3-24 [2, 256] --\n", "│ │ │ └─Linear: 4-53 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-54 [2, 256] --\n", "│ │ └─FeedForward: 3-25 [2, 256] --\n", "│ │ │ └─Linear: 4-55 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-56 [2, 256] --\n", "│ │ └─FeedForward: 3-26 [2, 1] --\n", "│ │ │ └─Linear: 4-57 [2, 1] 257\n", "│ │ │ └─Softsign: 4-58 [2, 1] --\n", "========================================================================================================================\n", "Total params: 8,696,065\n", "Trainable params: 8,696,065\n", "Non-trainable params: 0\n", "Total mult-adds (M): 25.74\n", "========================================================================================================================\n", "Input size (MB): 0.13\n", "Forward/backward pass size (MB): 234.98\n", "Params size (MB): 34.78\n", "Estimated Total Size (MB): 269.89\n", "========================================================================================================================" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torchinfo import summary\n", "\n", "role_model = params[\"fixed_role_model\"]\n", "s = train_set[0][role_model]\n", "summary(model[role_model], input_size=((2, *s[0].shape), (2, *s[1].shape)), depth=9) # 8 max" ] }, { "cell_type": "code", "execution_count": 24, "id": "0f42c4d1", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:02:01.656538Z", "iopub.status.busy": "2024-03-03T11:02:01.655477Z", "iopub.status.idle": "2024-03-03T11:57:24.427329Z", "shell.execute_reply": "2024-03-03T11:57:24.426341Z" }, "papermill": { "duration": 3322.809941, "end_time": "2024-03-03T11:57:24.448804", "exception": false, "start_time": "2024-03-03T11:02:01.638863", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "g_loss_mul 0.1\n", "Epoch 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.11390677978924942, 'avg_role_model_std_loss': 1.6075929376992917, 'avg_role_model_mean_pred_loss': 0.05657142685045635, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.11390677978924942, 'n_size': 320, 'n_batch': 80, 'duration': 84.47398400306702, 'duration_batch': 1.0559248000383377, 'duration_size': 0.2639812000095844, 'avg_pred_std': 0.09700191575684584}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.016991531396342907, 'avg_role_model_std_loss': 16.768777330477747, 'avg_role_model_mean_pred_loss': 0.0011311913316453647, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.016991531396342907, 'n_size': 80, 'n_batch': 20, 'duration': 17.97755455970764, 'duration_batch': 0.8988777279853821, 'duration_size': 0.22471943199634553, 'avg_pred_std': 0.016752634930890055}\n", "Epoch 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.02035002698939934, 'avg_role_model_std_loss': 3.3028829722441513, 'avg_role_model_mean_pred_loss': 0.0012390867967089977, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.02035002698939934, 'n_size': 320, 'n_batch': 80, 'duration': 84.68451976776123, 'duration_batch': 1.0585564970970154, 'duration_size': 0.26463912427425385, 'avg_pred_std': 0.06200033854111098}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.011681753185985144, 'avg_role_model_std_loss': 4.98499947126902, 'avg_role_model_mean_pred_loss': 0.00011886494331716512, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011681753185985144, 'n_size': 80, 'n_batch': 20, 'duration': 18.428929567337036, 'duration_batch': 0.9214464783668518, 'duration_size': 0.23036161959171295, 'avg_pred_std': 0.022514818981289864}\n", "Epoch 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.015553263086258085, 'avg_role_model_std_loss': 4.40048983076071, 'avg_role_model_mean_pred_loss': 0.0012610154882150097, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.015553263086258085, 'n_size': 320, 'n_batch': 80, 'duration': 84.47719740867615, 'duration_batch': 1.055964967608452, 'duration_size': 0.263991241902113, 'avg_pred_std': 0.04350193784048315}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.012001678717297182, 'avg_role_model_std_loss': 2.0838609129365295, 'avg_role_model_mean_pred_loss': 0.00018355202230315414, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012001678717297182, 'n_size': 80, 'n_batch': 20, 'duration': 18.23084259033203, 'duration_batch': 0.9115421295166015, 'duration_size': 0.22788553237915038, 'avg_pred_std': 0.026969157496932895}\n", "Epoch 3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.012385944992274744, 'avg_role_model_std_loss': 3.628137231519884, 'avg_role_model_mean_pred_loss': 0.0006446951736899908, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012385944992274744, 'n_size': 320, 'n_batch': 80, 'duration': 84.30943512916565, 'duration_batch': 1.0538679391145707, 'duration_size': 0.26346698477864267, 'avg_pred_std': 0.03804389561410062}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.011361928719270508, 'avg_role_model_std_loss': 2.9160453390245267, 'avg_role_model_mean_pred_loss': 0.00016351417628470699, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011361928719270508, 'n_size': 80, 'n_batch': 20, 'duration': 18.110987901687622, 'duration_batch': 0.9055493950843811, 'duration_size': 0.2263873487710953, 'avg_pred_std': 0.02250743337208405}\n", "Epoch 4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.012263958598123282, 'avg_role_model_std_loss': 3.76344175813676, 'avg_role_model_mean_pred_loss': 0.0004104777633484336, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012263958598123282, 'n_size': 320, 'n_batch': 80, 'duration': 84.6977150440216, 'duration_batch': 1.05872143805027, 'duration_size': 0.2646803595125675, 'avg_pred_std': 0.03591603521199431}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.011301952104258817, 'avg_role_model_std_loss': 2.2325485425771605, 'avg_role_model_mean_pred_loss': 0.00016856359858614668, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011301952104258817, 'n_size': 80, 'n_batch': 20, 'duration': 18.113478422164917, 'duration_batch': 0.9056739211082458, 'duration_size': 0.22641848027706146, 'avg_pred_std': 0.028445655293762685}\n", "Epoch 5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.011240640739742958, 'avg_role_model_std_loss': 2.6737590567842524, 'avg_role_model_mean_pred_loss': 0.00015413710455663006, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011240640739742958, 'n_size': 320, 'n_batch': 80, 'duration': 84.4088442325592, 'duration_batch': 1.05511055290699, 'duration_size': 0.2637776382267475, 'avg_pred_std': 0.04411998361465521}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009927860215248075, 'avg_role_model_std_loss': 1.5344439736139974, 'avg_role_model_mean_pred_loss': 2.0057166005040678e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009927860215248075, 'n_size': 80, 'n_batch': 20, 'duration': 18.26319670677185, 'duration_batch': 0.9131598353385926, 'duration_size': 0.22828995883464814, 'avg_pred_std': 0.03175867693498731}\n", "Epoch 6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.011858757758818683, 'avg_role_model_std_loss': 3.5301289926825974, 'avg_role_model_mean_pred_loss': 0.0002819792471252053, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011858757758818683, 'n_size': 320, 'n_batch': 80, 'duration': 84.27680397033691, 'duration_batch': 1.0534600496292115, 'duration_size': 0.26336501240730287, 'avg_pred_std': 0.03954663624172099}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.012078169028973207, 'avg_role_model_std_loss': 4.102863686283408, 'avg_role_model_mean_pred_loss': 0.000203830946862138, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012078169028973207, 'n_size': 80, 'n_batch': 20, 'duration': 18.098918199539185, 'duration_batch': 0.9049459099769592, 'duration_size': 0.2262364774942398, 'avg_pred_std': 0.021373049879912287}\n", "Epoch 7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.012539188255323097, 'avg_role_model_std_loss': 3.3092838149226282, 'avg_role_model_mean_pred_loss': 0.00023790401172929277, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012539188255323097, 'n_size': 320, 'n_batch': 80, 'duration': 84.52259421348572, 'duration_batch': 1.0565324276685715, 'duration_size': 0.26413310691714287, 'avg_pred_std': 0.04043100443377625}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009674683933417328, 'avg_role_model_std_loss': 0.9977919148524961, 'avg_role_model_mean_pred_loss': 1.1186481361031685e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009674683933417328, 'n_size': 80, 'n_batch': 20, 'duration': 18.143858671188354, 'duration_batch': 0.9071929335594178, 'duration_size': 0.22679823338985444, 'avg_pred_std': 0.04099587097298354}\n", "Epoch 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.011895681292844528, 'avg_role_model_std_loss': 2.8657706266691036, 'avg_role_model_mean_pred_loss': 0.0002743295693480711, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011895681292844528, 'n_size': 320, 'n_batch': 80, 'duration': 84.31953191757202, 'duration_batch': 1.0539941489696503, 'duration_size': 0.26349853724241257, 'avg_pred_std': 0.042222958011552694}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009975977733847684, 'avg_role_model_std_loss': 1.1835033869873883, 'avg_role_model_mean_pred_loss': 2.0587880793299097e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009975977733847684, 'n_size': 80, 'n_batch': 20, 'duration': 18.042185306549072, 'duration_batch': 0.9021092653274536, 'duration_size': 0.2255273163318634, 'avg_pred_std': 0.035635373927652834}\n", "Epoch 9\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010924837501352157, 'avg_role_model_std_loss': 2.593468050354761, 'avg_role_model_mean_pred_loss': 9.935248510884783e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010924837501352157, 'n_size': 320, 'n_batch': 80, 'duration': 84.26527738571167, 'duration_batch': 1.0533159673213959, 'duration_size': 0.26332899183034897, 'avg_pred_std': 0.045242260512895885}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.010778275438860873, 'avg_role_model_std_loss': 3.146427918606969, 'avg_role_model_mean_pred_loss': 7.896810024603518e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010778275438860873, 'n_size': 80, 'n_batch': 20, 'duration': 18.036173820495605, 'duration_batch': 0.9018086910247802, 'duration_size': 0.22545217275619506, 'avg_pred_std': 0.022399809048511087}\n", "Epoch 10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.011383894624395907, 'avg_role_model_std_loss': 2.930524671732963, 'avg_role_model_mean_pred_loss': 0.00018457749493845378, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011383894624395907, 'n_size': 320, 'n_batch': 80, 'duration': 84.58404278755188, 'duration_batch': 1.0573005348443985, 'duration_size': 0.2643251337110996, 'avg_pred_std': 0.042266642485628836}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009755220862280112, 'avg_role_model_std_loss': 1.5220770264881138, 'avg_role_model_mean_pred_loss': 1.059227741908586e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009755220862280112, 'n_size': 80, 'n_batch': 20, 'duration': 17.980788469314575, 'duration_batch': 0.8990394234657287, 'duration_size': 0.22475985586643218, 'avg_pred_std': 0.03333416555542499}\n", "Epoch 11\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010721845665102592, 'avg_role_model_std_loss': 1.8063977722115752, 'avg_role_model_mean_pred_loss': 9.594455589044034e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010721845665102592, 'n_size': 320, 'n_batch': 80, 'duration': 84.16543889045715, 'duration_batch': 1.0520679861307145, 'duration_size': 0.2630169965326786, 'avg_pred_std': 0.047608432272681966}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009559676682692952, 'avg_role_model_std_loss': 1.3243680567820775, 'avg_role_model_mean_pred_loss': 9.947778448005095e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009559676682692952, 'n_size': 80, 'n_batch': 20, 'duration': 18.158392667770386, 'duration_batch': 0.9079196333885193, 'duration_size': 0.22697990834712983, 'avg_pred_std': 0.031689733476378025}\n", "Epoch 12\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.011144705655783581, 'avg_role_model_std_loss': 2.4096494605510026, 'avg_role_model_mean_pred_loss': 0.00032173466922595813, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011144705655783581, 'n_size': 320, 'n_batch': 80, 'duration': 84.13634371757507, 'duration_batch': 1.0517042964696883, 'duration_size': 0.2629260741174221, 'avg_pred_std': 0.044573025617864914}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009567469572357368, 'avg_role_model_std_loss': 1.5696021520542787, 'avg_role_model_mean_pred_loss': 1.1404201347278014e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009567469572357368, 'n_size': 80, 'n_batch': 20, 'duration': 17.92456030845642, 'duration_batch': 0.896228015422821, 'duration_size': 0.22405700385570526, 'avg_pred_std': 0.03133328107651323}\n", "Epoch 13\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010633787492497503, 'avg_role_model_std_loss': 2.5790517816062755, 'avg_role_model_mean_pred_loss': 0.0003470440018317923, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010633787492497503, 'n_size': 320, 'n_batch': 80, 'duration': 84.38668012619019, 'duration_batch': 1.0548335015773773, 'duration_size': 0.26370837539434433, 'avg_pred_std': 0.04384268364228774}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009436554487911053, 'avg_role_model_std_loss': 1.4380358837069962, 'avg_role_model_mean_pred_loss': 1.7537142142520778e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009436554487911053, 'n_size': 80, 'n_batch': 20, 'duration': 18.151715517044067, 'duration_batch': 0.9075857758522033, 'duration_size': 0.22689644396305084, 'avg_pred_std': 0.03785984092392027}\n", "Epoch 14\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.01104943135223948, 'avg_role_model_std_loss': 2.1088742282857766, 'avg_role_model_mean_pred_loss': 0.0002007622467717723, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01104943135223948, 'n_size': 320, 'n_batch': 80, 'duration': 84.27291560173035, 'duration_batch': 1.0534114450216294, 'duration_size': 0.26335286125540736, 'avg_pred_std': 0.0478535434929654}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.01036879940838844, 'avg_role_model_std_loss': 0.8904372502282059, 'avg_role_model_mean_pred_loss': 7.801706653562946e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01036879940838844, 'n_size': 80, 'n_batch': 20, 'duration': 18.29806423187256, 'duration_batch': 0.914903211593628, 'duration_size': 0.228725802898407, 'avg_pred_std': 0.051562142791226506}\n", "Epoch 15\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.011128615495272243, 'avg_role_model_std_loss': 1.7709791813538458, 'avg_role_model_mean_pred_loss': 0.00021516576313122763, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011128615495272243, 'n_size': 320, 'n_batch': 80, 'duration': 85.58019542694092, 'duration_batch': 1.0697524428367615, 'duration_size': 0.2674381107091904, 'avg_pred_std': 0.049932096980046484}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.00949138020951068, 'avg_role_model_std_loss': 1.0204340409804673, 'avg_role_model_mean_pred_loss': 1.8300672786680795e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00949138020951068, 'n_size': 80, 'n_batch': 20, 'duration': 18.28961968421936, 'duration_batch': 0.914480984210968, 'duration_size': 0.228620246052742, 'avg_pred_std': 0.04147392325103283}\n", "Epoch 16\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010429539207643756, 'avg_role_model_std_loss': 1.5362692275628125, 'avg_role_model_mean_pred_loss': 0.00025102296517166747, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010429539207643756, 'n_size': 320, 'n_batch': 80, 'duration': 84.87365674972534, 'duration_batch': 1.0609207093715667, 'duration_size': 0.26523017734289167, 'avg_pred_std': 0.05195366198895499}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009322703200996329, 'avg_role_model_std_loss': 1.743853185043554, 'avg_role_model_mean_pred_loss': 1.2576797408925255e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009322703200996329, 'n_size': 80, 'n_batch': 20, 'duration': 17.977401971817017, 'duration_batch': 0.8988700985908509, 'duration_size': 0.22471752464771272, 'avg_pred_std': 0.03613527067936957}\n", "Epoch 17\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010465346425053212, 'avg_role_model_std_loss': 1.4406497691373554, 'avg_role_model_mean_pred_loss': 0.00016243213518344694, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010465346425053212, 'n_size': 320, 'n_batch': 80, 'duration': 84.5495491027832, 'duration_batch': 1.0568693637847901, 'duration_size': 0.26421734094619753, 'avg_pred_std': 0.04888715610140935}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009582025398412953, 'avg_role_model_std_loss': 3.081415921854597, 'avg_role_model_mean_pred_loss': 5.574405601427302e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009582025398412953, 'n_size': 80, 'n_batch': 20, 'duration': 17.912070274353027, 'duration_batch': 0.8956035137176513, 'duration_size': 0.22390087842941284, 'avg_pred_std': 0.025306159909814597}\n", "Epoch 18\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010162418862455525, 'avg_role_model_std_loss': 1.5145359354102208, 'avg_role_model_mean_pred_loss': 6.209586736886912e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010162418862455525, 'n_size': 320, 'n_batch': 80, 'duration': 84.28248190879822, 'duration_batch': 1.0535310238599778, 'duration_size': 0.26338275596499444, 'avg_pred_std': 0.0519078379671555}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009476123469903541, 'avg_role_model_std_loss': 1.0078544585018676, 'avg_role_model_mean_pred_loss': 1.8715846066100564e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009476123469903541, 'n_size': 80, 'n_batch': 20, 'duration': 17.901034355163574, 'duration_batch': 0.8950517177581787, 'duration_size': 0.22376292943954468, 'avg_pred_std': 0.03873717384412885}\n", "Epoch 19\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010770342201794847, 'avg_role_model_std_loss': 2.009929773015307, 'avg_role_model_mean_pred_loss': 0.00013072784897420476, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010770342201794847, 'n_size': 320, 'n_batch': 80, 'duration': 84.82972049713135, 'duration_batch': 1.060371506214142, 'duration_size': 0.2650928765535355, 'avg_pred_std': 0.04557969000888988}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009593866099567094, 'avg_role_model_std_loss': 0.7345563380621798, 'avg_role_model_mean_pred_loss': 1.3211068784391156e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009593866099567094, 'n_size': 80, 'n_batch': 20, 'duration': 18.183232307434082, 'duration_batch': 0.9091616153717041, 'duration_size': 0.22729040384292604, 'avg_pred_std': 0.04364799705799669}\n", "Epoch 20\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.011184974256775605, 'avg_role_model_std_loss': 1.735422194222258, 'avg_role_model_mean_pred_loss': 0.00015653051535653267, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011184974256775605, 'n_size': 320, 'n_batch': 80, 'duration': 84.4685800075531, 'duration_batch': 1.0558572500944137, 'duration_size': 0.2639643125236034, 'avg_pred_std': 0.04943769198143855}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.00922505634080153, 'avg_role_model_std_loss': 1.1232440773048438, 'avg_role_model_mean_pred_loss': 2.5773531761252856e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00922505634080153, 'n_size': 80, 'n_batch': 20, 'duration': 18.073370695114136, 'duration_batch': 0.9036685347557067, 'duration_size': 0.22591713368892669, 'avg_pred_std': 0.0375345426844433}\n", "Epoch 21\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010670667553040403, 'avg_role_model_std_loss': 1.6657181285418345, 'avg_role_model_mean_pred_loss': 0.00026890914427888377, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010670667553040403, 'n_size': 320, 'n_batch': 80, 'duration': 84.49483013153076, 'duration_batch': 1.0561853766441345, 'duration_size': 0.2640463441610336, 'avg_pred_std': 0.0491774610709399}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.011584205193139496, 'avg_role_model_std_loss': 4.0657152492325626, 'avg_role_model_mean_pred_loss': 3.660291929232784e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011584205193139496, 'n_size': 80, 'n_batch': 20, 'duration': 17.99208927154541, 'duration_batch': 0.8996044635772705, 'duration_size': 0.22490111589431763, 'avg_pred_std': 0.030582628422416748}\n", "Epoch 22\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.011918405044889368, 'avg_role_model_std_loss': 1.9266218843145224, 'avg_role_model_mean_pred_loss': 0.000546820462109244, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011918405044889368, 'n_size': 320, 'n_batch': 80, 'duration': 84.82066988945007, 'duration_batch': 1.060258373618126, 'duration_size': 0.2650645934045315, 'avg_pred_std': 0.049488182202912866}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009260490916494746, 'avg_role_model_std_loss': 2.0702868502448837, 'avg_role_model_mean_pred_loss': 1.115468241086326e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009260490916494746, 'n_size': 80, 'n_batch': 20, 'duration': 18.213539123535156, 'duration_batch': 0.9106769561767578, 'duration_size': 0.22766923904418945, 'avg_pred_std': 0.03006206527352333}\n", "Epoch 23\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010364986025706457, 'avg_role_model_std_loss': 6.042574923066354, 'avg_role_model_mean_pred_loss': 0.00012542121491789134, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010364986025706457, 'n_size': 320, 'n_batch': 80, 'duration': 84.50646162033081, 'duration_batch': 1.0563307702541351, 'duration_size': 0.2640826925635338, 'avg_pred_std': 0.045613451191994156}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009321882369113155, 'avg_role_model_std_loss': 1.242719502127511, 'avg_role_model_mean_pred_loss': 2.0408335805721654e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009321882369113155, 'n_size': 80, 'n_batch': 20, 'duration': 17.993885278701782, 'duration_batch': 0.8996942639350891, 'duration_size': 0.22492356598377228, 'avg_pred_std': 0.03894345450680703}\n", "Epoch 24\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010127854481288523, 'avg_role_model_std_loss': 1.3548822252588457, 'avg_role_model_mean_pred_loss': 9.947211313221551e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010127854481288523, 'n_size': 320, 'n_batch': 80, 'duration': 84.6096601486206, 'duration_batch': 1.0576207518577576, 'duration_size': 0.2644051879644394, 'avg_pred_std': 0.05492281899787486}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.00973101281633717, 'avg_role_model_std_loss': 2.803459626334097, 'avg_role_model_mean_pred_loss': 3.2155193029839366e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00973101281633717, 'n_size': 80, 'n_batch': 20, 'duration': 18.34537410736084, 'duration_batch': 0.917268705368042, 'duration_size': 0.2293171763420105, 'avg_pred_std': 0.02301093057030812}\n", "Epoch 25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010673620513580317, 'avg_role_model_std_loss': 1.630422155917519, 'avg_role_model_mean_pred_loss': 0.0002799555220394656, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010673620513580317, 'n_size': 320, 'n_batch': 80, 'duration': 84.58660197257996, 'duration_batch': 1.0573325246572494, 'duration_size': 0.26433313116431234, 'avg_pred_std': 0.046520658489316705}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009462428228289355, 'avg_role_model_std_loss': 2.6746044414641346, 'avg_role_model_mean_pred_loss': 1.355185217843946e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009462428228289355, 'n_size': 80, 'n_batch': 20, 'duration': 18.011495113372803, 'duration_batch': 0.9005747556686401, 'duration_size': 0.22514368891716002, 'avg_pred_std': 0.025657533458434044}\n", "Epoch 26\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.009991624747397055, 'avg_role_model_std_loss': 1.627927773291799, 'avg_role_model_mean_pred_loss': 8.125562307790029e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009991624747397055, 'n_size': 320, 'n_batch': 80, 'duration': 84.53830814361572, 'duration_batch': 1.0567288517951965, 'duration_size': 0.2641822129487991, 'avg_pred_std': 0.050572005746653305}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.01111485290057317, 'avg_role_model_std_loss': 2.272924814505939, 'avg_role_model_mean_pred_loss': 0.000244029233883869, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01111485290057317, 'n_size': 80, 'n_batch': 20, 'duration': 18.182761192321777, 'duration_batch': 0.9091380596160888, 'duration_size': 0.2272845149040222, 'avg_pred_std': 0.02627429796848446}\n", "Epoch 27\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.011247798605666048, 'avg_role_model_std_loss': 2.117468389116715, 'avg_role_model_mean_pred_loss': 0.00020426704126478152, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011247798605666048, 'n_size': 320, 'n_batch': 80, 'duration': 84.47703838348389, 'duration_batch': 1.0559629797935486, 'duration_size': 0.26399074494838715, 'avg_pred_std': 0.04596492229029536}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009433183281817036, 'avg_role_model_std_loss': 1.450164285198241, 'avg_role_model_mean_pred_loss': 1.8529243547119787e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009433183281817036, 'n_size': 80, 'n_batch': 20, 'duration': 18.082876205444336, 'duration_batch': 0.9041438102722168, 'duration_size': 0.2260359525680542, 'avg_pred_std': 0.03382655227323994}\n", "Epoch 28\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.009869320008056093, 'avg_role_model_std_loss': 1.5068146906768447, 'avg_role_model_mean_pred_loss': 9.115362455692777e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009869320008056093, 'n_size': 320, 'n_batch': 80, 'duration': 84.38812756538391, 'duration_batch': 1.054851594567299, 'duration_size': 0.26371289864182473, 'avg_pred_std': 0.05035865947138518}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009240709943878756, 'avg_role_model_std_loss': 1.3003549632573517, 'avg_role_model_mean_pred_loss': 6.990615373461684e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009240709943878756, 'n_size': 80, 'n_batch': 20, 'duration': 18.295058012008667, 'duration_batch': 0.9147529006004333, 'duration_size': 0.22868822515010834, 'avg_pred_std': 0.03447220445377752}\n", "Epoch 29\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.01015008869671874, 'avg_role_model_std_loss': 1.6426295430076934, 'avg_role_model_mean_pred_loss': 8.031235901584562e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01015008869671874, 'n_size': 320, 'n_batch': 80, 'duration': 84.30770874023438, 'duration_batch': 1.0538463592529297, 'duration_size': 0.2634615898132324, 'avg_pred_std': 0.04944546818442177}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.008979192694823723, 'avg_role_model_std_loss': 1.013558323823196, 'avg_role_model_mean_pred_loss': 1.5330314378037e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008979192694823723, 'n_size': 80, 'n_batch': 20, 'duration': 17.981907844543457, 'duration_batch': 0.8990953922271728, 'duration_size': 0.2247738480567932, 'avg_pred_std': 0.035569448093883696}\n", "Epoch 30\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.010348305049683404, 'avg_role_model_std_loss': 1.0208274961811328, 'avg_role_model_mean_pred_loss': 0.00023567322375602772, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010348305049683404, 'n_size': 320, 'n_batch': 80, 'duration': 84.40145015716553, 'duration_batch': 1.055018126964569, 'duration_size': 0.2637545317411423, 'avg_pred_std': 0.05335634221555665}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009130275297502521, 'avg_role_model_std_loss': 1.2715821872590367, 'avg_role_model_mean_pred_loss': 7.716895467813068e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009130275297502521, 'n_size': 80, 'n_batch': 20, 'duration': 18.11615824699402, 'duration_batch': 0.9058079123497009, 'duration_size': 0.22645197808742523, 'avg_pred_std': 0.03190355768892914}\n", "Stopped False\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Eval loss {'role_model': 'tab_ddpm_concat', 'n_size': 399, 'n_batch': 100, 'role_model_metrics': {'avg_loss': 0.019684911810006377, 'avg_g_mag_loss': nan, 'avg_g_cos_loss': 0.0021236211053054007, 'pred_duration': 2.178314447402954, 'grad_duration': 1.2552392482757568, 'total_duration': 3.433553695678711, 'pred_std': 0.06186112388968468, 'std_loss': 0.5796476006507874, 'mean_pred_loss': 4.718968921224587e-05, 'pred_rmse': 0.14030292630195618, 'pred_mae': 0.10000143945217133, 'pred_mape': 0.7038320302963257, 'grad_rmse': 0.2803622782230377, 'grad_mae': 0.19953711330890656, 'grad_mape': 0.9926933646202087}, 'non_role_model_metrics': {'avg_loss': 0, 'avg_g_mag_loss': 0, 'avg_g_cos_loss': 0, 'avg_pred_duration': 0, 'avg_grad_duration': 0, 'avg_total_duration': 0, 'avg_pred_std': 0, 'avg_std_loss': 0, 'avg_mean_pred_loss': 0}, 'avg_metrics': {'avg_loss': 0.019684911810006377, 'avg_g_mag_loss': nan, 'avg_g_cos_loss': 0.0021236211053054007, 'avg_pred_duration': 2.178314447402954, 'avg_grad_duration': 1.2552392482757568, 'avg_total_duration': 3.433553695678711, 'avg_pred_std': 0.06186112388968468, 'avg_std_loss': 0.5796476006507874, 'avg_mean_pred_loss': 4.718968921224587e-05}, 'min_metrics': {'avg_loss': 0.019684911810006377, 'avg_g_mag_loss': nan, 'avg_g_cos_loss': 0.0021236211053054007, 'pred_duration': 2.178314447402954, 'grad_duration': 1.2552392482757568, 'total_duration': 3.433553695678711, 'pred_std': 0.06186112388968468, 'std_loss': 0.5796476006507874, 'mean_pred_loss': 4.718968921224587e-05, 'pred_rmse': 0.14030292630195618, 'pred_mae': 0.10000143945217133, 'pred_mape': 0.7038320302963257, 'grad_rmse': 0.2803622782230377, 'grad_mae': 0.19953711330890656, 'grad_mape': 0.9926933646202087}, 'model_metrics': {'tab_ddpm_concat': {'avg_loss': 0.019684911810006377, 'avg_g_mag_loss': nan, 'avg_g_cos_loss': 0.0021236211053054007, 'pred_duration': 2.178314447402954, 'grad_duration': 1.2552392482757568, 'total_duration': 3.433553695678711, 'pred_std': 0.06186112388968468, 'std_loss': 0.5796476006507874, 'mean_pred_loss': 4.718968921224587e-05, 'pred_rmse': 0.14030292630195618, 'pred_mae': 0.10000143945217133, 'pred_mape': 0.7038320302963257, 'grad_rmse': 0.2803622782230377, 'grad_mae': 0.19953711330890656, 'grad_mape': 0.9926933646202087}}}\n" ] } ], "source": [ "import torch\n", "from ml_utility_loss.loss_learning.estimator.pipeline import train, train_2\n", "from ml_utility_loss.loss_learning.estimator.process_simple import train_epoch, eval as _eval\n", "from ml_utility_loss.params import GradientPenaltyMode\n", "from ml_utility_loss.util import clear_memory\n", "import time\n", "#torch.autograd.set_detect_anomaly(True)\n", "\n", "clear_memory()\n", "\n", "opt = params[\"Optim\"](model.parameters())\n", "loss = train_2(\n", " [train_set, val_set, test_set],\n", " preprocessor=preprocessor,\n", " whole_model=model,\n", " optim=opt,\n", " log_dir=\"logs\",\n", " checkpoint_dir=\"checkpoints\",\n", " verbose=True,\n", " allow_same_prediction=allow_same_prediction,\n", " wandb=wandb if log_wandb else None,\n", " study_name=study_name,\n", " **params\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "id": "9b514a07", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:57:24.490374Z", "iopub.status.busy": "2024-03-03T11:57:24.489536Z", "iopub.status.idle": "2024-03-03T11:57:24.493889Z", "shell.execute_reply": "2024-03-03T11:57:24.493023Z" }, "papermill": { "duration": 0.027217, "end_time": "2024-03-03T11:57:24.495842", "exception": false, "start_time": "2024-03-03T11:57:24.468625", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "model = loss[\"whole_model\"]\n", "opt = loss[\"optim\"]" ] }, { "cell_type": "code", "execution_count": 26, "id": "331a49e1", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:57:24.536429Z", "iopub.status.busy": "2024-03-03T11:57:24.535822Z", "iopub.status.idle": "2024-03-03T11:57:24.615653Z", "shell.execute_reply": "2024-03-03T11:57:24.614420Z" }, "papermill": { "duration": 0.103308, "end_time": "2024-03-03T11:57:24.618128", "exception": false, "start_time": "2024-03-03T11:57:24.514820", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import torch\n", "from copy import deepcopy\n", "\n", "torch.save(deepcopy(model.state_dict()), \"model.pt\")\n", "#torch.save(deepcopy(opt.state_dict()), \"optim.pt\")" ] }, { "cell_type": "code", "execution_count": 27, "id": "123b4b17", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:57:24.659332Z", "iopub.status.busy": "2024-03-03T11:57:24.659016Z", "iopub.status.idle": "2024-03-03T11:57:24.929981Z", "shell.execute_reply": "2024-03-03T11:57:24.928997Z" }, "papermill": { "duration": 0.294048, "end_time": "2024-03-03T11:57:24.931923", "exception": false, "start_time": "2024-03-03T11:57:24.637875", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "history = loss[\"history\"]\n", "history.to_csv(\"history.csv\")\n", "history[[\"avg_loss_train\", \"avg_loss_test\"]].plot()" ] }, { "cell_type": "code", "execution_count": 28, "id": "2586ba0a", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:57:24.972376Z", "iopub.status.busy": "2024-03-03T11:57:24.972050Z", "iopub.status.idle": "2024-03-03T11:59:09.938182Z", "shell.execute_reply": "2024-03-03T11:59:09.937120Z" }, "papermill": { "duration": 104.989386, "end_time": "2024-03-03T11:59:09.940839", "exception": false, "start_time": "2024-03-03T11:57:24.951453", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "\n", "from ml_utility_loss.loss_learning.estimator.pipeline import eval\n", "#eval_loss = loss[\"eval_loss\"]\n", "\n", "batch_size = params[\"batch_size_low\"] if \"batch_size_low\" in params else params[\"batch_size\"]\n", "\n", "eval_loss = eval(\n", " test_set, model,\n", " batch_size=batch_size,\n", ")" ] }, { "cell_type": "code", "execution_count": 29, "id": "187137f6", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:59:09.984105Z", "iopub.status.busy": "2024-03-03T11:59:09.983752Z", "iopub.status.idle": "2024-03-03T11:59:10.005079Z", "shell.execute_reply": "2024-03-03T11:59:10.004227Z" }, "papermill": { "duration": 0.045165, "end_time": "2024-03-03T11:59:10.007079", "exception": false, "start_time": "2024-03-03T11:59:09.961914", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
avg_g_cos_lossavg_g_mag_lossavg_lossgrad_durationgrad_maegrad_mapegrad_rmsemean_pred_losspred_durationpred_maepred_mapepred_rmsepred_stdstd_losstotal_duration
tab_ddpm_concat0.0025560.5241130.0196851.2617410.1995370.9926930.2803620.0000472.1811290.1000010.7038320.1403030.0618610.5796483.442869
\n", "
" ], "text/plain": [ " avg_g_cos_loss avg_g_mag_loss avg_loss grad_duration \\\n", "tab_ddpm_concat 0.002556 0.524113 0.019685 1.261741 \n", "\n", " grad_mae grad_mape grad_rmse mean_pred_loss \\\n", "tab_ddpm_concat 0.199537 0.992693 0.280362 0.000047 \n", "\n", " pred_duration pred_mae pred_mape pred_rmse pred_std \\\n", "tab_ddpm_concat 2.181129 0.100001 0.703832 0.140303 0.061861 \n", "\n", " std_loss total_duration \n", "tab_ddpm_concat 0.579648 3.442869 " ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "metrics = pd.DataFrame(eval_loss[\"model_metrics\"]).T\n", "metrics.to_csv(\"eval.csv\")\n", "metrics" ] }, { "cell_type": "code", "execution_count": 30, "id": "123d305b", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:59:10.059613Z", "iopub.status.busy": "2024-03-03T11:59:10.059114Z", "iopub.status.idle": "2024-03-03T11:59:10.519619Z", "shell.execute_reply": "2024-03-03T11:59:10.518672Z" }, "papermill": { "duration": 0.492551, "end_time": "2024-03-03T11:59:10.521766", "exception": false, "start_time": "2024-03-03T11:59:10.029215", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.util import clear_memory\n", "clear_memory()" ] }, { "cell_type": "code", "execution_count": 31, "id": "a3eecc2a", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T11:59:10.564330Z", "iopub.status.busy": "2024-03-03T11:59:10.563999Z", "iopub.status.idle": "2024-03-03T12:01:02.226778Z", "shell.execute_reply": "2024-03-03T12:01:02.225923Z" }, "papermill": { "duration": 111.687045, "end_time": "2024-03-03T12:01:02.229479", "exception": false, "start_time": "2024-03-03T11:59:10.542434", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caching in ../../../../insurance/_cache_test/tab_ddpm_concat/all inf False\n" ] } ], "source": [ "#\"\"\"\n", "from ml_utility_loss.loss_learning.estimator.process import pred, pred_2\n", "from ml_utility_loss.util import stack_samples\n", "\n", "#samples = test_set[list(range(len(test_set)))]\n", "#y = {m: pred(model[m], s) for m, s in samples.items()}\n", "y = pred_2(model, test_set, batch_size=batch_size)\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": 32, "id": "6ab51db8", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T12:01:02.273874Z", "iopub.status.busy": "2024-03-03T12:01:02.273107Z", "iopub.status.idle": "2024-03-03T12:01:02.290903Z", "shell.execute_reply": "2024-03-03T12:01:02.290177Z" }, "papermill": { "duration": 0.042098, "end_time": "2024-03-03T12:01:02.292899", "exception": false, "start_time": "2024-03-03T12:01:02.250801", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "from ml_utility_loss.util import transpose_dict\n", "\n", "os.makedirs(\"pred\", exist_ok=True)\n", "y2 = transpose_dict(y)\n", "for k, v in y2.items():\n", " df = pd.DataFrame(v)\n", " df.to_csv(f\"pred/{k}.csv\")" ] }, { "cell_type": "code", "execution_count": 33, "id": "d81a30f1", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T12:01:02.334500Z", "iopub.status.busy": "2024-03-03T12:01:02.334198Z", "iopub.status.idle": "2024-03-03T12:01:02.339555Z", "shell.execute_reply": "2024-03-03T12:01:02.338676Z" }, "papermill": { "duration": 0.027945, "end_time": "2024-03-03T12:01:02.341506", "exception": false, "start_time": "2024-03-03T12:01:02.313561", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'tab_ddpm_concat': 0.038165719780903475}\n" ] } ], "source": [ "print({k: sum(v[\"pred\"])/len(v[\"pred\"]) for k, v in y.items()})" ] }, { "cell_type": "code", "execution_count": 34, "id": "3b3ff322", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T12:01:02.383613Z", "iopub.status.busy": "2024-03-03T12:01:02.383327Z", "iopub.status.idle": "2024-03-03T12:01:02.706189Z", "shell.execute_reply": "2024-03-03T12:01:02.705209Z" }, "papermill": { "duration": 0.346517, "end_time": "2024-03-03T12:01:02.708385", "exception": false, "start_time": "2024-03-03T12:01:02.361868", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_pred_density_2\n", "\n", "_ = plot_pred_density_2(y)" ] }, { "cell_type": "code", "execution_count": 35, "id": "e79e4b0f", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T12:01:02.750620Z", "iopub.status.busy": "2024-03-03T12:01:02.750272Z", "iopub.status.idle": "2024-03-03T12:01:03.054973Z", "shell.execute_reply": "2024-03-03T12:01:03.054008Z" }, "papermill": { "duration": 0.328422, "end_time": "2024-03-03T12:01:03.057207", "exception": false, "start_time": "2024-03-03T12:01:02.728785", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_density_3\n", "\n", "_ = plot_density_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))" ] }, { "cell_type": "code", "execution_count": 36, "id": "745adde1", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T12:01:03.102557Z", "iopub.status.busy": "2024-03-03T12:01:03.101747Z", "iopub.status.idle": "2024-03-03T12:01:03.274006Z", "shell.execute_reply": "2024-03-03T12:01:03.273036Z" }, "papermill": { "duration": 0.197286, "end_time": "2024-03-03T12:01:03.276114", "exception": false, "start_time": "2024-03-03T12:01:03.078828", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_box_3\n", "\n", "_ = plot_box_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))" ] }, { "cell_type": "code", "execution_count": 37, "id": "eabe1bab", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T12:01:03.322849Z", "iopub.status.busy": "2024-03-03T12:01:03.322537Z", "iopub.status.idle": "2024-03-03T12:01:03.618114Z", "shell.execute_reply": "2024-03-03T12:01:03.617173Z" }, "papermill": { "duration": 0.322102, "end_time": "2024-03-03T12:01:03.620373", "exception": false, "start_time": "2024-03-03T12:01:03.298271", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#\"\"\"\n", "from ml_utility_loss.loss_learning.visualization import plot_grad, plot_grad_2, plot_grad_3\n", "import matplotlib.pyplot as plt\n", "\n", "#plot_grad_2(y, model.models)\n", "for m in model.models:\n", " ym = y[m]\n", " fig, ax = plt.subplots()\n", " plot_grad_3(ym[\"error\"], ym[\"grad\"], name=f\"{m}_grad\", fig=fig, ax=ax)\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "54c0e9f3", "metadata": { "papermill": { "duration": 0.022768, "end_time": "2024-03-03T12:01:03.666008", "exception": false, "start_time": "2024-03-03T12:01:03.643240", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "celltoolbar": "Tags", "colab": { "authorship_tag": "ABX9TyOOVfelovKP9fLGU7SvvRie", "gpuType": "T4", "mount_file_id": "17POSGAvge8y9DW9WGs2jLkibaRjToayg", "provenance": [] }, "kaggle": { "accelerator": "gpu", "dataSources": [], "dockerImageVersionId": 30648, "isGpuEnabled": true, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "papermill": { "default_parameters": {}, "duration": 3563.29435, "end_time": "2024-03-03T12:01:06.412574", "environment_variables": {}, "exception": null, "input_path": "eval/insurance/tab_ddpm_concat/4/mlu-eval.ipynb", "output_path": "eval/insurance/tab_ddpm_concat/4/mlu-eval.ipynb", "parameters": { "allow_same_prediction": true, "dataset": "insurance", "dataset_name": "insurance", "debug": false, "folder": "eval", "gp": false, "gp_multiply": false, "log_wandb": false, "param_index": 3, "path": "eval/insurance/tab_ddpm_concat/4", "path_prefix": "../../../../", "random_seed": 4, "single_model": "tab_ddpm_concat" }, "start_time": "2024-03-03T11:01:43.118224", "version": "2.5.0" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }