{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "982e76f5", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:11.074576Z", "iopub.status.busy": "2024-02-29T04:29:11.073762Z", "iopub.status.idle": "2024-02-29T04:29:11.113694Z", "shell.execute_reply": "2024-02-29T04:29:11.112810Z" }, "papermill": { "duration": 0.055415, "end_time": "2024-02-29T04:29:11.115700", "exception": false, "start_time": "2024-02-29T04:29:11.060285", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import joblib\n", "\n", "#joblib.parallel_backend(\"threading\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "675f0b41", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:11.141741Z", "iopub.status.busy": "2024-02-29T04:29:11.140838Z", "iopub.status.idle": "2024-02-29T04:29:11.148404Z", "shell.execute_reply": "2024-02-29T04:29:11.147580Z" }, "papermill": { "duration": 0.02269, "end_time": "2024-02-29T04:29:11.150466", "exception": false, "start_time": "2024-02-29T04:29:11.127776", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'\\n%cd /kaggle/working\\n#!git clone https://github.com/R-N/ml-utility-loss\\n%cd ml-utility-loss\\n!git pull\\n#!pip install .\\n!pip install . --no-deps --force-reinstall --upgrade\\n#'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", "%cd /kaggle/working\n", "#!git clone https://github.com/R-N/ml-utility-loss\n", "%cd ml-utility-loss\n", "!git pull\n", "#!pip install .\n", "!pip install . --no-deps --force-reinstall --upgrade\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "5ae30f5c", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:11.174255Z", "iopub.status.busy": "2024-02-29T04:29:11.173792Z", "iopub.status.idle": "2024-02-29T04:29:11.177803Z", "shell.execute_reply": "2024-02-29T04:29:11.176979Z" }, "papermill": { "duration": 0.018365, "end_time": "2024-02-29T04:29:11.179910", "exception": false, "start_time": "2024-02-29T04:29:11.161545", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "plt.rcParams['figure.figsize'] = [3,3]" ] }, { "cell_type": "code", "execution_count": 4, "id": "9f42c810", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:11.203700Z", "iopub.status.busy": "2024-02-29T04:29:11.203202Z", "iopub.status.idle": "2024-02-29T04:29:11.207400Z", "shell.execute_reply": "2024-02-29T04:29:11.206583Z" }, "executionInfo": { "elapsed": 678, "status": "ok", "timestamp": 1696841022168, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "ns5hFcVL2yvs", "papermill": { "duration": 0.018933, "end_time": "2024-02-29T04:29:11.209878", "exception": false, "start_time": "2024-02-29T04:29:11.190945", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "datasets = [\n", " \"insurance\",\n", " \"treatment\",\n", " \"contraceptive\"\n", "]\n", "\n", "study_dir = \"./\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "85d0c8ce", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:11.234360Z", "iopub.status.busy": "2024-02-29T04:29:11.233845Z", "iopub.status.idle": "2024-02-29T04:29:11.239162Z", "shell.execute_reply": "2024-02-29T04:29:11.238293Z" }, "papermill": { "duration": 0.019487, "end_time": "2024-02-29T04:29:11.241190", "exception": false, "start_time": "2024-02-29T04:29:11.221703", "status": "completed" }, "tags": [ "parameters" ] }, "outputs": [], "source": [ "#Parameters\n", "import os\n", "\n", "path_prefix = \"../../../../\"\n", "\n", "dataset_dir = os.path.join(path_prefix, \"ml-utility-loss/datasets\")\n", "dataset_name = \"treatment\"\n", "model_name=\"ml_utility_2\"\n", "models = [\"tvae\", \"realtabformer\", \"lct_gan\", \"tab_ddpm_concat\"]\n", "single_model = \"lct_gan\"\n", "random_seed = 42\n", "gp = True\n", "gp_multiply = True\n", "folder = \"eval\"\n", "debug = False\n", "path = None" ] }, { "cell_type": "code", "execution_count": 6, "id": "e997d4e6", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:11.266264Z", "iopub.status.busy": "2024-02-29T04:29:11.266000Z", "iopub.status.idle": "2024-02-29T04:29:11.270684Z", "shell.execute_reply": "2024-02-29T04:29:11.269865Z" }, "papermill": { "duration": 0.019149, "end_time": "2024-02-29T04:29:11.272484", "exception": false, "start_time": "2024-02-29T04:29:11.253335", "status": "completed" }, "tags": [ "injected-parameters" ] }, "outputs": [], "source": [ "# Parameters\n", "dataset = \"treatment\"\n", "dataset_name = \"treatment\"\n", "single_model = \"tvae\"\n", "gp = False\n", "gp_multiply = False\n", "random_seed = 2\n", "debug = False\n", "folder = \"eval\"\n", "path_prefix = \"../../../../\"\n", "path = \"eval/treatment/tvae/2\"\n" ] }, { "cell_type": "code", "execution_count": null, "id": "bd7c02d6", "metadata": { "papermill": { "duration": 0.011209, "end_time": "2024-02-29T04:29:11.294824", "exception": false, "start_time": "2024-02-29T04:29:11.283615", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 7, "id": "5f45b1d0", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:11.318911Z", "iopub.status.busy": "2024-02-29T04:29:11.318297Z", "iopub.status.idle": "2024-02-29T04:29:11.327653Z", "shell.execute_reply": "2024-02-29T04:29:11.326895Z" }, "executionInfo": { "elapsed": 7, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "UdvXYv3c3LXy", "papermill": { "duration": 0.023753, "end_time": "2024-02-29T04:29:11.329797", "exception": false, "start_time": "2024-02-29T04:29:11.306044", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working\n", "/kaggle/working/eval/treatment/tvae/2\n" ] } ], "source": [ "from pathlib import Path\n", "import os\n", "\n", "%cd /kaggle/working/\n", "\n", "if path is None:\n", " path = os.path.join(folder, dataset_name, single_model, random_seed)\n", "Path(path).mkdir(parents=True, exist_ok=True)\n", "\n", "%cd {path}" ] }, { "cell_type": "code", "execution_count": 8, "id": "f85bf540", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:11.353602Z", "iopub.status.busy": "2024-02-29T04:29:11.353337Z", "iopub.status.idle": "2024-02-29T04:29:13.595630Z", "shell.execute_reply": "2024-02-29T04:29:13.594714Z" }, "papermill": { "duration": 2.256654, "end_time": "2024-02-29T04:29:13.597744", "exception": false, "start_time": "2024-02-29T04:29:11.341090", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Set seed to \n" ] } ], "source": [ "from ml_utility_loss.util import seed\n", "if single_model:\n", " model_name=f\"{model_name}_{single_model}\"\n", "if random_seed is not None:\n", " seed(random_seed)\n", " print(\"Set seed to\", seed)" ] }, { "cell_type": "code", "execution_count": 9, "id": "8489feae", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:13.624599Z", "iopub.status.busy": "2024-02-29T04:29:13.623911Z", "iopub.status.idle": "2024-02-29T04:29:13.639836Z", "shell.execute_reply": "2024-02-29T04:29:13.639010Z" }, "papermill": { "duration": 0.031202, "end_time": "2024-02-29T04:29:13.641867", "exception": false, "start_time": "2024-02-29T04:29:13.610665", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import json\n", "import os\n", "\n", "df = pd.read_csv(os.path.join(dataset_dir, f\"{dataset_name}.csv\"))\n", "with open(os.path.join(dataset_dir, f\"{dataset_name}.json\")) as f:\n", " info = json.load(f)" ] }, { "cell_type": "code", "execution_count": 10, "id": "debcc684", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:13.666451Z", "iopub.status.busy": "2024-02-29T04:29:13.666195Z", "iopub.status.idle": "2024-02-29T04:29:13.673384Z", "shell.execute_reply": "2024-02-29T04:29:13.672673Z" }, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "Vrl2QkoV3o_8", "papermill": { "duration": 0.021525, "end_time": "2024-02-29T04:29:13.675216", "exception": false, "start_time": "2024-02-29T04:29:13.653691", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "task = info[\"task\"]\n", "target = info[\"target\"]\n", "cat_features = info[\"cat_features\"]\n", "mixed_features = info[\"mixed_features\"]\n", "longtail_features = info[\"longtail_features\"]\n", "integer_features = info[\"integer_features\"]\n", "\n", "test = df.sample(frac=0.2, random_state=42)\n", "train = df[~df.index.isin(test.index)]" ] }, { "cell_type": "code", "execution_count": 11, "id": "7538184a", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:13.699417Z", "iopub.status.busy": "2024-02-29T04:29:13.699140Z", "iopub.status.idle": "2024-02-29T04:29:13.800471Z", "shell.execute_reply": "2024-02-29T04:29:13.799702Z" }, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "TilUuFk9vqMb", "papermill": { "duration": 0.116049, "end_time": "2024-02-29T04:29:13.802564", "exception": false, "start_time": "2024-02-29T04:29:13.686515", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import ml_utility_loss.synthesizers.tab_ddpm.params as TAB_DDPM_PARAMS\n", "import ml_utility_loss.synthesizers.lct_gan.params as LCT_GAN_PARAMS\n", "import ml_utility_loss.synthesizers.realtabformer.params as RTF_PARAMS\n", "from ml_utility_loss.synthesizers.realtabformer.params.default import GPT2_PARAMS, REALTABFORMER_PARAMS\n", "from ml_utility_loss.util import filter_dict_2, filter_dict\n", "\n", "tab_ddpm_params = getattr(TAB_DDPM_PARAMS, dataset_name).BEST\n", "lct_gan_params = getattr(LCT_GAN_PARAMS, dataset_name).BEST\n", "lct_ae_params = filter_dict_2(lct_gan_params, LCT_GAN_PARAMS.default.AE_PARAMS)\n", "rtf_params = getattr(RTF_PARAMS, dataset_name).BEST\n", "rtf_params = filter_dict(rtf_params, REALTABFORMER_PARAMS)\n", "\n", "lct_ae_embedding_size=lct_gan_params[\"embedding_size\"]\n", "tab_ddpm_normalization=\"quantile\"\n", "tab_ddpm_cat_encoding=tab_ddpm_params[\"cat_encoding\"]\n", "#tab_ddpm_cat_encoding=\"one-hot\"\n", "tab_ddpm_y_policy=\"default\"\n", "tab_ddpm_is_y_cond=True" ] }, { "cell_type": "code", "execution_count": 12, "id": "cca61838", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:13.829096Z", "iopub.status.busy": "2024-02-29T04:29:13.828793Z", "iopub.status.idle": "2024-02-29T04:29:18.514659Z", "shell.execute_reply": "2024-02-29T04:29:18.513888Z" }, "executionInfo": { "elapsed": 3113, "status": "ok", "timestamp": 1696841025277, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "7Abt8nStvr9Z", "papermill": { "duration": 4.702039, "end_time": "2024-02-29T04:29:18.517136", "exception": false, "start_time": "2024-02-29T04:29:13.815097", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-02-29 04:29:16.137890: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-02-29 04:29:16.137945: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-02-29 04:29:16.139721: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_lct_ae\n", "\n", "lct_ae = load_lct_ae(\n", " dataset_name=dataset_name,\n", " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", " model_name=\"lct_ae\",\n", " df_name=\"df\",\n", ")\n", "lct_ae = None" ] }, { "cell_type": "code", "execution_count": 13, "id": "6f83b7b6", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:18.543026Z", "iopub.status.busy": "2024-02-29T04:29:18.542205Z", "iopub.status.idle": "2024-02-29T04:29:18.548277Z", "shell.execute_reply": "2024-02-29T04:29:18.547581Z" }, "papermill": { "duration": 0.021126, "end_time": "2024-02-29T04:29:18.550299", "exception": false, "start_time": "2024-02-29T04:29:18.529173", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_rtf_embed\n", "\n", "rtf_embed = load_rtf_embed(\n", " dataset_name=dataset_name,\n", " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", " model_name=\"realtabformer\",\n", " df_name=\"df\",\n", " ckpt_type=\"best-disc-model\"\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "0026de74", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:18.576519Z", "iopub.status.busy": "2024-02-29T04:29:18.576245Z", "iopub.status.idle": "2024-02-29T04:29:40.762246Z", "shell.execute_reply": "2024-02-29T04:29:40.760891Z" }, "executionInfo": { "elapsed": 20137, "status": "ok", "timestamp": 1696841045408, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "tbaguWxAvtPi", "papermill": { "duration": 22.202272, "end_time": "2024-02-29T04:29:40.764735", "exception": false, "start_time": "2024-02-29T04:29:18.562463", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/1 [00:00 torch.Tensor>,\n", " 'single_model': True,\n", " 'bias': True,\n", " 'bias_final': True,\n", " 'pma_ffn_mode': 'shared',\n", " 'patience': 10,\n", " 'inds_init_mode': 'torch',\n", " 'grad_clip': 0.8,\n", " 'gradient_penalty_mode': {'gradient_penalty': False,\n", " 'calc_grad_m': False,\n", " 'avg_non_role_model_m': False,\n", " 'inverse_avg_non_role_model_m': False},\n", " 'synth_data': 2,\n", " 'dataset_size': 2048,\n", " 'batch_size': 4,\n", " 'epochs': 100,\n", " 'lr_mul': 0.04,\n", " 'n_warmup_steps': 220,\n", " 'Optim': torch_optimizer.diffgrad.DiffGrad,\n", " 'loss_balancer_beta': 0.73,\n", " 'loss_balancer_r': 0.94,\n", " 'fixed_role_model': 'tvae',\n", " 'd_model': 512,\n", " 'attn_activation': ml_utility_loss.activations.LeakyHardsigmoid,\n", " 'tf_d_inner': 512,\n", " 'tf_n_layers_enc': 4,\n", " 'tf_n_head': 64,\n", " 'tf_activation': ml_utility_loss.activations.LeakyHardtanh,\n", " 'tf_activation_final': ml_utility_loss.activations.LeakyHardtanh,\n", " 'ada_d_hid': 1024,\n", " 'ada_n_layers': 7,\n", " 'ada_activation': torch.nn.modules.activation.SELU,\n", " 'ada_activation_final': ml_utility_loss.activations.LeakyHardsigmoid,\n", " 'head_d_hid': 128,\n", " 'head_n_layers': 8,\n", " 'head_n_head': 64,\n", " 'head_activation': ml_utility_loss.activations.LeakyHardsigmoid,\n", " 'head_activation_final': ml_utility_loss.activations.LeakyHardsigmoid,\n", " 'models': ['tvae'],\n", " 'max_seconds': 3600,\n", " 'tf_lora': False,\n", " 'tf_num_inds': 64,\n", " 'ada_n_seeds': 0,\n", " 'gradient_penalty_kwargs': {'mag_loss': True,\n", " 'mse_mag': False,\n", " 'mag_corr': False,\n", " 'seq_mag': False,\n", " 'cos_loss': False,\n", " 'mag_corr_kwargs': {'only_sign': False},\n", " 'cos_loss_kwargs': {'only_sign': True, 'cos_matrix': False},\n", " 'mse_mag_kwargs': {'target': 0.2, 'multiply': False}}}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import ml_utility_loss.loss_learning.estimator.params2 as PARAMS\n", "from ml_utility_loss.tuning import map_parameters\n", "from ml_utility_loss.loss_learning.estimator.params.default import update_param_space, update_param_space_2\n", "import wandb\n", "\n", "#\"\"\"\n", "param_space = {\n", " **getattr(PARAMS, dataset_name).PARAM_SPACE,\n", "}\n", "params = {\n", " **getattr(PARAMS, dataset_name).BEST,\n", "}\n", "if gp:\n", " params[\"gradient_penalty_mode\"] = \"ALL\"\n", " params[\"mse_mag\"] = True\n", " if gp_multiply:\n", " params[\"mse_mag_multiply\"] = True\n", " params[\"mse_mag_target\"] = 1.0\n", " else:\n", " params[\"mse_mag_multiply\"] = False\n", " params[\"mse_mag_target\"] = 0.1\n", "else:\n", " params[\"gradient_penalty_mode\"] = \"NONE\"\n", " params[\"mse_mag\"] = False\n", "params[\"single_model\"] = False\n", "if models:\n", " params[\"models\"] = models\n", "if single_model:\n", " params[\"fixed_role_model\"] = single_model\n", " params[\"single_model\"] = True\n", " params[\"models\"] = [single_model]\n", "if params[\"fixed_role_model\"] == \"realtabformer\" and dataset_name == \"treatment\":\n", " params[\"batch_size\"] = 2\n", "params[\"max_seconds\"] = 3600\n", "params[\"patience\"] = 10\n", "params[\"epochs\"] = 100\n", "if debug:\n", " params[\"epochs\"] = 2\n", "with open(\"params.json\", \"w\") as f:\n", " json.dump(params, f)\n", "params = map_parameters(params, param_space=param_space)\n", "params" ] }, { "cell_type": "code", "execution_count": 19, "id": "a48bd9e9", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:41.281815Z", "iopub.status.busy": "2024-02-29T04:29:41.281494Z", "iopub.status.idle": "2024-02-29T04:29:41.349146Z", "shell.execute_reply": "2024-02-29T04:29:41.348155Z" }, "papermill": { "duration": 0.083785, "end_time": "2024-02-29T04:29:41.351100", "exception": false, "start_time": "2024-02-29T04:29:41.267315", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "load_dataset_3_factory 2\n", "Caching in ../../../../treatment/_cache/tvae/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_2/treatment [80, 20]\n", "Caching in ../../../../treatment/_cache4/tvae/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_4/treatment [80, 20]\n", "Caching in ../../../../treatment/_cache5/tvae/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_5/treatment [160, 40]\n", "[320, 80]\n", "[320, 80]\n" ] } ], "source": [ "train_set, val_set = datasetsn(model=params[\"fixed_role_model\"], synth_data=params[\"synth_data\"])" ] }, { "cell_type": "code", "execution_count": 20, "id": "2fcb1418", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "execution": { "iopub.execute_input": "2024-02-29T04:29:41.380722Z", "iopub.status.busy": "2024-02-29T04:29:41.380408Z", "iopub.status.idle": "2024-02-29T04:29:41.922049Z", "shell.execute_reply": "2024-02-29T04:29:41.921052Z" }, "executionInfo": { "elapsed": 396850, "status": "error", "timestamp": 1696841446059, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "_bt1MQc5kpSk", "outputId": "01c1d3e5-ac64-461d-835a-b76f4a66e6d6", "papermill": { "duration": 0.558849, "end_time": "2024-02-29T04:29:41.924289", "exception": false, "start_time": "2024-02-29T04:29:41.365440", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating model of type \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[*] Embedding False True\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "['tvae'] 1\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.model.pipeline import remove_non_model_params\n", "from ml_utility_loss.loss_learning.estimator.pipeline import create_model\n", "from ml_utility_loss.util import filter_dict, clear_memory\n", "\n", "clear_memory()\n", "\n", "params2 = remove_non_model_params(params)\n", "adapters = filter_dict(preprocessor.adapter_sizes, params[\"models\"])\n", "\n", "model = create_model(\n", " adapters=adapters,\n", " #Body=\"twin_encoder\",\n", " **params2,\n", ")\n", "#cf.apply_weight_standardization(model, n_last_layers_ignore=0)\n", "print(model.models, len(model.adapters))" ] }, { "cell_type": "code", "execution_count": 21, "id": "938f94fc", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:41.953848Z", "iopub.status.busy": "2024-02-29T04:29:41.953527Z", "iopub.status.idle": "2024-02-29T04:29:41.957492Z", "shell.execute_reply": "2024-02-29T04:29:41.956677Z" }, "papermill": { "duration": 0.021605, "end_time": "2024-02-29T04:29:41.959458", "exception": false, "start_time": "2024-02-29T04:29:41.937853", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "study_name=f\"{model_name}_{dataset_name}\"" ] }, { "cell_type": "code", "execution_count": 22, "id": "12fb613e", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:41.989740Z", "iopub.status.busy": "2024-02-29T04:29:41.989123Z", "iopub.status.idle": "2024-02-29T04:29:41.996421Z", "shell.execute_reply": "2024-02-29T04:29:41.995567Z" }, "papermill": { "duration": 0.024554, "end_time": "2024-02-29T04:29:41.998337", "exception": false, "start_time": "2024-02-29T04:29:41.973783", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "18701313" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "count_parameters(model)" ] }, { "cell_type": "code", "execution_count": 23, "id": "bd386e57", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:42.026419Z", "iopub.status.busy": "2024-02-29T04:29:42.026154Z", "iopub.status.idle": "2024-02-29T04:29:42.124245Z", "shell.execute_reply": "2024-02-29T04:29:42.123367Z" }, "papermill": { "duration": 0.114636, "end_time": "2024-02-29T04:29:42.126241", "exception": false, "start_time": "2024-02-29T04:29:42.011605", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "========================================================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "========================================================================================================================\n", "MLUtilitySingle [2, 2648, 95] --\n", "├─Adapter: 1-1 [2, 2648, 95] --\n", "│ └─Sequential: 2-1 [2, 2648, 512] --\n", "│ │ └─FeedForward: 3-1 [2, 2648, 1024] --\n", "│ │ │ └─Linear: 4-1 [2, 2648, 1024] 98,304\n", "│ │ │ └─SELU: 4-2 [2, 2648, 1024] --\n", "│ │ └─FeedForward: 3-2 [2, 2648, 1024] --\n", "│ │ │ └─Linear: 4-3 [2, 2648, 1024] 1,049,600\n", "│ │ │ └─SELU: 4-4 [2, 2648, 1024] --\n", "│ │ └─FeedForward: 3-3 [2, 2648, 1024] --\n", "│ │ │ └─Linear: 4-5 [2, 2648, 1024] 1,049,600\n", "│ │ │ └─SELU: 4-6 [2, 2648, 1024] --\n", "│ │ └─FeedForward: 3-4 [2, 2648, 1024] --\n", "│ │ │ └─Linear: 4-7 [2, 2648, 1024] 1,049,600\n", "│ │ │ └─SELU: 4-8 [2, 2648, 1024] --\n", "│ │ └─FeedForward: 3-5 [2, 2648, 1024] --\n", "│ │ │ └─Linear: 4-9 [2, 2648, 1024] 1,049,600\n", "│ │ │ └─SELU: 4-10 [2, 2648, 1024] --\n", "│ │ └─FeedForward: 3-6 [2, 2648, 1024] --\n", "│ │ │ └─Linear: 4-11 [2, 2648, 1024] 1,049,600\n", "│ │ │ └─SELU: 4-12 [2, 2648, 1024] --\n", "│ │ └─FeedForward: 3-7 [2, 2648, 512] --\n", "│ │ │ └─Linear: 4-13 [2, 2648, 512] 524,800\n", "│ │ │ └─LeakyHardsigmoid: 4-14 [2, 2648, 512] --\n", "├─Adapter: 1-2 [2, 661, 95] (recursive)\n", "│ └─Sequential: 2-2 [2, 661, 512] (recursive)\n", "│ │ └─FeedForward: 3-8 [2, 661, 1024] (recursive)\n", "│ │ │ └─Linear: 4-15 [2, 661, 1024] (recursive)\n", "│ │ │ └─SELU: 4-16 [2, 661, 1024] --\n", "│ │ └─FeedForward: 3-9 [2, 661, 1024] (recursive)\n", "│ │ │ └─Linear: 4-17 [2, 661, 1024] (recursive)\n", "│ │ │ └─SELU: 4-18 [2, 661, 1024] --\n", "│ │ └─FeedForward: 3-10 [2, 661, 1024] (recursive)\n", "│ │ │ └─Linear: 4-19 [2, 661, 1024] (recursive)\n", "│ │ │ └─SELU: 4-20 [2, 661, 1024] --\n", "│ │ └─FeedForward: 3-11 [2, 661, 1024] (recursive)\n", "│ │ │ └─Linear: 4-21 [2, 661, 1024] (recursive)\n", "│ │ │ └─SELU: 4-22 [2, 661, 1024] --\n", "│ │ └─FeedForward: 3-12 [2, 661, 1024] (recursive)\n", "│ │ │ └─Linear: 4-23 [2, 661, 1024] (recursive)\n", "│ │ │ └─SELU: 4-24 [2, 661, 1024] --\n", "│ │ └─FeedForward: 3-13 [2, 661, 1024] (recursive)\n", "│ │ │ └─Linear: 4-25 [2, 661, 1024] (recursive)\n", "│ │ │ └─SELU: 4-26 [2, 661, 1024] --\n", "│ │ └─FeedForward: 3-14 [2, 661, 512] (recursive)\n", "│ │ │ └─Linear: 4-27 [2, 661, 512] (recursive)\n", "│ │ │ └─LeakyHardsigmoid: 4-28 [2, 661, 512] --\n", "├─TwinEncoder: 1-3 [2, 8192] --\n", "│ └─Encoder: 2-3 [2, 16, 512] --\n", "│ │ └─ModuleList: 3-16 -- (recursive)\n", "│ │ │ └─EncoderLayer: 4-29 [2, 2648, 512] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-1 [2, 2648, 512] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-1 [2, 64, 512] 32,768\n", "│ │ │ │ │ └─MultiHeadAttention: 6-2 [2, 64, 512] --\n", "│ │ │ │ │ │ └─Linear: 7-1 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-2 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-3 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-4 [2, 64, 64, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-1 [2, 64, 64, 2648] --\n", "│ │ │ │ │ │ └─Linear: 7-5 [2, 64, 512] 262,656\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-6 [2, 64, 512] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-3 [2, 2648, 512] --\n", "│ │ │ │ │ │ └─Linear: 7-7 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-8 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-9 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-10 [2, 64, 2648, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-2 [2, 64, 2648, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-11 [2, 2648, 512] 262,656\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-12 [2, 2648, 512] --\n", "│ │ │ │ └─DoubleFeedForward: 5-2 [2, 2648, 512] --\n", "│ │ │ │ │ └─Linear: 6-4 [2, 2648, 512] 262,656\n", "│ │ │ │ │ └─LeakyHardtanh: 6-5 [2, 2648, 512] --\n", "│ │ │ │ │ └─Linear: 6-6 [2, 2648, 512] 262,656\n", "│ │ │ └─EncoderLayer: 4-30 [2, 2648, 512] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-3 [2, 2648, 512] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-7 [2, 64, 512] 32,768\n", "│ │ │ │ │ └─MultiHeadAttention: 6-8 [2, 64, 512] --\n", "│ │ │ │ │ │ └─Linear: 7-13 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-14 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-15 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-16 [2, 64, 64, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-3 [2, 64, 64, 2648] --\n", "│ │ │ │ │ │ └─Linear: 7-17 [2, 64, 512] 262,656\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-18 [2, 64, 512] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-9 [2, 2648, 512] --\n", "│ │ │ │ │ │ └─Linear: 7-19 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-20 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-21 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-22 [2, 64, 2648, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-4 [2, 64, 2648, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-23 [2, 2648, 512] 262,656\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-24 [2, 2648, 512] --\n", "│ │ │ │ └─DoubleFeedForward: 5-4 [2, 2648, 512] --\n", "│ │ │ │ │ └─Linear: 6-10 [2, 2648, 512] 262,656\n", "│ │ │ │ │ └─LeakyHardtanh: 6-11 [2, 2648, 512] --\n", "│ │ │ │ │ └─Linear: 6-12 [2, 2648, 512] 262,656\n", "│ │ │ └─EncoderLayer: 4-31 [2, 2648, 512] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-5 [2, 2648, 512] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-13 [2, 64, 512] 32,768\n", "│ │ │ │ │ └─MultiHeadAttention: 6-14 [2, 64, 512] --\n", "│ │ │ │ │ │ └─Linear: 7-25 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-26 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-27 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-28 [2, 64, 64, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-5 [2, 64, 64, 2648] --\n", "│ │ │ │ │ │ └─Linear: 7-29 [2, 64, 512] 262,656\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-30 [2, 64, 512] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-15 [2, 2648, 512] --\n", "│ │ │ │ │ │ └─Linear: 7-31 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-32 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-33 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-34 [2, 64, 2648, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-6 [2, 64, 2648, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-35 [2, 2648, 512] 262,656\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-36 [2, 2648, 512] --\n", "│ │ │ │ └─DoubleFeedForward: 5-6 [2, 2648, 512] --\n", "│ │ │ │ │ └─Linear: 6-16 [2, 2648, 512] 262,656\n", "│ │ │ │ │ └─LeakyHardtanh: 6-17 [2, 2648, 512] --\n", "│ │ │ │ │ └─Linear: 6-18 [2, 2648, 512] 262,656\n", "│ │ │ └─EncoderLayer: 4-32 [2, 16, 512] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-7 [2, 2648, 512] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-19 [2, 64, 512] 32,768\n", "│ │ │ │ │ └─MultiHeadAttention: 6-20 [2, 64, 512] --\n", "│ │ │ │ │ │ └─Linear: 7-37 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-38 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-39 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-40 [2, 64, 64, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-7 [2, 64, 64, 2648] --\n", "│ │ │ │ │ │ └─Linear: 7-41 [2, 64, 512] 262,656\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-42 [2, 64, 512] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-21 [2, 2648, 512] --\n", "│ │ │ │ │ │ └─Linear: 7-43 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-44 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-45 [2, 64, 512] 262,144\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-46 [2, 64, 2648, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-8 [2, 64, 2648, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-47 [2, 2648, 512] 262,656\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-48 [2, 2648, 512] --\n", "│ │ │ │ └─DoubleFeedForward: 5-8 [2, 2648, 512] --\n", "│ │ │ │ │ └─Linear: 6-22 [2, 2648, 512] 262,656\n", "│ │ │ │ │ └─LeakyHardtanh: 6-23 [2, 2648, 512] --\n", "│ │ │ │ │ └─Linear: 6-24 [2, 2648, 512] 262,656\n", "│ │ │ │ └─PoolingByMultiheadAttention: 5-9 [2, 16, 512] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-25 [2, 16, 512] 8,192\n", "│ │ │ │ │ └─SimpleMultiHeadAttention: 6-26 [2, 16, 512] --\n", "│ │ │ │ │ │ └─Linear: 7-49 [2, 16, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-50 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─Linear: 7-51 [2, 2648, 512] 262,144\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-52 [2, 64, 16, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-9 [2, 64, 16, 2648] --\n", "│ │ │ │ │ │ └─Linear: 7-53 [2, 16, 512] 262,656\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-54 [2, 16, 512] --\n", "│ │ │ │ └─DoubleFeedForward: 5-10 [2, 16, 512] (recursive)\n", "│ │ │ │ │ └─Linear: 6-27 [2, 16, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-28 [2, 16, 512] --\n", "│ │ │ │ │ └─Linear: 6-29 [2, 16, 512] (recursive)\n", "│ └─Encoder: 2-4 [2, 16, 512] (recursive)\n", "│ │ └─ModuleList: 3-16 -- (recursive)\n", "│ │ │ └─EncoderLayer: 4-33 [2, 661, 512] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-11 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-30 [2, 64, 512] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-31 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-55 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-56 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-57 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-58 [2, 64, 64, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-10 [2, 64, 64, 661] --\n", "│ │ │ │ │ │ └─Linear: 7-59 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-60 [2, 64, 512] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-32 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-61 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-62 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-63 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-64 [2, 64, 661, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-11 [2, 64, 661, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-65 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-66 [2, 661, 512] --\n", "│ │ │ │ └─DoubleFeedForward: 5-12 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─Linear: 6-33 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-34 [2, 661, 512] --\n", "│ │ │ │ │ └─Linear: 6-35 [2, 661, 512] (recursive)\n", "│ │ │ └─EncoderLayer: 4-34 [2, 661, 512] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-13 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-36 [2, 64, 512] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-37 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-67 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-68 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-69 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-70 [2, 64, 64, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-12 [2, 64, 64, 661] --\n", "│ │ │ │ │ │ └─Linear: 7-71 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-72 [2, 64, 512] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-38 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-73 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-74 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-75 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-76 [2, 64, 661, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-13 [2, 64, 661, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-77 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-78 [2, 661, 512] --\n", "│ │ │ │ └─DoubleFeedForward: 5-14 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─Linear: 6-39 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-40 [2, 661, 512] --\n", "│ │ │ │ │ └─Linear: 6-41 [2, 661, 512] (recursive)\n", "│ │ │ └─EncoderLayer: 4-35 [2, 661, 512] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-15 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-42 [2, 64, 512] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-43 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-79 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-80 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-81 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-82 [2, 64, 64, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-14 [2, 64, 64, 661] --\n", "│ │ │ │ │ │ └─Linear: 7-83 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-84 [2, 64, 512] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-44 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-85 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-86 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-87 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-88 [2, 64, 661, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-15 [2, 64, 661, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-89 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-90 [2, 661, 512] --\n", "│ │ │ │ └─DoubleFeedForward: 5-16 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─Linear: 6-45 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-46 [2, 661, 512] --\n", "│ │ │ │ │ └─Linear: 6-47 [2, 661, 512] (recursive)\n", "│ │ │ └─EncoderLayer: 4-36 [2, 16, 512] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-17 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-48 [2, 64, 512] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-49 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-91 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-92 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-93 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-94 [2, 64, 64, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-16 [2, 64, 64, 661] --\n", "│ │ │ │ │ │ └─Linear: 7-95 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-96 [2, 64, 512] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-50 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-97 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-98 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-99 [2, 64, 512] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-100 [2, 64, 661, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-17 [2, 64, 661, 64] --\n", "│ │ │ │ │ │ └─Linear: 7-101 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-102 [2, 661, 512] --\n", "│ │ │ │ └─DoubleFeedForward: 5-18 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─Linear: 6-51 [2, 661, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-52 [2, 661, 512] --\n", "│ │ │ │ │ └─Linear: 6-53 [2, 661, 512] (recursive)\n", "│ │ │ │ └─PoolingByMultiheadAttention: 5-19 [2, 16, 512] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-54 [2, 16, 512] (recursive)\n", "│ │ │ │ │ └─SimpleMultiHeadAttention: 6-55 [2, 16, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-103 [2, 16, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-104 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-105 [2, 661, 512] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-106 [2, 64, 16, 8] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-18 [2, 64, 16, 661] --\n", "│ │ │ │ │ │ └─Linear: 7-107 [2, 16, 512] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-108 [2, 16, 512] --\n", "│ │ │ │ └─DoubleFeedForward: 5-20 [2, 16, 512] (recursive)\n", "│ │ │ │ │ └─Linear: 6-56 [2, 16, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-57 [2, 16, 512] --\n", "│ │ │ │ │ └─Linear: 6-58 [2, 16, 512] (recursive)\n", "├─Head: 1-4 [2] --\n", "│ └─Sequential: 2-5 [2, 1] --\n", "│ │ └─FeedForward: 3-17 [2, 128] --\n", "│ │ │ └─Linear: 4-37 [2, 128] 1,048,704\n", "│ │ │ └─LeakyHardsigmoid: 4-38 [2, 128] --\n", "│ │ └─FeedForward: 3-18 [2, 128] --\n", "│ │ │ └─Linear: 4-39 [2, 128] 16,512\n", "│ │ │ └─LeakyHardsigmoid: 4-40 [2, 128] --\n", "│ │ └─FeedForward: 3-19 [2, 128] --\n", "│ │ │ └─Linear: 4-41 [2, 128] 16,512\n", "│ │ │ └─LeakyHardsigmoid: 4-42 [2, 128] --\n", "│ │ └─FeedForward: 3-20 [2, 128] --\n", "│ │ │ └─Linear: 4-43 [2, 128] 16,512\n", "│ │ │ └─LeakyHardsigmoid: 4-44 [2, 128] --\n", "│ │ └─FeedForward: 3-21 [2, 128] --\n", "│ │ │ └─Linear: 4-45 [2, 128] 16,512\n", "│ │ │ └─LeakyHardsigmoid: 4-46 [2, 128] --\n", "│ │ └─FeedForward: 3-22 [2, 128] --\n", "│ │ │ └─Linear: 4-47 [2, 128] 16,512\n", "│ │ │ └─LeakyHardsigmoid: 4-48 [2, 128] --\n", "│ │ └─FeedForward: 3-23 [2, 128] --\n", "│ │ │ └─Linear: 4-49 [2, 128] 16,512\n", "│ │ │ └─LeakyHardsigmoid: 4-50 [2, 128] --\n", "│ │ └─FeedForward: 3-24 [2, 1] --\n", "│ │ │ └─Linear: 4-51 [2, 1] 129\n", "│ │ │ └─LeakyHardsigmoid: 4-52 [2, 1] --\n", "========================================================================================================================\n", "Total params: 18,701,313\n", "Trainable params: 18,701,313\n", "Non-trainable params: 0\n", "Total mult-adds (M): 74.05\n", "========================================================================================================================\n", "Input size (MB): 2.51\n", "Forward/backward pass size (MB): 1079.48\n", "Params size (MB): 74.81\n", "Estimated Total Size (MB): 1156.80\n", "========================================================================================================================" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torchinfo import summary\n", "\n", "role_model = params[\"fixed_role_model\"]\n", "s = train_set[0][role_model]\n", "summary(model[role_model], input_size=((2, *s[0].shape), (2, *s[1].shape)), depth=9) # 8 max" ] }, { "cell_type": "code", "execution_count": 24, "id": "0f42c4d1", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T04:29:42.157988Z", "iopub.status.busy": "2024-02-29T04:29:42.157605Z", "iopub.status.idle": "2024-02-29T05:32:09.727771Z", "shell.execute_reply": "2024-02-29T05:32:09.726851Z" }, "papermill": { "duration": 3747.588598, "end_time": "2024-02-29T05:32:09.729982", "exception": false, "start_time": "2024-02-29T04:29:42.141384", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.16.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: W&B syncing is set to \u001b[1m`offline`\u001b[0m in this directory. \n", "\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb online`\u001b[0m or set \u001b[1mWANDB_MODE=online\u001b[0m to enable cloud syncing.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "g_loss_mul 0.1\n", "Epoch 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.25079594189301135, 'avg_role_model_std_loss': 172.3493912117849, 'avg_role_model_mean_pred_loss': 0.0756463885481935, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.25079594189301135, 'n_size': 320, 'n_batch': 80, 'duration': 100.33574557304382, 'duration_batch': 1.2541968196630477, 'duration_size': 0.31354920491576194, 'avg_pred_std': 0.0283103398049775}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.15434771333821118, 'avg_role_model_std_loss': 0.8378194279823219, 'avg_role_model_mean_pred_loss': 0.03735067891972328, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.15434771333821118, 'n_size': 80, 'n_batch': 20, 'duration': 19.603960514068604, 'duration_batch': 0.9801980257034302, 'duration_size': 0.24504950642585754, 'avg_pred_std': 0.15204731123521925}\n", "Epoch 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.03561886841698651, 'avg_role_model_std_loss': 0.3229346345896033, 'avg_role_model_mean_pred_loss': 0.003486672353001108, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.03561886841698651, 'n_size': 320, 'n_batch': 80, 'duration': 99.75122737884521, 'duration_batch': 1.2468903422355653, 'duration_size': 0.3117225855588913, 'avg_pred_std': 0.2135042036534287}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007464131528354301, 'avg_role_model_std_loss': 14.743032303131127, 'avg_role_model_mean_pred_loss': 0.0003255664299921697, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007464131528354301, 'n_size': 80, 'n_batch': 20, 'duration': 19.27774691581726, 'duration_batch': 0.963887345790863, 'duration_size': 0.24097183644771575, 'avg_pred_std': 0.033036651482689194}\n", "Epoch 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.008308970414873329, 'avg_role_model_std_loss': 2.2638450761613056, 'avg_role_model_mean_pred_loss': 0.0002922931804504517, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008308970414873329, 'n_size': 320, 'n_batch': 80, 'duration': 99.78807187080383, 'duration_batch': 1.2473508983850479, 'duration_size': 0.31183772459626197, 'avg_pred_std': 0.19011376096532331}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.008792089688631677, 'avg_role_model_std_loss': 7.008411024302973, 'avg_role_model_mean_pred_loss': 0.00045460039846188566, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008792089688631677, 'n_size': 80, 'n_batch': 20, 'duration': 19.60298991203308, 'duration_batch': 0.9801494956016541, 'duration_size': 0.24503737390041352, 'avg_pred_std': 0.05645044087141286}\n", "Epoch 3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.008624525008644923, 'avg_role_model_std_loss': 0.4896122309531961, 'avg_role_model_mean_pred_loss': 0.00020031151738803265, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008624525008644923, 'n_size': 320, 'n_batch': 80, 'duration': 99.84555792808533, 'duration_batch': 1.2480694741010665, 'duration_size': 0.3120173685252666, 'avg_pred_std': 0.20240991505852435}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007710156342363916, 'avg_role_model_std_loss': 3.4492962674491308, 'avg_role_model_mean_pred_loss': 0.00016955623478978054, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007710156342363916, 'n_size': 80, 'n_batch': 20, 'duration': 19.69949173927307, 'duration_batch': 0.9849745869636536, 'duration_size': 0.2462436467409134, 'avg_pred_std': 0.057611686212476344}\n", "Epoch 4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.005540997025855176, 'avg_role_model_std_loss': 0.2870428302302061, 'avg_role_model_mean_pred_loss': 0.00010962202765298911, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.005540997025855176, 'n_size': 320, 'n_batch': 80, 'duration': 99.68231582641602, 'duration_batch': 1.2460289478302002, 'duration_size': 0.31150723695755006, 'avg_pred_std': 0.20185390445403756}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007617261353880167, 'avg_role_model_std_loss': 0.8548410554893053, 'avg_role_model_mean_pred_loss': 0.00021040458378775994, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007617261353880167, 'n_size': 80, 'n_batch': 20, 'duration': 19.60577392578125, 'duration_batch': 0.9802886962890625, 'duration_size': 0.24507217407226561, 'avg_pred_std': 0.06373736902605742}\n", "Epoch 5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.004228771507609963, 'avg_role_model_std_loss': 0.8442374373333109, 'avg_role_model_mean_pred_loss': 7.225186645620774e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.004228771507609963, 'n_size': 320, 'n_batch': 80, 'duration': 99.56437826156616, 'duration_batch': 1.244554728269577, 'duration_size': 0.31113868206739426, 'avg_pred_std': 0.19006720920442605}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007288631894334685, 'avg_role_model_std_loss': 1.333325219784001, 'avg_role_model_mean_pred_loss': 0.00021377515046261815, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007288631894334685, 'n_size': 80, 'n_batch': 20, 'duration': 19.61089253425598, 'duration_batch': 0.9805446267127991, 'duration_size': 0.24513615667819977, 'avg_pred_std': 0.05554712610319257}\n", "Epoch 6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.004288671186168358, 'avg_role_model_std_loss': 0.7424422502960226, 'avg_role_model_mean_pred_loss': 7.162387234466161e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.004288671186168358, 'n_size': 320, 'n_batch': 80, 'duration': 99.62058591842651, 'duration_batch': 1.2452573239803315, 'duration_size': 0.3113143309950829, 'avg_pred_std': 0.19745058890111977}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.006442523133591749, 'avg_role_model_std_loss': 0.5144731080438725, 'avg_role_model_mean_pred_loss': 0.00011468025654934877, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.006442523133591749, 'n_size': 80, 'n_batch': 20, 'duration': 19.81056571006775, 'duration_batch': 0.9905282855033875, 'duration_size': 0.24763207137584686, 'avg_pred_std': 0.06324401344172656}\n", "Epoch 7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.003485065951008437, 'avg_role_model_std_loss': 0.5401337196294321, 'avg_role_model_mean_pred_loss': 3.538603915114859e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.003485065951008437, 'n_size': 320, 'n_batch': 80, 'duration': 99.80776119232178, 'duration_batch': 1.2475970149040223, 'duration_size': 0.3118992537260056, 'avg_pred_std': 0.18830189779400824}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.009077594889095052, 'avg_role_model_std_loss': 0.3387085039643353, 'avg_role_model_mean_pred_loss': 0.0002953080845562894, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009077594889095052, 'n_size': 80, 'n_batch': 20, 'duration': 19.67119312286377, 'duration_batch': 0.9835596561431885, 'duration_size': 0.24588991403579713, 'avg_pred_std': 0.07019970323890448}\n", "Epoch 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0029438145053859444, 'avg_role_model_std_loss': 0.8214074499624902, 'avg_role_model_mean_pred_loss': 2.2453812508680688e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0029438145053859444, 'n_size': 320, 'n_batch': 80, 'duration': 99.6238522529602, 'duration_batch': 1.2452981531620027, 'duration_size': 0.31132453829050066, 'avg_pred_std': 0.19482076268177478}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007072782384057064, 'avg_role_model_std_loss': 0.48727775096755294, 'avg_role_model_mean_pred_loss': 0.00015469519749622407, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007072782384057064, 'n_size': 80, 'n_batch': 20, 'duration': 19.500911951065063, 'duration_batch': 0.9750455975532532, 'duration_size': 0.2437613993883133, 'avg_pred_std': 0.06650436315685511}\n", "Epoch 9\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.002503083477677137, 'avg_role_model_std_loss': 0.7746123696918176, 'avg_role_model_mean_pred_loss': 1.2513642564668465e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002503083477677137, 'n_size': 320, 'n_batch': 80, 'duration': 99.6650288105011, 'duration_batch': 1.2458128601312637, 'duration_size': 0.3114532150328159, 'avg_pred_std': 0.18602521782158873}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007042999230907299, 'avg_role_model_std_loss': 0.4210409534451173, 'avg_role_model_mean_pred_loss': 0.00016206004065456026, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007042999230907299, 'n_size': 80, 'n_batch': 20, 'duration': 19.607900619506836, 'duration_batch': 0.9803950309753418, 'duration_size': 0.24509875774383544, 'avg_pred_std': 0.06358127733692527}\n", "Epoch 10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0019445733821157774, 'avg_role_model_std_loss': 0.45887769680392837, 'avg_role_model_mean_pred_loss': 2.8811827562555402e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0019445733821157774, 'n_size': 320, 'n_batch': 80, 'duration': 99.54072690010071, 'duration_batch': 1.2442590862512588, 'duration_size': 0.3110647715628147, 'avg_pred_std': 0.17861799619859084}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007491035046405159, 'avg_role_model_std_loss': 0.4072774214367428, 'avg_role_model_mean_pred_loss': 0.0001684058567391844, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007491035046405159, 'n_size': 80, 'n_batch': 20, 'duration': 19.5883047580719, 'duration_batch': 0.979415237903595, 'duration_size': 0.24485380947589874, 'avg_pred_std': 0.0683793492615223}\n", "Epoch 11\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0016902872650462087, 'avg_role_model_std_loss': 0.14312844078152837, 'avg_role_model_mean_pred_loss': 1.2574293510679222e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0016902872650462087, 'n_size': 320, 'n_batch': 80, 'duration': 99.5855655670166, 'duration_batch': 1.2448195695877076, 'duration_size': 0.3112048923969269, 'avg_pred_std': 0.18321805561427026}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.00778989009122597, 'avg_role_model_std_loss': 0.30623174232314343, 'avg_role_model_mean_pred_loss': 0.0002285926662562332, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00778989009122597, 'n_size': 80, 'n_batch': 20, 'duration': 19.662875652313232, 'duration_batch': 0.9831437826156616, 'duration_size': 0.2457859456539154, 'avg_pred_std': 0.0688946488313377}\n", "Epoch 12\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0016234107402169685, 'avg_role_model_std_loss': 0.26942976858223344, 'avg_role_model_mean_pred_loss': 1.856327916652254e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0016234107402169685, 'n_size': 320, 'n_batch': 80, 'duration': 100.64648413658142, 'duration_batch': 1.2580810517072678, 'duration_size': 0.31452026292681695, 'avg_pred_std': 0.18938408511457966}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007789343649346847, 'avg_role_model_std_loss': 0.39892919776157215, 'avg_role_model_mean_pred_loss': 0.00019050486260980827, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007789343649346847, 'n_size': 80, 'n_batch': 20, 'duration': 19.436559677124023, 'duration_batch': 0.9718279838562012, 'duration_size': 0.2429569959640503, 'avg_pred_std': 0.06721605993807316}\n", "Epoch 13\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0014992240631727326, 'avg_role_model_std_loss': 0.32569619336183353, 'avg_role_model_mean_pred_loss': 1.7869830239917398e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0014992240631727326, 'n_size': 320, 'n_batch': 80, 'duration': 99.36593246459961, 'duration_batch': 1.242074155807495, 'duration_size': 0.31051853895187376, 'avg_pred_std': 0.19987344325636514}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007941817007667851, 'avg_role_model_std_loss': 0.3285603669361308, 'avg_role_model_mean_pred_loss': 0.00020670617296367766, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007941817007667851, 'n_size': 80, 'n_batch': 20, 'duration': 19.3479266166687, 'duration_batch': 0.967396330833435, 'duration_size': 0.24184908270835875, 'avg_pred_std': 0.06898661321029068}\n", "Epoch 14\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0015280315952168166, 'avg_role_model_std_loss': 0.1436633735797855, 'avg_role_model_mean_pred_loss': 2.658071664495944e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0015280315952168166, 'n_size': 320, 'n_batch': 80, 'duration': 99.75400042533875, 'duration_batch': 1.2469250053167342, 'duration_size': 0.31173125132918356, 'avg_pred_std': 0.189214165561134}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007610848201147746, 'avg_role_model_std_loss': 0.30814517063022323, 'avg_role_model_mean_pred_loss': 0.00022678597003369382, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007610848201147746, 'n_size': 80, 'n_batch': 20, 'duration': 19.45919370651245, 'duration_batch': 0.9729596853256226, 'duration_size': 0.24323992133140565, 'avg_pred_std': 0.06733945356681943}\n", "Epoch 15\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0012301572283377026, 'avg_role_model_std_loss': 0.10807322694710982, 'avg_role_model_mean_pred_loss': 5.3408418359659e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0012301572283377026, 'n_size': 320, 'n_batch': 80, 'duration': 99.28717947006226, 'duration_batch': 1.2410897433757782, 'duration_size': 0.31027243584394454, 'avg_pred_std': 0.19792793567758055}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.006984197727433639, 'avg_role_model_std_loss': 0.247640823103211, 'avg_role_model_mean_pred_loss': 0.0001665852697917726, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.006984197727433639, 'n_size': 80, 'n_batch': 20, 'duration': 19.52495002746582, 'duration_batch': 0.976247501373291, 'duration_size': 0.24406187534332274, 'avg_pred_std': 0.06778875123709441}\n", "Epoch 16\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0010839548700573686, 'avg_role_model_std_loss': 0.07997205620506662, 'avg_role_model_mean_pred_loss': 9.080777822098943e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0010839548700573686, 'n_size': 320, 'n_batch': 80, 'duration': 99.66356492042542, 'duration_batch': 1.2457945615053176, 'duration_size': 0.3114486403763294, 'avg_pred_std': 0.19800062356516718}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.008054383638955187, 'avg_role_model_std_loss': 0.28627982361469717, 'avg_role_model_mean_pred_loss': 0.00025097979236513577, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008054383638955187, 'n_size': 80, 'n_batch': 20, 'duration': 19.489768266677856, 'duration_batch': 0.9744884133338928, 'duration_size': 0.2436221033334732, 'avg_pred_std': 0.06728147398680448}\n", "Epoch 17\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0007286213950465026, 'avg_role_model_std_loss': 0.10312503655737952, 'avg_role_model_mean_pred_loss': 6.776996160624913e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0007286213950465026, 'n_size': 320, 'n_batch': 80, 'duration': 99.96071243286133, 'duration_batch': 1.2495089054107666, 'duration_size': 0.31237722635269166, 'avg_pred_std': 0.1800330831320025}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007334689510025782, 'avg_role_model_std_loss': 0.2440898734063012, 'avg_role_model_mean_pred_loss': 0.00019387806316268908, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007334689510025782, 'n_size': 80, 'n_batch': 20, 'duration': 19.55093002319336, 'duration_batch': 0.977546501159668, 'duration_size': 0.244386625289917, 'avg_pred_std': 0.06980508081614971}\n", "Epoch 18\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0007013096967625643, 'avg_role_model_std_loss': 0.0806437349208462, 'avg_role_model_mean_pred_loss': 3.612507871512266e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0007013096967625643, 'n_size': 320, 'n_batch': 80, 'duration': 99.62839913368225, 'duration_batch': 1.245354989171028, 'duration_size': 0.311338747292757, 'avg_pred_std': 0.193802969326498}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.008005985312775011, 'avg_role_model_std_loss': 0.25233456931382536, 'avg_role_model_mean_pred_loss': 0.00024166704105831326, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008005985312775011, 'n_size': 80, 'n_batch': 20, 'duration': 19.874733448028564, 'duration_batch': 0.9937366724014283, 'duration_size': 0.24843416810035707, 'avg_pred_std': 0.06946740932762623}\n", "Epoch 19\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.000650154861421015, 'avg_role_model_std_loss': 0.13836135070985306, 'avg_role_model_mean_pred_loss': 6.324771196684101e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.000650154861421015, 'n_size': 320, 'n_batch': 80, 'duration': 100.67735123634338, 'duration_batch': 1.2584668904542924, 'duration_size': 0.3146167226135731, 'avg_pred_std': 0.1874849540356081}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007276729341538157, 'avg_role_model_std_loss': 0.19176392750296145, 'avg_role_model_mean_pred_loss': 0.0001885933500119812, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007276729341538157, 'n_size': 80, 'n_batch': 20, 'duration': 19.5623676776886, 'duration_batch': 0.97811838388443, 'duration_size': 0.2445295959711075, 'avg_pred_std': 0.06984148817136884}\n", "Epoch 20\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0005832819898387243, 'avg_role_model_std_loss': 0.04374472918600412, 'avg_role_model_mean_pred_loss': 9.791867645731814e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0005832819898387243, 'n_size': 320, 'n_batch': 80, 'duration': 99.7720296382904, 'duration_batch': 1.24715037047863, 'duration_size': 0.3117875926196575, 'avg_pred_std': 0.19292465539183468}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007389668950054329, 'avg_role_model_std_loss': 0.19728227270163642, 'avg_role_model_mean_pred_loss': 0.00020410724985005512, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007389668950054329, 'n_size': 80, 'n_batch': 20, 'duration': 19.41769814491272, 'duration_batch': 0.970884907245636, 'duration_size': 0.242721226811409, 'avg_pred_std': 0.06944395890459418}\n", "Epoch 21\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0005352863459663126, 'avg_role_model_std_loss': 0.01976681658542425, 'avg_role_model_mean_pred_loss': 6.443602266463554e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0005352863459663126, 'n_size': 320, 'n_batch': 80, 'duration': 100.6177191734314, 'duration_batch': 1.2577214896678925, 'duration_size': 0.31443037241697314, 'avg_pred_std': 0.19984434806974605}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007991572895844002, 'avg_role_model_std_loss': 0.22150783375837904, 'avg_role_model_mean_pred_loss': 0.00023175869880640577, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007991572895844002, 'n_size': 80, 'n_batch': 20, 'duration': 20.159916162490845, 'duration_batch': 1.0079958081245421, 'duration_size': 0.25199895203113554, 'avg_pred_std': 0.07001060470938683}\n", "Epoch 22\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0005078368830425007, 'avg_role_model_std_loss': 0.02868203009257253, 'avg_role_model_mean_pred_loss': 8.534584292643494e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0005078368830425007, 'n_size': 320, 'n_batch': 80, 'duration': 99.94173312187195, 'duration_batch': 1.2492716640233994, 'duration_size': 0.31231791600584985, 'avg_pred_std': 0.19320008249487727}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007862234280037229, 'avg_role_model_std_loss': 0.2152756543153373, 'avg_role_model_mean_pred_loss': 0.0002432201600976569, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007862234280037229, 'n_size': 80, 'n_batch': 20, 'duration': 19.478534698486328, 'duration_batch': 0.9739267349243164, 'duration_size': 0.2434816837310791, 'avg_pred_std': 0.06986867655068636}\n", "Epoch 23\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.000520391069062498, 'avg_role_model_std_loss': 0.03020876141404938, 'avg_role_model_mean_pred_loss': 4.458749571339605e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.000520391069062498, 'n_size': 320, 'n_batch': 80, 'duration': 100.07614707946777, 'duration_batch': 1.2509518384933471, 'duration_size': 0.3127379596233368, 'avg_pred_std': 0.1921757934615016}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007205282323411666, 'avg_role_model_std_loss': 0.20713103588941523, 'avg_role_model_mean_pred_loss': 0.0001827256362942009, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007205282323411666, 'n_size': 80, 'n_batch': 20, 'duration': 19.527496099472046, 'duration_batch': 0.9763748049736023, 'duration_size': 0.24409370124340057, 'avg_pred_std': 0.06974663501605391}\n", "Epoch 24\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0005546999067291836, 'avg_role_model_std_loss': 0.027485956521196276, 'avg_role_model_mean_pred_loss': 6.895572416364425e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0005546999067291836, 'n_size': 320, 'n_batch': 80, 'duration': 99.80755043029785, 'duration_batch': 1.247594380378723, 'duration_size': 0.31189859509468076, 'avg_pred_std': 0.1784662834368646}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.007224415110249538, 'avg_role_model_std_loss': 0.17986427203068162, 'avg_role_model_mean_pred_loss': 0.00017955067020931638, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007224415110249538, 'n_size': 80, 'n_batch': 20, 'duration': 19.42549467086792, 'duration_batch': 0.971274733543396, 'duration_size': 0.242818683385849, 'avg_pred_std': 0.0726472232490778}\n", "Epoch 25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0005678469070971914, 'avg_role_model_std_loss': 0.0783861569140182, 'avg_role_model_mean_pred_loss': 2.3522705448089696e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0005678469070971914, 'n_size': 320, 'n_batch': 80, 'duration': 99.91300010681152, 'duration_batch': 1.248912501335144, 'duration_size': 0.312228125333786, 'avg_pred_std': 0.18876876458525657}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.008989003312308341, 'avg_role_model_std_loss': 0.17780489571450744, 'avg_role_model_mean_pred_loss': 0.00034604211847950596, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008989003312308341, 'n_size': 80, 'n_batch': 20, 'duration': 19.550805807113647, 'duration_batch': 0.9775402903556824, 'duration_size': 0.2443850725889206, 'avg_pred_std': 0.0718118923716247}\n", "Epoch 26\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0004598389233777311, 'avg_role_model_std_loss': 0.03701534456806623, 'avg_role_model_mean_pred_loss': 6.740662266656723e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0004598389233777311, 'n_size': 320, 'n_batch': 80, 'duration': 99.83816456794739, 'duration_batch': 1.2479770570993423, 'duration_size': 0.31199426427483556, 'avg_pred_std': 0.19492796149570496}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.008001866593258456, 'avg_role_model_std_loss': 0.1477779638089487, 'avg_role_model_mean_pred_loss': 0.0002354497060985228, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008001866593258456, 'n_size': 80, 'n_batch': 20, 'duration': 19.476824283599854, 'duration_batch': 0.9738412141799927, 'duration_size': 0.24346030354499817, 'avg_pred_std': 0.07350850850343704}\n", "Epoch 27\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0004317075494100209, 'avg_role_model_std_loss': 0.016361890759745458, 'avg_role_model_mean_pred_loss': 3.335215542313308e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0004317075494100209, 'n_size': 320, 'n_batch': 80, 'duration': 99.95542669296265, 'duration_batch': 1.2494428336620331, 'duration_size': 0.3123607084155083, 'avg_pred_std': 0.18771503504831344}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.008291181290405802, 'avg_role_model_std_loss': 0.19831402401805462, 'avg_role_model_mean_pred_loss': 0.0002830501877054914, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008291181290405802, 'n_size': 80, 'n_batch': 20, 'duration': 19.60791325569153, 'duration_batch': 0.9803956627845765, 'duration_size': 0.24509891569614412, 'avg_pred_std': 0.06914721243083477}\n", "Epoch 28\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0004966751410051984, 'avg_role_model_std_loss': 0.013397608251197823, 'avg_role_model_mean_pred_loss': 1.0833947258724608e-07, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0004966751410051984, 'n_size': 320, 'n_batch': 80, 'duration': 100.00627398490906, 'duration_batch': 1.2500784248113632, 'duration_size': 0.3125196062028408, 'avg_pred_std': 0.19725441149203107}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.008539365028264, 'avg_role_model_std_loss': 0.21735998928961636, 'avg_role_model_mean_pred_loss': 0.0002757914544125217, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008539365028264, 'n_size': 80, 'n_batch': 20, 'duration': 19.56255578994751, 'duration_batch': 0.9781277894973754, 'duration_size': 0.24453194737434386, 'avg_pred_std': 0.06783094126731157}\n", "Epoch 29\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: - 0.000 MB of 0.000 MB uploaded\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.00043703318435177606, 'avg_role_model_std_loss': 0.012256504303220872, 'avg_role_model_mean_pred_loss': 2.395966142935274e-08, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00043703318435177606, 'n_size': 320, 'n_batch': 80, 'duration': 99.63266062736511, 'duration_batch': 1.245408257842064, 'duration_size': 0.311352064460516, 'avg_pred_std': 0.18974662573309614}\n", "Time out: 3631.7292110919952/3600\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: \n", "\u001b[34m\u001b[1mwandb\u001b[0m: Run history:\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_test █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_train █▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_test █▁▂▂▃▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_train ▁█▇██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_test █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_train █▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_test █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_train █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_test ▁█▄▃▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_train █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_test ▄▁▄▄▄▄▅▄▃▄▃▄▂▂▂▃▃▃▆▃▂█▃▃▂▃▃▄▃\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_train ▆▃▄▄▃▂▃▄▃▃▂▃█▁▃▁▃▄▃█▃█▄▅▄▄▄▄▅\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_test ▄▁▄▄▄▄▅▄▃▄▃▄▂▂▂▃▃▃▆▃▂█▃▃▂▃▃▄▃\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_train ▆▃▄▄▃▂▃▄▃▃▂▃█▁▃▁▃▄▃█▃█▄▅▄▄▄▄▅\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_test ▄▁▄▄▄▄▅▄▃▄▃▄▂▂▂▃▃▃▆▃▂█▃▃▂▃▃▄▃\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_train ▆▃▄▄▃▂▃▄▃▃▂▃█▁▃▁▃▄▃█▃█▄▅▄▄▄▄▅\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_size_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_size_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \n", "\u001b[34m\u001b[1mwandb\u001b[0m: Run summary:\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_test 0.00854\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_train 0.0005\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_test 0.06783\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_train 0.19725\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_test 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_test 0.00854\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_train 0.0005\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_test 0.00028\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_train 0.0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_test 0.21736\n", "\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_train 0.0134\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_test 0.97813\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_train 1.25008\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_test 0.24453\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_train 0.31252\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_test 19.56256\n", "\u001b[34m\u001b[1mwandb\u001b[0m: duration_train 100.00627\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_test 20\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_train 80\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_size_test 80\n", "\u001b[34m\u001b[1mwandb\u001b[0m: n_size_train 320\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: You can sync this run to the cloud by running:\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[1mwandb sync /kaggle/working/eval/treatment/tvae/2/wandb/offline-run-20240229_042943-h1x5mo8p\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Find logs at: \u001b[35m\u001b[1m./wandb/offline-run-20240229_042943-h1x5mo8p/logs\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Eval loss {'role_model': 'tvae', 'n_size': 399, 'n_batch': 100, 'role_model_metrics': {'avg_loss': 0.002340294746186043, 'avg_g_mag_loss': 0.022517556619054607, 'avg_g_cos_loss': 0.06498512633163529, 'pred_duration': 2.4336907863616943, 'grad_duration': 4.428678035736084, 'total_duration': 6.862368822097778, 'pred_std': 0.06654548645019531, 'std_loss': 0.0213878583163023, 'mean_pred_loss': 1.473031716159312e-05, 'pred_rmse': 0.04837659373879433, 'pred_mae': 0.035608064383268356, 'pred_mape': 0.06660553067922592, 'grad_rmse': 0.3612377345561981, 'grad_mae': 0.13169732689857483, 'grad_mape': 1.3246768712997437}, 'non_role_model_metrics': {'avg_loss': 0, 'avg_g_mag_loss': 0, 'avg_g_cos_loss': 0, 'avg_pred_duration': 0, 'avg_grad_duration': 0, 'avg_total_duration': 0, 'avg_pred_std': 0, 'avg_std_loss': 0, 'avg_mean_pred_loss': 0}, 'avg_metrics': {'avg_loss': 0.002340294746186043, 'avg_g_mag_loss': 0.022517556619054607, 'avg_g_cos_loss': 0.06498512633163529, 'avg_pred_duration': 2.4336907863616943, 'avg_grad_duration': 4.428678035736084, 'avg_total_duration': 6.862368822097778, 'avg_pred_std': 0.06654548645019531, 'avg_std_loss': 0.0213878583163023, 'avg_mean_pred_loss': 1.473031716159312e-05}, 'min_metrics': {'avg_loss': 0.002340294746186043, 'avg_g_mag_loss': 0.022517556619054607, 'avg_g_cos_loss': 0.06498512633163529, 'pred_duration': 2.4336907863616943, 'grad_duration': 4.428678035736084, 'total_duration': 6.862368822097778, 'pred_std': 0.06654548645019531, 'std_loss': 0.0213878583163023, 'mean_pred_loss': 1.473031716159312e-05, 'pred_rmse': 0.04837659373879433, 'pred_mae': 0.035608064383268356, 'pred_mape': 0.06660553067922592, 'grad_rmse': 0.3612377345561981, 'grad_mae': 0.13169732689857483, 'grad_mape': 1.3246768712997437}, 'model_metrics': {'tvae': {'avg_loss': 0.002340294746186043, 'avg_g_mag_loss': 0.022517556619054607, 'avg_g_cos_loss': 0.06498512633163529, 'pred_duration': 2.4336907863616943, 'grad_duration': 4.428678035736084, 'total_duration': 6.862368822097778, 'pred_std': 0.06654548645019531, 'std_loss': 0.0213878583163023, 'mean_pred_loss': 1.473031716159312e-05, 'pred_rmse': 0.04837659373879433, 'pred_mae': 0.035608064383268356, 'pred_mape': 0.06660553067922592, 'grad_rmse': 0.3612377345561981, 'grad_mae': 0.13169732689857483, 'grad_mape': 1.3246768712997437}}}\n" ] } ], "source": [ "import torch\n", "from ml_utility_loss.loss_learning.estimator.pipeline import train, train_2\n", "from ml_utility_loss.loss_learning.estimator.process_simple import train_epoch, eval as _eval\n", "from ml_utility_loss.params import GradientPenaltyMode\n", "from ml_utility_loss.util import clear_memory\n", "import time\n", "#torch.autograd.set_detect_anomaly(True)\n", "\n", "clear_memory()\n", "\n", "opt = params[\"Optim\"](model.parameters())\n", "loss = train_2(\n", " [train_set, val_set, test_set],\n", " preprocessor=preprocessor,\n", " whole_model=model,\n", " optim=opt,\n", " log_dir=\"logs\",\n", " checkpoint_dir=\"checkpoints\",\n", " verbose=True,\n", " allow_same_prediction=False,\n", " wandb=wandb,\n", " study_name=study_name,\n", " **params\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "id": "9b514a07", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:32:09.773143Z", "iopub.status.busy": "2024-02-29T05:32:09.772841Z", "iopub.status.idle": "2024-02-29T05:32:09.777116Z", "shell.execute_reply": "2024-02-29T05:32:09.776290Z" }, "papermill": { "duration": 0.0282, "end_time": "2024-02-29T05:32:09.779164", "exception": false, "start_time": "2024-02-29T05:32:09.750964", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "model = loss[\"whole_model\"]\n", "opt = loss[\"optim\"]" ] }, { "cell_type": "code", "execution_count": 26, "id": "331a49e1", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:32:09.820437Z", "iopub.status.busy": "2024-02-29T05:32:09.819934Z", "iopub.status.idle": "2024-02-29T05:32:10.291755Z", "shell.execute_reply": "2024-02-29T05:32:10.290586Z" }, "papermill": { "duration": 0.496243, "end_time": "2024-02-29T05:32:10.294946", "exception": false, "start_time": "2024-02-29T05:32:09.798703", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import torch\n", "from copy import deepcopy\n", "\n", "torch.save(deepcopy(model.state_dict()), \"model.pt\")\n", "torch.save(deepcopy(opt.state_dict()), \"optim.pt\")" ] }, { "cell_type": "code", "execution_count": 27, "id": "123b4b17", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:32:10.343563Z", "iopub.status.busy": "2024-02-29T05:32:10.342710Z", "iopub.status.idle": "2024-02-29T05:32:10.631486Z", "shell.execute_reply": "2024-02-29T05:32:10.630522Z" }, "papermill": { "duration": 0.314088, "end_time": "2024-02-29T05:32:10.634360", "exception": false, "start_time": "2024-02-29T05:32:10.320272", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "history = loss[\"history\"]\n", "history.to_csv(\"history.csv\")\n", "history[[\"avg_loss_train\", \"avg_loss_test\"]].plot()" ] }, { "cell_type": "code", "execution_count": 28, "id": "2586ba0a", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:32:10.685003Z", "iopub.status.busy": "2024-02-29T05:32:10.684611Z", "iopub.status.idle": "2024-02-29T05:34:04.939597Z", "shell.execute_reply": "2024-02-29T05:34:04.938519Z" }, "papermill": { "duration": 114.280164, "end_time": "2024-02-29T05:34:04.942172", "exception": false, "start_time": "2024-02-29T05:32:10.662008", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "\n", "from ml_utility_loss.loss_learning.estimator.pipeline import eval\n", "#eval_loss = loss[\"eval_loss\"]\n", "\n", "batch_size = params[\"batch_size_low\"] if \"batch_size_low\" in params else params[\"batch_size\"]\n", "\n", "eval_loss = eval(\n", " test_set, model,\n", " batch_size=batch_size,\n", ")" ] }, { "cell_type": "code", "execution_count": 29, "id": "187137f6", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:34:04.987806Z", "iopub.status.busy": "2024-02-29T05:34:04.987011Z", "iopub.status.idle": "2024-02-29T05:34:05.008253Z", "shell.execute_reply": "2024-02-29T05:34:05.007293Z" }, "papermill": { "duration": 0.046226, "end_time": "2024-02-29T05:34:05.010204", "exception": false, "start_time": "2024-02-29T05:34:04.963978", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
avg_g_cos_lossavg_g_mag_lossavg_lossgrad_durationgrad_maegrad_mapegrad_rmsemean_pred_losspred_durationpred_maepred_mapepred_rmsepred_stdstd_losstotal_duration
tvae0.0683190.0739650.002344.4347470.1316971.3246770.3612380.0000152.4250930.0356080.0666060.0483770.0665450.0213886.85984
\n", "
" ], "text/plain": [ " avg_g_cos_loss avg_g_mag_loss avg_loss grad_duration grad_mae \\\n", "tvae 0.068319 0.073965 0.00234 4.434747 0.131697 \n", "\n", " grad_mape grad_rmse mean_pred_loss pred_duration pred_mae \\\n", "tvae 1.324677 0.361238 0.000015 2.425093 0.035608 \n", "\n", " pred_mape pred_rmse pred_std std_loss total_duration \n", "tvae 0.066606 0.048377 0.066545 0.021388 6.85984 " ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "metrics = pd.DataFrame(eval_loss[\"model_metrics\"]).T\n", "metrics.to_csv(\"eval.csv\")\n", "metrics" ] }, { "cell_type": "code", "execution_count": 30, "id": "123d305b", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:34:05.052398Z", "iopub.status.busy": "2024-02-29T05:34:05.052094Z", "iopub.status.idle": "2024-02-29T05:34:05.621766Z", "shell.execute_reply": "2024-02-29T05:34:05.620695Z" }, "papermill": { "duration": 0.594222, "end_time": "2024-02-29T05:34:05.624883", "exception": false, "start_time": "2024-02-29T05:34:05.030661", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.util import clear_memory\n", "clear_memory()" ] }, { "cell_type": "code", "execution_count": 31, "id": "a3eecc2a", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:34:05.672499Z", "iopub.status.busy": "2024-02-29T05:34:05.672189Z", "iopub.status.idle": "2024-02-29T05:36:07.804471Z", "shell.execute_reply": "2024-02-29T05:36:07.803670Z" }, "papermill": { "duration": 122.15973, "end_time": "2024-02-29T05:36:07.806970", "exception": false, "start_time": "2024-02-29T05:34:05.647240", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caching in ../../../../treatment/_cache_test/tvae/all inf False\n" ] } ], "source": [ "#\"\"\"\n", "from ml_utility_loss.loss_learning.estimator.process import pred, pred_2\n", "from ml_utility_loss.util import stack_samples\n", "\n", "#samples = test_set[list(range(len(test_set)))]\n", "#y = {m: pred(model[m], s) for m, s in samples.items()}\n", "y = pred_2(model, test_set, batch_size=batch_size)\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": 32, "id": "6ab51db8", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:36:07.853517Z", "iopub.status.busy": "2024-02-29T05:36:07.853197Z", "iopub.status.idle": "2024-02-29T05:36:07.870127Z", "shell.execute_reply": "2024-02-29T05:36:07.869295Z" }, "papermill": { "duration": 0.042527, "end_time": "2024-02-29T05:36:07.871962", "exception": false, "start_time": "2024-02-29T05:36:07.829435", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "from ml_utility_loss.util import transpose_dict\n", "\n", "os.makedirs(\"pred\", exist_ok=True)\n", "y2 = transpose_dict(y)\n", "for k, v in y2.items():\n", " df = pd.DataFrame(v)\n", " df.to_csv(f\"pred/{k}.csv\")" ] }, { "cell_type": "code", "execution_count": 33, "id": "d81a30f1", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:36:07.914664Z", "iopub.status.busy": "2024-02-29T05:36:07.914365Z", "iopub.status.idle": "2024-02-29T05:36:07.919622Z", "shell.execute_reply": "2024-02-29T05:36:07.918622Z" }, "papermill": { "duration": 0.028815, "end_time": "2024-02-29T05:36:07.921727", "exception": false, "start_time": "2024-02-29T05:36:07.892912", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'tvae': 0.5567177837355095}\n" ] } ], "source": [ "print({k: sum(v[\"pred\"])/len(v[\"pred\"]) for k, v in y.items()})" ] }, { "cell_type": "code", "execution_count": 34, "id": "3b3ff322", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:36:07.965945Z", "iopub.status.busy": "2024-02-29T05:36:07.965677Z", "iopub.status.idle": "2024-02-29T05:36:08.337152Z", "shell.execute_reply": "2024-02-29T05:36:08.336249Z" }, "papermill": { "duration": 0.396392, "end_time": "2024-02-29T05:36:08.339358", "exception": false, "start_time": "2024-02-29T05:36:07.942966", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASIAAAE8CAYAAABkYrxdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABDfUlEQVR4nO3dd3hb9b0/8PfRlrW8Z+zYcZyEELITCCFkssIFUloaCoW4aUhbnFLI5T7U7dOEUUjoBRrKpWlLwYYfIxQaRtmjJClNE0IGCSQ4yyse8ZRsWdY8398fR5ItW7YlWdKRrM/refRIOufo6Hts6aPv/nKMMQZCCBGRROwEEEIIBSJCiOgoEBFCREeBiBAiOgpEhBDRUSAihIiOAhEhRHQUiAghoqNARAgRHQUiQojoKBCRsNq7dy/uv/9+GI1GsZNC4ggFIhJWe/fuxQMPPECBiASFAhEhRHQUiEjY3H///fif//kfAEBRURE4jgPHcdBqtVi6dOmg43meR15eHr73ve95tz322GO49NJLkZaWBrVajTlz5uD111/3+34vvvgi5syZA7VajdTUVNx8882or6+PzMWRiOJoGhASLkePHsXWrVvxyiuv4Pe//z3S09MBAGfOnMGDDz6IhoYGZGdne4/fs2cPFi9ejNdee80bjPLz83H99ddj6tSpsNvt2LFjB7744gu88847uPbaa72vffjhh/Gb3/wG3//+97F48WK0trbiqaeeglarxeHDh5GcnBzVayejxAgJo//93/9lAFh1dbV3W1VVFQPAnnrqKZ9j77zzTqbVapnFYvFu6/+YMcbsdjubNm0aW7ZsmXdbTU0Nk0ql7OGHH/Y59tixY0wmkw3aTmIfFc1IxE2aNAkzZ87Eq6++6t3mcrnw+uuv47rrroNarfZu7/+4s7MTJpMJixYtwqFDh7zbd+7cCZ7n8f3vfx9tbW3eW3Z2NkpKSvDZZ59F58JI2MjETgBJDKtXr8avfvUrNDQ0IC8vD7t27UJLSwtWr17tc9w777yD3/72tzhy5AhsNpt3O8dx3senTp0CYwwlJSV+30sul0fmIkjEUCAiUbF69WqUl5fjtddew913342//e1vMBgMuPrqq73H/Otf/8L111+Pyy+/HH/84x+Rk5MDuVyOiooKvPzyy97jeJ4Hx3F4//33IZVKB72XVquNyjWR8KFARMKqf86lv6KiIsyfPx+vvvoqNmzYgJ07d2LVqlVQKpXeY/7+979DpVLhww8/9NleUVHhc67i4mIwxlBUVIRJkyZF5kJIVFEdEQkrjUYDAH47NK5evRr79u3Dc889h7a2tkHFMqlUCo7j4HK5vNtqamrw5ptv+hx34403QiqV4oEHHgAb0OjLGEN7e3t4LoZEDTXfk7A6cOAA5s+fj5UrV+Lmm2+GXC7HddddB41Gg3PnzqGgoABarRZyuRzNzc0+9Tn//Oc/sXz5cixatAi33HILWlpa8PTTTyM7OxtHjx71CTpbt25FeXk5Lr30UqxatQo6nQ7V1dV44403sH79etx7771iXD4JlYgtdmSMeuihh1heXh6TSCSDmvIXLlzIALB169b5fe2zzz7LSkpKmFKpZFOmTGEVFRVs8+bNzN9H9e9//zu77LLLmEajYRqNhk2ZMoWVlZWxqqqqSF0aiRDKERFCREd1RIQQ0VEgIoSIjgIRIUR0FIgIIaITNRAVFhZ6p4rofysrKxMzWYSQKBO1Z/WBAwd8Oq99/fXXuOKKK3DTTTeJmCpCSLTFVPP93XffjXfeeQenTp0acqhAfzzPo7GxETqdLqDjCSHRxRhDd3c3cnNzIZEMXQCLmbFmdrsdL774IjZu3DhkULHZbD4jshsaGjB16tRoJZEQEqL6+nqMGzduyP0xE4jefPNNGI1GlJaWDnnMli1b8MADDwzaXl9fD71eH8HUEUJC0dXVhfz8fOh0umGPi5mi2VVXXQWFQoF//OMfQx4zMEfkuUiTyUSBiJAY1NXVBYPBMOJ3NCZyRLW1tfjkk0+wc+fOYY9TKpU+00MQQsaGmOhHVFFRgczMTJ/J0QkhiUP0QMTzPCoqKrBmzRrIZDGRQSOERJno3/xPPvkEdXV1WLt2bUTOzxiD0+n06a9EAiOVSiGTyahrBIk40QPRlVdeOWiWvXCx2+1oamqCxWKJyPkTQVJSEnJycqBQKMROChnDRA9EkcLzPKqrqyGVSpGbmwuFQkG/7EFgjMFut6O1tRXV1dUoKSkZtkMaIaMxZgOR3W4Hz/PIz89HUlKS2MmJS2q1GnK5HLW1tbDb7VCpVGInKWIO1nbg1HkzitI1mFeYComEfrSiacz/xNGv+Ogkyt8vPyUJzV1W7D3Tjt0nW8VOTsJJjE8ZISPI1Kuw4oIsAMCReiOaTVaRU5RYKBCRhPVtc5dPwJmWZ8AFOULv30N1nWIlKyFRIEpwhYWF2LZtm9jJiDqrw4XPvm3FjgN1aDL1erfPHp8MADjdYobVQV0+ooUCEUlIh+uMsDpcSNUokKXrq4TP1KmQrlXAxTPUtPeImMLEQoFoDLDb7WInIa5YHS5v0euSCWmDWsiK0rUAgJYu26DXkshIyEBkd/JD3pwuPuBjHQEcG4olS5Zgw4YN2LBhAwwGA9LT0/Gb3/zG2/GzsLAQDz30EG6//Xbo9XqsX78eAPD5559j0aJFUKvVyM/Px1133YWenr5f9ZaWFlx33XVQq9UoKirCSy+9FFL64t03jSbYnTzStQqUZGoH7Z+Rb8DahUW4fFKGCKlLTGO2H9Fwnv7s9JD7itI1WDUrz/v8L3vOwOHy3/N7XIoaN83N9z5/7t/V6LX71ivcc8WkkNL4/PPP48c//jG++OILfPnll1i/fj0KCgpwxx13AAAee+wxbNq0CZs3bwYAnDlzBldffTV++9vf4rnnnkNra6s3mFVUVAAASktL0djYiM8++wxyuRx33XUXWlpaQkpfvOJ5hq/qTQCAmfkpfju56lTyQdtIZCVkIIoH+fn5+P3vfw+O4zB58mQcO3YMv//9772BaNmyZfjv//5v7/Hr1q3DrbfeirvvvhsAUFJSgj/84Q9YvHgxtm/fjrq6Orz//vv44osvMG/ePADAs88+iwsuuCDq1yam6vYemHodUMmlmJIz/GRdJHoSMhCVLZ045L6BHWrXX1485LEDf0zXLiwaTbJ8XHLJJT6/1gsWLMDjjz/uHbw7d+5cn+O/+uorHD161Ke4xRjzDnU5efIkZDIZ5syZ490/ZcoUJCcnhy3N8cDpYtCpZJicrYNcOnTNRF27BYfrO5GhVeLSielRTGFiSshApJAFXjUWqWNHS6PR+Dw3m834yU9+grvuumvQsQUFBTh58mS0khbTJmfrMDFTCyc/fP2dxeHE2dYeasKPkoQMRPFg//79Ps/37duHkpISSKVSv8fPnj0bx48fx8SJ/nN7U6ZMgdPpxMGDB71Fs6qqKhiNxrCmOx5IJRykEv9/R480jTATaJvZDsYYDZiOsIRsNYsHdXV12LhxI6qqqvDKK6/gqaeewi9+8Yshj7/vvvuwd+9ebNiwAUeOHMGpU6fw1ltvYcOGDQCAyZMn4+qrr8ZPfvIT7N+/HwcPHsS6deugVqujdUmiYoyhrt0CFx/YlDOpGgUkHAe7k0e3zRnh1BEKRDHq9ttvR29vL+bPn4+ysjL84he/8DbT+zN9+nTs3r0bJ0+exKJFizBr1ixs2rQJubm53mMqKiqQm5uLxYsX48Ybb8T69euRmZkZjcsRXWu3DX8/dA6Ve2sCCkZSCYcUjdB61mGmflqRRkWzGCWXy7Ft2zZs37590L6amhq/r5k3bx4++uijIc+ZnZ2Nd955x2fbbbfdNqp0xouT580AgGy9CtIAp/hI0yjRbrajvceGwnTNyC8gIaMcEUkIZ1qFQDTRTwfGoaRphVkp2yhHFHEUiMiY19FjR0ePHVIJh/FpgU+Sl5KkiGpLaCKjolkM2rVrl9hJGFPOunND41LUUMmHby3rryRTi0lZWmoxiwIKRGTM8xTLijMCL5YBoOlio0j0fGdDQwN++MMfIi0tDWq1GhdddBG+/PJLsZNFxgib04VmkzCKniqcY5eoOaLOzk4sXLgQS5cuxfvvv4+MjAycOnUKKSkpYiaLjCFyiQSr5+WjucsKgzqIwawuB3Dmnzhz4jAa+BQUXHIjCnNoNH6kiBqIHn30UeTn53tHhwNAUVH4xmsRIpFwyDaokG0IYgUSxoDjbwFtp+Dq7Ya8pw3cNzuB7PWDBxiSsBC1aPb2229j7ty5uOmmm5CZmYlZs2bhmWeeGfJ4m82Grq4unxshYXf+a6DtFCCRwlK4Ai5ODr6zXthOIkLUQHT27Fls374dJSUl+PDDD/Gzn/0Md911F55//nm/x2/ZsgUGg8F7y8/P93scIQDQa3fh4+PnUdXcHfhqwrwLqP6X8LhwESTjZqNRPwM2pws4R3WXkSJqIOJ5HrNnz8YjjzyCWbNmYf369bjjjjvwpz/9ye/x5eXlMJlM3lt9fX2UU0ziSXOXFV83mLDvbHvgTfAtxwGrCVBqgXFzoVXKcF47FTaeA7qbgZ62yCY6QYkaiHJycjB16lSfbRdccAHq6ur8Hq9UKqHX631uhAzFszpHlj6I+qHGw8J93hxAKodWJYNTqkKbLEfY3loV5lQSQORAtHDhQlRV+f5jT548ifHjx4uUIjKWnO8S1izLCbSi2twKmBoATgJkTwcA6JRCS1uzogA8Y0DH2YikNdGJGojuuece7Nu3D4888ghOnz6Nl19+GX/5y19QVlYWmTdkDHDao38LtH4CwAsvvIC0tDTYbL4rSKxatSphBqiGA2PM238o4Bazpq+E+/SJQtEMgEougVYpgyy1EE6eAd1NQtM+CStRm+/nzZuHN954A+Xl5XjwwQdRVFSEbdu24dZbb43MG7ocwL8ej8y5h7PovwGZIqBDb7rpJtx11114++23cdNNNwEQVt949913hx1ZT3wZLQ5YHS7IJBzStcqRX8AY0Pqt8NidGwIAjuNwx+UThP3/2QPYzEBXA5BSGJmEJyjRe1b/13/9F44dOwar1YoTJ054J4dPVGq1GrfccotP36oXX3wRBQUFWLJkiXgJizPnu4ViWaZeGdi0H10NgK1b+MFI8dOXjeMAg7uV1kiNJOGWWGPNpHIhdyLG+wbhjjvuwLx589DQ0IC8vDxUVlaitLSUBl8GobNHKD5l6ALIDQF9uaG0EkA6xNdCnwe0nADM58OQQtJfYgUijgu4iCSmWbNmYcaMGXjhhRdw5ZVX4ptvvsG7774rdrLiyoLiNMzINyCgmWEZ62sNy5gyaPc3jSYcqu3EFJUC8wChGZ+EVWIFojiybt06bNu2DQ0NDVixYgV13gxBkiLAj3dXI2DtEnKuqYOLZU4XQ5vZjvNKg7DB1g3YewAFDaINF9HriIh/t9xyC86dO4dnnnkGa9euFTs5Y1ubOzeUNtFvMVqrEgJat1MCJKUKG6l4FlYUiGKUwWDAd7/7XWi1WqxatUrs5MSVZpMVbxw+hy9rOgJ7QfsZ4T7d//LgSQphMrUemxPQuhcboB7WYUWBKIY1NDTg1ltvhVIZYIUrASAM7ahps6DB2DvywZYOIahwEiB1gt9DPEW8XrsLTO3OEVnaw5VcAqojikmdnZ3YtWsXdu3ahT/+8Y9iJyfutJuFjowB9R/y5IaS8wG5/46PGneOyMkz2JWpUAKUIwozCkQxaNasWejs7MSjjz6KyZMni52cuNPRI6y6kaoJoIW0/bRwn+Z/hVwAkEklUMolsDl4WGQGIRBRjiisKBDFoKHWLSOBMVqEPkQpSSMEIqcNMLoHWA8TiAAgTaOA3cnDoUwWNjh6AbsFUAS+KggZGgUiMqbYnTzM7iWik5NG6EjacRZgvNAS5mkNG8LqeQV9T1R6obnf0k6BKEzGfGV1wBNiEb/i7e9n7BWKZWqFdOSlgwIolvmVlC7cU/EsbMZsIJLLhV9Di8Uickrim+fv5/l7xrpeuwsKmQQpI+WGeL6vojrYQKR2L+7Q2xl8AolfY7ZoJpVKkZycjJaWFgBAUlISjdUKAmMMFosFLS0tSE5OhlQa+MKEYhqfpsGdS4rhcI2Qk+tuFOp5ZErAMG7E837b3IUD1R3IT03CErW7h7XVFIYUE2AMByIAyM7OBgBvMCLBS05O9v4d4wXHcVDIRvjRaTsl3KdOACQjB1nPMA+dSg6kJgsbrcZRpZP0GdOBiOM45OTkIDMzEw4HTWYVLLlcHjc5oaAFWT+kUQpflR67E1BRjijcxnQg8pBKpWP3C0V8vHqgDiq5FMsvyIJWOcTHu9fo7k3NAWnFAZ3XM8zDYnMBqmRho90izMAZBzM6xLoxW1lNEo/V4UKj0YqzrT2QS4cpmnkqqQ3jALk6oHN7A5HdBSZT9vXCplxRWFAgImNGV69Q/E5SSKGUDZMDbnfXDwXRWuYZb8YzBpuT71c8M4aSVDIABSIyZnRZhUCkH26N+yB6U/cnlXBQyISvi8Xer3hGOaKwSIg6IpIYuqxCj2q9aphA1FkjrOaqTgGS0oI6f5pGAYeLh4tnfTmiXmNoiSU+RM0R3X///eA4zuc2ZcrgqToJCYSnaKZXD/P72r+1LMh+ZTfPL8BtCwqFebDVycJGKpqFheg5ogsvvBCffPKJ97lMJnqSSJwaMUfEWL9AFFhr2ZD6TxtLRk30b71MJou7DnMkNkk5oR5HpxriY93VKDS5yxRAcoH/YwKl1An3FIjCQvRAdOrUKeTm5kKlUmHBggXYsmULCgr8f0hsNpvPCqhdXV3RSiaJA9dOzxl+kK4nNxRgb+qBvqo34qtzRkzK0uGSce5AZO8R6pxCOB/pI2od0cUXX4zKykp88MEH2L59O6qrq7Fo0SJ0d/v/ldmyZQsMBoP3RitbkIE8dY1+hTra3s3h4tFutgvzHcnVgMT9O065olETNRBdc801uOmmmzB9+nRcddVVeO+992A0GvG3v/3N7/Hl5eUwmUzeW309rbhJAmTrBszuMYdDzE09Es+0IlaHS6jopuJZ2IheNOsvOTkZkyZNwunTp/3uVyqVNJE88au2vQd7TrWhIDUJiydlDD6g46xwr8sOeT0ytbt3da/DJWxQ6oSpQCgQjVpMdWg0m804c+YMcnJyxE4KiTOdFgfaum3eJvxBPIFoFK1laneOqNfeLxABFIjCQNRAdO+992L37t2oqanB3r178Z3vfAdSqRQ/+MEPxEwWiUN9fYj8NN3zPNBRLTwOsVgG9AtEDgpE4SZq0ezcuXP4wQ9+gPb2dmRkZOCyyy7Dvn37kJHhJ2tNyDC8wzv8Nd13NQhDO+QqQJcb8nt4imZ2Jw+ni4dMqRd22Kj1drREDUQ7duwQ8+3JGGJ2d2b024fIUyxLKQIkoRcClO4+Skq5FA4Xg8ybI6JANFoxVVlNSKg8K3dolX6KZp5ANIpiGSB0DVi3qN85qGgWNjFVWU1IKHieoccm1NtolAM6FtrMQHez8HiUgWgQ5YBOjSRkFIhI3LO7eKRo5FDJpdAoBmTyO92V1LosQKkN7xsrNAAnEcaw2c3hPXeCoaIZiXsquRS3Lyj0v9NbLBvlIFe3vWfacLrFjNkFKZiWZxCCm7VLyHl5pgYhQaMcERm7GOvXbF8UllP22l1oN9vR7a4ch8Kdy6Ic0ahQICJjV0+rsHaZVA7o88JySu8wD6e7TsjTS5sC0ahQ0YzEvS+qO1B1vhvT8wyYkZ/ct6OzRrhPLgjb6HiVXPjttnk6NXpzRD1hOX+iohwRiXsdPXa0ddtgd/G+Ozprhfvk8WF7L9Wg3tXuQGSjHNFoUCAica+vD1G/DD7vAozuQJQS/kBkdbiDnrdoRjmi0aBAROJej79A1N0EuNzzBmmzwvZePlOBAFRZHSYUiEhcY4z5zxH1rx8KcpL84ajlUmiVsr7+ShSIwoIqq0lcszl52J1CMUnjE4g8xbLCsL5fqkaBOy7v10PbWzSzCN0Fwhj0EgnliEhc8xTLlHKJdwFEuBzCiHsg7IFoEIVGCD6MBxyWyL7XGEaBiMQ1F8+QrlUgTaPo22iqFyqrlTphIcVIkkiFeiiAWs5GgYpmJK5l6lW4beDwjv7FsggUld452oiOHjuuvjAbmXqVkCuyW9z1ROGrGE8klCMiY4/JvajCaNcuG+r0vQ60m+3osVOnxnChQETGFpejb9oPw7iIvIVKRk344UZFMxLXPvi6Ca1mOy6bmI6idI2wmivvEno8R6h+aHBfIurUOFqUIyJxrd09vMO7wqunWGYYF7GmdM94s77e1ZQjGi0KRCSuDepVbTon3BsiUz8E+BmBT+PNRi1mAtHWrVvBcRzuvvtusZNC4oSr3xSxWpVMWDbIG4giUz8E9AtEdiqahUtIgejs2bNhTcSBAwfw5z//GdOnTw/recnY5hnaIZVwwppj5vNCZbVMCWgityRVkkIY5uHtQElFs1ELKRBNnDgRS5cuxYsvvgir1TqqBJjNZtx666145plnkJIS4c5nZEzxFMs0Shk4jutXP5Q/qmWDRnJBjh53XD4Byy9w9xnyBCKXQ1g/jQQtpP/WoUOHMH36dGzcuBHZ2dn4yU9+gi+++CKkBJSVleHaa6/FihUrRjzWZrOhq6vL50YSlydHpPPWD/WrqI4mmUKYBRKg4lmIQgpEM2fOxJNPPonGxkY899xzaGpqwmWXXYZp06bhiSeeQGtra0Dn2bFjBw4dOoQtW7YEdPyWLVtgMBi8t/z8/FCST8YICQekaxVI0SiEAaee+qFkET4XVDwblVHlX2UyGW688Ua89tprePTRR3H69Gnce++9yM/Px+23346mpqYhX1tfX49f/OIXeOmll6BSqQJ6v/LycphMJu+tvr5+NMkncW5ipg63LSjEFVOzAEuHMMxCIgN0ORF9X6vDhb8dqMcL/6np6zbQfxQ+CdqoAtGXX36JO++8Ezk5OXjiiSdw77334syZM/j444/R2NiIG264YcjXHjx4EC0tLZg9ezZkMhlkMhl2796NP/zhD5DJZHC5Bi9Yp1QqodfrfW6EAABMdcK9Pjds81MPRS6VoMHYi3azHTYnzdQYDiH1rH7iiSdQUVGBqqoqrFy5Ei+88AJWrlwJibuCsKioCJWVlSgsLBzyHMuXL8exY8d8tv3oRz/ClClTcN9990EqjeyHiYwxxujVD0klHBQyCexOHlaHS2jOp6LZqIQUiLZv3461a9eitLQUOTn+s8GZmZl49tlnhzyHTqfDtGnTfLZpNBqkpaUN2k6IPy/trwXPM1xzUQ7So1w/pPQGIsoRhUNIgejjjz9GQUGBNwfkwRhDfX09CgoKoFAosGbNmrAkkpCBGGPoMNvh5BnkTjNgNQnLP4dp/bKRqBVSdFudfat5UCAalZACUXFxMZqampCZmemzvaOjA0VFRX7rdwKxa9eukF5HEo/VwcPJCxXF2l53o4g2U+jMGAU0Aj+8Qqqs9rYUDGA2mwNuASNkNLptDgBCL2dpV/Sb7WkEfngFlSPauHEjAIDjOGzatAlJSUnefS6XC/v378fMmTPDmkBC/PGMMdMoZb49qqNEo5RCp5JB4hnhr3B/Fxw0iX4oggpEhw8fBiDkiI4dOwaFom+eYIVCgRkzZuDee+8NbwoJ8cNsFXpVJ8scQHebsDGKPaqXTM7Eksn9qibk7hwR7wIcvX2BiQQkqED02WefARCa2Z988knqx0NE4xnekep09+JPSusrHolBKgPkKsBhFYpnFIiCElJldUVFRbjTQUhQlHIJ0rUKpPPnhQ1iDOsYSKF1ByIzgMiN/h+LAg5EN954IyorK6HX63HjjTcOe+zOnTtHnTBChjO7IAWzC1KAg7sBG6I+0LXZZMXuky3QKuW4drq7L51CA/S0UYV1CAIORAaDQZhqwf2YENE57f0myo9ujsjFGBqNViQn9euqQi1nIQs4EPUvjlHRjMSE7kZhhVWlDlBF98dRJRswbzXQLxBRX6JghdSPqLe3FxZL3yjj2tpabNu2DR999FHYEkbIUBwuHtt3ncFHew/AxTOhfijKzeWefkQ2pws87xmB7+7USEtPBy2kQHTDDTfghRdeAAAYjUbMnz8fjz/+OG644QZs3749rAkkZKAemxNWhwuc6ZwwEWO0J0JDXyBiDDQCPwxCnqFx0aJFAIDXX38d2dnZqK2txQsvvIA//OEPYU0gIQOZbU5wzIVUVys4cBFdsWMonhH4gL/e1VQ0C1ZIgchisUCn0wEAPvroI9x4442QSCS45JJLUFtbG9YEEjKQ2eZEkr0dKgkv9N3RpIuSjkHLCtHS0yELefL8N998E/X19fjwww9x5ZVXAgBaWlqokyOJuB6bE3pbk5AjMUS/fshD6x7m4eIHzNLo6BV6WJOAhdShcdOmTbjllltwzz33YPny5ViwYAEAIXc0a9assCaQkIG6rU7obOehUEpEqR/yWD1vQJFQphamImG8UGGt1ImTsDgUUiD63ve+h8suuwxNTU2YMWOGd/vy5cvxne98J2yJI8Qfs9UBg60ZiiRJ1PsPDUsiEYZ22MxC8YwCUcBCCkQAkJ2djezsbJ9t8+fPH3WCCBlJKtcNjcwBpVIP6LJHfkE0yfsFIhKwkAJRT08Ptm7dik8//RQtLS3ged5nf7hXgiWkv0vTLMC4ZCBlfMQnyh/OiaYuHDtnQmG6BvOLUoWNCi2AFmo5C1JIgWjdunXYvXs3brvtNuTk5HiHfhASFVFY3z4QFrsLDcZe6FT9vkbUlygkIQWi999/H++++y4WLlwY7vQQMizGGDhvIBK3fkgld/cjctJ4s9EKKRClpKQgNTU13GkhZESN55vR+G01tGoFLojSRPlD6Zsutv94M5q7OhQh9SN66KGHsGnTJp/xZoREg62lGk6eoUeZKaw5L6JB81YDlCMKUUg5oscffxxnzpxBVlYWCgsLIZfLffYfOnQooPNs374d27dvR01NDQDgwgsvxKZNm3DNNdeEkiySAJzt1cKD5OgP6xhI7Q5EvRSIRi2kQLRq1aqwvPm4ceOwdetWlJSUgDGG559/HjfccAMOHz6MCy+8MCzvQcYQxgCjMIRIklIkcmL66ohsDh48zyCRcDTMI0QhBaLNmzeH5c2vu+46n+cPP/wwtm/fjn379vkNRDabDTabzfu8q6srLOkgcaK3Ey5rN3hOCkWauC1mAKCUSSGXclDJpbC7eKgk0r4ckdMGuByAVD78SQiAEOuIAGH6j7/+9a8oLy9HR0cHAKFI1tDQENL5XC4XduzYgZ6eHu+QkYG2bNkCg8HgveXnx1CvWjKkodbBC1pnDWwOHmZFJvRJ4k9OL5Vw2LCsBOsWTfDWF0GmBCTu33fKFQUspEB09OhRTJo0CY8++igee+wxGI1GAMJc1eXl5UGd69ixY9BqtVAqlfjpT3+KN954A1OnTvV7bHl5OUwmk/dWX18fSvJJFNW1W/D2V41wuPiRDx4B66yF3emCSZULvTrkQQGRxXFUTxSCkALRxo0bUVpailOnTvms7Lpy5Urs2bMnqHNNnjwZR44cwf79+/Gzn/0Ma9aswfHjx/0eq1QqodfrfW4kdtmdPD463oyzrT04XGcEAJgsDhw9Zwz+ZIyB76yFXi2HNLUQWmWMBiKAAlEIQvpvHjhwAH/+858Hbc/Ly0Nzc3NQ51IoFJg4cSIAYM6cOThw4ACefPJJv+cn8eWbRhO6rU7o1XLMKkiG2ebEi/trwfMMEzK0wQWTnlZInb2YkpeKKZfNFwaYxoC9Z9pQ32HB3MJUFGe4K6ppgrSghfTfVCqVfiuKT548iYyM0a3nxPO8T4U0iU+MMRypNwIA5o5PgVwqgUYhRYZWCSfPcMSdQwqYsU64NxSIOr5sIKPFgUajFaZeR99GyhEFLaRAdP311+PBBx+EwyH88TmOQ11dHe677z5897vfDfg85eXl2LNnD2pqanDs2DGUl5dj165duPXWW0NJFokhTSYrjBYHFDIJLsgRitAcx2FWQTIAYcCod9L5QHRUg2cMLAb6D/XnHeZBfYlGJaRA9Pjjj8NsNiMjIwO9vb1YvHgxJk6cCJ1Oh4cffjjg87S0tOD222/H5MmTsXz5chw4cAAffvghrrjiilCSRWLIyfPdAIDiDK13bmcAmJChhVohhdnmRG1HgD3zXU7AWIPqth5UfCsJrY4pQlQy92oetKzQqIRUR2QwGPDxxx/j3//+N7766iuYzWbMnj0bK1asCOo8zz77bChvT2IcYwynzgtfwpIsrc8+qYTD5CwdjtQbcep8N4rSA1iv3lQHuJwwMxVMXDLk0tioHwIApd/e1dSpMVhBByKe51FZWYmdO3eipqYGHMehqKgI2dnZwshomhIk4dmcPCZmatHcZcX41MH9fSZmanGk3oizbT19PZKH0yHMb3VengtwHJKTYqeToJrGm4VFUIGIMYbrr78e7733HmbMmIGLLroIjDGcOHECpaWl2LlzJ958880IJZXEC5VciqVTMofcn5eshkouhdPFo9NiR5pWOfwJO6rhYgzN0lwAQLJa3MGu/fXVEfkrmvUIw1Lox3lEQQWiyspK7NmzB59++imWLl3qs++f//wnVq1ahRdeeAG33357WBNJxhaJhMN3Z+chRaMYuZjVawR62mB18jAq86CUS7xf/ligkgvDPHwuw1M0453CUA+5yu9rSZ+g/qOvvPIKfvWrXw0KQgCwbNky/PKXv8RLL70UtsSR+OPiGc51WvqW2BlCpl4VWF1PpzDavkeZBZdUhWS1IqaK/zkGFTYsK/Fd0UMq75uihIpnAQkqEB09ehRXX331kPuvueYafPXVV6NOFIlf57useO3Lc6j4d3XAY8yGPa79DACgQylMghZL9UMAhg6KnlyRgwJRIIIKRB0dHcjKyhpyf1ZWFjo7O0edKBK/znX2AgCy9KoRcy7fNJrw//5Tg0N1Q3xmXA5vjkiWORlF6RrkJavDmt6IoQrroARVR+RyuSCTDf0SqVQKp9M56kSR+FXv7huU76e1bCC7k0eb2Y6zrT2YM97P1MMdZ4U+RCoDJhZNwMQYKpL19/6xJnRZHbj6whwYPDk2CkRBCbrVrLS0FEql/1YOGpqR2JwuHo1GIUc0LmXknMuEdC12VbWi0WiF1eHqm0rDo+2kcJ8xKaZbnppMwhCPHruzXyCiuauDEVQgWrNmzYjHUItZ4mrussLJMyQppEjTjNzEbkiSI02rQLvZjtp2CyZn91sZleeB9tMAAEfKRDjsTiQpYnPEvUouhanXQX2JRiGo/2xFRUWk0kHGgPoOITeUn5oUcMtWUboG7WY7qtvMvoHIVAc4rIAiCbXOVPxj91nkpybhe3PEn5lxIL99ieTuoikFooDETocMEvfOdQr1Q4EUyzw8Qzyq2yy+g2Bbq4T7tIlo6bYDgO9ChjHE/yT6VDQLBgUiEjaXT8rAZSXpGJ8WwPgxt1yD0Mva6nCh0STkqMC7gJYTwuOMKTjfbQUAZOtjs2Ogp27LRkWzkMXmTwyJS1l6FbKCDBYSCYcpOTrYHHzfKP3OGsDRCyiSwFIK0XysBgCQbYjNQKQcdsVXCw3zCAAFIiK6pZMHjEs7/41wnzkVxl4XrA4XpBIuoApwMXiGefjwBCLGAw5L33PiFwUiEhYHajqgU8lQlK6BUjaKGRSd9r5m+8ypqHfXO+UYVJDF0PQf/c3KT8bsghTfjRIpIFcLOTt7DwWiEcTmf5bEFZvThb2n2/H+sWbflqMgMMbQ0mVFa+3XQo9qdTKgz0Wdu4NkQQAdJMUy9DAPmiAtUJQjIqPW0NkLnjEY1HIY1KGNBTvWYMKnJ1owv+tzpKcycFkXAhyHC3MNSFJIA5tALdYotEBPG1VYB4ByRGTU6t3jy0aTa5mUpYOON0FirEO3zQXkzAAgNO8vm5KFzBhtMQOESdHeOtKAVw/U+Q7gpZazgFEgIqPmLT6lhR6IVHIpZsmFlTpOWNPBlPGzZp1UwuFsaw8ajVbYnENMkEaGRYGIjEqPzYm2bmGMYTAdGQdxOTFVUgMJB5zgivHawXP48JtmGC32MKU0cuRSibfrQa+d5q4OhaiBaMuWLZg3bx50Oh0yMzOxatUqVFVViZkkEiRPq1aGTjm6sWAtx6FmdozLzkKnugANnb043tiFqubuMKU0svz3rqYcUaBEDUS7d+9GWVkZ9u3bh48//hgOhwNXXnklenroHxcv2s1CjmX8KIplYAyo3w8AyJu2CP81Iw9Tc/W4YmoW5hf5mR4kBqkVQiCy2P0FImo1G4morWYffPCBz/PKykpkZmbi4MGDuPzyy0VKFQnGwonpuGicAaPqN9xxVmhdksqBnJkokatQkqUb+XUxJMkdiKhoFpqYar43mUwAgNRU/7+CNpvNZ84jf8tek+jTq0Y5fas7N4TcmXE70bxquKKZo1cYPxdDS2XHmpiprOZ5HnfffTcWLlyIadOm+T1my5YtMBgM3lt+fn6UU0n6G2mC/IB01gCdtcKXdNy80Z9PJEkKYZiHz99ErgYk7t96G/1oDidmAlFZWRm+/vpr7NixY8hjysvLYTKZvLf6+vooppD0xxhD5d4avHm4Ad1WR6gnAar3CI9zZgIqQ9jSF20Li9OxYVkJFhSn9W3kOEDpLmLa4qPSXSwxUTTbsGED3nnnHezZswfjxg098ZVSqRxymloSXec6e9HV64DN6fK2GAWt/TRgagCkMmD8gvAmMMqGXK1WqQN6OwEr5YiGI2ogYozh5z//Od544w3s2rULRUVFYiaHBOFbd7N6SaYutMGoLidw+lPhcd6cvpzDWEM5ooCIGojKysrw8ssv46233oJOp0NzczMAwGAwQK2Ok2VjEpDDxeNUi/DFmpIdYgCp3y/kFBQaoODSMKZOHKZeB3ZVtQAAbpiZ17dD5e4hToFoWKIGou3btwMAlixZ4rO9oqICpaWl0U8QCci3Td2wOXgY1PLQelObW4HavcLjicvjtqVsoLOtPZBKODDG+kbke3NEVDQbjuhFMxJfGGM4cs4IAJiRnxz88s8uJ3DiLWFd+NQJQObU8CdSBJ5+RC6ewebk+5ZGUror4CkQDStmWs1IfGgw9qKt2wa5lMOFuUEOTGUMOPWhkCNSJAFTrh0zU6j2H2/ms6yQJ0dEldXDiolWMxI/MnUqrLggC73+FkQcjqepvumoEHym/Beg1EYuoSJQyaWwO3lY7C4ke0a8eOqIHL3ChG/SUXb+HKMoEJGgKGQSXDQuyP4+Lidw5lOg4ZDwvORKIK04/IkTWZJCiq5eh+94M5lK6J7gcgoV1knxMXYu2igQkcjhXUJfoeo9wlgyjgMmrgDyZoudsohI8g58dfZt5DihnsjSLtQTUSDyiwIRCYjDxeONQw2YkqPDhbkGSD0d+Hhe+JL1tAiDOx0W4Zff1g10NQrFEUCoE5q8EkgvEe8iIkyjkEEhk8A5cOiLUucORNSEPxQKRCQg3zZ1o8HYC7PNiWm5BsDSAZz7Emg9IazdNRRFkjB8Y9w84fEYtmxKJlZMzRq8gyqsR0SBiIyIMYYj9Z0AgJm5akhOfww0HhbW7AKEClhtlvCFU2iE6S+UOkCTLmwfIy1jIxlymAd1ahwRBSIyouYuK9rMdqQ6mnBR00eA3f2FSp0AjJsLpBTSFBfDoU6NI6JAREZ0vMGE7O5jmOs8DLlaI4ySn3wNkEpjA/vzDPPgGcN3ZvUbvK1KFu57jWIkKy5QICLDcjjssH/zDxR2VSEzRw9kTxOa32U0C8JAEk4Y5iHhBgzzUCcL91aT0J8qQYqqwaBARIZmt6D93y8iueskFHIZ9BddLVQ60xfJrySFDBwH8Iyh1+HqW0xAqRf+ZrxTmL96rM40MAo0xIP412sEDr+IJGszDHodpDNWg8ufT0FoGFIJ552byWzr15dIIhWCEUDFsyFQjogMZm4Fju4AbGbok9Mw9fLVQgsYGZFGKYPF7oLF5gL6Z3zUyULRzGoEQFMcD0Q5IuKrpw346mXAZhaCz6zbKAgFQaP0kyMCqMJ6BBSISB9LB/DVK0IHRV0Wagq+g07X2JgrKFo07nohn/FmQL8Ka2NU0xMvKBARgdMGfP13ISekzQB/0Wp8UGVC5d4aNBp7xU5d3NAohWEe/MC5tihHNCyqIyJCk/KJfwjFMqUWmL4arTYpeu0uKGQSZOspVxSoSyak4dLitMETxlGOaFiUIyJAw0Gg7ZSwBte07wJKHeo6hPFj41LUQw9dIINIJZz/WSs9OSKbWZgShPigQJToLB3A2c+Ex8XLAH0uAKDeHYgKUsf2QNWokasBmUJ4TLmiQSgQJTLGgKr3hF/olELvPEFOF4+GTqFeiAJRcFw8w1tHGvDivlrYnP0qrDkOUKcIjy0d4iQuhokaiPbs2YPrrrsOubm54DgOb775ppjJSTwtJwBjvTCD4ORrvJ0Vm0xWOHkGrVKGVI1C5ETGF6mEQ4OxF63dNpitA4pgSe5VYC3t0U9YjBM1EPX09GDGjBl4+umnxUxGYnI5+opkBQv6KlMBb/1Qfqo6+FU6CHQqYV7qLgpEARO11eyaa67BNddcI2YSEte5L4WJupQ6IP9in11zC1OQY1BBraCpPUKhV8nQ1m1Dt9Xhu4MC0ZDiqvneZrPBZrN5n3d10fwuIXE5gHNfCI+LLh+0soRSJsWEjLG1wkY06d05ou7hckQ0Ct9HXFVWb9myBQaDwXvLz6cxOyFpPir0nlbpgawLxU7NmKNTCb/vg3JE6lQh+DhtwvzexCuuAlF5eTlMJpP3Vl9fL3aS4g/PA/Xu3FD+xYNmVjxU14nPT7Wh3Wzz82ISiCHriKQyYVI5gIpnA8RV0UypVEKppAm5RqX1W2GYgVwN5MwYtPvrBhPazXZk6ZVI09LfOhQ6lQxyKQe51E/RKylN+Ptb2oGU8VFPW6yKq0BERokxoO4/wuNxcwfVDZltTrSb7eA4IJ/6D4Usx6BC2dKJ/lscNelA+xmgpzX6CYthogYis9mM06dPe59XV1fjyJEjSE1NRUFBgYgpG6M6zgLmFiEA5c0ZtNvTmzpDpwxuOWniY9guD1r3ckPm89FJTJwQNRB9+eWXWLp0qff5xo0bAQBr1qxBZWWlSKkaw+r2Cfe5M4Wi2QA0rCMKvIGohVrO+hE1EC1ZsgRs4HQJJDJM5wBjnVA5PW7+oN2MMW9HRgpEo3e4rhNfN3bhwlw9Zhek9O1QpwqDi10OoLeTlqB2i6tWMzIKntxQ1oV9C/7102lxoNvqhFTCITd5cG6JBMfm5NHWbUNb94DWR4kE0GYIj6l45kWBKBH0tAnTfHDcoF7UHl29DqgVUuQmqyGX0sditFKShDF6Rotj8E6qJxqEWs0SgSc3lF4y5PzTheka/OTyCbA6+CgmbOxKSRJaJDst9sE7tZnCfTcFIg/66RvrrCbg/DfC4/xLhj2U4zgaXxYmye4ckcXugsU+oGOjTpjzCV0NQoU1oUA05tV/ATAeSC4ADHl+D3G6eGo0CDOFTIJkd66orXtArkibJXShcNqEYjOhQDSm2cxA4xHh8fhLhzzsYG0n/vqvahyu64xOuhJEhk7omd5qtvrukEi8M2HCRMOUAApEY9u5L4RljvW5wgyMQ6hu64HZJrSYkfDJ0quQrlVAJvHzNTOME+67GqKbqBhFldVjlaMXaDgkPB5/6ZAd53psTjR3Cb/YNPVHeM0rTMW8wiH6CXkCkelc9BIUwyhHNFbV7RM6zWkzgLSJQx5W3dYDxoBsgwpaJf0uRY0+D+AkwgDYXioSUyAai6wmYQZGAChaPOwwgjOtZgDAhHRNNFKWkFw8g905oFuETNnXeNBxNvqJijEUiMai6n8JdUPJ+cPmhmxOF+rahWEdVCyLjP+cacf2XadxpN44eGdqsXDfToGIAtFYY2oAzn8tPC5eNmxu6HSLGU6eISVJjnQtrdYRCUq5BA4X879sd+oE4d5YIxSjExgForGEdwnrlDEGZE/rayIeQo5BjdnjUzCzIIVW64iQce5xew3GXvD8gL5a2kxhxkaXE2g/7efViYMC0VhSu1foICdXA8XLRzw8VaPA4kkZmJmfHPm0Jah0rTC3k93Jo6lrQH8ijgOypgqPPb3fExQForGiswao/bfwuOQKQEFTecQCiYRDUbrwvzjTYh58QKZ78YKOswk9oT4ForGgtxM4/rZQJMuZPuLKHA4Xj/ePNeFcp4WGdkRBsbsh4EyrefDfW5sB6HOEYnXjYRFSFxsoEMU7ew9w9G/CvTYDKLlyxJccazDh2+ZufPTNeRpzGQUFaUmQSzkYLQ40mqyDDxg3T7hvOCjUFyUgCkTxrNcIHH4RsHQIlZ7TVw+aEH8gq8OF/Wc7AAgrukpoWEfEKWVSzBmfissnZSBN46d1MmOKMFmd3QKcOxD9BMYACkTxqv0McOh5dxDSAzNuFpaPHsHeM22wOlxI0yowLdcQhYQSAFhQnIY541P8L0ogkQor7gJCPZ818VYwpkAUb+w9QNUH7uKYRSiOzbotoLmPT7d046t6EwBgyaRMyg2JZFAvawDIcne3cDmA428JdUYJJCYC0dNPP43CwkKoVCpcfPHF+OKLL8ROUuyxdQPVe4D9f+6r1MybDcxe43cO6oHq2i14/1gzAGD2+BQUpFGrmhgajL144T81ON44INfDccAF1wEyhTAQ9sTbCRWMRB/l+Oqrr2Ljxo3405/+hIsvvhjbtm3DVVddhaqqKmRmZoqdPHHZLUDHGaC1Smje9XwwtZnAxBVBrRTaYOyFk2eYkKHBoon+p4slkVfXbkG31YmPj58HzxguzNX3dSZNSgWmrgK+/jvQ8q1QRJt0FaDLFjXN0cAxkdtvL774YsybNw//93//BwDgeR75+fn4+c9/jl/+8pfDvrarqwsGgwEmkwl6/ci5gpjDmDBLn9MqTNthNbqXI24DuhqF+p/+DHlCC0v6ZGFyrUGnY7A5eVjsLrR22yCVABMzdd59h+uNuCjPQJPji4jnGT78phnfNncDAPKS1Ziaq0dushp6lQwyqUSo/zv+lvDZAITZNdMnCQFJnQLIk/z+/2NRoN9RUXNEdrsdBw8eRHl5uXebRCLBihUr8J///GfQ8TabDTZb3/IsXV0BVup1VANnPvU/P7DPNjb0tpCPFe6/bjDB5uyX1WY8JC47ODAo5VJckN33TzrR3IVeu3BsryIVRs0EdCYVodecBk21DLdk9n0I3zh8Due7bEJMc/Fw9htGoFPJMCFdC4mEA8dxvutrEVFIJByunpaNFI0CX1R3oMHYiwb3OLSLJ6Ti0uJ0IK0YbVN+iGOfv40U82lw544COOo9BwOHrBQ9clK0gEQKi4Ph68ZuMJ9hOn2PcwyqkdeqC2WIz0XfD6haIBCiBqK2tja4XC5kZWX5bM/KysK333476PgtW7bggQceCP6NnDbALO5a445es/9KSgA8JwOUWkCpB9QpqLfY0OBKhlmRBadUJRzEANicgz4vNgfvDVoeCpkEKUkKZBuUsLt4qCQ0IX4s4TgOl0xIw4W5enzT2IW6dgtauq3QKfu6XtjkOhzRXg6lahZSLWehtzVBY2+HwmUBwOCy9wJ24cPA7E64ekxDv6FMDSgjUCfIwrfii+h1RMEoLy/3LksNCDmi/Pz8kV+YnC80bwP9In+/b7S/X5Ihfl1CPTZnshWu/v83jgNkKkCugkwmB/Qq765Z+TZMd+dsPKfwnGlgS9c103Lg4HlwAGRSCZIUUip6xQmdSo5LJqThkglpYIz5ZK7TNAp8b457Fkdc0LeD8eAcvdDLXYBSAvAuyBx2FEzy01ESABiDVikFkuSDto+aInxzWIkaiNLT0yGVSnH+vO/6TufPn0d29uAKOqVSCaVSGfwbKTRAalGoyQyL7CCm+/FMuh4Iw8APGIlLHMf5/J6p5FLkD1mc8v0wKQHkpUUsaVEh6k+nQqHAnDlz8Omnn3q38TyPTz/9FAsWLBAxZYSQaBK9aLZx40asWbMGc+fOxfz587Ft2zb09PTgRz/6kdhJI4REieiBaPXq1WhtbcWmTZvQ3NyMmTNn4oMPPhhUgU0IGbtE70c0GnHfj4iQMS7Q7yg1rxBCREeBiBAiOgpEhBDRiV5ZPRqe6q2Ah3oQQqLK890cqSo6rgNRd7cwcDCg3tWEENF0d3fDYBh6Ir64bjXjeR6NjY3Q6XQRW5fLM4ykvr4+IVrmEu16AbrmSF4zYwzd3d3Izc2FZJgZA+I6RySRSDBu3LiRDwwDvV6fMB9SIPGuF6BrjpThckIeVFlNCBEdBSJCiOgoEI1AqVRi8+bNoY36j0OJdr0AXXMsiOvKakLI2EA5IkKI6CgQEUJER4GIECI6CkSEENFRIEJwK80+88wzWLRoEVJSUpCSkoIVK1bE3cq0oa6su2PHDnAch1WrVkU2gREQ7DUbjUaUlZUhJycHSqUSkyZNwnvvvRel1IZHsNe8bds2TJ48GWq1Gvn5+bjnnntgtQ4xKX+4sQS3Y8cOplAo2HPPPce++eYbdscdd7Dk5GR2/vx5v8ffcsst7Omnn2aHDx9mJ06cYKWlpcxgMLBz585FOeWhCfZ6Paqrq1leXh5btGgRu+GGG6KT2DAJ9pptNhubO3cuW7lyJfv8889ZdXU127VrFzty5EiUUx66YK/5pZdeYkqlkr300kusurqaffjhhywnJ4fdc889UUlvwgei+fPns7KyMu9zl8vFcnNz2ZYtWwJ6vdPpZDqdjj3//PORSmJYhXK9TqeTXXrppeyvf/0rW7NmTdwFomCvefv27WzChAnMbrdHK4lhF+w1l5WVsWXLlvls27hxI1u4cGFE0+mR0EUzz0qzK1as8G4bbqVZfywWCxwOB1JTUyOVzLAJ9XoffPBBZGZm4sc//nE0khlWoVzz22+/jQULFqCsrAxZWVmYNm0aHnnkEbhcLr/Hx5pQrvnSSy/FwYMHvcW3s2fP4r333sPKlSujkua4HvQ6WsGuNOvPfffdh9zcXJ9/eqwK5Xo///xzPPvsszhy5EgUUhh+oVzz2bNn8c9//hO33nor3nvvPZw+fRp33nknHA4HNm/eHI1kj0oo13zLLbegra0Nl112GRhjcDqd+OlPf4pf/epX0UgyVVaPxtatW7Fjxw688cYbUKlUI78gznR3d+O2227DM888g/T0dLGTEzU8zyMzMxN/+ctfMGfOHKxevRq//vWv8ac//UnspEXMrl278Mgjj+CPf/wjDh06hJ07d+Ldd9/FQw89FJX3T+gcUbArzfb32GOPYevWrfjkk08wffr0SCYzbIK93jNnzqCmpgbXXXeddxvPC+tmy2QyVFVVobi4OLKJHqVQ/sc5OTmQy+WQSqXebRdccAGam5tht9uhUCgimubRCuWaf/Ob3+C2227DunXrAAAXXXQRenp6sH79evz6178edi6hcEjoHFGoK83+7ne/w0MPPYQPPvgAc+fOjUZSwyLY650yZQqOHTuGI0eOeG/XX389li5diiNHjsTFzJih/I8XLlyI06dPe4MuAJw8eRI5OTkxH4SA0K7ZYrEMCjaeQMyiMRw1KlXiMWzHjh1MqVSyyspKdvz4cbZ+/XqWnJzMmpubGWOM3XbbbeyXv/yl9/itW7cyhULBXn/9ddbU1OS9dXd3i3UJQQn2egeKx1azYK+5rq6O6XQ6tmHDBlZVVcXeeecdlpmZyX7729+KdQlBC/aaN2/ezHQ6HXvllVfY2bNn2UcffcSKi4vZ97///aikN+EDEWOMPfXUU6ygoIApFAo2f/58tm/fPu++xYsXszVr1nifjx8/ngEYdNu8eXP0Ex6iYK53oHgMRIwFf8179+5lF198MVMqlWzChAns4YcfZk6nM8qpHp1grtnhcLD777+fFRcXM5VKxfLz89mdd97JOjs7o5JWmgaEECK6hK4jIoTEBgpEhBDRUSAihIiOAhEhRHQUiAghoqNARAgRHQUiQojoKBARQkRHgYjElcrKSiQnJ3uf33///Zg5c6b3eWlpaVxOZZvoKBARv0pLS8FxHH76058O2ldWVgaO41BaWupzfLgDQGFhIbZt2+azbfXq1Th58uSQr3nyySdRWVnpfb5kyRLcfffdYU0XCT8KRGRI+fn52LFjB3p7e73brFYrXn75ZRQUFIiSJrVajczMzCH3GwwGnxwTiQ8UiMiQZs+ejfz8fOzcudO7befOnSgoKMCsWbNGdW5/OZVVq1Z5c1lLlixBbW0t7rnnHnAcB47jAAwumg3UP2dWWlqK3bt348knn/Seo7q6GhMnTsRjjz3m87ojR46A4zicPn16VNdFQkOBiAxr7dq1qKio8D5/7rnn8KMf/Sji77tz506MGzcODz74IJqamtDU1BT0OZ588kksWLAAd9xxh/ccBQUFg64JACoqKnD55Zdj4sSJ4boEEgQKRGRYP/zhD/H555+jtrYWtbW1+Pe//40f/vCHEX/f1NRUSKVS6HQ6ZGdnjzhjpj8GgwEKhQJJSUnec0ilUpSWlqKqqso7UbzD4cDLL7+MtWvXhvsySIASeqpYMrKMjAxce+21qKysBGMM1157bdzPX52bm4trr70Wzz33HObPn49//OMfsNlsuOmmm8ROWsKiHBEZ0dq1a1FZWYnnn38+bLkGiUQyaApSh8MRlnMHYt26dd6K+IqKCqxevRpJSUlRe3/iiwIRGdHVV18Nu90Oh8OBq666KiznzMjI8Kn3cblc+Prrr32OUSgUo15LbKhzrFy5EhqNBtu3b8cHH3xAxTKRUdGMjEgqleLEiRPex0MxmUyD1j9LS0vzO8n+smXLsHHjRrz77rsoLi7GE088AaPR6HNMYWEh9uzZg5tvvhlKpTKkImFhYSH279+PmpoaaLVapKamQiKReOuKysvLUVJSMuxiCSTyKEdEAqLX66HX64c9ZteuXZg1a5bP7YEHHvB77Nq1a7FmzRrcfvvtWLx4MSZMmIClS5f6HPPggw+ipqYGxcXFyMjICCnd9957L6RSKaZOnYqMjAzU1dV59/34xz+G3W6PSisgGR7NWU0S1r/+9S8sX74c9fX1g1ZFJdFFgYgkHJvNhtbWVqxZswbZ2dl46aWXxE5SwqOiGUk4r7zyCsaPHw+j0Yjf/e53YieHgHJEhJAYQDkiQojoKBARQkRHgYgQIjoKRIQQ0VEgIoSIjgIRIUR0FIgIIaKjQEQIEd3/B5bwo9EGInRUAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_pred_density_2\n", "\n", "_ = plot_pred_density_2(y)" ] }, { "cell_type": "code", "execution_count": 35, "id": "e79e4b0f", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:36:08.385465Z", "iopub.status.busy": "2024-02-29T05:36:08.385142Z", "iopub.status.idle": "2024-02-29T05:36:08.746209Z", "shell.execute_reply": "2024-02-29T05:36:08.745171Z" }, "papermill": { "duration": 0.38658, "end_time": "2024-02-29T05:36:08.748418", "exception": false, "start_time": "2024-02-29T05:36:08.361838", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_density_3\n", "\n", "_ = plot_density_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))" ] }, { "cell_type": "code", "execution_count": 36, "id": "745adde1", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:36:08.796581Z", "iopub.status.busy": "2024-02-29T05:36:08.796231Z", "iopub.status.idle": "2024-02-29T05:36:08.962034Z", "shell.execute_reply": "2024-02-29T05:36:08.961064Z" }, "papermill": { "duration": 0.194298, "end_time": "2024-02-29T05:36:08.965546", "exception": false, "start_time": "2024-02-29T05:36:08.771248", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_box_3\n", "\n", "_ = plot_box_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))" ] }, { "cell_type": "code", "execution_count": 37, "id": "eabe1bab", "metadata": { "execution": { "iopub.execute_input": "2024-02-29T05:36:09.024114Z", "iopub.status.busy": "2024-02-29T05:36:09.023330Z", "iopub.status.idle": "2024-02-29T05:36:09.234183Z", "shell.execute_reply": "2024-02-29T05:36:09.233273Z" }, "papermill": { "duration": 0.236662, "end_time": "2024-02-29T05:36:09.236192", "exception": false, "start_time": "2024-02-29T05:36:08.999530", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#\"\"\"\n", "from ml_utility_loss.loss_learning.visualization import plot_grad, plot_grad_2, plot_grad_3\n", "import matplotlib.pyplot as plt\n", "\n", "#plot_grad_2(y, model.models)\n", "for m in model.models:\n", " ym = y[m]\n", " fig, ax = plt.subplots()\n", " plot_grad_3(ym[\"error\"], ym[\"grad\"], name=f\"{m}_grad\", fig=fig, ax=ax)\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "54c0e9f3", "metadata": { "papermill": { "duration": 0.023029, "end_time": "2024-02-29T05:36:09.282713", "exception": false, "start_time": "2024-02-29T05:36:09.259684", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "celltoolbar": "Tags", "colab": { "authorship_tag": "ABX9TyOOVfelovKP9fLGU7SvvRie", "gpuType": "T4", "mount_file_id": "17POSGAvge8y9DW9WGs2jLkibaRjToayg", "provenance": [] }, "kaggle": { "accelerator": "gpu", "dataSources": [], "dockerImageVersionId": 30648, "isGpuEnabled": true, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "papermill": { "default_parameters": {}, "duration": 4022.381901, "end_time": "2024-02-29T05:36:12.029238", "environment_variables": {}, "exception": null, "input_path": "eval/treatment/tvae/2/mlu-eval.ipynb", "output_path": "eval/treatment/tvae/2/mlu-eval.ipynb", "parameters": { "dataset": "treatment", "dataset_name": "treatment", "debug": false, "folder": "eval", "gp": false, "gp_multiply": false, "path": "eval/treatment/tvae/2", "path_prefix": "../../../../", "random_seed": 2, "single_model": "tvae" }, "start_time": "2024-02-29T04:29:09.647337", "version": "2.5.0" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }