{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "982e76f5", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:52.982920Z", "iopub.status.busy": "2024-03-03T06:20:52.982060Z", "iopub.status.idle": "2024-03-03T06:20:53.015985Z", "shell.execute_reply": "2024-03-03T06:20:53.015207Z" }, "papermill": { "duration": 0.049159, "end_time": "2024-03-03T06:20:53.018212", "exception": false, "start_time": "2024-03-03T06:20:52.969053", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import joblib\n", "\n", "#joblib.parallel_backend(\"threading\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "675f0b41", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:53.043661Z", "iopub.status.busy": "2024-03-03T06:20:53.043222Z", "iopub.status.idle": "2024-03-03T06:20:53.050291Z", "shell.execute_reply": "2024-03-03T06:20:53.049455Z" }, "papermill": { "duration": 0.022153, "end_time": "2024-03-03T06:20:53.052350", "exception": false, "start_time": "2024-03-03T06:20:53.030197", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'\\n%cd /kaggle/working\\n#!git clone https://github.com/R-N/ml-utility-loss\\n%cd ml-utility-loss\\n!git pull\\n#!pip install .\\n!pip install . --no-deps --force-reinstall --upgrade\\n#'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", "%cd /kaggle/working\n", "#!git clone https://github.com/R-N/ml-utility-loss\n", "%cd ml-utility-loss\n", "!git pull\n", "#!pip install .\n", "!pip install . --no-deps --force-reinstall --upgrade\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "5ae30f5c", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:53.076016Z", "iopub.status.busy": "2024-03-03T06:20:53.075743Z", "iopub.status.idle": "2024-03-03T06:20:53.079876Z", "shell.execute_reply": "2024-03-03T06:20:53.079050Z" }, "papermill": { "duration": 0.01836, "end_time": "2024-03-03T06:20:53.081880", "exception": false, "start_time": "2024-03-03T06:20:53.063520", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "plt.rcParams['figure.figsize'] = [3,3]" ] }, { "cell_type": "code", "execution_count": 4, "id": "9f42c810", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:53.105046Z", "iopub.status.busy": "2024-03-03T06:20:53.104804Z", "iopub.status.idle": "2024-03-03T06:20:53.108756Z", "shell.execute_reply": "2024-03-03T06:20:53.107919Z" }, "executionInfo": { "elapsed": 678, "status": "ok", "timestamp": 1696841022168, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "ns5hFcVL2yvs", "papermill": { "duration": 0.017807, "end_time": "2024-03-03T06:20:53.110664", "exception": false, "start_time": "2024-03-03T06:20:53.092857", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "datasets = [\n", " \"insurance\",\n", " \"treatment\",\n", " \"contraceptive\"\n", "]\n", "\n", "study_dir = \"./\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "85d0c8ce", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:53.134112Z", "iopub.status.busy": "2024-03-03T06:20:53.133799Z", "iopub.status.idle": "2024-03-03T06:20:53.139660Z", "shell.execute_reply": "2024-03-03T06:20:53.138933Z" }, "papermill": { "duration": 0.01956, "end_time": "2024-03-03T06:20:53.141460", "exception": false, "start_time": "2024-03-03T06:20:53.121900", "status": "completed" }, "tags": [ "parameters" ] }, "outputs": [], "source": [ "#Parameters\n", "import os\n", "\n", "path_prefix = \"../../../../\"\n", "\n", "dataset_dir = os.path.join(path_prefix, \"ml-utility-loss/datasets\")\n", "dataset_name = \"treatment\"\n", "model_name=\"ml_utility_2\"\n", "models = [\"tvae\", \"realtabformer\", \"lct_gan\", \"tab_ddpm_concat\"]\n", "single_model = \"lct_gan\"\n", "random_seed = 42\n", "gp = True\n", "gp_multiply = True\n", "folder = \"eval\"\n", "debug = False\n", "path = None\n", "param_index = 0\n", "allow_same_prediction = True\n", "log_wandb = False" ] }, { "cell_type": "code", "execution_count": 6, "id": "6b2f5a50", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:53.166707Z", "iopub.status.busy": "2024-03-03T06:20:53.166324Z", "iopub.status.idle": "2024-03-03T06:20:53.171505Z", "shell.execute_reply": "2024-03-03T06:20:53.170774Z" }, "papermill": { "duration": 0.019876, "end_time": "2024-03-03T06:20:53.173357", "exception": false, "start_time": "2024-03-03T06:20:53.153481", "status": "completed" }, "tags": [ "injected-parameters" ] }, "outputs": [], "source": [ "# Parameters\n", "dataset = \"contraceptive\"\n", "dataset_name = \"contraceptive\"\n", "single_model = \"tab_ddpm_concat\"\n", "gp = True\n", "gp_multiply = False\n", "random_seed = 1\n", "debug = False\n", "folder = \"eval\"\n", "path_prefix = \"../../../../\"\n", "path = \"eval/contraceptive/tab_ddpm_concat/1\"\n", "param_index = 3\n", "allow_same_prediction = True\n", "log_wandb = False\n" ] }, { "cell_type": "code", "execution_count": null, "id": "bd7c02d6", "metadata": { "papermill": { "duration": 0.011128, "end_time": "2024-03-03T06:20:53.195823", "exception": false, "start_time": "2024-03-03T06:20:53.184695", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 7, "id": "5f45b1d0", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:53.219625Z", "iopub.status.busy": "2024-03-03T06:20:53.219254Z", "iopub.status.idle": "2024-03-03T06:20:53.228812Z", "shell.execute_reply": "2024-03-03T06:20:53.228007Z" }, "executionInfo": { "elapsed": 7, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "UdvXYv3c3LXy", "papermill": { "duration": 0.023626, "end_time": "2024-03-03T06:20:53.230684", "exception": false, "start_time": "2024-03-03T06:20:53.207058", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working\n", "/kaggle/working/eval/contraceptive/tab_ddpm_concat/1\n" ] } ], "source": [ "from pathlib import Path\n", "import os\n", "\n", "%cd /kaggle/working/\n", "\n", "if path is None:\n", " path = os.path.join(folder, dataset_name, single_model, random_seed)\n", "Path(path).mkdir(parents=True, exist_ok=True)\n", "\n", "%cd {path}" ] }, { "cell_type": "code", "execution_count": 8, "id": "f85bf540", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:53.254197Z", "iopub.status.busy": "2024-03-03T06:20:53.253949Z", "iopub.status.idle": "2024-03-03T06:20:55.504273Z", "shell.execute_reply": "2024-03-03T06:20:55.503354Z" }, "papermill": { "duration": 2.264428, "end_time": "2024-03-03T06:20:55.506284", "exception": false, "start_time": "2024-03-03T06:20:53.241856", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Set seed to \n" ] } ], "source": [ "from ml_utility_loss.util import seed\n", "if single_model:\n", " model_name=f\"{model_name}_{single_model}\"\n", "if random_seed is not None:\n", " seed(random_seed)\n", " print(\"Set seed to\", seed)" ] }, { "cell_type": "code", "execution_count": 9, "id": "8489feae", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:55.533351Z", "iopub.status.busy": "2024-03-03T06:20:55.532777Z", "iopub.status.idle": "2024-03-03T06:20:55.544642Z", "shell.execute_reply": "2024-03-03T06:20:55.543969Z" }, "papermill": { "duration": 0.027615, "end_time": "2024-03-03T06:20:55.546493", "exception": false, "start_time": "2024-03-03T06:20:55.518878", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import json\n", "import os\n", "\n", "df = pd.read_csv(os.path.join(dataset_dir, f\"{dataset_name}.csv\"))\n", "with open(os.path.join(dataset_dir, f\"{dataset_name}.json\")) as f:\n", " info = json.load(f)" ] }, { "cell_type": "code", "execution_count": 10, "id": "debcc684", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:55.570029Z", "iopub.status.busy": "2024-03-03T06:20:55.569753Z", "iopub.status.idle": "2024-03-03T06:20:55.576879Z", "shell.execute_reply": "2024-03-03T06:20:55.576120Z" }, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "Vrl2QkoV3o_8", "papermill": { "duration": 0.021309, "end_time": "2024-03-03T06:20:55.579014", "exception": false, "start_time": "2024-03-03T06:20:55.557705", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "task = info[\"task\"]\n", "target = info[\"target\"]\n", "cat_features = info[\"cat_features\"]\n", "mixed_features = info[\"mixed_features\"]\n", "longtail_features = info[\"longtail_features\"]\n", "integer_features = info[\"integer_features\"]\n", "\n", "test = df.sample(frac=0.2, random_state=42)\n", "train = df[~df.index.isin(test.index)]" ] }, { "cell_type": "code", "execution_count": 11, "id": "7538184a", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:55.604960Z", "iopub.status.busy": "2024-03-03T06:20:55.604181Z", "iopub.status.idle": "2024-03-03T06:20:55.706334Z", "shell.execute_reply": "2024-03-03T06:20:55.705537Z" }, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1696841022169, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "TilUuFk9vqMb", "papermill": { "duration": 0.117849, "end_time": "2024-03-03T06:20:55.708735", "exception": false, "start_time": "2024-03-03T06:20:55.590886", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import ml_utility_loss.synthesizers.tab_ddpm.params as TAB_DDPM_PARAMS\n", "import ml_utility_loss.synthesizers.lct_gan.params as LCT_GAN_PARAMS\n", "import ml_utility_loss.synthesizers.realtabformer.params as RTF_PARAMS\n", "from ml_utility_loss.synthesizers.realtabformer.params.default import GPT2_PARAMS, REALTABFORMER_PARAMS\n", "from ml_utility_loss.util import filter_dict_2, filter_dict\n", "\n", "tab_ddpm_params = getattr(TAB_DDPM_PARAMS, dataset_name).BEST\n", "lct_gan_params = getattr(LCT_GAN_PARAMS, dataset_name).BEST\n", "lct_ae_params = filter_dict_2(lct_gan_params, LCT_GAN_PARAMS.default.AE_PARAMS)\n", "rtf_params = getattr(RTF_PARAMS, dataset_name).BEST\n", "rtf_params = filter_dict(rtf_params, REALTABFORMER_PARAMS)\n", "\n", "lct_ae_embedding_size=lct_gan_params[\"embedding_size\"]\n", "tab_ddpm_normalization=\"quantile\"\n", "tab_ddpm_cat_encoding=tab_ddpm_params[\"cat_encoding\"]\n", "#tab_ddpm_cat_encoding=\"one-hot\"\n", "tab_ddpm_y_policy=\"default\"\n", "tab_ddpm_is_y_cond=True" ] }, { "cell_type": "code", "execution_count": 12, "id": "cca61838", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:20:55.734744Z", "iopub.status.busy": "2024-03-03T06:20:55.734416Z", "iopub.status.idle": "2024-03-03T06:21:00.421850Z", "shell.execute_reply": "2024-03-03T06:21:00.420856Z" }, "executionInfo": { "elapsed": 3113, "status": "ok", "timestamp": 1696841025277, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "7Abt8nStvr9Z", "papermill": { "duration": 4.703061, "end_time": "2024-03-03T06:21:00.424311", "exception": false, "start_time": "2024-03-03T06:20:55.721250", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-03-03 06:20:57.973451: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-03-03 06:20:57.973508: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-03-03 06:20:57.975128: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_lct_ae\n", "\n", "lct_ae = load_lct_ae(\n", " dataset_name=dataset_name,\n", " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", " model_name=\"lct_ae\",\n", " df_name=\"df\",\n", ")\n", "lct_ae = None" ] }, { "cell_type": "code", "execution_count": 13, "id": "6f83b7b6", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:21:00.450585Z", "iopub.status.busy": "2024-03-03T06:21:00.449944Z", "iopub.status.idle": "2024-03-03T06:21:00.456256Z", "shell.execute_reply": "2024-03-03T06:21:00.455505Z" }, "papermill": { "duration": 0.021612, "end_time": "2024-03-03T06:21:00.458577", "exception": false, "start_time": "2024-03-03T06:21:00.436965", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_rtf_embed\n", "\n", "rtf_embed = load_rtf_embed(\n", " dataset_name=dataset_name,\n", " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", " model_name=\"realtabformer\",\n", " df_name=\"df\",\n", " ckpt_type=\"best-disc-model\"\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "0026de74", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:21:00.491218Z", "iopub.status.busy": "2024-03-03T06:21:00.490912Z", "iopub.status.idle": "2024-03-03T06:21:09.035892Z", "shell.execute_reply": "2024-03-03T06:21:09.034691Z" }, "executionInfo": { "elapsed": 20137, "status": "ok", "timestamp": 1696841045408, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "tbaguWxAvtPi", "papermill": { "duration": 8.564414, "end_time": "2024-03-03T06:21:09.038449", "exception": false, "start_time": "2024-03-03T06:21:00.474035", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n", "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:119: ConvergenceWarning: Number of distinct clusters (4) found smaller than n_clusters (10). Possibly due to duplicate points in X.\n", " .fit(X)\n", "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:119: ConvergenceWarning: Number of distinct clusters (4) found smaller than n_clusters (10). Possibly due to duplicate points in X.\n", " .fit(X)\n", "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n", "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:119: ConvergenceWarning: Number of distinct clusters (4) found smaller than n_clusters (10). Possibly due to duplicate points in X.\n", " .fit(X)\n", "100%|██████████| 1/1 [00:00<00:00, 2.64it/s]\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.preprocessing import DataPreprocessor\n", "\n", "preprocessor = DataPreprocessor(\n", " task,\n", " target=target,\n", " cat_features=cat_features,\n", " mixed_features=mixed_features,\n", " longtail_features=longtail_features,\n", " integer_features=integer_features,\n", " lct_ae_embedding_size=lct_ae_embedding_size,\n", " lct_ae_params=lct_ae_params,\n", " lct_ae=lct_ae,\n", " tab_ddpm_normalization=tab_ddpm_normalization,\n", " tab_ddpm_cat_encoding=tab_ddpm_cat_encoding,\n", " tab_ddpm_y_policy=tab_ddpm_y_policy,\n", " tab_ddpm_is_y_cond=tab_ddpm_is_y_cond,\n", " realtabformer_embedding=rtf_embed,\n", " realtabformer_params=rtf_params,\n", ")\n", "preprocessor.fit(df)" ] }, { "cell_type": "code", "execution_count": 15, "id": "a9c9b110", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-03-03T06:21:09.067302Z", "iopub.status.busy": "2024-03-03T06:21:09.066950Z", "iopub.status.idle": "2024-03-03T06:21:09.073932Z", "shell.execute_reply": "2024-03-03T06:21:09.073059Z" }, "executionInfo": { "elapsed": 13, "status": "ok", "timestamp": 1696841045411, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "OxUH_GBEv2qK", "outputId": "76464c90-3baf-4bdc-a955-6f4fddc16b9c", "papermill": { "duration": 0.023972, "end_time": "2024-03-03T06:21:09.075866", "exception": false, "start_time": "2024-03-03T06:21:09.051894", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'tvae': 46,\n", " 'realtabformer': (24, 72, Embedding(72, 672), True),\n", " 'lct_gan': 40,\n", " 'tab_ddpm_concat': 10}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preprocessor.adapter_sizes" ] }, { "cell_type": "code", "execution_count": 16, "id": "3cb9ed90", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:21:09.101444Z", "iopub.status.busy": "2024-03-03T06:21:09.101141Z", "iopub.status.idle": "2024-03-03T06:21:09.105888Z", "shell.execute_reply": "2024-03-03T06:21:09.105064Z" }, "papermill": { "duration": 0.019747, "end_time": "2024-03-03T06:21:09.107826", "exception": false, "start_time": "2024-03-03T06:21:09.088079", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_dataset_3_factory\n", "\n", "datasetsn = load_dataset_3_factory(\n", " dataset_dir=os.path.join(path_prefix, \"ml-utility-loss/\"),\n", " dataset_name=dataset_name,\n", " preprocessor=preprocessor,\n", " cache_dir=path_prefix,\n", ")\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "ad1eb833", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:21:09.133152Z", "iopub.status.busy": "2024-03-03T06:21:09.132866Z", "iopub.status.idle": "2024-03-03T06:21:09.164333Z", "shell.execute_reply": "2024-03-03T06:21:09.163457Z" }, "papermill": { "duration": 0.046241, "end_time": "2024-03-03T06:21:09.166272", "exception": false, "start_time": "2024-03-03T06:21:09.120031", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caching in ../../../../contraceptive/_cache_test/tab_ddpm_concat/all inf False\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_dataset\n", "\n", "test_set = load_dataset(\n", " dataset_dir=os.path.join(path_prefix, \"ml-utility-loss/\", \"datasets_5\", dataset_name),\n", " preprocessor=preprocessor,\n", " cache_dir=os.path.join(path_prefix, dataset_name, \"_cache_test\"),\n", " start=200,\n", " #stop=600,\n", " val=False,\n", " ratio=0,\n", " drop_first_column=True,\n", " model=single_model,\n", ")" ] }, { "cell_type": "code", "execution_count": 18, "id": "14ff8b40", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:21:09.194198Z", "iopub.status.busy": "2024-03-03T06:21:09.193902Z", "iopub.status.idle": "2024-03-03T06:21:09.512656Z", "shell.execute_reply": "2024-03-03T06:21:09.511741Z" }, "executionInfo": { "elapsed": 588, "status": "ok", "timestamp": 1696841049215, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "NgahtU1q9uLO", "papermill": { "duration": 0.334976, "end_time": "2024-03-03T06:21:09.514903", "exception": false, "start_time": "2024-03-03T06:21:09.179927", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'loss_balancer_beta': 0.6806661100374879,\n", " 'loss_balancer_r': 0.9427716710925113,\n", " 'tf_pma_low': 4,\n", " 'grad_loss_fn': torch.Tensor>,\n", " 'pma_ffn_mode': 'shared',\n", " 'patience': 10,\n", " 'inds_init_mode': 'fixnorm',\n", " 'grad_clip': 0.7494458230986923,\n", " 'gradient_penalty_mode': {'gradient_penalty': True,\n", " 'forward_once': False,\n", " 'calc_grad_m': False,\n", " 'avg_non_role_model_m': False,\n", " 'inverse_avg_non_role_model_m': False},\n", " 'dataset_size': 2048,\n", " 'batch_size': 4,\n", " 'epochs': 100,\n", " 'lr_mul': 0.07424782199493057,\n", " 'n_warmup_steps': 104,\n", " 'Optim': functools.partial(, amsgrad=True),\n", " 'fixed_role_model': 'tab_ddpm_concat',\n", " 'd_model': 128,\n", " 'attn_activation': ml_utility_loss.activations.LeakyHardtanh,\n", " 'tf_d_inner': 512,\n", " 'tf_n_layers_enc': 3,\n", " 'tf_n_head': 32,\n", " 'tf_activation': torch.nn.modules.activation.ReLU6,\n", " 'tf_activation_final': ml_utility_loss.activations.LeakyHardtanh,\n", " 'ada_d_hid': 1024,\n", " 'ada_n_layers': 8,\n", " 'ada_activation': torch.nn.modules.activation.Softsign,\n", " 'ada_activation_final': ml_utility_loss.activations.LeakyHardsigmoid,\n", " 'head_d_hid': 256,\n", " 'head_n_layers': 8,\n", " 'head_n_head': 16,\n", " 'head_activation': torch.nn.modules.activation.ReLU6,\n", " 'head_activation_final': ml_utility_loss.activations.LeakyHardsigmoid,\n", " 'single_model': True,\n", " 'models': ['tab_ddpm_concat'],\n", " 'max_seconds': 3600,\n", " 'Body': 'twin_encoder',\n", " 'loss_balancer_log': False,\n", " 'loss_balancer_lbtw': False,\n", " 'pma_skip_small': False,\n", " 'isab_skip_small': False,\n", " 'layer_norm': False,\n", " 'pma_layer_norm': False,\n", " 'attn_residual': True,\n", " 'tf_n_layers_dec': False,\n", " 'tf_isab_rank': 0,\n", " 'tf_layer_norm': False,\n", " 'tf_pma_start': -1,\n", " 'head_n_seeds': 0,\n", " 'dropout': 0,\n", " 'combine_mode': 'diff_left',\n", " 'tf_isab_mode': 'separate',\n", " 'bias': True,\n", " 'bias_final': True,\n", " 'synth_data': 2,\n", " 'tf_lora': False,\n", " 'tf_num_inds': 16,\n", " 'ada_n_seeds': 0,\n", " 'gradient_penalty_kwargs': {'mag_loss': True,\n", " 'mse_mag': True,\n", " 'mag_corr': False,\n", " 'seq_mag': False,\n", " 'cos_loss': False,\n", " 'mag_corr_kwargs': {'only_sign': False},\n", " 'cos_loss_kwargs': {'only_sign': True, 'cos_matrix': False},\n", " 'mse_mag_kwargs': {'target': 0.1, 'multiply': False}}}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import ml_utility_loss.loss_learning.estimator.params2 as PARAMS\n", "from ml_utility_loss.tuning import map_parameters\n", "from ml_utility_loss.loss_learning.estimator.params.default import update_param_space, update_param_space_2\n", "import wandb\n", "\n", "#\"\"\"\n", "param_space = {\n", " **getattr(PARAMS, dataset_name).PARAM_SPACE,\n", "}\n", "params = {\n", " **getattr(PARAMS, dataset_name).BESTS[param_index],\n", "}\n", "if gp:\n", " params[\"gradient_penalty_mode\"] = \"ALL\"\n", " params[\"mse_mag\"] = True\n", " if gp_multiply:\n", " params[\"mse_mag_multiply\"] = True\n", " params[\"mse_mag_target\"] = 1.0\n", " else:\n", " params[\"mse_mag_multiply\"] = False\n", " params[\"mse_mag_target\"] = 0.1\n", "else:\n", " params[\"gradient_penalty_mode\"] = \"NONE\"\n", " params[\"mse_mag\"] = False\n", "params[\"single_model\"] = False\n", "if models:\n", " params[\"models\"] = models\n", "if single_model:\n", " params[\"fixed_role_model\"] = single_model\n", " params[\"single_model\"] = True\n", " params[\"models\"] = [single_model]\n", "if params[\"fixed_role_model\"] == \"realtabformer\" and dataset_name == \"treatment\":\n", " params[\"batch_size\"] = 2\n", "params[\"max_seconds\"] = 3600\n", "params[\"patience\"] = 10\n", "params[\"epochs\"] = 100\n", "if debug:\n", " params[\"epochs\"] = 2\n", "with open(\"params.json\", \"w\") as f:\n", " json.dump(params, f)\n", "params = map_parameters(params, param_space=param_space)\n", "params" ] }, { "cell_type": "code", "execution_count": 19, "id": "a48bd9e9", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:21:09.541987Z", "iopub.status.busy": "2024-03-03T06:21:09.541668Z", "iopub.status.idle": "2024-03-03T06:21:09.612421Z", "shell.execute_reply": "2024-03-03T06:21:09.611487Z" }, "papermill": { "duration": 0.086426, "end_time": "2024-03-03T06:21:09.614383", "exception": false, "start_time": "2024-03-03T06:21:09.527957", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "load_dataset_3_factory 2\n", "Caching in ../../../../contraceptive/_cache/tab_ddpm_concat/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_2/contraceptive [80, 20]\n", "Caching in ../../../../contraceptive/_cache4/tab_ddpm_concat/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_4/contraceptive [80, 20]\n", "Caching in ../../../../contraceptive/_cache5/tab_ddpm_concat/all inf False\n", "Splitting without random!\n", "Split with reverse index!\n", "../../../../ml-utility-loss/datasets_5/contraceptive [160, 40]\n", "[320, 80]\n", "[320, 80]\n" ] } ], "source": [ "train_set, val_set = datasetsn(model=params[\"fixed_role_model\"], synth_data=params[\"synth_data\"])" ] }, { "cell_type": "code", "execution_count": 20, "id": "2fcb1418", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "execution": { "iopub.execute_input": "2024-03-03T06:21:09.642331Z", "iopub.status.busy": "2024-03-03T06:21:09.641944Z", "iopub.status.idle": "2024-03-03T06:21:10.080306Z", "shell.execute_reply": "2024-03-03T06:21:10.079378Z" }, "executionInfo": { "elapsed": 396850, "status": "error", "timestamp": 1696841446059, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "_bt1MQc5kpSk", "outputId": "01c1d3e5-ac64-461d-835a-b76f4a66e6d6", "papermill": { "duration": 0.454288, "end_time": "2024-03-03T06:21:10.082258", "exception": false, "start_time": "2024-03-03T06:21:09.627970", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating model of type \n", "[*] Embedding False True\n", "['tab_ddpm_concat'] 1\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.model.pipeline import remove_non_model_params\n", "from ml_utility_loss.loss_learning.estimator.pipeline import create_model\n", "from ml_utility_loss.util import filter_dict, clear_memory\n", "\n", "clear_memory()\n", "\n", "params2 = remove_non_model_params(params)\n", "adapters = filter_dict(preprocessor.adapter_sizes, params[\"models\"])\n", "\n", "model = create_model(\n", " adapters=adapters,\n", " #Body=\"twin_encoder\",\n", " **params2,\n", ")\n", "#cf.apply_weight_standardization(model, n_last_layers_ignore=0)\n", "print(model.models, len(model.adapters))" ] }, { "cell_type": "code", "execution_count": 21, "id": "938f94fc", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:21:10.111992Z", "iopub.status.busy": "2024-03-03T06:21:10.111204Z", "iopub.status.idle": "2024-03-03T06:21:10.115437Z", "shell.execute_reply": "2024-03-03T06:21:10.114548Z" }, "papermill": { "duration": 0.021493, "end_time": "2024-03-03T06:21:10.117388", "exception": false, "start_time": "2024-03-03T06:21:10.095895", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "study_name=f\"{model_name}_{dataset_name}\"" ] }, { "cell_type": "code", "execution_count": 22, "id": "12fb613e", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:21:10.145249Z", "iopub.status.busy": "2024-03-03T06:21:10.144488Z", "iopub.status.idle": "2024-03-03T06:21:10.151622Z", "shell.execute_reply": "2024-03-03T06:21:10.150768Z" }, "papermill": { "duration": 0.023142, "end_time": "2024-03-03T06:21:10.153501", "exception": false, "start_time": "2024-03-03T06:21:10.130359", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "7827841" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "count_parameters(model)" ] }, { "cell_type": "code", "execution_count": 23, "id": "bd386e57", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:21:10.180220Z", "iopub.status.busy": "2024-03-03T06:21:10.179932Z", "iopub.status.idle": "2024-03-03T06:21:10.257073Z", "shell.execute_reply": "2024-03-03T06:21:10.256169Z" }, "papermill": { "duration": 0.092939, "end_time": "2024-03-03T06:21:10.259122", "exception": false, "start_time": "2024-03-03T06:21:10.166183", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "========================================================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "========================================================================================================================\n", "MLUtilitySingle [2, 1179, 10] --\n", "├─Adapter: 1-1 [2, 1179, 10] --\n", "│ └─Sequential: 2-1 [2, 1179, 128] --\n", "│ │ └─FeedForward: 3-1 [2, 1179, 1024] --\n", "│ │ │ └─Linear: 4-1 [2, 1179, 1024] 11,264\n", "│ │ │ └─Softsign: 4-2 [2, 1179, 1024] --\n", "│ │ └─FeedForward: 3-2 [2, 1179, 1024] --\n", "│ │ │ └─Linear: 4-3 [2, 1179, 1024] 1,049,600\n", "│ │ │ └─Softsign: 4-4 [2, 1179, 1024] --\n", "│ │ └─FeedForward: 3-3 [2, 1179, 1024] --\n", "│ │ │ └─Linear: 4-5 [2, 1179, 1024] 1,049,600\n", "│ │ │ └─Softsign: 4-6 [2, 1179, 1024] --\n", "│ │ └─FeedForward: 3-4 [2, 1179, 1024] --\n", "│ │ │ └─Linear: 4-7 [2, 1179, 1024] 1,049,600\n", "│ │ │ └─Softsign: 4-8 [2, 1179, 1024] --\n", "│ │ └─FeedForward: 3-5 [2, 1179, 1024] --\n", "│ │ │ └─Linear: 4-9 [2, 1179, 1024] 1,049,600\n", "│ │ │ └─Softsign: 4-10 [2, 1179, 1024] --\n", "│ │ └─FeedForward: 3-6 [2, 1179, 1024] --\n", "│ │ │ └─Linear: 4-11 [2, 1179, 1024] 1,049,600\n", "│ │ │ └─Softsign: 4-12 [2, 1179, 1024] --\n", "│ │ └─FeedForward: 3-7 [2, 1179, 1024] --\n", "│ │ │ └─Linear: 4-13 [2, 1179, 1024] 1,049,600\n", "│ │ │ └─Softsign: 4-14 [2, 1179, 1024] --\n", "│ │ └─FeedForward: 3-8 [2, 1179, 128] --\n", "│ │ │ └─Linear: 4-15 [2, 1179, 128] 131,200\n", "│ │ │ └─LeakyHardsigmoid: 4-16 [2, 1179, 128] --\n", "├─Adapter: 1-2 [2, 294, 10] (recursive)\n", "│ └─Sequential: 2-2 [2, 294, 128] (recursive)\n", "│ │ └─FeedForward: 3-9 [2, 294, 1024] (recursive)\n", "│ │ │ └─Linear: 4-17 [2, 294, 1024] (recursive)\n", "│ │ │ └─Softsign: 4-18 [2, 294, 1024] --\n", "│ │ └─FeedForward: 3-10 [2, 294, 1024] (recursive)\n", "│ │ │ └─Linear: 4-19 [2, 294, 1024] (recursive)\n", "│ │ │ └─Softsign: 4-20 [2, 294, 1024] --\n", "│ │ └─FeedForward: 3-11 [2, 294, 1024] (recursive)\n", "│ │ │ └─Linear: 4-21 [2, 294, 1024] (recursive)\n", "│ │ │ └─Softsign: 4-22 [2, 294, 1024] --\n", "│ │ └─FeedForward: 3-12 [2, 294, 1024] (recursive)\n", "│ │ │ └─Linear: 4-23 [2, 294, 1024] (recursive)\n", "│ │ │ └─Softsign: 4-24 [2, 294, 1024] --\n", "│ │ └─FeedForward: 3-13 [2, 294, 1024] (recursive)\n", "│ │ │ └─Linear: 4-25 [2, 294, 1024] (recursive)\n", "│ │ │ └─Softsign: 4-26 [2, 294, 1024] --\n", "│ │ └─FeedForward: 3-14 [2, 294, 1024] (recursive)\n", "│ │ │ └─Linear: 4-27 [2, 294, 1024] (recursive)\n", "│ │ │ └─Softsign: 4-28 [2, 294, 1024] --\n", "│ │ └─FeedForward: 3-15 [2, 294, 1024] (recursive)\n", "│ │ │ └─Linear: 4-29 [2, 294, 1024] (recursive)\n", "│ │ │ └─Softsign: 4-30 [2, 294, 1024] --\n", "│ │ └─FeedForward: 3-16 [2, 294, 128] (recursive)\n", "│ │ │ └─Linear: 4-31 [2, 294, 128] (recursive)\n", "│ │ │ └─LeakyHardsigmoid: 4-32 [2, 294, 128] --\n", "├─TwinEncoder: 1-3 [2, 512] --\n", "│ └─Encoder: 2-3 [2, 4, 128] --\n", "│ │ └─ModuleList: 3-18 -- (recursive)\n", "│ │ │ └─EncoderLayer: 4-33 [2, 1179, 128] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-1 [2, 1179, 128] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-1 [2, 16, 128] 2,048\n", "│ │ │ │ │ └─MultiHeadAttention: 6-2 [2, 16, 128] --\n", "│ │ │ │ │ │ └─Linear: 7-1 [2, 16, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-2 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-3 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-4 [2, 32, 16, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-1 [2, 32, 16, 1179] --\n", "│ │ │ │ │ │ └─Linear: 7-5 [2, 16, 128] 16,512\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-6 [2, 16, 128] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-3 [2, 1179, 128] --\n", "│ │ │ │ │ │ └─Linear: 7-7 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-8 [2, 16, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-9 [2, 16, 128] 16,384\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-10 [2, 32, 1179, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-2 [2, 32, 1179, 16] --\n", "│ │ │ │ │ │ └─Linear: 7-11 [2, 1179, 128] 16,512\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-12 [2, 1179, 128] --\n", "│ │ │ │ └─DoubleFeedForward: 5-2 [2, 1179, 128] --\n", "│ │ │ │ │ └─Linear: 6-4 [2, 1179, 512] 66,048\n", "│ │ │ │ │ └─ReLU6: 6-5 [2, 1179, 512] --\n", "│ │ │ │ │ └─Linear: 6-6 [2, 1179, 128] 65,664\n", "│ │ │ └─EncoderLayer: 4-34 [2, 1179, 128] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-3 [2, 1179, 128] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-7 [2, 16, 128] 2,048\n", "│ │ │ │ │ └─MultiHeadAttention: 6-8 [2, 16, 128] --\n", "│ │ │ │ │ │ └─Linear: 7-13 [2, 16, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-14 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-15 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-16 [2, 32, 16, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-3 [2, 32, 16, 1179] --\n", "│ │ │ │ │ │ └─Linear: 7-17 [2, 16, 128] 16,512\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-18 [2, 16, 128] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-9 [2, 1179, 128] --\n", "│ │ │ │ │ │ └─Linear: 7-19 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-20 [2, 16, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-21 [2, 16, 128] 16,384\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-22 [2, 32, 1179, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-4 [2, 32, 1179, 16] --\n", "│ │ │ │ │ │ └─Linear: 7-23 [2, 1179, 128] 16,512\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-24 [2, 1179, 128] --\n", "│ │ │ │ └─DoubleFeedForward: 5-4 [2, 1179, 128] --\n", "│ │ │ │ │ └─Linear: 6-10 [2, 1179, 512] 66,048\n", "│ │ │ │ │ └─ReLU6: 6-11 [2, 1179, 512] --\n", "│ │ │ │ │ └─Linear: 6-12 [2, 1179, 128] 65,664\n", "│ │ │ └─EncoderLayer: 4-35 [2, 4, 128] --\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-5 [2, 1179, 128] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-13 [2, 16, 128] 2,048\n", "│ │ │ │ │ └─MultiHeadAttention: 6-14 [2, 16, 128] --\n", "│ │ │ │ │ │ └─Linear: 7-25 [2, 16, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-26 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-27 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-28 [2, 32, 16, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-5 [2, 32, 16, 1179] --\n", "│ │ │ │ │ │ └─Linear: 7-29 [2, 16, 128] 16,512\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-30 [2, 16, 128] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-15 [2, 1179, 128] --\n", "│ │ │ │ │ │ └─Linear: 7-31 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-32 [2, 16, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-33 [2, 16, 128] 16,384\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-34 [2, 32, 1179, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-6 [2, 32, 1179, 16] --\n", "│ │ │ │ │ │ └─Linear: 7-35 [2, 1179, 128] 16,512\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-36 [2, 1179, 128] --\n", "│ │ │ │ └─DoubleFeedForward: 5-6 [2, 1179, 128] --\n", "│ │ │ │ │ └─Linear: 6-16 [2, 1179, 512] 66,048\n", "│ │ │ │ │ └─LeakyHardtanh: 6-17 [2, 1179, 512] --\n", "│ │ │ │ │ └─Linear: 6-18 [2, 1179, 128] 65,664\n", "│ │ │ │ └─PoolingByMultiheadAttention: 5-7 [2, 4, 128] --\n", "│ │ │ │ │ └─TensorInductionPoint: 6-19 [2, 4, 128] 512\n", "│ │ │ │ │ └─SimpleMultiHeadAttention: 6-20 [2, 4, 128] --\n", "│ │ │ │ │ │ └─Linear: 7-37 [2, 4, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-38 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─Linear: 7-39 [2, 1179, 128] 16,384\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-40 [2, 32, 4, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-7 [2, 32, 4, 1179] --\n", "│ │ │ │ │ │ └─Linear: 7-41 [2, 4, 128] 16,512\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-42 [2, 4, 128] --\n", "│ │ │ │ └─DoubleFeedForward: 5-8 [2, 4, 128] (recursive)\n", "│ │ │ │ │ └─Linear: 6-21 [2, 4, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-22 [2, 4, 512] --\n", "│ │ │ │ │ └─Linear: 6-23 [2, 4, 128] (recursive)\n", "│ └─Encoder: 2-4 [2, 4, 128] (recursive)\n", "│ │ └─ModuleList: 3-18 -- (recursive)\n", "│ │ │ └─EncoderLayer: 4-36 [2, 294, 128] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-9 [2, 294, 128] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-24 [2, 16, 128] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-25 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-43 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-44 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-45 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-46 [2, 32, 16, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-8 [2, 32, 16, 294] --\n", "│ │ │ │ │ │ └─Linear: 7-47 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-48 [2, 16, 128] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-26 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-49 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-50 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-51 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-52 [2, 32, 294, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-9 [2, 32, 294, 16] --\n", "│ │ │ │ │ │ └─Linear: 7-53 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-54 [2, 294, 128] --\n", "│ │ │ │ └─DoubleFeedForward: 5-10 [2, 294, 128] (recursive)\n", "│ │ │ │ │ └─Linear: 6-27 [2, 294, 512] (recursive)\n", "│ │ │ │ │ └─ReLU6: 6-28 [2, 294, 512] --\n", "│ │ │ │ │ └─Linear: 6-29 [2, 294, 128] (recursive)\n", "│ │ │ └─EncoderLayer: 4-37 [2, 294, 128] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-11 [2, 294, 128] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-30 [2, 16, 128] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-31 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-55 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-56 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-57 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-58 [2, 32, 16, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-10 [2, 32, 16, 294] --\n", "│ │ │ │ │ │ └─Linear: 7-59 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-60 [2, 16, 128] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-32 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-61 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-62 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-63 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-64 [2, 32, 294, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-11 [2, 32, 294, 16] --\n", "│ │ │ │ │ │ └─Linear: 7-65 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-66 [2, 294, 128] --\n", "│ │ │ │ └─DoubleFeedForward: 5-12 [2, 294, 128] (recursive)\n", "│ │ │ │ │ └─Linear: 6-33 [2, 294, 512] (recursive)\n", "│ │ │ │ │ └─ReLU6: 6-34 [2, 294, 512] --\n", "│ │ │ │ │ └─Linear: 6-35 [2, 294, 128] (recursive)\n", "│ │ │ └─EncoderLayer: 4-38 [2, 4, 128] (recursive)\n", "│ │ │ │ └─SimpleInducedSetAttention: 5-13 [2, 294, 128] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-36 [2, 16, 128] (recursive)\n", "│ │ │ │ │ └─MultiHeadAttention: 6-37 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-67 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-68 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-69 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-70 [2, 32, 16, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-12 [2, 32, 16, 294] --\n", "│ │ │ │ │ │ └─Linear: 7-71 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-72 [2, 16, 128] --\n", "│ │ │ │ │ └─MultiHeadAttention: 6-38 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-73 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-74 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-75 [2, 16, 128] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-76 [2, 32, 294, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-13 [2, 32, 294, 16] --\n", "│ │ │ │ │ │ └─Linear: 7-77 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-78 [2, 294, 128] --\n", "│ │ │ │ └─DoubleFeedForward: 5-14 [2, 294, 128] (recursive)\n", "│ │ │ │ │ └─Linear: 6-39 [2, 294, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-40 [2, 294, 512] --\n", "│ │ │ │ │ └─Linear: 6-41 [2, 294, 128] (recursive)\n", "│ │ │ │ └─PoolingByMultiheadAttention: 5-15 [2, 4, 128] (recursive)\n", "│ │ │ │ │ └─TensorInductionPoint: 6-42 [2, 4, 128] (recursive)\n", "│ │ │ │ │ └─SimpleMultiHeadAttention: 6-43 [2, 4, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-79 [2, 4, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-80 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─Linear: 7-81 [2, 294, 128] (recursive)\n", "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-82 [2, 32, 4, 4] --\n", "│ │ │ │ │ │ │ └─Softmax: 8-14 [2, 32, 4, 294] --\n", "│ │ │ │ │ │ └─Linear: 7-83 [2, 4, 128] (recursive)\n", "│ │ │ │ │ │ └─LeakyHardtanh: 7-84 [2, 4, 128] --\n", "│ │ │ │ └─DoubleFeedForward: 5-16 [2, 4, 128] (recursive)\n", "│ │ │ │ │ └─Linear: 6-44 [2, 4, 512] (recursive)\n", "│ │ │ │ │ └─LeakyHardtanh: 6-45 [2, 4, 512] --\n", "│ │ │ │ │ └─Linear: 6-46 [2, 4, 128] (recursive)\n", "├─Head: 1-4 [2] --\n", "│ └─Sequential: 2-5 [2, 1] --\n", "│ │ └─FeedForward: 3-19 [2, 256] --\n", "│ │ │ └─Linear: 4-39 [2, 256] 131,328\n", "│ │ │ └─ReLU6: 4-40 [2, 256] --\n", "│ │ └─FeedForward: 3-20 [2, 256] --\n", "│ │ │ └─Linear: 4-41 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-42 [2, 256] --\n", "│ │ └─FeedForward: 3-21 [2, 256] --\n", "│ │ │ └─Linear: 4-43 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-44 [2, 256] --\n", "│ │ └─FeedForward: 3-22 [2, 256] --\n", "│ │ │ └─Linear: 4-45 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-46 [2, 256] --\n", "│ │ └─FeedForward: 3-23 [2, 256] --\n", "│ │ │ └─Linear: 4-47 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-48 [2, 256] --\n", "│ │ └─FeedForward: 3-24 [2, 256] --\n", "│ │ │ └─Linear: 4-49 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-50 [2, 256] --\n", "│ │ └─FeedForward: 3-25 [2, 256] --\n", "│ │ │ └─Linear: 4-51 [2, 256] 65,792\n", "│ │ │ └─ReLU6: 4-52 [2, 256] --\n", "│ │ └─FeedForward: 3-26 [2, 1] --\n", "│ │ │ └─Linear: 4-53 [2, 1] 257\n", "│ │ │ └─LeakyHardsigmoid: 4-54 [2, 1] --\n", "========================================================================================================================\n", "Total params: 7,827,841\n", "Trainable params: 7,827,841\n", "Non-trainable params: 0\n", "Total mult-adds (M): 30.76\n", "========================================================================================================================\n", "Input size (MB): 0.12\n", "Forward/backward pass size (MB): 260.58\n", "Params size (MB): 31.31\n", "Estimated Total Size (MB): 292.01\n", "========================================================================================================================" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torchinfo import summary\n", "\n", "role_model = params[\"fixed_role_model\"]\n", "s = train_set[0][role_model]\n", "summary(model[role_model], input_size=((2, *s[0].shape), (2, *s[1].shape)), depth=9) # 8 max" ] }, { "cell_type": "code", "execution_count": 24, "id": "0f42c4d1", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T06:21:10.290901Z", "iopub.status.busy": "2024-03-03T06:21:10.289996Z", "iopub.status.idle": "2024-03-03T07:22:58.625808Z", "shell.execute_reply": "2024-03-03T07:22:58.624885Z" }, "papermill": { "duration": 3708.373089, "end_time": "2024-03-03T07:22:58.647238", "exception": false, "start_time": "2024-03-03T06:21:10.274149", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "g_loss_mul 0.1\n", "Epoch 0\n", "Train loss {'avg_role_model_loss': 0.01341511887658271, 'avg_role_model_std_loss': 0.24023109539175563, 'avg_role_model_mean_pred_loss': 0.000575269428866676, 'avg_role_model_g_mag_loss': 0.004837815280916402, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.015502669336274266, 'n_size': 320, 'n_batch': 80, 'duration': 86.60318899154663, 'duration_batch': 1.082539862394333, 'duration_size': 0.27063496559858324, 'avg_pred_std': 0.11721891383640468}\n", "Val loss {'avg_role_model_loss': 0.018764166883192955, 'avg_role_model_std_loss': 0.7378902015632776, 'avg_role_model_mean_pred_loss': 0.0008810971392904321, 'avg_role_model_g_mag_loss': 0.006337779993191362, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.02165449762251228, 'n_size': 80, 'n_batch': 20, 'duration': 19.204163312911987, 'duration_batch': 0.9602081656455994, 'duration_size': 0.24005204141139985, 'avg_pred_std': 0.08749947492033243}\n", "Epoch 1\n", "Train loss {'avg_role_model_loss': 0.01306824244238669, 'avg_role_model_std_loss': 1.023421409522797, 'avg_role_model_mean_pred_loss': 0.0006003775347632523, 'avg_role_model_g_mag_loss': 0.007479726555175148, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.014787302106560674, 'n_size': 320, 'n_batch': 80, 'duration': 87.21888208389282, 'duration_batch': 1.0902360260486603, 'duration_size': 0.2725590065121651, 'avg_pred_std': 0.1071490949485451}\n", "Val loss {'avg_role_model_loss': 0.007755115552572534, 'avg_role_model_std_loss': 2.593804732917124, 'avg_role_model_mean_pred_loss': 0.00013148071556594586, 'avg_role_model_g_mag_loss': 0.006523144242237322, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.008876382285961881, 'n_size': 80, 'n_batch': 20, 'duration': 19.415016651153564, 'duration_batch': 0.9707508325576782, 'duration_size': 0.24268770813941956, 'avg_pred_std': 0.026741976058110593}\n", "Epoch 2\n", "Train loss {'avg_role_model_loss': 0.0028183516405988485, 'avg_role_model_std_loss': 0.7737382953319809, 'avg_role_model_mean_pred_loss': 9.625664277708966e-06, 'avg_role_model_g_mag_loss': 0.00837607061257586, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0028855871143605325, 'n_size': 320, 'n_batch': 80, 'duration': 87.30349731445312, 'duration_batch': 1.091293716430664, 'duration_size': 0.272823429107666, 'avg_pred_std': 0.0851033627637662}\n", "Val loss {'avg_role_model_loss': 0.0034189124999102205, 'avg_role_model_std_loss': 2.2048701629042626, 'avg_role_model_mean_pred_loss': 1.1537741464451301e-05, 'avg_role_model_g_mag_loss': 0.008503273967653513, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0035019157105125485, 'n_size': 80, 'n_batch': 20, 'duration': 19.296178340911865, 'duration_batch': 0.9648089170455932, 'duration_size': 0.2412022292613983, 'avg_pred_std': 0.02689541974104941}\n", "Epoch 3\n", "Train loss {'avg_role_model_loss': 0.0031308751182223204, 'avg_role_model_std_loss': 1.4198307410230258, 'avg_role_model_mean_pred_loss': 1.015549445890062e-05, 'avg_role_model_g_mag_loss': 0.008551960467593744, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0032078142976388335, 'n_size': 320, 'n_batch': 80, 'duration': 87.20788407325745, 'duration_batch': 1.0900985509157182, 'duration_size': 0.27252463772892954, 'avg_pred_std': 0.07063841636409052}\n", "Val loss {'avg_role_model_loss': 0.0025563906941897586, 'avg_role_model_std_loss': 1.7120467301925602, 'avg_role_model_mean_pred_loss': 4.022160688882392e-06, 'avg_role_model_g_mag_loss': 0.008696647034958005, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0026055111757159466, 'n_size': 80, 'n_batch': 20, 'duration': 19.3008930683136, 'duration_batch': 0.9650446534156799, 'duration_size': 0.24126116335391998, 'avg_pred_std': 0.031063590943813325}\n", "Epoch 4\n", "Train loss {'avg_role_model_loss': 0.0020325787663750816, 'avg_role_model_std_loss': 0.9927601797127522, 'avg_role_model_mean_pred_loss': 3.947625345419869e-06, 'avg_role_model_g_mag_loss': 0.008560203079832717, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002069109250442125, 'n_size': 320, 'n_batch': 80, 'duration': 86.83163189888, 'duration_batch': 1.085395398736, 'duration_size': 0.271348849684, 'avg_pred_std': 0.08267453832668252}\n", "Val loss {'avg_role_model_loss': 0.002587378228054149, 'avg_role_model_std_loss': 1.5380892607378684, 'avg_role_model_mean_pred_loss': 5.001206458744855e-06, 'avg_role_model_g_mag_loss': 0.008236717758700251, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002642228409240488, 'n_size': 80, 'n_batch': 20, 'duration': 19.28103756904602, 'duration_batch': 0.964051878452301, 'duration_size': 0.24101296961307525, 'avg_pred_std': 0.03489238116890192}\n", "Epoch 5\n", "Train loss {'avg_role_model_loss': 0.0018827232423063833, 'avg_role_model_std_loss': 0.36071941533007446, 'avg_role_model_mean_pred_loss': 3.0995558925314963e-06, 'avg_role_model_g_mag_loss': 0.008657984499586746, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0019155390065861866, 'n_size': 320, 'n_batch': 80, 'duration': 86.8512351512909, 'duration_batch': 1.085640439391136, 'duration_size': 0.271410109847784, 'avg_pred_std': 0.08422506358474494}\n", "Val loss {'avg_role_model_loss': 0.002631575356645044, 'avg_role_model_std_loss': 1.174698173921024, 'avg_role_model_mean_pred_loss': 7.702732040915094e-06, 'avg_role_model_g_mag_loss': 0.008593563642352819, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002682924266991904, 'n_size': 80, 'n_batch': 20, 'duration': 19.29587149620056, 'duration_batch': 0.964793574810028, 'duration_size': 0.241198393702507, 'avg_pred_std': 0.03935151700861752}\n", "Epoch 6\n", "Train loss {'avg_role_model_loss': 0.00163584444890148, 'avg_role_model_std_loss': 0.32704864908948517, 'avg_role_model_mean_pred_loss': 3.7024563248498977e-06, 'avg_role_model_g_mag_loss': 0.008764190669171511, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0016616366719972575, 'n_size': 320, 'n_batch': 80, 'duration': 86.82903480529785, 'duration_batch': 1.085362935066223, 'duration_size': 0.27134073376655576, 'avg_pred_std': 0.09021815697196871}\n", "Val loss {'avg_role_model_loss': 0.0027584389958065004, 'avg_role_model_std_loss': 1.2948427033639747, 'avg_role_model_mean_pred_loss': 8.096104292887318e-06, 'avg_role_model_g_mag_loss': 0.007971254212316125, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002821568233775906, 'n_size': 80, 'n_batch': 20, 'duration': 19.269461393356323, 'duration_batch': 0.9634730696678162, 'duration_size': 0.24086826741695405, 'avg_pred_std': 0.03149411482736468}\n", "Epoch 7\n", "Train loss {'avg_role_model_loss': 0.001588404598624038, 'avg_role_model_std_loss': 0.30104452142173416, 'avg_role_model_mean_pred_loss': 2.1237631645752536e-06, 'avg_role_model_g_mag_loss': 0.008680318144615739, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0016122827923027217, 'n_size': 320, 'n_batch': 80, 'duration': 86.72441625595093, 'duration_batch': 1.0840552031993866, 'duration_size': 0.27101380079984666, 'avg_pred_std': 0.08948315244633705}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.002168826656998135, 'avg_role_model_std_loss': 1.1961765073471269, 'avg_role_model_mean_pred_loss': 3.218722386769124e-06, 'avg_role_model_g_mag_loss': 0.008597346721217036, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0022069613172789105, 'n_size': 80, 'n_batch': 20, 'duration': 19.133681535720825, 'duration_batch': 0.9566840767860413, 'duration_size': 0.23917101919651032, 'avg_pred_std': 0.03733009579591453}\n", "Epoch 8\n", "Train loss {'avg_role_model_loss': 0.001353109241972561, 'avg_role_model_std_loss': 0.2765715855718042, 'avg_role_model_mean_pred_loss': 2.087236364796191e-06, 'avg_role_model_g_mag_loss': 0.008927629049867391, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0013720194954657927, 'n_size': 320, 'n_batch': 80, 'duration': 86.93680453300476, 'duration_batch': 1.0867100566625596, 'duration_size': 0.2716775141656399, 'avg_pred_std': 0.09051385981729254}\n", "Val loss {'avg_role_model_loss': 0.0024928924576670397, 'avg_role_model_std_loss': 0.8476081335626077, 'avg_role_model_mean_pred_loss': 6.1137728412461815e-06, 'avg_role_model_g_mag_loss': 0.008560518850572407, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0025436472140427215, 'n_size': 80, 'n_batch': 20, 'duration': 19.220860481262207, 'duration_batch': 0.9610430240631104, 'duration_size': 0.2402607560157776, 'avg_pred_std': 0.03947796570137143}\n", "Epoch 9\n", "Train loss {'avg_role_model_loss': 0.0011078594061473268, 'avg_role_model_std_loss': 0.2964725857345002, 'avg_role_model_mean_pred_loss': 1.1485152926839908e-06, 'avg_role_model_g_mag_loss': 0.009000935748917981, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0011219609619729453, 'n_size': 320, 'n_batch': 80, 'duration': 86.97083497047424, 'duration_batch': 1.087135437130928, 'duration_size': 0.271783859282732, 'avg_pred_std': 0.08712627965724096}\n", "Val loss {'avg_role_model_loss': 0.002337238602922298, 'avg_role_model_std_loss': 0.6530551572132708, 'avg_role_model_mean_pred_loss': 6.45294174761446e-06, 'avg_role_model_g_mag_loss': 0.00837808190844953, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0023841866728616878, 'n_size': 80, 'n_batch': 20, 'duration': 19.421545267105103, 'duration_batch': 0.9710772633552551, 'duration_size': 0.24276931583881378, 'avg_pred_std': 0.04457564675249159}\n", "Epoch 10\n", "Train loss {'avg_role_model_loss': 0.0012053420616211952, 'avg_role_model_std_loss': 0.1931557533955261, 'avg_role_model_mean_pred_loss': 9.706161082351402e-07, 'avg_role_model_g_mag_loss': 0.008857980108587071, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.001221423167589819, 'n_size': 320, 'n_batch': 80, 'duration': 87.4618911743164, 'duration_batch': 1.093273639678955, 'duration_size': 0.27331840991973877, 'avg_pred_std': 0.0897115994244814}\n", "Val loss {'avg_role_model_loss': 0.002253816397569608, 'avg_role_model_std_loss': 0.8841316987068751, 'avg_role_model_mean_pred_loss': 4.704099107730175e-06, 'avg_role_model_g_mag_loss': 0.008425972750410437, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0022973355487920346, 'n_size': 80, 'n_batch': 20, 'duration': 19.446954488754272, 'duration_batch': 0.9723477244377137, 'duration_size': 0.24308693110942842, 'avg_pred_std': 0.04312506481073797}\n", "Epoch 11\n", "Train loss {'avg_role_model_loss': 0.0009794028399483067, 'avg_role_model_std_loss': 0.1359832607921138, 'avg_role_model_mean_pred_loss': 1.01823794474289e-06, 'avg_role_model_g_mag_loss': 0.009016435348894448, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0009917656654579333, 'n_size': 320, 'n_batch': 80, 'duration': 87.36561322212219, 'duration_batch': 1.0920701652765274, 'duration_size': 0.27301754131913186, 'avg_pred_std': 0.08817402567947283}\n", "Val loss {'avg_role_model_loss': 0.0025446744170039893, 'avg_role_model_std_loss': 0.6562726007028459, 'avg_role_model_mean_pred_loss': 8.59153058119233e-06, 'avg_role_model_g_mag_loss': 0.00828958151396364, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002601461045560427, 'n_size': 80, 'n_batch': 20, 'duration': 19.171342372894287, 'duration_batch': 0.9585671186447143, 'duration_size': 0.23964177966117858, 'avg_pred_std': 0.046762616652995345}\n", "Epoch 12\n", "Train loss {'avg_role_model_loss': 0.0008729144968128821, 'avg_role_model_std_loss': 0.1683987133885907, 'avg_role_model_mean_pred_loss': 6.835936173131307e-07, 'avg_role_model_g_mag_loss': 0.009030447795521469, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0008834377124003368, 'n_size': 320, 'n_batch': 80, 'duration': 87.13745355606079, 'duration_batch': 1.0892181694507599, 'duration_size': 0.27230454236268997, 'avg_pred_std': 0.09424625603714958}\n", "Val loss {'avg_role_model_loss': 0.0025075858677155336, 'avg_role_model_std_loss': 0.6813347520466777, 'avg_role_model_mean_pred_loss': 9.080454932988875e-06, 'avg_role_model_g_mag_loss': 0.007943603885360062, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002564889276982285, 'n_size': 80, 'n_batch': 20, 'duration': 19.267545700073242, 'duration_batch': 0.9633772850036622, 'duration_size': 0.24084432125091554, 'avg_pred_std': 0.04623653790913522}\n", "Epoch 13\n", "Train loss {'avg_role_model_loss': 0.0008042755165661219, 'avg_role_model_std_loss': 0.2813704043714537, 'avg_role_model_mean_pred_loss': 3.3645215451869243e-07, 'avg_role_model_g_mag_loss': 0.009021570929326117, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0008135768726788229, 'n_size': 320, 'n_batch': 80, 'duration': 86.61275219917297, 'duration_batch': 1.082659402489662, 'duration_size': 0.2706648506224155, 'avg_pred_std': 0.08749473023926839}\n", "Val loss {'avg_role_model_loss': 0.0024872228954336607, 'avg_role_model_std_loss': 0.8424709204863575, 'avg_role_model_mean_pred_loss': 1.2303503870288169e-05, 'avg_role_model_g_mag_loss': 0.00836391884367913, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00254274607723346, 'n_size': 80, 'n_batch': 20, 'duration': 19.089064598083496, 'duration_batch': 0.9544532299041748, 'duration_size': 0.2386133074760437, 'avg_pred_std': 0.04271733276546001}\n", "Epoch 14\n", "Train loss {'avg_role_model_loss': 0.0008596906129241689, 'avg_role_model_std_loss': 0.12678315311761706, 'avg_role_model_mean_pred_loss': 7.324179896302241e-07, 'avg_role_model_g_mag_loss': 0.009051537868799642, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0008700103617229615, 'n_size': 320, 'n_batch': 80, 'duration': 86.11309504508972, 'duration_batch': 1.0764136880636215, 'duration_size': 0.26910342201590537, 'avg_pred_std': 0.09114952590316534}\n", "Val loss {'avg_role_model_loss': 0.003142218668654095, 'avg_role_model_std_loss': 0.5989244450749837, 'avg_role_model_mean_pred_loss': 1.701286636457411e-05, 'avg_role_model_g_mag_loss': 0.007768401200883091, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0032341236757929435, 'n_size': 80, 'n_batch': 20, 'duration': 18.97400975227356, 'duration_batch': 0.948700487613678, 'duration_size': 0.2371751219034195, 'avg_pred_std': 0.045658442331478}\n", "Epoch 15\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0007071739214552508, 'avg_role_model_std_loss': 0.11358937955730664, 'avg_role_model_mean_pred_loss': 8.504173889859433e-07, 'avg_role_model_g_mag_loss': 0.009156818711198866, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0007155208293625037, 'n_size': 320, 'n_batch': 80, 'duration': 85.8247458934784, 'duration_batch': 1.07280932366848, 'duration_size': 0.26820233091712, 'avg_pred_std': 0.08828951774630696}\n", "Val loss {'avg_role_model_loss': 0.0027607452910160648, 'avg_role_model_std_loss': 0.5599186637941784, 'avg_role_model_mean_pred_loss': 1.3696136823915239e-05, 'avg_role_model_g_mag_loss': 0.00771399496588856, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0028348844876745715, 'n_size': 80, 'n_batch': 20, 'duration': 18.95707416534424, 'duration_batch': 0.9478537082672119, 'duration_size': 0.23696342706680298, 'avg_pred_std': 0.04666830957867205}\n", "Epoch 16\n", "Train loss {'avg_role_model_loss': 0.0006871584620967042, 'avg_role_model_std_loss': 0.056319261586604344, 'avg_role_model_mean_pred_loss': 7.729947919388539e-07, 'avg_role_model_g_mag_loss': 0.009021204826422036, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0006950973430775775, 'n_size': 320, 'n_batch': 80, 'duration': 85.77924585342407, 'duration_batch': 1.0722405731678009, 'duration_size': 0.2680601432919502, 'avg_pred_std': 0.09376501713413746}\n", "Val loss {'avg_role_model_loss': 0.0029137152865587267, 'avg_role_model_std_loss': 0.6469025506016806, 'avg_role_model_mean_pred_loss': 9.816658314748538e-06, 'avg_role_model_g_mag_loss': 0.007721508503891528, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002987763029523194, 'n_size': 80, 'n_batch': 20, 'duration': 18.965726137161255, 'duration_batch': 0.9482863068580627, 'duration_size': 0.23707157671451567, 'avg_pred_std': 0.048026554053649306}\n", "Epoch 17\n", "Train loss {'avg_role_model_loss': 0.0006477722853333034, 'avg_role_model_std_loss': 0.08796985171551344, 'avg_role_model_mean_pred_loss': 5.949152889407616e-07, 'avg_role_model_g_mag_loss': 0.009104125620797276, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0006556250523999552, 'n_size': 320, 'n_batch': 80, 'duration': 86.42096662521362, 'duration_batch': 1.0802620828151703, 'duration_size': 0.2700655207037926, 'avg_pred_std': 0.09506104957545176}\n", "Val loss {'avg_role_model_loss': 0.003072608428192325, 'avg_role_model_std_loss': 0.6891105846539176, 'avg_role_model_mean_pred_loss': 1.3079660551437028e-05, 'avg_role_model_g_mag_loss': 0.007593843806535006, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0031587239500368014, 'n_size': 80, 'n_batch': 20, 'duration': 19.130789279937744, 'duration_batch': 0.9565394639968872, 'duration_size': 0.2391348659992218, 'avg_pred_std': 0.04223878695629537}\n", "Epoch 18\n", "Train loss {'avg_role_model_loss': 0.0006514384161164343, 'avg_role_model_std_loss': 0.09940106079761221, 'avg_role_model_mean_pred_loss': 1.795125833634764e-07, 'avg_role_model_g_mag_loss': 0.009067119332030416, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0006588086861484044, 'n_size': 320, 'n_batch': 80, 'duration': 87.44042205810547, 'duration_batch': 1.0930052757263184, 'duration_size': 0.2732513189315796, 'avg_pred_std': 0.09078157264739276}\n", "Val loss {'avg_role_model_loss': 0.0024777524726232515, 'avg_role_model_std_loss': 0.7228478452472927, 'avg_role_model_mean_pred_loss': 1.1093569846298834e-05, 'avg_role_model_g_mag_loss': 0.008034943602979183, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002537216582277324, 'n_size': 80, 'n_batch': 20, 'duration': 19.3044171333313, 'duration_batch': 0.965220856666565, 'duration_size': 0.24130521416664125, 'avg_pred_std': 0.04896460571326315}\n", "Epoch 19\n", "Train loss {'avg_role_model_loss': 0.0006081814904973726, 'avg_role_model_std_loss': 0.09099644586050033, 'avg_role_model_mean_pred_loss': 3.1757028354558596e-07, 'avg_role_model_g_mag_loss': 0.009112988878041507, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.000615051853355908, 'n_size': 320, 'n_batch': 80, 'duration': 86.71533060073853, 'duration_batch': 1.0839416325092315, 'duration_size': 0.27098540812730787, 'avg_pred_std': 0.09587977258488536}\n", "Val loss {'avg_role_model_loss': 0.00240422225324437, 'avg_role_model_std_loss': 0.6358416218907224, 'avg_role_model_mean_pred_loss': 1.1121995321849986e-05, 'avg_role_model_g_mag_loss': 0.008140259771607816, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0024615431102574803, 'n_size': 80, 'n_batch': 20, 'duration': 19.22822141647339, 'duration_batch': 0.9614110708236694, 'duration_size': 0.24035276770591735, 'avg_pred_std': 0.04650567690841854}\n", "Epoch 20\n", "Train loss {'avg_role_model_loss': 0.0005358590460900814, 'avg_role_model_std_loss': 0.06825528511013328, 'avg_role_model_mean_pred_loss': 2.678394117848915e-07, 'avg_role_model_g_mag_loss': 0.009244002948980779, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0005420027183504316, 'n_size': 320, 'n_batch': 80, 'duration': 85.95558786392212, 'duration_batch': 1.0744448482990265, 'duration_size': 0.26861121207475663, 'avg_pred_std': 0.0906242580153048}\n", "Val loss {'avg_role_model_loss': 0.0025660589744802564, 'avg_role_model_std_loss': 0.6802260238559029, 'avg_role_model_mean_pred_loss': 9.92956469314521e-06, 'avg_role_model_g_mag_loss': 0.007699522981420159, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0026287092332495376, 'n_size': 80, 'n_batch': 20, 'duration': 18.98186469078064, 'duration_batch': 0.9490932345390319, 'duration_size': 0.23727330863475798, 'avg_pred_std': 0.04403721652925015}\n", "Epoch 21\n", "Train loss {'avg_role_model_loss': 0.0004775314865582914, 'avg_role_model_std_loss': 0.06455386250157247, 'avg_role_model_mean_pred_loss': 2.1060846659059965e-07, 'avg_role_model_g_mag_loss': 0.009264782385434956, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00048304505050964506, 'n_size': 320, 'n_batch': 80, 'duration': 85.94736814498901, 'duration_batch': 1.0743421018123627, 'duration_size': 0.26858552545309067, 'avg_pred_std': 0.09594059572555125}\n", "Val loss {'avg_role_model_loss': 0.0030469499950413594, 'avg_role_model_std_loss': 0.8013474333885384, 'avg_role_model_mean_pred_loss': 1.4471279196381915e-05, 'avg_role_model_g_mag_loss': 0.007496322155930102, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0031343029033450874, 'n_size': 80, 'n_batch': 20, 'duration': 19.01077938079834, 'duration_batch': 0.950538969039917, 'duration_size': 0.23763474225997924, 'avg_pred_std': 0.04275354435667396}\n", "Epoch 22\n", "Train loss {'avg_role_model_loss': 0.0005181060546192384, 'avg_role_model_std_loss': 0.049453351887659557, 'avg_role_model_mean_pred_loss': 3.420242802625219e-07, 'avg_role_model_g_mag_loss': 0.009284257958643138, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0005241578612640296, 'n_size': 320, 'n_batch': 80, 'duration': 85.98614287376404, 'duration_batch': 1.0748267859220504, 'duration_size': 0.2687066964805126, 'avg_pred_std': 0.09294412531889976}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Val loss {'avg_role_model_loss': 0.002864245133969234, 'avg_role_model_std_loss': 0.4667397131830512, 'avg_role_model_mean_pred_loss': 1.4686058315405149e-05, 'avg_role_model_g_mag_loss': 0.007691610814072191, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002942165515560191, 'n_size': 80, 'n_batch': 20, 'duration': 19.0608127117157, 'duration_batch': 0.9530406355857849, 'duration_size': 0.23826015889644622, 'avg_pred_std': 0.052058699540793896}\n", "Epoch 23\n", "Train loss {'avg_role_model_loss': 0.00043534455560347853, 'avg_role_model_std_loss': 0.07720435771696674, 'avg_role_model_mean_pred_loss': 3.792708702435137e-07, 'avg_role_model_g_mag_loss': 0.009331031411420554, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00044047548610706146, 'n_size': 320, 'n_batch': 80, 'duration': 85.86611938476562, 'duration_batch': 1.0733264923095702, 'duration_size': 0.26833162307739256, 'avg_pred_std': 0.09743042136542499}\n", "Val loss {'avg_role_model_loss': 0.0030215614737244324, 'avg_role_model_std_loss': 0.6433068920836377, 'avg_role_model_mean_pred_loss': 1.4186154598316848e-05, 'avg_role_model_g_mag_loss': 0.007460616645403206, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.003107524223742075, 'n_size': 80, 'n_batch': 20, 'duration': 18.98639988899231, 'duration_batch': 0.9493199944496155, 'duration_size': 0.23732999861240386, 'avg_pred_std': 0.050000256625935435}\n", "Epoch 24\n", "Train loss {'avg_role_model_loss': 0.00042359301896794933, 'avg_role_model_std_loss': 0.04604176523210342, 'avg_role_model_mean_pred_loss': 1.3871310008198302e-07, 'avg_role_model_g_mag_loss': 0.009280809713527561, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.000428745166027511, 'n_size': 320, 'n_batch': 80, 'duration': 86.06277060508728, 'duration_batch': 1.075784632563591, 'duration_size': 0.26894615814089773, 'avg_pred_std': 0.09038203665986658}\n", "Val loss {'avg_role_model_loss': 0.0026232202184473864, 'avg_role_model_std_loss': 0.575840337219779, 'avg_role_model_mean_pred_loss': 1.1411215524770312e-05, 'avg_role_model_g_mag_loss': 0.007816769555211068, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002690732198243495, 'n_size': 80, 'n_batch': 20, 'duration': 19.039226293563843, 'duration_batch': 0.9519613146781921, 'duration_size': 0.23799032866954803, 'avg_pred_std': 0.04816291746683419}\n", "Epoch 25\n", "Train loss {'avg_role_model_loss': 0.00037949440795728153, 'avg_role_model_std_loss': 0.024586391349637894, 'avg_role_model_mean_pred_loss': 1.68624920803169e-07, 'avg_role_model_g_mag_loss': 0.009385608148295432, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00038402225013669524, 'n_size': 320, 'n_batch': 80, 'duration': 86.02636766433716, 'duration_batch': 1.0753295958042144, 'duration_size': 0.2688323989510536, 'avg_pred_std': 0.09288427671417594}\n", "Val loss {'avg_role_model_loss': 0.0026626741448126266, 'avg_role_model_std_loss': 0.6883730123883651, 'avg_role_model_mean_pred_loss': 1.143276741413235e-05, 'avg_role_model_g_mag_loss': 0.0078037152998149395, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0027320246357703582, 'n_size': 80, 'n_batch': 20, 'duration': 19.006802558898926, 'duration_batch': 0.9503401279449463, 'duration_size': 0.23758503198623657, 'avg_pred_std': 0.04854826694354415}\n", "Epoch 26\n", "Train loss {'avg_role_model_loss': 0.00036825626919494424, 'avg_role_model_std_loss': 0.06836344334289635, 'avg_role_model_mean_pred_loss': 2.864049477936933e-07, 'avg_role_model_g_mag_loss': 0.009369707445148378, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0003726810866737651, 'n_size': 320, 'n_batch': 80, 'duration': 87.12945866584778, 'duration_batch': 1.0891182333230973, 'duration_size': 0.27227955833077433, 'avg_pred_std': 0.0932244597002864}\n", "Val loss {'avg_role_model_loss': 0.002895271330635296, 'avg_role_model_std_loss': 0.6430459570903622, 'avg_role_model_mean_pred_loss': 1.4183609434681444e-05, 'avg_role_model_g_mag_loss': 0.0074269796255975965, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002975074338610284, 'n_size': 80, 'n_batch': 20, 'duration': 19.336661100387573, 'duration_batch': 0.9668330550193787, 'duration_size': 0.24170826375484467, 'avg_pred_std': 0.051954638119786976}\n", "Epoch 27\n", "Train loss {'avg_role_model_loss': 0.00045546451366362817, 'avg_role_model_std_loss': 0.07809749107524873, 'avg_role_model_mean_pred_loss': 1.9649450981934686e-07, 'avg_role_model_g_mag_loss': 0.009275437286123633, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00046089367940567173, 'n_size': 320, 'n_batch': 80, 'duration': 86.8785035610199, 'duration_batch': 1.0859812945127487, 'duration_size': 0.2714953236281872, 'avg_pred_std': 0.09191167326644063}\n", "Val loss {'avg_role_model_loss': 0.0028493454796262086, 'avg_role_model_std_loss': 0.6770066560213991, 'avg_role_model_mean_pred_loss': 1.2013152249146231e-05, 'avg_role_model_g_mag_loss': 0.007457524072378874, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0029267561651067807, 'n_size': 80, 'n_batch': 20, 'duration': 19.251688241958618, 'duration_batch': 0.9625844120979309, 'duration_size': 0.24064610302448272, 'avg_pred_std': 0.04829146796837449}\n", "Epoch 28\n", "Train loss {'avg_role_model_loss': 0.00033174752630884543, 'avg_role_model_std_loss': 0.06612742295388116, 'avg_role_model_mean_pred_loss': 1.180179221131573e-07, 'avg_role_model_g_mag_loss': 0.009398166846949607, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0003358864170422748, 'n_size': 320, 'n_batch': 80, 'duration': 86.8982789516449, 'duration_batch': 1.0862284868955612, 'duration_size': 0.2715571217238903, 'avg_pred_std': 0.09198763176100329}\n", "Val loss {'avg_role_model_loss': 0.0027787499493570067, 'avg_role_model_std_loss': 0.7222041939938209, 'avg_role_model_mean_pred_loss': 1.1131508226305125e-05, 'avg_role_model_g_mag_loss': 0.0075377147179096935, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0028551170726132113, 'n_size': 80, 'n_batch': 20, 'duration': 19.059608459472656, 'duration_batch': 0.9529804229736328, 'duration_size': 0.2382451057434082, 'avg_pred_std': 0.04740339070558548}\n", "Epoch 29\n", "Train loss {'avg_role_model_loss': 0.0003597953202529425, 'avg_role_model_std_loss': 0.024496306587910778, 'avg_role_model_mean_pred_loss': 1.142882032886816e-07, 'avg_role_model_g_mag_loss': 0.009390755963977426, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00036418304370045007, 'n_size': 320, 'n_batch': 80, 'duration': 86.48576998710632, 'duration_batch': 1.081072124838829, 'duration_size': 0.27026803120970727, 'avg_pred_std': 0.09743952928110958}\n", "Val loss {'avg_role_model_loss': 0.0026345706894062458, 'avg_role_model_std_loss': 0.7635843992035003, 'avg_role_model_mean_pred_loss': 1.1359682959266593e-05, 'avg_role_model_g_mag_loss': 0.007559091807343066, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0027058403313276356, 'n_size': 80, 'n_batch': 20, 'duration': 18.992973566055298, 'duration_batch': 0.9496486783027649, 'duration_size': 0.23741216957569122, 'avg_pred_std': 0.046653942298144103}\n", "Epoch 30\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train loss {'avg_role_model_loss': 0.0004039141500470578, 'avg_role_model_std_loss': 0.09576162853820311, 'avg_role_model_mean_pred_loss': 1.8995679496383588e-07, 'avg_role_model_g_mag_loss': 0.009361328079830856, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00040872053962175413, 'n_size': 320, 'n_batch': 80, 'duration': 86.09150242805481, 'duration_batch': 1.076143780350685, 'duration_size': 0.26903594508767126, 'avg_pred_std': 0.09370959255611525}\n", "Val loss {'avg_role_model_loss': 0.0027511596577824093, 'avg_role_model_std_loss': 0.6731958146176111, 'avg_role_model_mean_pred_loss': 1.417353958319545e-05, 'avg_role_model_g_mag_loss': 0.007637712243013084, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0028271504568692764, 'n_size': 80, 'n_batch': 20, 'duration': 18.946882724761963, 'duration_batch': 0.9473441362380981, 'duration_size': 0.23683603405952453, 'avg_pred_std': 0.047970831673592326}\n", "Epoch 31\n", "Train loss {'avg_role_model_loss': 0.00030792704831128503, 'avg_role_model_std_loss': 0.03007892958051279, 'avg_role_model_mean_pred_loss': 1.6044353357220544e-07, 'avg_role_model_g_mag_loss': 0.009489066211972385, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00031191278568485357, 'n_size': 320, 'n_batch': 80, 'duration': 86.08254933357239, 'duration_batch': 1.076031866669655, 'duration_size': 0.2690079666674137, 'avg_pred_std': 0.09290082987863571}\n", "Val loss {'avg_role_model_loss': 0.0027197946154046805, 'avg_role_model_std_loss': 0.6787573918956695, 'avg_role_model_mean_pred_loss': 9.246046926492114e-06, 'avg_role_model_g_mag_loss': 0.007403567247092724, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002790288768301252, 'n_size': 80, 'n_batch': 20, 'duration': 19.061045169830322, 'duration_batch': 0.9530522584915161, 'duration_size': 0.23826306462287902, 'avg_pred_std': 0.04738053930923343}\n", "Epoch 32\n", "Train loss {'avg_role_model_loss': 0.0002918601881447103, 'avg_role_model_std_loss': 0.02865394608418743, 'avg_role_model_mean_pred_loss': 1.377780366400853e-07, 'avg_role_model_g_mag_loss': 0.009482038370333613, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00029560920602307307, 'n_size': 320, 'n_batch': 80, 'duration': 85.58978748321533, 'duration_batch': 1.0698723435401916, 'duration_size': 0.2674680858850479, 'avg_pred_std': 0.09774628963787109}\n", "Val loss {'avg_role_model_loss': 0.002647562528727576, 'avg_role_model_std_loss': 0.6132856009367742, 'avg_role_model_mean_pred_loss': 1.2823440413711751e-05, 'avg_role_model_g_mag_loss': 0.007520708185620606, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0027170190507604273, 'n_size': 80, 'n_batch': 20, 'duration': 18.80093550682068, 'duration_batch': 0.940046775341034, 'duration_size': 0.2350116938352585, 'avg_pred_std': 0.0500286478549242}\n", "Epoch 33\n", "Train loss {'avg_role_model_loss': 0.0002785044133588599, 'avg_role_model_std_loss': 0.06320022720836534, 'avg_role_model_mean_pred_loss': 4.3557023588974905e-08, 'avg_role_model_g_mag_loss': 0.009524956706445663, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0002821075783458582, 'n_size': 320, 'n_batch': 80, 'duration': 85.059579372406, 'duration_batch': 1.063244742155075, 'duration_size': 0.26581118553876876, 'avg_pred_std': 0.09342062773648649}\n", "Time out: 3615.1168954372406/3600\n", "Eval loss {'role_model': 'tab_ddpm_concat', 'n_size': 399, 'n_batch': 100, 'role_model_metrics': {'avg_loss': 0.0013795483690064475, 'avg_g_mag_loss': 0.03192370239529646, 'avg_g_cos_loss': 0.010496431505914432, 'pred_duration': 1.911043405532837, 'grad_duration': 1.572967529296875, 'total_duration': 3.484010934829712, 'pred_std': 0.05930750072002411, 'std_loss': 0.007085928227752447, 'mean_pred_loss': 1.234028559338185e-06, 'pred_rmse': 0.03714227303862572, 'pred_mae': 0.02984810806810856, 'pred_mape': 0.06894627958536148, 'grad_rmse': 0.06608082354068756, 'grad_mae': 0.05324888974428177, 'grad_mape': 0.8990559577941895}, 'non_role_model_metrics': {'avg_loss': 0, 'avg_g_mag_loss': 0, 'avg_g_cos_loss': 0, 'avg_pred_duration': 0, 'avg_grad_duration': 0, 'avg_total_duration': 0, 'avg_pred_std': 0, 'avg_std_loss': 0, 'avg_mean_pred_loss': 0}, 'avg_metrics': {'avg_loss': 0.0013795483690064475, 'avg_g_mag_loss': 0.03192370239529646, 'avg_g_cos_loss': 0.010496431505914432, 'avg_pred_duration': 1.911043405532837, 'avg_grad_duration': 1.572967529296875, 'avg_total_duration': 3.484010934829712, 'avg_pred_std': 0.05930750072002411, 'avg_std_loss': 0.007085928227752447, 'avg_mean_pred_loss': 1.234028559338185e-06}, 'min_metrics': {'avg_loss': 0.0013795483690064475, 'avg_g_mag_loss': 0.03192370239529646, 'avg_g_cos_loss': 0.010496431505914432, 'pred_duration': 1.911043405532837, 'grad_duration': 1.572967529296875, 'total_duration': 3.484010934829712, 'pred_std': 0.05930750072002411, 'std_loss': 0.007085928227752447, 'mean_pred_loss': 1.234028559338185e-06, 'pred_rmse': 0.03714227303862572, 'pred_mae': 0.02984810806810856, 'pred_mape': 0.06894627958536148, 'grad_rmse': 0.06608082354068756, 'grad_mae': 0.05324888974428177, 'grad_mape': 0.8990559577941895}, 'model_metrics': {'tab_ddpm_concat': {'avg_loss': 0.0013795483690064475, 'avg_g_mag_loss': 0.03192370239529646, 'avg_g_cos_loss': 0.010496431505914432, 'pred_duration': 1.911043405532837, 'grad_duration': 1.572967529296875, 'total_duration': 3.484010934829712, 'pred_std': 0.05930750072002411, 'std_loss': 0.007085928227752447, 'mean_pred_loss': 1.234028559338185e-06, 'pred_rmse': 0.03714227303862572, 'pred_mae': 0.02984810806810856, 'pred_mape': 0.06894627958536148, 'grad_rmse': 0.06608082354068756, 'grad_mae': 0.05324888974428177, 'grad_mape': 0.8990559577941895}}}\n" ] } ], "source": [ "import torch\n", "from ml_utility_loss.loss_learning.estimator.pipeline import train, train_2\n", "from ml_utility_loss.loss_learning.estimator.process_simple import train_epoch, eval as _eval\n", "from ml_utility_loss.params import GradientPenaltyMode\n", "from ml_utility_loss.util import clear_memory\n", "import time\n", "#torch.autograd.set_detect_anomaly(True)\n", "\n", "clear_memory()\n", "\n", "opt = params[\"Optim\"](model.parameters())\n", "loss = train_2(\n", " [train_set, val_set, test_set],\n", " preprocessor=preprocessor,\n", " whole_model=model,\n", " optim=opt,\n", " log_dir=\"logs\",\n", " checkpoint_dir=\"checkpoints\",\n", " verbose=True,\n", " allow_same_prediction=allow_same_prediction,\n", " wandb=wandb if log_wandb else None,\n", " study_name=study_name,\n", " **params\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "id": "9b514a07", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:22:58.688280Z", "iopub.status.busy": "2024-03-03T07:22:58.688002Z", "iopub.status.idle": "2024-03-03T07:22:58.692209Z", "shell.execute_reply": "2024-03-03T07:22:58.691365Z" }, "papermill": { "duration": 0.027439, "end_time": "2024-03-03T07:22:58.694177", "exception": false, "start_time": "2024-03-03T07:22:58.666738", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "model = loss[\"whole_model\"]\n", "opt = loss[\"optim\"]" ] }, { "cell_type": "code", "execution_count": 26, "id": "331a49e1", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:22:58.736147Z", "iopub.status.busy": "2024-03-03T07:22:58.735234Z", "iopub.status.idle": "2024-03-03T07:22:58.808696Z", "shell.execute_reply": "2024-03-03T07:22:58.807708Z" }, "papermill": { "duration": 0.098126, "end_time": "2024-03-03T07:22:58.811111", "exception": false, "start_time": "2024-03-03T07:22:58.712985", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import torch\n", "from copy import deepcopy\n", "\n", "torch.save(deepcopy(model.state_dict()), \"model.pt\")\n", "#torch.save(deepcopy(opt.state_dict()), \"optim.pt\")" ] }, { "cell_type": "code", "execution_count": 27, "id": "123b4b17", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:22:58.854026Z", "iopub.status.busy": "2024-03-03T07:22:58.853705Z", "iopub.status.idle": "2024-03-03T07:22:59.129282Z", "shell.execute_reply": "2024-03-03T07:22:59.128422Z" }, "papermill": { "duration": 0.299231, "end_time": "2024-03-03T07:22:59.131243", "exception": false, "start_time": "2024-03-03T07:22:58.832012", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "history = loss[\"history\"]\n", "history.to_csv(\"history.csv\")\n", "history[[\"avg_loss_train\", \"avg_loss_test\"]].plot()" ] }, { "cell_type": "code", "execution_count": 28, "id": "2586ba0a", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:22:59.172775Z", "iopub.status.busy": "2024-03-03T07:22:59.172172Z", "iopub.status.idle": "2024-03-03T07:24:33.452817Z", "shell.execute_reply": "2024-03-03T07:24:33.451754Z" }, "papermill": { "duration": 94.304329, "end_time": "2024-03-03T07:24:33.455239", "exception": false, "start_time": "2024-03-03T07:22:59.150910", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "\n", "from ml_utility_loss.loss_learning.estimator.pipeline import eval\n", "#eval_loss = loss[\"eval_loss\"]\n", "\n", "batch_size = params[\"batch_size_low\"] if \"batch_size_low\" in params else params[\"batch_size\"]\n", "\n", "eval_loss = eval(\n", " test_set, model,\n", " batch_size=batch_size,\n", ")" ] }, { "cell_type": "code", "execution_count": 29, "id": "187137f6", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:24:33.498790Z", "iopub.status.busy": "2024-03-03T07:24:33.498395Z", "iopub.status.idle": "2024-03-03T07:24:33.519717Z", "shell.execute_reply": "2024-03-03T07:24:33.518764Z" }, "papermill": { "duration": 0.045233, "end_time": "2024-03-03T07:24:33.521618", "exception": false, "start_time": "2024-03-03T07:24:33.476385", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
avg_g_cos_lossavg_g_mag_lossavg_lossgrad_durationgrad_maegrad_mapegrad_rmsemean_pred_losspred_durationpred_maepred_mapepred_rmsepred_stdstd_losstotal_duration
tab_ddpm_concat0.0210140.0352950.001381.5304480.0532490.8990560.0660810.0000011.9069040.0298480.0689460.0371420.0593080.0070863.437351
\n", "
" ], "text/plain": [ " avg_g_cos_loss avg_g_mag_loss avg_loss grad_duration \\\n", "tab_ddpm_concat 0.021014 0.035295 0.00138 1.530448 \n", "\n", " grad_mae grad_mape grad_rmse mean_pred_loss \\\n", "tab_ddpm_concat 0.053249 0.899056 0.066081 0.000001 \n", "\n", " pred_duration pred_mae pred_mape pred_rmse pred_std \\\n", "tab_ddpm_concat 1.906904 0.029848 0.068946 0.037142 0.059308 \n", "\n", " std_loss total_duration \n", "tab_ddpm_concat 0.007086 3.437351 " ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "metrics = pd.DataFrame(eval_loss[\"model_metrics\"]).T\n", "metrics.to_csv(\"eval.csv\")\n", "metrics" ] }, { "cell_type": "code", "execution_count": 30, "id": "123d305b", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:24:33.562297Z", "iopub.status.busy": "2024-03-03T07:24:33.561794Z", "iopub.status.idle": "2024-03-03T07:24:33.923484Z", "shell.execute_reply": "2024-03-03T07:24:33.922460Z" }, "papermill": { "duration": 0.384487, "end_time": "2024-03-03T07:24:33.925661", "exception": false, "start_time": "2024-03-03T07:24:33.541174", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.util import clear_memory\n", "clear_memory()" ] }, { "cell_type": "code", "execution_count": 31, "id": "a3eecc2a", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:24:33.968371Z", "iopub.status.busy": "2024-03-03T07:24:33.968023Z", "iopub.status.idle": "2024-03-03T07:26:10.283377Z", "shell.execute_reply": "2024-03-03T07:26:10.282542Z" }, "papermill": { "duration": 96.339632, "end_time": "2024-03-03T07:26:10.285845", "exception": false, "start_time": "2024-03-03T07:24:33.946213", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caching in ../../../../contraceptive/_cache_test/tab_ddpm_concat/all inf False\n" ] } ], "source": [ "#\"\"\"\n", "from ml_utility_loss.loss_learning.estimator.process import pred, pred_2\n", "from ml_utility_loss.util import stack_samples\n", "\n", "#samples = test_set[list(range(len(test_set)))]\n", "#y = {m: pred(model[m], s) for m, s in samples.items()}\n", "y = pred_2(model, test_set, batch_size=batch_size)\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": 32, "id": "6ab51db8", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:26:10.329880Z", "iopub.status.busy": "2024-03-03T07:26:10.329107Z", "iopub.status.idle": "2024-03-03T07:26:10.347050Z", "shell.execute_reply": "2024-03-03T07:26:10.346351Z" }, "papermill": { "duration": 0.04205, "end_time": "2024-03-03T07:26:10.349032", "exception": false, "start_time": "2024-03-03T07:26:10.306982", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "from ml_utility_loss.util import transpose_dict\n", "\n", "os.makedirs(\"pred\", exist_ok=True)\n", "y2 = transpose_dict(y)\n", "for k, v in y2.items():\n", " df = pd.DataFrame(v)\n", " df.to_csv(f\"pred/{k}.csv\")" ] }, { "cell_type": "code", "execution_count": 33, "id": "d81a30f1", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:26:10.390598Z", "iopub.status.busy": "2024-03-03T07:26:10.390251Z", "iopub.status.idle": "2024-03-03T07:26:10.395376Z", "shell.execute_reply": "2024-03-03T07:26:10.394444Z" }, "papermill": { "duration": 0.028231, "end_time": "2024-03-03T07:26:10.397658", "exception": false, "start_time": "2024-03-03T07:26:10.369427", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'tab_ddpm_concat': 0.42752763501982344}\n" ] } ], "source": [ "print({k: sum(v[\"pred\"])/len(v[\"pred\"]) for k, v in y.items()})" ] }, { "cell_type": "code", "execution_count": 34, "id": "3b3ff322", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:26:10.442119Z", "iopub.status.busy": "2024-03-03T07:26:10.441234Z", "iopub.status.idle": "2024-03-03T07:26:10.820216Z", "shell.execute_reply": "2024-03-03T07:26:10.819351Z" }, "papermill": { "duration": 0.403394, "end_time": "2024-03-03T07:26:10.822165", "exception": false, "start_time": "2024-03-03T07:26:10.418771", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASIAAAE8CAYAAABkYrxdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABHO0lEQVR4nO2deZxT5fX/Pzd7JuvsG7PAsAyrDJsiIqD8EEWFav22isqIqK1QqnztS+z3VVGpRVpFXCqtfgtIRWlrsbWiIl8VEHdAFkUZtlmYfZiZZDLJZLvP74+bZCbMlmSS3CRz3q9XXpPcPLn3JJP5zHnOc55zOMYYA0EQhIhIxDaAIAiChIggCNEhISIIQnRIiAiCEB0SIoIgRIeEiCAI0SEhIghCdEiICIIQHRIigiBEh4QozigsLMT1118f0WtwHIfHHnus33GPPfYYOI6LqC3E4ICEKMp89tlneOyxx9Da2iq2KUQc8NJLL2Hr1q1imxFxSIiizGeffYbHH3+chIgICBIigiCIKEFCFEUee+wx/OpXvwIADB06FBzHgeM4lJeXY8uWLbjqqquQkZEBpVKJMWPGYNOmTb2e64MPPsDEiROhUqkwZswY7Ny5M2h77HY7HnzwQaSnp0On0+HGG2/E+fPnexx74MABTJ06FSqVCkVFRfjzn//c4ziO47BixQps374do0aNgkqlwuTJk7F///5unwXHcSgrK8Ptt98Og8GA9PR0/OY3vwFjDFVVVVi4cCH0ej2ysrLwzDPPBP3+AOC9997DrFmzoNPpoNfrMXXqVLz++ut+Y/7xj39g8uTJUKvVSEtLw+23347q6mq/MaWlpdBqtaiursaiRYug1WqRnp6Ohx56CG63228sz/N47rnnMH78eKhUKqSnp2P+/Pk4ePCgb0wgv+/CwkJ899132Ldvn++7Mnv27JA+h5iHEVHj6NGj7NZbb2UA2LPPPsv++te/sr/+9a/MYrGwqVOnstLSUvbss8+yF154gc2bN48BYC+++KLfOQoKCtjIkSOZ0Whkq1evZhs2bGDjx49nEomEffDBB0HZc/vttzMA7LbbbmMvvvgiu+mmm9iECRMYALZmzRrfuGPHjjG1Ws3y8/PZunXr2Nq1a1lmZqZvbFcAsHHjxrG0tDT2xBNPsPXr17OCggKmVqvZ8ePHfePWrFnDALCJEyeyW2+9lb300ktswYIFDADbsGEDGzVqFPv5z3/OXnrpJTZjxgwGgO3bty+o97dlyxbGcRwbN24ce/LJJ9kf//hHtmzZMnbHHXf4jQHApk6dyp599lm2evVqplarWWFhIWtpafGNW7JkCVOpVGzs2LFs6dKlbNOmTezmm29mANhLL73kd93S0lIGgF177bVs48aN7Omnn2YLFy5kL7zwgm9MIL/vt956iw0ZMoQVFxf7vivB/o7jBRKiKPOHP/yBAWDnzp3zO261WruNveaaa9iwYcP8jhUUFDAA7J///KfvmMlkYtnZ2aykpCRgO44cOcIAsPvvv9/v+G233dZNiBYtWsRUKhWrqKjwHTtx4gSTSqU9ChEAdvDgQd+xiooKplKp2I9+9CPfMa8Q3Xvvvb5jLpeLDRkyhHEcx5566inf8ZaWFqZWq9mSJUsCfn+tra1Mp9OxSy+9lNlsNr/neJ5njDHmcDhYRkYGGzdunN+Yd955hwFgjz76qO/YkiVLGAD2xBNP+J2rpKSETZ482ff4o48+YgDYypUru9nkvS5jgf++x44dy2bNmhXAO45vaGoWI6jVat99k8mEpqYmzJo1C2fPnoXJZPIbm5OTgx/96Ee+x3q9HnfeeSe++eYb1NXVBXS9d999FwCwcuVKv+MPPPCA32O3243du3dj0aJFyM/P9x0fPXo0rrnmmh7PPX36dEyePNn3OD8/HwsXLsTu3bu7TWOWLVvmuy+VSjFlyhQwxnD33Xf7jhuNRowaNQpnz54N6L0BwJ49e9DW1obVq1dDpVL5PedNOTh48CAaGhpw//33+41ZsGABiouLsWvXrm7n/dnPfub3eObMmX52/fOf/wTHcVizZk2313ZNdQjm9z0YICGKET799FPMnTsXGo0GRqMR6enp+PWvfw0A3b6Yw4cP75a/M3LkSABAeXl5QNerqKiARCJBUVGR3/FRo0b5PW5sbITNZsOIESO6nePisV56Gjty5EhYrVY0Njb6He8qbgBgMBigUqmQlpbW7XhLS0vvb+gizpw5AwAYN25cr2MqKioA9Pw+iouLfc978cZ7upKcnOxn15kzZ5CTk4OUlJQ+7Qvm9z0YkIltACF8ea+++moUFxdjw4YNyMvLg0KhwLvvvotnn30WPM+LbWLEkEqlAR0DACZyVePe7AqWwfz77g0SoijTUybyf/7zH9jtdrz99tt+HsLHH3/c4zlOnz4NxpjfucrKygAIKy2BUFBQAJ7ncebMGT+P4OTJk37j0tPToVarcerUqW7nuHisl57GlpWVISkpqZtHESm8nt63336L4cOH9zimoKAAgPA+rrrqKr/nTp486Xs+2Ovu3r0bzc3NvXpFwfy+B0vmOk3NooxGowEAv4RG73/arv/xTSYTtmzZ0uM5ampq8NZbb/kem81mbNu2DRMnTkRWVlZAdlx77bUAgOeff97v+MaNG/0eS6VSXHPNNfjXv/6FyspK3/Hvv/8eu3fv7vHcn3/+OQ4fPux7XFVVhX//+9+YN29e2LyK/pg3bx50Oh3WrVuHjo4Ov+e8n/OUKVOQkZGBP/3pT7Db7b7n33vvPXz//fdYsGBB0Ne9+eabwRjD448/3u0573WD+X1rNJpBkfxKHlGU8QZx/+d//gc//elPIZfLceWVV0KhUOCGG27AfffdB4vFgldeeQUZGRmora3tdo6RI0fi7rvvxtdff43MzExs3rwZ9fX1vQpXT0ycOBG33norXnrpJZhMJlx++eX48MMPcfr06W5jH3/8cbz//vuYOXMm7r//frhcLrzwwgsYO3Ysjh071m38uHHjcM0112DlypVQKpV46aWXfOeJFnq9Hs8++yyWLVuGqVOn4rbbbkNycjKOHj0Kq9WKV199FXK5HOvXr8ddd92FWbNm4dZbb0V9fT2ee+45FBYW4sEHHwz6unPmzMEdd9yB559/HqdOncL8+fPB8zw++eQTzJkzBytWrMC8efMC/n1PnjwZmzZtwm9/+1sMHz4cGRkZ3by3hEC8BbvBy9q1a1lubi6TSCS+pfy3336bTZgwgalUKlZYWMjWr1/PNm/e3G2pv6CggC1YsIDt3r2bTZgwgSmVSlZcXMz+8Y9/BG2HzWZjK1euZKmpqUyj0bAbbriBVVVVdVu+Z4yxffv2scmTJzOFQsGGDRvG/vSnP/mW4LsCgC1fvpy99tprbMSIEUypVLKSkhL28ccf+43zvraxsdHv+JIlS5hGo+lm66xZs9jYsWODfo9vv/02u/zyy5larWZ6vZ5NmzaNvfHGG35j/va3v7GSkhKmVCpZSkoKW7x4MTt//nxAdvX0GbhcLvaHP/yBFRcXM4VCwdLT09m1117LDh065GdXIL/vuro6tmDBAqbT6RiAhF3K5xijvmZE+OA4DsuXL8eLL74otilEHEExIoIgRIdiRAlIf0mNarUaBoMhStaEn8bGxm6JkV1RKBT95vEQsQUJUQKSnZ3d5/NLliyJ69ISU6dO7ZZs2JVZs2Zh79690TOIGDAkRAnInj17+nw+JycnYteORshx+/btsNlsvT6fnJwccRuI8ELBaoIgRIeC1QRBiE5cT814nkdNTQ10Ot2gSYUniHiCMYa2tjbk5ORAIund74lrIaqpqUFeXp7YZhAE0Q9VVVUYMmRIr8/HtRDpdDoAwpvU6/UiW0MQxMWYzWbk5eX5/lZ7I66FyDsd0+v1JEQEEcP0FzqhYDVBEKIjqhC53W785je/wdChQ6FWq1FUVIS1a9eKXgCLIIjoIurUbP369di0aRNeffVVjB07FgcPHsRdd90Fg8HQrZYyQRCJi6hC9Nlnn2HhwoW+AlSFhYV444038NVXX4lpFjHIYIzB5XL1uX+N6BmpVAqZTDbg9BlRhejyyy/Hyy+/jLKyMowcORJHjx7FgQMHsGHDhh7H2+12v0p6ZrM5WqYSCYrD4UBtbS2sVqvYpsQtSUlJyM7OhkKhCPkcogrR6tWrYTabUVxcDKlUCrfbjSeffBKLFy/ucfy6deuiWuWPSGx4nse5c+cglUqRk5MDhUJBibFBwBiDw+FAY2Mjzp07hxEjRvSZtNgXogrR3//+d2zfvh2vv/46xo4diyNHjuCBBx5ATk4OlixZ0m38I488glWrVvkee3MUiMTH4eLRanMgQ6fqf3Cg53Q4wPM88vLykJSUFLbzDibUajXkcjkqKirgcDi69ZALFFGF6Fe/+hVWr16Nn/70pwCA8ePHo6KiAuvWretRiJRKJZRKZbTNJGKAz89ewIkaM26/LB86lTys5w71vzghEI7PT9TfgNVq7fYmpFLpoOzrRHTHm8bBGENNqw0dTje+PNssslVEJBBViG644QY8+eST2LVrF8rLy/HWW29hw4YNfu2UicHL3w9WYc+JerQ73LhypNAP7UStGR1OWt1KNEQVohdeeAE//vGPcf/992P06NF46KGHcN9992Ht2rVimkXEACarEzWtHThRY4ZMwiHXqEa6Tgk3z3Cyrk1s8wYVhYWF3frdhRtRY0Q6nQ4bN26M+Jsk4o9zF9oBADlGFVRyoSHh6Gw9GtsacbK+DZfkGUW0jgg3FKUjYpIKjxAVpml8x4ZnaAEAta0dND0LEofDIbYJfUJCRMQcjDHUmoQ20XnJncvqBrUcqVoFeMZQ2RzZBESHi+/15nLzAY91BjA2FGbPno0VK1ZgxYoVMBgMSEtLw29+8xtfgL+wsBBr167FnXfeCb1ej3vvvRcAcODAAcycORNqtRp5eXlYuXIl2tvbfedtaGjADTfcALVajaFDh2L79u0h2RcscV0GhEhMzDYXbA43pBIOaVr/bN0pBSngGcOQZHVEbfjjx91bb3sZmqbBopJc3+OX95+B093zRu0hyWrcMqUz123zp+dgc/h7cw/+v5Eh2fjqq6/i7rvvxldffYWDBw/i3nvvRX5+Pu655x4AwNNPP41HH30Ua9asAQCcOXMG8+fPx29/+1ts3rwZjY2NPjHztisvLS1FTU0NPv74Y8jlcqxcuRINDQ0h2RcMJEREzFFnFryhNK0SMqm/0z4mh+pOecnLy8Ozzz4LjuMwatQoHD9+HM8++6xPiK666ir893//t2/8smXLsHjxYjzwwAMAgBEjRuD555/HrFmzsGnTJlRWVuK9997DV199halTpwIA/vKXv2D06NERfy8kRETMIeGAdJ0S2cbwZVEHy/I5w3t9TnLRLpB7ryzqdezFO0aWzhg6ELP8uOyyy/y2pEyfPh3PPPOMb/PulClT/MYfPXoUx44d85tuMcZ8W13Kysogk8kwefJk3/PFxcUwGo1hs7k3SIiImGNEpg4jMnW91qW6YLGjotmKLL0KOcbITNEUssDDp5EaO1A0Go3fY4vFgvvuu6/HEjv5+fkoKyuLlmndICEiYpbeNqAeqzbhSGUrJuYZIyZE8cCXX37p9/iLL77AiBEjIJVKexw/adIknDhxAsOH9+ztFRcXw+Vy4dChQ76p2cmTJ9Ha2hpWu3uCVs2ImILnGXi+7wqduR7xqW7tvdvrYKCyshKrVq3CyZMn8cYbb+CFF17AL3/5y17HP/zww/jss8+wYsUKHDlyBKdOncK///1vrFixAgAwatQozJ8/H/fddx++/PJLHDp0CMuWLYNaHXmxJyEiYopacwde/Pg0dh4+3+sYrxfUZLHD7hq8+UR33nknbDYbpk2bhuXLl+OXv/ylb5m+JyZMmIB9+/ahrKwMM2fORElJCR599FG/FuRbtmxBTk4OZs2ahZtuugn33nsvMjIyIv5eaGpGxBQt7Q64+/GItEoZdCoZ2jpcaDDbkZcyOEt4yOVybNy4EZs2ber2XHl5eY+vmTp1Kj744INez5mVlYV33nnH79gdd9wxIDsDgTwiIqa40C5kAKdo+q72l6kXVtTqPUv9RHxDQkTEFC0BClGWQRCiOhKihICmZkRM0ewRouSkfoTI4xE1mO19jktU9u7dK7YJYYWEiIgZeJ6hrcMFADAm9V2FMVOvwi1ThoS1dCwhHiRERMzQZneBZwxSCQetsu+vpkImwZDkwRmkTkRIiIiYgTGGEZla8Kz/XulEYkFCRMQMxiQFrp+Q0/9ADy3tDhw93wqO4zDLU0qWiE9o1YyIW+wuHt9UtuJEjbnXfWlEfEBCRMQMdpc7KEFJ0yoglXDocLphsjkjaBkRaUiIiJjhrcPVePGj0yhvau9/MACZVIJ0ndDnjvKJ4hsSIiJmMHc44eIZ1Iqed4/3hDefqM5EQhTPkBARMYHTzaPdLmxgNagD7+TqzbCmrR7xDQkREROYPTEehUwCZRDFw7pmWPe3WTYgGANcjujfgoiNbdu2DampqbDb/bPKFy1aFJUNqpGAlu+JmMDsyag2qOVB5RAZk+RQyiWQSySwdLhg6Ccju1/cTuCTZwZ2jlCY+d+ArO9tLV5uueUWrFy5Em+//TZuueUWAEL3jV27dvW5sz6WIY+IiAksHiHSqYL738hxHJbOGIp7rhw2cBGKE9RqNW677TZf5w0AeO2115Cfn4/Zs2eLZ9gAII+IiAna7MLUrL+tHT3h7QQbFqRywTuJNtLgRPSee+7B1KlTUV1djdzcXGzduhWlpaVxm5FOQkTEBCkaBYZnaH11hgKCdwPlnwD13wHyJLBhs8GlDLBLBscFPEUSk5KSElxyySXYtm0b5s2bh++++w67du0S26yQISEiYoLiLD2Ks4LoWcYYcPI9oO44eMZQVn4ebd+dxPjrfw5VakHkDI0hli1bho0bN6K6uhpz585FXl5e/y+KUShGRMQnDd8DdccBjoOk+FrUyfPgcrlgPfofgA+tjXO8cdttt+H8+fN45ZVXsHTpUrHNGRAkRERMYHMEsb3D7QLOfizcL5gB5JSgY/h1cElUsJkagIbvImdoDGEwGHDzzTdDq9Vi0aJFYpszIEiICNFxunn8ad8ZvPjR6cC6cjR8B3SYAaUOyL8MAJCebECtbjzaOpxA9eEIWxw7VFdXY/HixVAqlWKbMiBIiAjRabcLS/ccByik/XwlGQPOfy3cHzLFt9qUZVChQTsKFgcPZq4GLI2RNFl0Wlpa8NZbb2Hv3r1Yvny52OYMGApWE6LjLQ+rVcr6X342nRdERioDsi/xHc7QKeGWaVAvz4PD1Qpl4w+ANnFrFJWUlKClpQXr16/HqFGjxDZnwJAQEaJj8XhEWlUAuTQNJ4Sf6cWAvLMDqVwqQZpOgZa2AljsTVA2nQSGzoyEuTFBb33L4hWamhGi452aaZX9JCbybmG1DAAyx3Z7Oi85CeqcYnASqeA12VrDbCkRKcgjIkSnzSdE/XhELeWA0wYokgBjYbenrxyZDiAdOFwAmKqB1gpAbez3+lTdcWCE4/Mjj4gQHe8+M01/HpFvWjYakPTx1TV6EhpbKvo8nVwuCJ/Vag3ITqJnvJ+f9/MMBfKICNHJTVaD44A0bR9L0DwPXDgt3M8o7vN8dl0epDwPWWuFsMrWSwBcKpXCaDSioaEBAJCUlBS3e7XEgDEGq9WKhoYGGI1GSKWh7/kjISJEZ1J+MiblJ/c9yFwNODsAuQrQD+l12J4T9ThR3YH/srmRLbEA1mZAk9rr+KysLADwiRERPEaj0fc5hgoJEREfNJ8RfqYM63NaplFKwUOGJi4F2WgTBKwPIeI4DtnZ2cjIyIDTSQX4g0Uulw/IE/JCQkSIiptnsLvcUMulfU+LvNOylKI+z+erYc1SMB5tQFstkD2hXzukUmlY/qCI0KBgNSEqjW12/HnfWWz5tLz3QR0mYTme4wSPqA+8NaxrWTJcPC94RETMI7oQVVdX4/bbb0dqairUajXGjx+PgwcPim0WESW8yYx9du7wekP6XGHpvg+SFDIYk+RoU2QKq3GWRqH8KxHTiDo1a2lpwYwZMzBnzhy89957SE9Px6lTp5Cc3E/gkkgYfFnVfVVmbD4n/Ezte1rmJUuvwg/tGphcShgZD7TVAcb4rdUzGBBViNavX4+8vDy/2rtDhw6wwh4RV7T3J0Q8D7RWCveTCwM6Z5ZBhR/q2tCAVBSgSYgTkRDFNKJOzd5++21MmTIFt9xyCzIyMlBSUoJXXnml1/F2ux1ms9nvRsQ3vg2vvRXNt9QBLjsgUwLawJaIhyQnYUyOHhk5+Z5z1IfDVCKCiCpEZ8+exaZNmzBixAjs3r0bP//5z7Fy5Uq8+uqrPY5ft24dDAaD7xbPpTEJAe/UTKPoRYi82dHG/L6zqbuQrlPimrFZKMj3eNcWyhGKdUQVIp7nMWnSJPzud79DSUkJ7r33Xtxzzz3405/+1OP4Rx55BCaTyXerqqqKssVEuPFOzXptI9TqEaIAp2V+aDOEn9YLwoZZImYRVYiys7MxZswYv2OjR49GZWVlj+OVSiX0er3fjYhvitK1GJGphb6nNtNuF2Dy/LMxBlcQn+cZGpxKmFxSQYTam8JgLREpRA1Wz5gxAydPnvQ7VlZWhoKCwdGFgQCuGJHW+5PmakGMFBpA08e4HjjVYMG7x2txeZsKlya3A+0NgC5zgNYSkUJUj+jBBx/EF198gd/97nc4ffo0Xn/9dbz88ssJUfqSCAO+aVlBrxtXeyNdJ2ygref1QpkKCljHNKIK0dSpU/HWW2/hjTfewLhx47B27Vps3LgRixcvFtMsIkrYXW5YHa7e69n4AtXBe8hGtRwyCQezNBUdLp4C1jGO6HvNrr/+elx//fVim0GIwOkGCz74rh5D0zRYVJLr/6TbKeT/AIJHFCQSCYdUrRKWjlRYHS6oLQ19lgQhxEX0LR7E4MVbEC2pp+0d5hohyKzUAipjSOdP1ylhkxnR7uCFyo5OKoAWq5AQEaLRWTS/B8fcdF74aRgSsheTplWAl8hgYp79abRyFrOQEBGi0ec+M58QhZ606g1YN/I64YCVhChWISEiRKNXIepavsPQezXG/kjXKXHFiDSMGT4MDIw8ohiGhIgQjV43vFqbhP1lUjmgyQj5/EqZFFMLU5CVPQQcOBKiGIaEiBAFN8/Qbhe2XXSLEXmzqQ1DAt5f1idJnmRImprFLKIv3xODEzfPUJJvRLtdKBPrR9dA9QBp63Ci1qJAhs0BIwA4rP0WVyOiDwkRIQoKmQSzR/Uy7fIKkT635+eD4FxTOz78/gLmWhUwqiF4RYr8AZ+XCC80NSNiiw4T0GEGOElYhCg5SQEAuMA8K2cUJ4pJSIgIUbA6XGi397C9w1wj/NSmAzLFgK+TohHO0Qw9eMaEkiBEzEFCRIjC1+UteHn/WXxy6iIPxStEYfCGACFrWymXwCpLhs3pBtobw3JeIryQEBGi4F2611y8dO/dX6bLDst1OI5DSpICNrkRHU630PmViDlIiAhR8O4z86vMyPOdQqTPCdu1kjUK2GQG2BxuwN4GuBxhOzcRHkiICFFo6ymZ0dokFEKTKQB1StiuZVTL4ZaqYIWw5QM28opiDVq+J6IOY6wzq7qrR+SND+myw5PI6GFEpg6pWiVyzuYCtjpheqYLrCMIER1IiIioY3W44eYZOO6i7h1hjg95SdEohNWz5gyPENHKWaxBUzMi6nRtISSVdCnx4VsxC198yI+kVOEnTc1iDvKIiKijlElQkm+ErOv0y+3sTDYMs0cEAGcaLbA1SzHSzUNBK2cxBwkREXWMSYru2zva6gDGCx07lLqwX/PzMxdgaeaRy1xQWC9Q2dgYg6ZmRGzQddk+AgJhUMvRIdOjw8UE78thCfs1iNAhISKijsnm7L69I0KBai8GtRyMk6JdohUO0PQspiAhIqLOnhP1eHn/WfxQ19Z50NvuJ0LL6gZPJ1kTvEJEK2exBAkREXUsHU4AXZIZ3c5OYdCGXpGxL7xC1AJP/IlWzmIKEiIiqjDGuteqbm8UgseKJEChjch19R4hamY6oX61tSUi1yFCIyQhOnv2bLjtIAYJdhcPp1uIDfmyqtvqhJ/azIitZHn3tLVxOrh5KgcSa4QkRMOHD8ecOXPw2muvoaOjI9w2EQmM1xtSyaWQSz1fP298SJsZsevKpRIsKsnFjdPHQ8JxQgE23h2x6xHBEZIQHT58GBMmTMCqVauQlZWF++67D1999VW4bSMSELNNiA/57bq31As/IyhEADA0TYP01FRIZAohZ8nWGtHrEYETkhBNnDgRzz33HGpqarB582bU1tbiiiuuwLhx47BhwwY0NlLxKaJn2jzlP7wxG/A80B55j8gHxwFJnp39ND2LGQYUrJbJZLjpppvwj3/8A+vXr8fp06fx0EMPIS8vD3feeSdqa2vDZSeRIKTplJhUkIxhaRrhgK1ZKP0hlQHq5Iheu97cgYPlzahzdbk2ERMMSIgOHjyI+++/H9nZ2diwYQMeeughnDlzBnv27EFNTQ0WLlwYLjuJBCHXqMaskekYl2sQDnSdloWx9EdPVDVb8cmpJlRaPXWJKKkxZghpr9mGDRuwZcsWnDx5Etdddx22bduG6667DhLPF2no0KHYunUrCgsLw2krkYhEKT4EADoV5RLFKiEJ0aZNm7B06VKUlpYiO7vnlPyMjAz85S9/GZBxROLRYO6ARilDkkIKjuO6rJhFJpGxK950gWamBTiQRxRDhCREe/bsQX5+vs8D8sIYQ1VVFfLz86FQKLBkyZKwGEkkBi43j+1fVgIAfjarCGq5xD+HKMJ4V+qaeS0Yx8A52gFnByBXRfzaRN+ENCkvKipCU1P3RnXNzc0YOnTogI0iEhPviplCJoFKLhEK2TttQjNFTXrEr69VyMBxgANyOGSettM2yrCOBUISom5N8TxYLBaoVPTfhegZs2ePmV4l85+WJaUAUnnEry+RcL5tJTaZXjhIcaKYIKip2apVqwAIvaIeffRRJCUl+Z5zu9348ssvMXHixLAaSCQOZpu3hZBHdKIYqPaiU8nQ1uFCu8QAI+opThQjBCVE33zzDQDBIzp+/DgUis6WwAqFApdccgkeeuih8FpIJAxtXo9I7fnaWTzxoSh21Jg9KgMcByQ3XQDKyyipMUYISog+/vhjAMBdd92F5557Dnq9PiJGEYlJ59TM6xFFb8XMS6beEzqwpwk/aWoWE4S0arZly5Zw20EMAlqsghAZk+TCapV3r1cUp2Y+vA0crc1UvzoGCFiIbrrpJmzduhV6vR433XRTn2N37tw5YMOIxGN8rgFZBhXStSqg3bP9R6UH5Oqo2WCyOXGqvg0SMEziuM761REo2E8ETsBCZDAYhJUOz32CCBbftg4AuBD9QDUglCH55FQTDGo5JqmMwvK9tZmESGQCFqKu0zGamhEDRoQVM0DIJQIgFO/XJYOztQhxouSCqNpB+BNSHpHNZoPVavU9rqiowMaNG/HBBx+EbMhTTz0FjuPwwAMPhHwOInYxWZ2oNdnQ4fQUIxNJiDRKKQDAxTM4lUbhIC3hi05IQrRw4UJs27YNANDa2opp06bhmWeewcKFC7Fp06agz/f111/jz3/+MyZMmBCKOUQccKLWjB1fVeGTU01CZcQIF8vvDZlUApVcECOr1DNVpOxq0Qm5QuPMmTMBAG+++SaysrJQUVGBbdu24fnnnw/qXBaLBYsXL8Yrr7yC5OTI1qMhxKPV6gAAJCfJhdbSvBuQKQFV9OONWo9X1C71pJ+QRyQ6IQmR1WqFTicE9z744APcdNNNkEgkuOyyy1BRURHUuZYvX44FCxZg7ty5/Y612+0wm81+NyI+aLV1WbrvOi0TYdlc49nmYZZ4y4G0CJUiCdEIuXj+v/71L1RVVWH37t2YN28eAKChoSGoJMcdO3bg8OHDWLduXUDj161bB4PB4Lvl5eWFYj4RZRhjaPF4RAa1olOIdCLkD6GzjVEbnyRUhmQ80NEqii2EQEhC9Oijj+Khhx5CYWEhLr30UkyfPh2A4B2VlJQEdI6qqir88pe/xPbt2wPeKPvII4/AZDL5blVVVaGYT0QZi90Fu5OHhOOEqZlIgWovUwtTsPiyfEzIM3aWp6U4kaiElFn94x//GFdccQVqa2txySWX+I5fffXV+NGPfhTQOQ4dOoSGhgZMmjTJd8ztdmP//v148cUXYbfbIZVK/V6jVCqhVCpDMZkQkQsWwRsyJskhk3CiC1GypnOPJNQpgKVRCJ6nFoliDxGiEAFAVlYWsrL8NytOmzYt4NdfffXVOH78uN+xu+66C8XFxXj44Ye7iRARv1xoF4QoVasQpkAuByCRAkmp4hoGdNpAAWtRCUmI2tvb8dRTT+HDDz9EQ0MD+IsCfYF0gtXpdBg3bpzfMY1Gg9TU1G7HifimIDUJs0alC/3nLdXCQU2aIEYi0OF049tqExxuHpdrPHvOaPOrqIQkRMuWLcO+fftwxx13IDs727f1gyB6Ik2rRJrWM6U+J+60DACcbh6fnGqChOMwfWoyOIA8IpEJSYjee+897Nq1CzNmzAirMXv37g3r+YgYxFf6I3o1iC5G4ykZyzMGq9QADSCUrXU5AJmiv5cTESCkVbPk5GSkpKSE2xYiAbHYXfiuxoRmT5yos1h+dDOquyKRcEhSeJIaeXnn7n9aORONkIRo7dq1ePTRR/32mxFET1Q1W/HBd/XYc6IOsFsEz4PjRBUioDOp0WJ3dbagpjiRaIQ0NXvmmWdw5swZZGZmorCwEHK5f+Hzw4cPh8U4Iv6pM3cA8FRGbPPUIEpKFbZ3iIhWKUMD7Gi3u4UlfFM1xYlEJCQhWrRoUZjNIBKVBj8h+kE4qOu5KWc00ZJHFFOEJERr1qwJtx1EAuLmGRrMdgBAll4FNHg8Ir34QuQ3NdN1KRtLiELICY2tra148803cebMGfzqV79CSkoKDh8+jMzMTOTm5obTRiJOabLY4eIZlHIJjGpZ59QsBjyisTl6DEvXCIX87Z70E+sFql8tEiEJ0bFjxzB37lwYDAaUl5fjnnvuQUpKCnbu3InKykpfrSJicFPVLCxm5BrV4OwmoaurRApoxA1UA0JvNV9/NYlnv5nLLtioSOr9hURECGnVbNWqVSgtLcWpU6f8Nqxed9112L9/f9iMI+Kb8y02AMCQ5CTA7PGGtBnCjvdYQioXivgD1OdMJEL6RngrKl5Mbm4u6urqBmwUkRjMH5eF6labkFVdHTvTMkCIXx2ubIHF7sKVI9IhVacAHWYhYG2k8jLRJiSPSKlU9liUrKysDOnp6QM2ikgMVHIpitK1wh4z03nhoD5HXKM8SDjgs9MXcKSyFVZHl5UzCliLQkhCdOONN+KJJ56A0ylU3eM4DpWVlXj44Ydx8803h9VAIgFwOzszqg2x4W1wHOcrpN9ud3fuwqclfFEISYieeeYZWCwWpKenw2azYdasWRg+fDh0Oh2efPLJcNtIxCHvHKvB52cuwOZwA+ZqoQqiUidKjere8FvC9xZII49IFEKKERkMBuzZsweffvopjh49CovFgkmTJgVUd5pIfFraHThVb8EZrh0l+Uag1VNJ05gXU0vjXiFqt7uANG9SYyst4YtA0ELE8zy2bt2KnTt3ory8HBzHYejQocjKygJjjEqCxChl9W2obrFhbK4eGbrASvOGyqkGCwBgSLJaaN3jjQ8ZhkT0usHi6+ZhdwHKFCG1gHcBHSZAbRTXuEFGUFMzxhhuvPFGLFu2DNXV1Rg/fjzGjh2LiooKlJaWBlwmloguZfVt2HWsFkeqWvHmofNo63BG7FqMMZyoMQEAirN1Qtsgs1eI8iN23VDQKLpMzSSSLvWraXoWbYLyiLZu3Yr9+/fjww8/xJw5c/ye++ijj7Bo0SJs27YNd955Z1iNJAZGc7sDEo4DzxjsTh4Hy1swpzgySYW1pg60WJ1QyCQYkaETRMjtEkptaNIics1Q8U3NHC7hgDpZ6LlmbQGoyk1UCcojeuONN/DrX/+6mwgBwFVXXYXVq1dj+/btYTOOCA+XDUtF6YxC3HCJkMPzfZ0ZTndk+nidqBHSOorStVDIJECzp2xwytCYi7sMS9dg8WX5mD/Wk9tEm19FIyghOnbsGObPn9/r89deey2OHj06YKOI8GNQy1GUroVOJYPdyfuynsOJ082jrKENgLCXC0AXIRoW9usNlCSFDBk6FdSeImlQUy6RWAQlRM3NzcjM7L3WcGZmJlpaqMpdLGF3uX33OY7D0DQNAKC8qT3s13K4eBSla5GmVWBIsloohNbmqVGdPDTs1ws75BGJRlAxIrfbDZms95dIpVK4XK4BG0WEB6vDhT/vO4tUrQKLLy2AVMJhTI4eaVolCj2CFE40ShmuGdtl9bT5jPCELhNQasN+vXBwqKIF5g4nphWmQOP1iDpMQlwr1vbEJTBBfdKMMZSWlvba5NBut4fFKCI81JmEomSMAVKJEJ/JNqiRbVBH9Lq+FI4GTyG0tJERvd5AOFLVCrPNiVGZOmgMGqF4vssh1K/W0nalaBGUEC1ZsqTfMbRiFjt4hSjLENm8IQCovGCFSi5Buk4pCJHDCrSUC09mjIn49UNFq5TCbHMKuUQcJ8SJ2uqE6RkJUdQISoi2bNkSKTuICOCtF52l9xcik9WJ8gvt0CilGJ6hG/B1GGPYW9aACxYHrhufjVFZOqDxe2Fbhy6zM/YSg/ht8wAEW9vqKGAdZULaa0bEPowxNLR5yrRe5BGdbbLgox8a8F1N9woKodDYZscFiwMyCYeC1CRhLljzjfBkZmx37e3c5uEJ6vtaUFNdomhCQpSgWB1u2BxucByQovFvGujtutpkcYTlWt4tHYVpGs+WjirA0igEe7PGh+UakULbzSPyJF22N4pk0eCEhChB8TY0NKjlkEv9f83pOkGIzDYnOpzubq8NljONghANz/CsjFV8JvzMGNvZvDBG8W7zaPcKkcYTF/LWryaiAglRgqKUSTA2R98pDl1QyaXQqYQ/wCbLwFY6m9sduGARtpAMTdMIAermc8IG0oLpAzp3NNB22+ZhFGx3O4VlfCIqUKJEgpKhV2He2N77y6frlGjrcKHJ4hBqSoeI1xvKS1FDJeGBst3CE9kTOzeRxjCZBiVuv6zAJ0iQSDv3nLU30S78KEEe0SDFGydqbBuYR1R5QejUUZSWBHz/trDapNQBQ68csI3RQCmTIl2n7NzmAXSZnjWJY9QghDyiBIQxhgvtDhjVcsikPf+v8QrRBYtdEI/aI0BLhWc6wgCZSvAM1CmAJlX440xK69ZqZ1FJLurrqpFa+y5grhI8ijE3AvLI5y5FDG+VgHYSomhBQpSAdDh5/PXzCkglHO6fXdSjGOWnJOG/puQivfFL4KuvhZyfrjg7hGqFOOd/XKERREkhLNNLbc3I8e4nk8qAMYsAY2zVHeqP4+dNaLLYMWGIAalaJa2ciQAJUQJisgmFz9Ryaa8ekVoG5J5/D2g6JRxIGQZkjRNEhpMATquwzcF6oTNe0mECHO3CrSscB6QUAUVXCd5TnPF9nRnVLTbkJqsFIeo6NaOysVGBhCgB8QqRIUne8wDGgJMeEZLKgOLrgYzR3cdd7Nm47B5hagRcDnxyqgFyjRFji0dDZ4jd7On+8KvUCAhTUolU2Pja0RoXQfd4h4QoAWm1duYQ9Uj9t0DdcZg63DiZPhfJbAhGBHJimVLoS6bPgdXhwsEfzgJmYIJaHzbbxUDTtXY1IJSNTUoRkjLbL5AQRQFaNUtAfB5RT0LkaAdOfQAAqDZOxqctBpxpDL42UU2rsI8tTatAkiK+/59plRclNQKdcSJaOYsKJEQJSJ9CVH5AKHOhy4Ri2AwAwIX24JfwvTv7I11SJBp0bnztkmWuoYB1NInvf2VEj3iFyHhxjMjaDNQcEe4XXY1UhbDE3mxxgOcZJJLAg7I1JqHUbDRKjESaHj0ib8CalvCjAglRgsEYw/hcA1ptThjV/ptdUfWlsEyfWgQkF8DAM8gkHFw8g7nDCWOSoueTXoSbZ6j3eEQ5xkTyiHoRIp4X4kZExCAhSjA4jsOlw3pYQne0A3XfCvfzLwMASCQcUrQKNJjtaLI4AhaixjY7XDyDSi5Fcm8rc3GEQS333+YBACojIJULe85szTHXCinRIJkfLNR8I3Qx1WcDhjzf4VRNlwzrAOlwumFQy5FjVCVEZ1+phOu+zUMiAbSe3m+WenEMG0SQR5RgmGxOuHkGvUrWmczIGFB7TLifO8UvQS9NK3hBrbbAu78Wpmmw9IqhcEWoN1rMoM0ETNWCEGWOFduahIaEKMH4prIF31S2YnJBMq4c6YlztHr2kMkUQPoov/FjcwwoztZD09UbCJDesrbjkR/qzKht7cDILB1yvXEvn0dEK2eRRtRv0rp16zB16lTodDpkZGRg0aJFOHnypJgmxT09rpjVHRd+ZowR4h5dUCuk0CplAU+x3DwDS8CCYeca23GkqhW1rV0aT2o9PfxoahZxRBWiffv2Yfny5fjiiy+wZ88eOJ1OzJs3D+3t4W/+N1jolkPkcgCNnrY+YSjbeqbRgk37zuDjkw0DPlcsoVMJn1fbxStnHCcE+u0WkSwbHIg6NXv//ff9Hm/duhUZGRk4dOgQrrwyPurZxBKMMZisFwlR81lhz5TaCOhze3zd8fMmnG2yYMIQo68TbG/UtNpgdyZebMhbsbKto4sQSeVCMf32JsEritEmkYlATE3yTSahNGdKSs8bKO12O8xms9+N6MRid8HFM0g4zvcfHk2eqW7ayF53kdeZO3C2sR21JluPz3el1pdRHf+JjF3pFKKLgvbefCJLYnmAsUbMCBHP83jggQcwY8YMjBvXcwuadevWwWAw+G55eXk9jhuseKdlOpVM6OzqdgEXTgtPXhSk7kqqZ+XsQj9dPZxuHg1mYZk/EbZ2dMU3NevqEQEUJ4oSMSNEy5cvx7fffosdO3b0OuaRRx6ByWTy3aqqqqJoYezTLT7UWiHEiJTaXqdlAJAWYC5RvbkDPGPQKmXQqxJrwdXrEdkcbjhcXaaevpUz8ogiSUx8m1asWIF33nkH+/fvx5AhQ3odp1QqoVQqo2hZfJGuU+KyYam+Pyo09j8tAzo9olabEy433+uyfNcW1omQyNgVlVwKhUwCh4uHxe5CisyTZe71iGzNQj0mGX3/IoGoQsQYwy9+8Qu89dZb2Lt3L4YOHSqmOXFPhk6FDJ0ndsMYcMFTfTFtZJ+vS1JIoVZIYXO40Wx1dJ7jImp8+8sSKz7k5dZp+UhSSKGUdRFipRZQ6YEOs9CKOrlAPAMTGFGnZsuXL8drr72G119/HTqdDnV1dairq4PN1n/QlOiHtlrAYRWSGPupIc1xnK8bbF9xolyjGvkpScg1ht5+KJZJ0Sigkku7e3u6bOGnuSb6Rg0SRBWiTZs2wWQyYfbs2cjOzvbd/va3v4lpVtxSecGKVqtQ0gPNnqL3xgKh7Gk/pGkVUMgkfXZ+nVyQjJsnD0mI0h9B4Y2vtZEQRQrRp2ZEeLC73Pjn4fMAgJ/PLoKq+azwRMqwgF5/xfB0zBmVkXCxn2CoNdnwQ20bDElyTMrvUh5W7/WIasUxbBAQM6tmxMDwde5QSKGCs3MakRJY3E0hk/QpQnWmDlgdrl6fTwRarU4cqWrF2YtL52qzhGC/vU2IFRFhh4QoQfBmVBvVcmHZnvFCAfgwFH5njOGdYzX4876zON9iHfD5YpVekxplis56RG3kFUUCEqIEwS+HyBsfCnBa5uXjHxqw7fPybmJjsjnR1uGCVML1uqKWCHRNauwWNvDGiShgHRFIiBKEVt8eM5mwvwwAkoNLh+hwunHB4kBls78QVTV31qdWyBL3K6NTyiDhOLh55r/5FehcOSOPKCIk7rdqkOH1iFIkFqH2kEQadOvnvBRhWf58s3/6xNkmYed5fkpiLtt7kUg46NXC9Mw71fWhzxF+mmuEGtZEWCEhShC8FRZT7J6pg2GIENsIgrxkQWhqTR2wu4RlfLvLjcoLgoc0PCPxd597t8eYLq5YqUkXsqrdTsBSJ4JliQ0JUYJweVEqLh2WAqPNs/8uyPgQILSoTk6Sg2cMZxqElaMzDe1w8QzJSXKkaoITtnjEW1Cu2+ZXjuv0MFtpj2O4ISFKEEZn63F5oREKi+ePJMj4kJfibKF99IlaYZn6dKPFd/7BkGN02bBU/Hx2EaYX9dAJxdt0wERCFG5IiBIJU5VQ+kOh6dw1HiSjs/WQcByqmq1oMHfg2nFZuGxYKsblGsJsbGySpJBBJe8lE93oEaLWSmEvHxE2YmL3PTEwmix2tNtdyGg8DTUgTMtC9F4Majkm5BngcjOk65TgOK5n72Awos0Sqja67EIr6hDFnugOCVEC8F2NGYcrWrDA8R1GJiHgbOremDk8DRKOGxRTsYthjGFvWSNMVifmj8vy944kEmERoPmcECciIQobNDVLAFqtDshd7dC7WgRPKLlwQOeTSSWQSAafCAFCJYLT9Raca2r35Wb54Q1Yt5yLrmEJDglRAmCyOWHsOA+VTALosoQYEREyBs/KWauth5Io3kWA1gqA771SAREcJERxDs8ztFo9QqSQhrxaRnRi9OQStbT34BHpsgC5WijBa66OsmWJCwlRnGOyOeF2u5HiOC9UFgwhf4jwx1skrsXag0fEcZ0xOO9WGmLAkBDFOc1WB7SORmglLnAyVZ9F8onA8FWrbO+lWqVX7JspThQuSIjinJZ2Bwwd56FWSIUgtYR+pQMlVSsUyG9p91S7vBjv9LetTugCSwwY+tbGOcMztJiRYhbKcwxw2Z4Q0KtkkEs5yKUSWHoqBqfUdi7dXzgTXeMSFMojinOMcjeMkhZALaf4UJjgOA7LZg6Dsq+qlWkjhV5nTWVA9oToGpiAkEcU77SUC9sNNGmAanBsw4gGPXbz6Iq3c27zOWEFjRgQJERxTIfTjXMnj8Hc4aRpWbTRpAuleHkX0EzTs4FCQhTHNJhsOH/6OM40WoDUEWKbk1C0Wh3495Fq/PPQ+Z4HcFxn40pvR10iZEiI4hhT3VnI+A4o1ZrOEhVEWJBJJTjb2I6qFiscrl4qMnqnZxdO0/RsgJAQxTH2OuE/sTR9BC3bhxmtUgatUgbGgEaLvedBumyhS4rbKQStiZChb2+8whiY58uvyhotsjGJSYZeyCeqN3f0PIDjgKxxwv2641GyKjEhIYpT3G0N4K0t4DkpjENGim1OQpKpF1onNfQmRACQ6RGi1gqhaQEREiREcYr5/LfgGdCeNAR6TWJ31xCLDJ3gETW09TI1AwC1EUguEFIo6r6NjmEJCAlRPMIYbFXCVIDLHDMoC5hFA69H1NzuQIezj5IfWeOFnzXfUGmQECEhikfaapEpt2JMXipGj5sktjUJi0YpQ5pWgRyjGjZHHwKTPhpQJAH2Ngpahwht8YhH6k9AynEw5I6GIX3gve2J3ll8aUH/1SqlMiBnElB+ADj/NZBBiwfBQh5RvMG7gYYTwn1voJSIGAGXzM0pEbrrmqoBUy9JkESvkBDFG01laDO34rSJ4aw7XWxrBg3tdhec7j5aTSu1QOZY4X75gegYlUCQEMUbNd/gQrsDR5wFON1k6388MWDe/7YOr3xyVthK0xcFlwOcxNPlozI6xiUIJETxRPsFoKUCpg4X6rWjUZBKRfKjgTfD+mxjP0XQ1MlA9iXC/XP7qQljEJAQxRNVX8Lh5lEtyYVTrkV+CuUPRYOiDEHwzzW1w9XX9AwQvCKJTOh71vhDFKxLDEiI4oUOE1B3HC3tDlTrL0GGTiWUhyUiTpZeBa1SBoeLR/kFa9+DVXog/zLh/un/E7rCEv1CQhQvVHwOMB6VLB0WZRZGZGrFtmjQwHEcRmXpAADf1QSwjSN/ujBNs1uAMx9F2LrEgIQoHrA0ALVH4HDzOCYTsnhHZuhENmpwMS5XqH55rqkdbR099DvrilQGjJovbIqtOQI00BStP0iIYh3GgFMfCNs6jCMgTy1AlkHl60ZKRIcUjQK5yWowBnxbbe7/BcmFnVO0k7sAS2NE7Yt3KLM61qn6Sgh8SmUwjJ2HO1QGWPvabkBEjOnDUtFidWBsToC1wQtnCsmNrVXA8b8DJXcIMSSiG+QRxTItFcC5fcL94XMBtREcx0GjpP8fYpCXkoQJQ4yQBpptLZEC424GklKBDjPwzWuAtTmyRsYpJESxSlsd8O0/Ad4Ne8pIHLTn953ZS0QVh4tHTWsACaVyNXDJT4TgdYcJOLyNeqH1QEwI0R//+EcUFhZCpVLh0ksvxVdffSW2SeLSeFL47+myg9fnYpdjEj45fQG7jtWKbRkBwGRzYvuXFXjz0Hmcawqg06vKAJTcDugyAacNOPZ34Id3qUtsF0QXor/97W9YtWoV1qxZg8OHD+OSSy7BNddcg4aGBrFNiz6WBuDbncLN7YRdl4e3+RmoaHVCLuVwxYg0sS0kAOiUMqTrlHDzDP8+Uo3Pz1zo31tVaoGSO4FcT9mW2qPAF5uAst2C9zvIs7A5xsT9BC699FJMnToVL774IgCA53nk5eXhF7/4BVavXt3na81mMwwGA0wmE/T6OAoCMib8Z3RagfYmwFIvdIKwNMDNGMwdLlRqxuMzfgw6XBzkUg43XJJDWzpiCJebx4c/NOBEjbCCplZIMTJTi7zkJBSkaqCQ9fE/vrVKSHZsq+s8ptILnVh0WUJMSakTblJlXDdGCPRvVNSop8PhwKFDh/DII4/4jkkkEsydOxeff/55t/F2ux12e2emqtkcwDIqIGxCPP1/3Y/7aTAL6lhbhxMn68xgADgwYUiXcblGFXKMaoAxtDtc+K7G7HtewjvAMd57dmQb1Mg1qgGOQ5uuCH+35cHqSAUgFHD/f2Myhd72RMwgk0owb0wmClM1OHC6CWabE0erTDhaZcLPZhX5xu05UY+KC+2QcBykEg6+YppsNrR8NRam10PaehboMKOy5mu0WLu3JWKcFG6JHBPy0yCXSoX8JHAX/fQQzWqd4/8rbKuAogpRU1MT3G43MjMz/Y5nZmbihx+6J4GtW7cOjz/+ePAXctkFzyOMsA4nnNbehZC3uwGHR2ocLrjtnVsDvIvvbokcHTIDNMYC5BaPAVKHQ86UUNirMNSoxvAMLYamaagUbIzizbgenqFFVbMVZxotsNhdfltvrA4X2jpcPb7+AjLAxl4OwA20VqL22LdosldB5WqDwm2BjPf+03UDcICzK2LLO2LhWzyJq3XgRx55BKtWrfI9NpvNyMsLoLGgMQ+45KcX/bfw3A/0mN9xDgqXG/ltDsCzlOsVC47jwIGDViUFlAoAgILnUWR1do6TKcFkanASGTgOUMulgGdJXgPgrhnUPjqekEo4FKZpUJjWfeo8e1QGLhvmhptncPPdoyBSCQdwciC1CEVThyDHkyPGGADeBc7tAHgnOLcTnE7q+Tp29dTZRT+jiCJ8oQJRhSgtLQ1SqRT19fV+x+vr65GVldVtvFKphFKpDP5CCk3Ye8OrAOSmBDZWDiAjjkJYRPgwqOUwqAPLgk/ThvDdThBE9fMUCgUmT56MDz/80HeM53l8+OGHmD59uoiWEQQRTUSfmq1atQpLlizBlClTMG3aNGzcuBHt7e246667xDaNIIgoIboQ/eQnP0FjYyMeffRR1NXVYeLEiXj//fe7BbAJgkhcRM8jGghxm0dEEIOEQP9GY2gtkCCIwQoJEUEQokNCRBCE6IgerB4I3vBWwFs9CIKIKt6/zf5C0XEtRG1tbQAQWHY1QRCi0dbWBoOh98qWcb1qxvM8ampqoNPpRN+P5d1uUlVVRSt4XaDPpXcGw2fDGENbWxtycnIg6WOfXFx7RBKJBEOGDBHbDD/0en3CfqkGAn0uvZPon01fnpAXClYTBCE6JEQEQYgOCVGYUCqVWLNmTWjVARIY+lx6hz6bTuI6WE0QRGJAHhFBEKJDQkQQhOiQEBEEITokRARBiA4JURAE05H2lVdewcyZM5GcnIzk5GTMnTs3YTvYhtqpd8eOHeA4DosWLYqsgSIS7GfT2tqK5cuXIzs7G0qlEiNHjsS7774bJWtFhBEBsWPHDqZQKNjmzZvZd999x+655x5mNBpZfX19j+Nvu+029sc//pF988037Pvvv2elpaXMYDCw8+fPR9nyyBLs5+Ll3LlzLDc3l82cOZMtXLgwOsZGmWA/G7vdzqZMmcKuu+46duDAAXbu3Dm2d+9eduTIkShbHn1IiAJk2rRpbPny5b7Hbreb5eTksHXr1gX0epfLxXQ6HXv11VcjZaIohPK5uFwudvnll7P//d//ZUuWLElYIQr2s9m0aRMbNmwYczgc0TIxZqCpWQB4O9LOnTvXd6yvjrQ9YbVa4XQ6kZISYA+iOCDUz+WJJ55ARkYG7r777miYKQqhfDZvv/02pk+fjuXLlyMzMxPjxo3D7373O7jd7h7HJxJxvek1WgTbkbYnHn74YeTk5Ph9MeOdUD6XAwcO4C9/+QuOHDkSBQvFI5TP5uzZs/joo4+wePFivPvuuzh9+jTuv/9+OJ1OrFmzJhpmiwYJURR46qmnsGPHDuzduxcq1eDtYd/W1oY77rgDr7zyCtLS0sQ2J+bgeR4ZGRl4+eWXIZVKMXnyZFRXV+MPf/gDCRERfEfarjz99NN46qmn8H//93+YMGFCJM2MOsF+LmfOnEF5eTluuOEG3zGeF/qny2QynDx5EkVFRZE1OkqE8p3Jzs6GXC6HVCr1HRs9ejTq6urgcDigUCgiarOYUIwoAELtSPv73/8ea9euxfvvv48pU6ZEw9SoEuznUlxcjOPHj+PIkSO+24033og5c+bgyJEjCVVpM5TvzIwZM3D69GmfOANAWVkZsrOzE1qEANDyfaDs2LGDKZVKtnXrVnbixAl27733MqPRyOrq6hhjjN1xxx1s9erVvvFPPfUUUygU7M0332S1tbW+W1tbm1hvISIE+7lcTCKvmgX72VRWVjKdTsdWrFjBTp48yd555x2WkZHBfvvb34r1FqIGCVEQvPDCCyw/P58pFAo2bdo09sUXX/iemzVrFluyZInvcUFBAQPQ7bZmzZroGx5hgvlcLiaRhYix4D+bzz77jF166aVMqVSyYcOGsSeffJK5XK4oWx19qAwIQRCiQzEigiBEh4SIIAjRISEiCEJ0SIgIghAdEiKCIESHhIggCNEhISIIQnRIiAiCEB0SIiKu2Lp1K4xGo+/xY489hokTJ/oel5aWJnTp2USFhIjokdLSUnAch5/97Gfdnlu+fDk4jkNpaanf+HALQGFhITZu3Oh37Cc/+QnKysp6fc1zzz2HrVu3+h7Pnj0bDzzwQFjtIsIPCRHRK3l5edixYwdsNpvvWEdHB15//XXk5+eLYpNarUZGRkavzxsMBj+PiYgPSIiIXpk0aRLy8vKwc+dO37GdO3ciPz8fJSUlAzp3T57KokWLfF7W7NmzUVFRgQcffBAcx4HjOADdp2YX09UzKy0txb59+/Dcc8/5znHu3DkMHz4cTz/9tN/rjhw5Ao7jcPr06QG9LyI0SIiIPlm6dCm2bNnie7x582bcddddEb/uzp07MWTIEDzxxBOora1FbW1t0Od47rnnMH36dNxzzz2+c+Tn53d7TwCwZcsWXHnllRg+fHi43gIRBCRERJ/cfvvtOHDgACoqKlBRUYFPP/0Ut99+e8Svm5KSAqlUCp1Oh6ysrH4rYfaEwWCAQqFAUlKS7xxSqRSlpaU4efKkr8eY0+nE66+/jqVLl4b7bRABQqViiT5JT0/HggULsHXrVjDGsGDBgrivN52Tk4MFCxZg8+bNmDZtGv7zn//AbrfjlltuEdu0QQt5RES/LF26FFu3bsWrr74aNq9BIpHg4lJYTqczLOcOhGXLlvkC8Vu2bMFPfvITJCUlRe36hD8kRES/zJ8/Hw6HA06nE9dcc01Yzpmenu4X93G73fj222/9xigUigH39OrtHNdddx00Gg02bdqE999/n6ZlIkNTM6JfpFIpvv/+e9/93jCZTN36laWmpvZYFP+qq67CqlWrsGvXLhQVFWHDhg1obW31G1NYWIj9+/fjpz/9KZRKZUhTwsLCQnz55ZcoLy+HVqtFSkoKJBKJL1b0yCOPYMSIEX02QSAiD3lEREDo9Xro9fo+x+zduxclJSV+t8cff7zHsUuXLsWSJUtw5513YtasWRg2bBjmzJnjN+aJJ55AeXk5ioqKkJ6eHpLdDz30EKRSKcaMGYP09HRUVlb6nrv77rvhcDiisgpI9A3VrCYGLZ988gmuvvpqVFVVdevISkQXEiJi0GG329HY2IglS5YgKysL27dvF9ukQQ9NzYhBxxtvvIGCggK0trbi97//vdjmECCPiCCIGIA8IoIgRIeEiCAI0SEhIghCdEiICIIQHRIigiBEh4SIIAjRISEiCEJ0SIgIghCd/w91JP0Hbd+jRwAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_pred_density_2\n", "\n", "_ = plot_pred_density_2(y)" ] }, { "cell_type": "code", "execution_count": 35, "id": "e79e4b0f", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:26:10.864904Z", "iopub.status.busy": "2024-03-03T07:26:10.864289Z", "iopub.status.idle": "2024-03-03T07:26:11.171093Z", "shell.execute_reply": "2024-03-03T07:26:11.170171Z" }, "papermill": { "duration": 0.330431, "end_time": "2024-03-03T07:26:11.173164", "exception": false, "start_time": "2024-03-03T07:26:10.842733", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_density_3\n", "\n", "_ = plot_density_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))" ] }, { "cell_type": "code", "execution_count": 36, "id": "745adde1", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:26:11.219042Z", "iopub.status.busy": "2024-03-03T07:26:11.218703Z", "iopub.status.idle": "2024-03-03T07:26:11.438988Z", "shell.execute_reply": "2024-03-03T07:26:11.438043Z" }, "papermill": { "duration": 0.245959, "end_time": "2024-03-03T07:26:11.440967", "exception": false, "start_time": "2024-03-03T07:26:11.195008", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAATgAAAEmCAYAAAD2o4yBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAwvElEQVR4nO3deVgUZ54H8G9zNQ3NKdIccooBL0AhIo7ZeHBpNDpmIxojRwxOdJyN01ESNMKgTtgYDzSjso8joiYTTbLGzWYMajBmNAIqBEVFDMihEVBQaA5tG/rdP1wqabsbuxFsKH6f5+HBeuutt9+Xgq/VXVVvCRhjDIQQwkNGhu4AIYT0Fgo4QghvUcARQniLAo4QwlsUcIQQ3qKAI4TwFgUcIYS3KOAIIbxlYugO9EVKpRK3bt2ClZUVBAKBobtDCPkNxhiam5vh4uICI6Ouj9Eo4DS4desW3NzcDN0NQkgXbty4gSFDhnRZhwJOAysrKwCPfoDW1tYG7k3vUSgUOHbsGCIiImBqamro7pCnNFD2p0wmg5ubG/d32hUKOA0635ZaW1vzPuAsLCxgbW3N6z+IgWKg7U9dPj6ikwyEEN6igCOE8BYFHCGEtyjgCCG8RQFHCOEtCjhCCG/RZSKE9FNtbW24evUqt9xyX44zxeWwczgPsUioUtfPzw8WFhbPuosGRwFHSD919epVBAUFqZVv0FC3oKAAY8eO7f1O9TEUcIT0U35+figoKOCWS2saIf2iGJtfHQ1fZ1u1ugMRBRwh/ZSFhYXKUZlRVQOEp+5j+KgABHoMMmDP+g46yUAI4S0KOEIIb1HAEUJ4iwKOEMJbFHCEEN6igCOE8BYFHCGEtyjgCCG8RQFHCOEtCjhCCG9RwBFCeIsCjhDCWxRwhBDeooAjhPAWBRwhhLcMHnDbt2+Hp6cnzM3NERISgrNnz2qtm5WVBYFAoPJlbm6uUicuLk6tTlRUVG8PgxDSBxl0wsuDBw9CKpUiIyMDISEhSE9PR2RkJEpLS+Ho6KhxG2tra5SWlnLLAoFArU5UVBT27NnDLQuFQrU6hBD+M+gR3ObNm5GQkID4+HiMGDECGRkZsLCwQGZmptZtBAIBnJycuC+JRKJWRygUqtSxs7PrzWEQQvoogx3BPXz4EAUFBUhKSuLKjIyMEBYWhtzcXK3btbS0wMPDA0qlEmPHjsUHH3yAkSNHqtQ5efIkHB0dYWdnhylTpmD9+vUYNEj7FM5yuRxyuZxblslkAACFQgGFQtHdIfZ5nWPj8xgHkvb2du47n/epPmMzWMDV19ejo6ND7QhMIpGoPArtt3x9fZGZmQl/f380NTVh48aNmDBhAi5fvowhQ4YAePT2dM6cOfDy8kJ5eTlWrVqFadOmITc3F8bGxhrbTUtLQ2pqqlr5sWPHBsSj1o4fP27oLpAecKMFAEyQl5eHXy4Zuje9p62tTee6AsYY68W+aHXr1i24urrizJkzCA0N5coTExPxww8/ID8//4ltKBQKDB8+HPPnz8e6des01rl+/TqGDh2K7777DlOnTtVYR9MRnJubG+rr62Ftba3nyPoPhUKB48ePIzw8HKampobuDnlKF6rv4t93nceXCcEIcLc3dHd6jUwmg4ODA5qamp7492mwIzgHBwcYGxujrq5Opbyurg5OTk46tWFqaooxY8agrKxMax1vb284ODigrKxMa8AJhUKNJyJMTU0HxB/+QBkn35mYmHDf+bw/9RmbwU4ymJmZISgoCDk5OVyZUqlETk6OyhFdVzo6OlBcXAxnZ2etdW7evImGhoYu6xBC+MmgZ1GlUil27dqFvXv3oqSkBEuWLEFrayvi4+MBADExMSonIdauXYtjx47h+vXrKCwsxOuvv46qqiq8+eabAB6dgFi5ciXy8vJQWVmJnJwczJo1Cz4+PoiMjDTIGAkhhmPQ6+Cio6Nx584dJCcno7a2FoGBgcjOzuZOPFRXV8PI6NcMvnfvHhISElBbWws7OzsEBQXhzJkzGDFiBADA2NgYFy9exN69e9HY2AgXFxdERERg3bp1dC0cIQOQwU4y9GUymQw2NjY6fYjZnykUChw5cgTTp0/n9Wc2A0VRVQNm78zD4SXjef1ke33+Pg1+qxYhhPQWCjhCCG9RwBFCeIsCjhDCWxRwhBDeooAjhPAWBRwhhLcMeqEvebba2tpUZmppuS/HmeJy2Dmch1j064XQfn5+A2IWFcJ/FHADyNWrVxEUFKRWvuGx5YKCAowdO/bZdIqQXkQBN4D4+fmhoKCAWy6taYT0i2JsfnU0fJ1tVeoRwgcUcAOIhYWFypGZUVUDhKfuY/ioAF7f2kMGLjrJQAjhLQo4QghvUcARQniLAo4QwlsUcIQQ3qKAI4TwFgUcIYS3KOAIIbxFAUcI4S0KOEIIb1HAEUJ4iwKOEMJbFHCEEN4yeMBt374dnp6eMDc3R0hICM6ePau1blZWFgQCgcqXubm5Sh3GGJKTk+Hs7AyRSISwsDD8/PPPvT0MQkgfZNCAO3jwIKRSKVJSUlBYWIiAgABERkbi9u3bWrextrZGTU0N91VVVaWyfsOGDdi2bRsyMjKQn58PS0tLREZG4sGDB709HEJIH2PQgNu8eTMSEhIQHx+PESNGICMjAxYWFsjMzNS6jUAggJOTE/clkUi4dYwxpKen4/3338esWbPg7++Pffv24datWzh8+PAzGBEhpC8x2ISXDx8+REFBAZKSkrgyIyMjhIWFITc3V+t2LS0t8PDwgFKpxNixY/HBBx9g5MiRAICKigrU1tYiLCyMq29jY4OQkBDk5uZi3rx5GtuUy+WQy+XcskwmAwAoFAooFIqnGmdf1t7ezn3n8zgHioGyP/UZm8ECrr6+Hh0dHSpHYAAgkUhUHozyW76+vsjMzIS/vz+ampqwceNGTJgwAZcvX8aQIUNQW1vLtfF4m53rNElLS0Nqaqpa+bFjx3j98JUbLQBggry8PPxyydC9IU9roOzPtrY2nev2qynLQ0NDERoayi1PmDABw4cPx3/9139h3bp13W43KSkJUqmUW5bJZHBzc0NERASsra2fqs992YXqu0DxeYwfPx4B7vaG7g55SgNlf3a+w9KFwQLOwcEBxsbGqKurUymvq6uDk5OTTm2YmppizJgxKCsrAwBuu7q6Ojg7O6u0GRgYqLUdoVAIoVCoVm5qagpTU1Od+tIfmZiYcN/5PM6BYqDsT33GZrCTDGZmZggKCkJOTg5XplQqkZOTo3KU1pWOjg4UFxdzYebl5QUnJyeVNmUyGfLz83VukxDCHwZ9iyqVShEbG4vg4GCMGzcO6enpaG1tRXx8PAAgJiYGrq6uSEtLAwCsXbsW48ePh4+PDxobG/HRRx+hqqoKb775JoBHZ1iXL1+O9evXY9iwYfDy8sKaNWvg4uKC2bNnG2qYhBADMWjARUdH486dO0hOTkZtbS0CAwORnZ3NnSSorq6GkdGvB5n37t1DQkICamtrYWdnh6CgIJw5cwYjRozg6iQmJqK1tRWLFy9GY2MjJk6ciOzsbLULggkh/CdgjDFDd6KvkclksLGxQVNTE69PMhRVNWD2zjwcXjKenovKAwNlf+rz92nwW7UIIaS3UMARQniLAo4QwlsUcIQQ3upXdzIQMtBV1LeiVd6ucV35nVbue+dFv5pYCk3g5WDZK/3rayjgCOknKupbMXnjySfWe+fL4ifW+X7FpAERchRwhPQTnUdu6dGB8HEUq6+/L8c3J3MxY1IoLEXqtx4CQNntFiw/WKT1KJBvKOAI6Wd8HMUY5WqjVq5QKFA7GBjrYcfre1H1QScZCCG8RQFHCOEtvQPu+vXrvdEPQgjpcXoHnI+PDyZPnoxPPvmEHuRCCOnT9A64wsJC+Pv7QyqVwsnJCX/4wx+6fNQfIYQYit4BFxgYiK1bt+LWrVvIzMxETU0NJk6ciFGjRmHz5s24c+dOb/STEEL01u3LRExMTDBnzhy89NJL2LFjB5KSkrBixQqsWrUKc+fOxYcffqgybTh59rq66h2gK98J/3U74M6fP4/MzEwcOHAAlpaWWLFiBRYtWoSbN28iNTUVs2bNoreuBqTrVe8AXflO+EvvgNu8eTP27NmD0tJSTJ8+Hfv27cP06dO5mXe9vLyQlZUFT0/Pnu4r0cOTrnoH6Mp3wn96B9zOnTvxxhtvIC4uTutbUEdHR+zevfupO0eenrar3gG68p3wn94Bd/z4cbi7u6s8KwEAGGO4ceMG3N3dYWZmhtjY2B7rJCGEdIfeZ1GHDh2K+vp6tfK7d+/Cy8urRzpFCCE9Qe+A0/aMmpaWFnpyFSGkT9H5LapUKgXw6NmjycnJsLCw4NZ1dHQgPz+/y6fHE0LIs6ZzwP30008AHh3BFRcXw8zMjFtnZmaGgIAArFixoud7SAgh3aRzwH3//fcAgPj4eGzdupXXzwslhPCD3p/B7dmzp0fDbfv27fD09IS5uTlCQkJ0vjj4wIEDEAgEmD17tkp5XFwcBAKByldUVFSP9ZcQ0n/odAQ3Z84cZGVlwdraGnPmzOmy7qFDh3R+8YMHD0IqlSIjIwMhISFIT09HZGQkSktL4ejoqHW7yspKrFixAi+88ILG9VFRUdizZw+3LBRqvoiVEMJvOh3B2djYQCAQcP/u6ksfmzdvRkJCAuLj4zFixAhkZGTAwsICmZmZWrfp6OjAggULkJqaCm9vb411hEIhnJycuC87Ozu9+kUI4QedjuB+ezT0238/jYcPH6KgoABJSUlcmZGREcLCwpCbm6t1u7Vr18LR0RGLFi3CqVOnNNY5efIkHB0dYWdnhylTpmD9+vUYNGiQ1jblcjnkcjm3LJPJADy60l+hUOg7tD6hvb2d+65tDJ3lXY1Rl3bIs9Eqb4GR+S8ou3cFShP1+4Lb29txq/0Wim8Xa5084fq9VhiZ/4JWeQsUCguNdfo6fX4PDfbQmfr6enR0dEAikaiUSyQSXL16VeM2p0+fxu7du1FUVKS13aioKMyZMwdeXl4oLy/HqlWrMG3aNOTm5sLY2FjjNmlpaUhNTVUrP3bsmMrlMP3JjRYAMMHp06dRpflWVM7x48d7pB3Suwqbb8HSawfWFHRdb8d3O7pcb+kFHDnTgVorlx7s3bPT1tamc12dAm7MmDHcW9QnKSws1PnF9dHc3IyFCxdi165dcHBw0Fpv3rx53L9Hjx4Nf39/DB06FCdPnsTUqVM1bpOUlMRd5wc8OoJzc3NDREREvz1bfPmWDBuL8zBx4kSMdNE8BoVCgePHjyM8PFzrvai6tEOeDacbt7F/nzE2//toeA/WfASXn5ePkPEh2o/g7rRC+mUxpse8hLFu2j/n7ss632HpQqeAe/xMZU9wcHCAsbEx6urqVMrr6urg5OSkVr+8vByVlZWYOXMmV6ZUKgE8mpuutLQUQ4cOVdvO29sbDg4OKCsr0xpwQqFQ44kIU1PTfnsTeucvuImJyRPH0NU49WmH9C5LoRjKB67wsRuBURLNjw28YXIDox1Ha91XRu1NUD64C0uhuN/uT336rVPApaSkdLsz2piZmSEoKAg5OTlcgCqVSuTk5GDZsmVq9f38/FBcrDpv2fvvv4/m5mZs3boVbm5uGl/n5s2baGhooMk3CRmADPrgZ6lUitjYWAQHB2PcuHFIT09Ha2sr4uPjAQAxMTFwdXVFWloazM3NMWrUKJXtbW1tAYArb2lpQWpqKl555RU4OTmhvLwciYmJ8PHxQWRk5DMdGyHE8HQKOHt7e1y7dg0ODg6ws7Pr8vO4u3fv6vzi0dHRuHPnDpKTk1FbW4vAwEBkZ2dzJx6qq6vVpmXqirGxMS5evIi9e/eisbERLi4uiIiIwLp16+haOEIGIJ0CbsuWLbCysuL+resJB10sW7ZM41tS4NHlHl3JyspSWRaJRDh69GgP9ax/k3c8gJH5L6iQlcLIXPPpz87LCkrulmj9ULpC9ujSBHnHAwD6XedIiKHpFHC/nbwyLi6ut/pCetCt1ipYen2MVTrc+bYj+8mXFdxqDUQQJF3WI6Sv0fszOGNjY9TU1KjdStXQ0ABHR0d0dHT0WOdI97lYeqC14k/YGh2IoVqeydDe3o4fT/+I3038ndYjuPLbLXj7YBFcJnv0ZncJ6RV6B5y2CS/lcrnKFErEsITG5lA+cIWXtS9GDNL+TIYKkwoMtx+u9dS78kETlA/uQGhMk5mS/kfngNu2bRuARxNe/v3vf4dY/OtRQUdHB/71r3/Bz8+v53tICCHdpHPAbdmyBcCjI7iMjAyV257MzMzg6emJjIyMnu8hIYR0k84BV1FRAQCYPHkyDh06RDN0EEL6PL0/g+uc2ZcQQvo6vQPujTfe6HJ9V3O5EULIs6R3wN27d09lWaFQ4NKlS2hsbMSUKVN6rGOEEPK09A64r776Sq1MqVRiyZIlGmfzIIQQQ9H7oTMaGzEyglQq5c60EkJIX9AjAQc8mq+tc3prQgjpC/R+i/rbmW+BR9fF1dTU4J///KfKPauEEGJoegdc5xPuOxkZGWHw4MHYtGnTE8+wEkLIs0TXwRFCeKvHPoMjhJC+hgKOEMJbFHCEEN6igCOE8FaPBdzNmzexePHinmqOEEKeWo8FXENDA3bv3t1TzRFCyFOjt6iEEN6igCOE8BYFHCGEt3S+k2HOnDldrm9sbOxWB7Zv346PPvoItbW1CAgIwMcff4xx48Y9cbsDBw5g/vz5mDVrFg4fPsyVM8aQkpKCXbt2obGxEb/73e+wc+dODBs2rFv9I4T0XzofwdnY2HT55eHhgZiYGL1e/ODBg5BKpUhJSUFhYSECAgIQGRmJ27dvd7ldZWUlVqxYgRdeeEFt3YYNG7Bt2zZkZGQgPz8flpaWiIyMxIMHD/TqGyGk/9P5CG7Pnj09/uKbN29GQkIC4uPjAQAZGRn45z//iczMTLz33nsat+no6MCCBQuQmpqKU6dOqRw5MsaQnp6O999/H7NmzQIA7Nu3DxKJBIcPH8a8efN6fAyEkL5L75vte8rDhw9RUFCApKQkrszIyAhhYWHIzc3Vut3atWvh6OiIRYsW4dSpUyrrKioqUFtbi7CwMK7MxsYGISEhyM3N1RpwcrkccrmcW5bJZAAeTceuUCi6NT5D65ybr729XesYOsu7GqMu7ZBn40n7YqDsT336rXPA6ToVkq4Pnamvr0dHRwckEolKuUQiwdWrVzVuc/r0aezevRtFRUUa19fW1nJtPN5m5zpN0tLSkJqaqlZ+7NgxWFhYdDWMPutGCwCY4PTp06gSd133+PHjPdIO6V267gu+78+2tjad6+occFlZWfDw8MCYMWPAGOtWx55Gc3MzFi5ciF27dsHBwaFH205KSlKZyFMmk8HNzQ0RERGwtrbu0dd6Vi7fkmFjcR4mTpyIkS6ax6BQKHD8+HGEh4fD1NS02+2QZ+NJ+2Kg7M/Od1i60DnglixZgs8++wwVFRWIj4/H66+/Dnt7+251EAAcHBxgbGyMuro6lfK6ujo4OTmp1S8vL0dlZSVmzpzJlSmVSgCAiYkJSktLue3q6urg7Oys0mZgYKDWvgiFQgiFQrVyU1NTrb8ofZ2JiQn3/Ulj6Gqc+rRDepeu+4Lv+1Offut8FnX79u2oqalBYmIi/vd//xdubm6YO3cujh492q0jOjMzMwQFBSEnJ4crUyqVyMnJQWhoqFp9Pz8/FBcXo6ioiPt6+eWXMXnyZBQVFcHNzQ1eXl5wcnJSaVMmkyE/P19jm4QQftPrJINQKMT8+fMxf/58VFVVISsrC0uXLkV7ezsuX74MsVi/N/VSqRSxsbEIDg7GuHHjkJ6ejtbWVu6sakxMDFxdXZGWlgZzc3OMGjVKZXtbW1sAUClfvnw51q9fj2HDhsHLywtr1qyBi4sLZs+erVffCCH9X7fPohoZGUEgEIAxho6Ojm61ER0djTt37iA5ORm1tbUIDAxEdnY2d5KguroaRkb63WyRmJiI1tZWLF68GI2NjZg4cSKys7Nhbm7erT4SQvovvQJOLpfj0KFDyMzMxOnTpzFjxgz87W9/Q1RUlN5B1GnZsmVYtmyZxnUnT57sctusrCy1MoFAgLVr12Lt2rXd6g8hhD90DrilS5fiwIEDcHNzwxtvvIHPPvusx89mkp5zX/HoqPrSL01a67Tel+P8HcCp6h4sReonWQCg7HZLr/SP6O9J+5T2pzqdAy4jIwPu7u7w9vbGDz/8gB9++EFjvUOHDvVY50j3lf//L/J7h4qfUNME+8vOPbE9S6HBrgkn/0+3fUr787d0HmVMTAwEAkFv9oX0oIiRjy6ZGeoohsjUWGOd0pomvPNlMTb9+2j4OttobctSaAIvB8te6SfR3ZP2Ke1PdXpd6Ev6D3tLM8wb595lnc7bdoYOtsQoV+1/EKRveNI+pf2pjuaDI4TwFgUcIYS3KOAIIbxFAUcI4S0KOEIIb1HAEUJ4iwKOEMJbFHCEEN6igCOE8BYFHCGEtyjgCCG8RQFHCOEtCjhCCG9RwBFCeIsCjhDCWxRwhBDeooAjhPAWBRwhhLco4AghvEUBRwjhLYMH3Pbt2+Hp6Qlzc3OEhITg7NmzWuseOnQIwcHBsLW1haWlJQIDA7F//36VOnFxcRAIBCpfUVFRvT0MQkgfZNCHIx48eBBSqRQZGRkICQlBeno6IiMjUVpaCkdHR7X69vb2WL16Nfz8/GBmZoZvvvkG8fHxcHR0RGRkJFcvKioKe/bs4ZaFQs0PwSWE8JtBj+A2b96MhIQExMfHY8SIEcjIyICFhQUyMzM11p80aRJ+//vfY/jw4Rg6dCjefvtt+Pv74/Tp0yr1hEIhnJycuC87O7tnMRxCSB9jsCO4hw8foqCgAElJSVyZkZERwsLCkJub+8TtGWM4ceIESktL8eGHH6qsO3nyJBwdHWFnZ4cpU6Zg/fr1GDRokNa25HI55HI5tyyTyQAACoUCCoVC36H1G53P0Wxvb+f1OAeKgbI/9RmbwQKuvr4eHR0dkEgkKuUSiQRXr17Vul1TUxNcXV0hl8thbGyMHTt2IDw8nFsfFRWFOXPmwMvLC+Xl5Vi1ahWmTZuG3NxcGBtrfsJ7WloaUlNT1cqPHTsGCwuLbo6w77vRAgAmyMvLwy+XDN0b8rQGyv5sa2vTua5BP4PrDisrKxQVFaGlpQU5OTmQSqXw9vbGpEmTAADz5s3j6o4ePRr+/v4YOnQoTp48ialTp2psMykpCVKplFuWyWRwc3NDREQErK2te3U8hnSh+i5QfB7jx49HgLu9obtDntJA2Z+d77B0YbCAc3BwgLGxMerq6lTK6+rq4OTkpHU7IyMj+Pj4AAACAwNRUlKCtLQ0LuAe5+3tDQcHB5SVlWkNOKFQqPFEhKmpKUxNTXUcUf9jYmLCfefzOAeKgbI/9RmbwU4ymJmZISgoCDk5OVyZUqlETk4OQkNDdW5HqVSqfH72uJs3b6KhoQHOzs5P1V9CSP9j0LeoUqkUsbGxCA4Oxrhx45Ceno7W1lbEx8cDAGJiYuDq6oq0tDQAjz4rCw4OxtChQyGXy3HkyBHs378fO3fuBAC0tLQgNTUVr7zyCpycnFBeXo7ExET4+PioXEZCCBkYDBpw0dHRuHPnDpKTk1FbW4vAwEBkZ2dzJx6qq6thZPTrQWZrayuWLl2KmzdvQiQSwc/PD5988gmio6MBAMbGxrh48SL27t2LxsZGuLi4ICIiAuvWraNr4QgZgASMMWboTvQ1MpkMNjY2aGpq4vVJhqKqBszemYfDS8Yj0EP7ZTSkfxgo+1Ofv0+D36pFCCG9hQKOEMJbFHCEEN6igCOE8BYFHCGEtyjgCCG8RQFHCOEtCjhCCG9RwBFCeIsCjhDCWxRwhBDeooAjhPAWBRwhhLco4AghvEUBRwjhLQo4QghvUcARQniLAo4QwlsUcIQQ3qKAI4TwFgUcIYS3KOAIIbxFAUcI4S0KOEIIbxk84LZv3w5PT0+Ym5sjJCQEZ8+e1Vr30KFDCA4Ohq2tLSwtLREYGIj9+/er1GGMITk5Gc7OzhCJRAgLC8PPP//c28MghPRBBg24gwcPQiqVIiUlBYWFhQgICEBkZCRu376tsb69vT1Wr16N3NxcXLx4EfHx8YiPj8fRo0e5Ohs2bMC2bduQkZGB/Px8WFpaIjIyEg8ePHhWwyKE9BECxhgz1IuHhITg+eefx9/+9jcAgFKphJubG/70pz/hvffe06mNsWPH4qWXXsK6devAGIOLiwveeecdrFixAgDQ1NQEiUSCrKwszJs3T6c2ZTIZbGxs0NTUBGtr6+4Nrg9qa2vD1atXueXSmkZIvyjG5ldHw9fZliv38/ODhYWFAXpInkZRVQNm78zD4SXjEegxyNDd6TX6/H2aPKM+qXn48CEKCgqQlJTElRkZGSEsLAy5ublP3J4xhhMnTqC0tBQffvghAKCiogK1tbUICwvj6tnY2CAkJAS5ublaA04ul0Mul3PLMpkMAKBQKKBQKLo1vr7o0qVLCAkJUSt/ba/qcn5+PsaMGfOMekV6Snt7O/edT7+3j9NnbAYLuPr6enR0dEAikaiUSyQSlaOMxzU1NcHV1RVyuRzGxsbYsWMHwsPDAQC1tbVcG4+32blOk7S0NKSmpqqVHzt2jFdHMnK5HJs2beKWFUrg7gPA3hww/c2HFZWVlaipqTFAD8nTuNECACbIy8vDL5cM3Zve09bWpnNdgwVcd1lZWaGoqAgtLS3IycmBVCqFt7c3Jk2a1O02k5KSIJVKuWWZTAY3NzdERETw6i3q4xQKBY4fP47w8HCYmpoaujvkKV2ovgsUn8f48eMR4G5v6O70ms53WLowWMA5ODjA2NgYdXV1KuV1dXVwcnLSup2RkRF8fHwAAIGBgSgpKUFaWhomTZrEbVdXVwdnZ2eVNgMDA7W2KRQKIRQK1cpNTU0HxB/+QBkn35mYmHDf+bw/9Rmbwc6impmZISgoCDk5OVyZUqlETk4OQkNDdW5HqVRyn595eXnByclJpU2ZTIb8/Hy92iSE8INB36JKpVLExsYiODgY48aNQ3p6OlpbWxEfHw8AiImJgaurK9LS0gA8+qwsODgYQ4cOhVwux5EjR7B//37s3LkTACAQCLB8+XKsX78ew4YNg5eXF9asWQMXFxfMnj3bUMMkhBiIQQMuOjoad+7cQXJyMmpraxEYGIjs7GzuJEF1dTWMjH49yGxtbcXSpUtx8+ZNiEQi+Pn54ZNPPkF0dDRXJzExEa2trVi8eDEaGxsxceJEZGdnw9zc/JmPjxBiWAa9Dq6v4ut1cI9TKBQ4cuQIpk+fzuvPbAYKug5OncFv1SKEkN5CAUcI4S0KOEIIb1HAEUJ4iwKOEMJb/e5WLULII5pmh5HXlqHkkgjKBluVugN1hhgKOEL6qatXryIoKEit/PHZYQCgoKAAY8eOfQa96lso4Ajpp/z8/FBQUMAtt9yX45/f5+KlyaEQi4RqdQciCjhC+ikLCwuVozKFQoF79bcROi6YLtz+f3SSgRDCWxRwhBDeooAjhPAWBRwhhLco4AghvEUBRwjhLQo4Qghv0XVwGnTOAarP03v6I4VCgba2NshkMrpuigcGyv7s/LvUZa5eCjgNmpubAQBubm4G7gkhRJvm5mbY2Nh0WYemLNdAqVTi1q1bsLKygkAgMHR3ek3n819v3LjB66nZB4qBsj8ZY2huboaLi4vKM1s0oSM4DYyMjDBkyBBDd+OZsba25vUfxEAzEPbnk47cOtFJBkIIb1HAEUJ4iwJuABMKhUhJSYFQKHxyZdLn0f5URycZCCG8RUdwhBDeooAjhPAWBRwhhLco4HpAXFwcZs+e3aNtTpo0CcuXL++yjqenJ9LT03v0dQnhEwq4x+gSLKR/+ctf/oLAwEBDd0Ojvvb71tf687Qo4Ajp5x4+fGjoLvRZFHC/ERcXhx9++AFbt26FQCCAQCBAeXk5Fi1aBC8vL4hEIvj6+mLr1q0at09NTcXgwYNhbW2Nt956S+dfvNbWVsTExEAsFsPZ2RmbNm1Sq3P79m3MnDkTIpEIXl5e+PTTT9XqCAQC7Ny5E9OmTYNIJIK3tze+/PJLbn1lZSUEAgE+//xzvPDCCxCJRHj++edx7do1nDt3DsHBwRCLxZg2bRru3Lmj408NyMzMxMiRIyEUCuHs7Ixly5Zx66qrqzFr1iyIxWJYW1tj7ty5qKur49Z3Hl3t378fnp6esLGxwbx587gJD4BH9wZv2LABPj4+EAqFcHd3x1//+ldu/bvvvovnnnsOFhYW8Pb2xpo1a6BQKAAAWVlZSE1NxYULF7h9mpWVpfPYelN3f986PxL561//ChcXF/j6+gIAzpw5g8DAQJibmyM4OBiHDx+GQCBAUVERt+2lS5cwbdo0iMViSCQSLFy4EPX19Vr7U1lZ+ax+HL2DEU5jYyMLDQ1lCQkJrKamhtXU1LAHDx6w5ORkdu7cOXb9+nX2ySefMAsLC3bw4EFuu9jYWCYWi1l0dDS7dOkS++abb9jgwYPZqlWrdHrdJUuWMHd3d/bdd9+xixcvshkzZjArKyv29ttvc3WmTZvGAgICWG5uLjt//jybMGECE4lEbMuWLVwdAGzQoEFs165drLS0lL3//vvM2NiYXblyhTHGWEVFBQPA/Pz8WHZ2Nrty5QobP348CwoKYpMmTWKnT59mhYWFzMfHh7311ls69X3Hjh3M3Nycpaens9LSUnb27FmuTx0dHSwwMJBNnDiRnT9/nuXl5bGgoCD24osvctunpKQwsVjM5syZw4qLi9m//vUv5uTkpPKzS0xMZHZ2diwrK4uVlZWxU6dOsV27dnHr161bx3788UdWUVHBvv76ayaRSNiHH37IGGOsra2NvfPOO2zkyJHcPm1ra9NpbL3taX/fFi5cyC5dusQuXbrEmpqamL29PXv99dfZ5cuX2ZEjR9hzzz3HALCffvqJMcbYvXv32ODBg1lSUhIrKSlhhYWFLDw8nE2ePFlrf9rb2w3xo+kxFHCPefHFF1WCRZM//vGP7JVXXuGWY2Njmb29PWttbeXKdu7cycRiMevo6OiyrebmZmZmZsY+//xzrqyhoYGJRCKuH6WlpQwAO3v2LFenpKSEAVALuMeDKSQkhC1ZsoQx9mvA/f3vf+fWf/bZZwwAy8nJ4crS0tKYr69vl/3u5OLiwlavXq1x3bFjx5ixsTGrrq7myi5fvqwylpSUFGZhYcFkMhlXZ+XKlSwkJIQxxphMJmNCoVAl0J7ko48+YkFBQdxySkoKCwgI0Hn7Z6m7v28SiYTJ5XKubOfOnWzQoEHs/v37XNmuXbtUAm7dunUsIiJCpe0bN24wAKy0tFTn/vQnNJuIDrZv347MzExUV1fj/v37ePjwodqH1gEBAbCwsOCWQ0ND0dLSghs3bsDDw0Nr2+Xl5Xj48CFCQkK4Mnt7e+5tBwCUlJTAxMQEQUFBXJmfnx9sbW3V2gsNDVVb/u1bFADw9/fn/i2RSAAAo0ePVim7ffu21j53un37Nm7duoWpU6dqXF9SUgI3NzeVefVGjBgBW1tblJSU4Pnnnwfw6GywlZUVV8fZ2Zl7/ZKSEsjlcq2vAQAHDx7Etm3bUF5ejpaWFrS3t/fr2TR0+X0bPXo0zMzMuOXS0lL4+/vD3NycKxs3bpzKNhcuXMD3338PsVis9prl5eV47rnnenYgfQB9BvcEBw4cwIoVK7Bo0SIcO3YMRUVFiI+P79cf7P52ttfO+e4eL1MqlU9sRyQS9Xh/Hn/9J71Gbm4uFixYgOnTp+Obb77BTz/9hNWrV/fb/aPr75ulpaXebbe0tGDmzJkoKipS+fr555/xb//2bz01hD6FAu4xZmZm6Ojo4JZ//PFHTJgwAUuXLsWYMWPg4+OD8vJyte0uXLiA+/fvc8t5eXkQi8VPnBV46NChMDU1RX5+Pld27949XLt2jVv28/NDe3s7CgoKuLLS0lI0NjaqtZeXl6e2PHz48C770F1WVlbw9PRETk6OxvXDhw/HjRs3cOPGDa7sypUraGxsxIgRI3R6jWHDhkEkEml9jTNnzsDDwwOrV69GcHAwhg0bhqqqKpU6j+/TvqS7v2+P8/X1RXFxMeRyOVd27tw5lTpjx47F5cuX4enpCR8fH5WvzsDsyz+r7qCAe4ynpyfy8/NRWVmJ+vp6DBs2DOfPn8fRo0dx7do1rFmzRu0XB3h0qn7RokW4cuUKjhw5gpSUFCxbtuyJM46KxWIsWrQIK1euxIkTJ3Dp0iXExcWpbOfr64uoqCj84Q9/QH5+PgoKCvDmm29qPLr54osvkJmZiWvXriElJQVnz55VOavZ0/7yl79g06ZN2LZtG37++WcUFhbi448/BgCEhYVh9OjRWLBgAQoLC3H27FnExMTgxRdfRHBwsE7tm5ub491330ViYiL27duH8vJy5OXlYffu3QAeBWB1dTUOHDiA8vJybNu2DV999ZVKG56enqioqEBRURHq6+tVQsDQuvv79rjXXnsNSqUSixcvRklJCY4ePYqNGzcC+PUo/Y9//CPu3r2L+fPn49y5cygvL8fRo0cRHx/Phdrj/dHlSL5PM/SHgH1NaWkpGz9+PBOJRAwAu3r1KouLi2M2NjbM1taWLVmyhL333nsqH1rHxsayWbNmseTkZDZo0CAmFotZQkICe/DggU6v2dzczF5//XVmYWHBJBIJ27Bhg9qHvTU1Neyll15iQqGQubu7s3379jEPDw+1kwzbt29n4eHhTCgUMk9PT5Wzb50nGTo/dGaMse+//54BYPfu3ePK9uzZw2xsbHT+mWVkZDBfX19mamrKnJ2d2Z/+9CduXVVVFXv55ZeZpaUls7KyYq+++iqrra3l1ms6AbBlyxbm4eHBLXd0dLD169czDw8PZmpqytzd3dkHH3zArV+5ciX3c4+OjmZbtmxR6f+DBw/YK6+8wmxtbRkAtmfPHp3H1tue5vftcT/++CPz9/dnZmZmLCgoiP3jH//g2ux07do19vvf/57Z2toykUjE/Pz82PLly5lSqdTYn4qKil7+CfQumi6JRwQCAb766qsev22M9E+ffvop4uPj0dTU1GOfl/Y3dBaVEJ7Yt28fvL294erqigsXLuDdd9/F3LlzB2y4ARRwva66urrLD9SvXLkCd3f3Z9gj/Wi6pKDTt99+ixdeeOEZ9oZ0pba2FsnJyaitrYWzszNeffVVlTs+BiJ6i9rL2tvbu7zdxdPTEyYmfff/mbKyMq3rXF1dB/TRAen7KOAIIbxFl4kQQniLAo4QwlsUcIQQ3qKAI4TwFgUcMai4uDhuckVTU1NIJBKEh4cjMzNTr9uEsrKyNM6u0tt643kcpOdQwBGDi4qKQk1NDSorK/Htt99i8uTJePvttzFjxgy0t7cbunukPzPkfWKEaLuvMicnhwHgJrrctGkTGzVqFLOwsGBDhgxhS5YsYc3NzYyxX++n/e1XSkoKY4yxffv2saCgICYWi5lEImHz589ndXV13OvcvXuXvfbaa8zBwYGZm5szHx8flpmZya2vrq5mr776KrOxsWF2dnbs5Zdf5u7PTElJUXvd77//vld+TqR76AiO9ElTpkxBQEAADh06BAAwMjLCtm3bcPnyZezduxcnTpxAYmIiAGDChAlIT0+HtbU1ampqUFNTgxUrVgAAFAoF1q1bhwsXLuDw4cOorKxEXFwc9zpr1qzBlStX8O2336KkpAQ7d+6Eg4MDt21kZCSsrKxw6tQp/PjjjxCLxYiKisLDhw+xYsUKzJ07lzsCrampwYQJE57tD4p0zdAJSwY2bUdwjDEWHR3Nhg8frnHdF198wQYNGsQt6zoDyrlz5xgA7uhv5syZLD4+XmPd/fv3M19fX26mDcYYk8vlTCQSsaNHjz6x/8Tw6AiO9FmMMW4us++++w5Tp06Fq6srrKyssHDhQjQ0NKCtra3LNgoKCjBz5ky4u7vDysoKL774IoBH9wgDwJIlS3DgwAEEBgYiMTERZ86c4ba9cOECysrKYGVlBbFYDLFYDHt7ezx48ECnSSiJ4VHAkT6rpKQEXl5eqKysxIwZM+Dv74///u//RkFBAbZv3w6g62eCtra2IjIyEtbW1vj0009x7tw5bjLMzu2mTZuGqqoq/PnPf+aeL9H59ralpQVBQUFqU3xfu3YNr732Wi+PnvSEvnuXNxnQTpw4geLiYvz5z39GQUEBlEolNm3axM10/Pnnn6vU1zTV9tWrV9HQ0ID//M//5KaOP3/+vNprDR48GLGxsYiNjcULL7yAlStXYuPGjRg7diwOHjwIR0dHrQ+x4dsU33xDR3DE4ORyOWpra/HLL7+gsLAQH3zwAWbNmoUZM2YgJiYGPj4+UCgU+Pjjj3H9+nXs378fGRkZKm14enqipaUFOTk5qK+vR1tbG9zd3WFmZsZt9/XXX2PdunUq2yUnJ+N//ud/UFZWhsuXL+Obb77hnmGxYMECODg4YNasWTh16hQqKipw8uRJ/Md//Adu3rzJve7FixdRWlqK+vp67oHTpI8w9IeAZGCLjY3lLrEwMTFhgwcPZmFhYSwzM1PlmbKbN29mzs7OTCQSscjISLZv3z61qdbfeustNmjQIJXLRP7xj38wT09PJhQKWWhoKPv666/VnhU6fPhwJhKJmL29PZs1axa7fv0612ZNTQ2LiYlhDg4OTCgUMm9vb5aQkMCampoYY4zdvn2bhYeHM7FYTJeJ9EE0XRIhhLfoLSohhLco4AghvEUBRwjhLQo4QghvUcARQniLAo4QwlsUcIQQ3qKAI4TwFgUcIYS3KOAIIbxFAUcI4S0KOEIIb/0fKAyGlVD+OEAAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_utility_loss.loss_learning.visualization import plot_box_3\n", "\n", "_ = plot_box_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))" ] }, { "cell_type": "code", "execution_count": 37, "id": "eabe1bab", "metadata": { "execution": { "iopub.execute_input": "2024-03-03T07:26:11.486182Z", "iopub.status.busy": "2024-03-03T07:26:11.485846Z", "iopub.status.idle": "2024-03-03T07:26:11.789269Z", "shell.execute_reply": "2024-03-03T07:26:11.788356Z" }, "papermill": { "duration": 0.328419, "end_time": "2024-03-03T07:26:11.791409", "exception": false, "start_time": "2024-03-03T07:26:11.462990", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#\"\"\"\n", "from ml_utility_loss.loss_learning.visualization import plot_grad, plot_grad_2, plot_grad_3\n", "import matplotlib.pyplot as plt\n", "\n", "#plot_grad_2(y, model.models)\n", "for m in model.models:\n", " ym = y[m]\n", " fig, ax = plt.subplots()\n", " plot_grad_3(ym[\"error\"], ym[\"grad\"], name=f\"{m}_grad\", fig=fig, ax=ax)\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "54c0e9f3", "metadata": { "papermill": { "duration": 0.022718, "end_time": "2024-03-03T07:26:11.837070", "exception": false, "start_time": "2024-03-03T07:26:11.814352", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "celltoolbar": "Tags", "colab": { "authorship_tag": "ABX9TyOOVfelovKP9fLGU7SvvRie", "gpuType": "T4", "mount_file_id": "17POSGAvge8y9DW9WGs2jLkibaRjToayg", "provenance": [] }, "kaggle": { "accelerator": "gpu", "dataSources": [], "dockerImageVersionId": 30648, "isGpuEnabled": true, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" }, "papermill": { "default_parameters": {}, "duration": 3923.024134, "end_time": "2024-03-03T07:26:14.584105", "environment_variables": {}, "exception": null, "input_path": "eval/contraceptive/tab_ddpm_concat/1/mlu-eval.ipynb", "output_path": "eval/contraceptive/tab_ddpm_concat/1/mlu-eval.ipynb", "parameters": { "allow_same_prediction": true, "dataset": "contraceptive", "dataset_name": "contraceptive", "debug": false, "folder": "eval", "gp": true, "gp_multiply": false, "log_wandb": false, "param_index": 3, "path": "eval/contraceptive/tab_ddpm_concat/1", "path_prefix": "../../../../", "random_seed": 1, "single_model": "tab_ddpm_concat" }, "start_time": "2024-03-03T06:20:51.559971", "version": "2.5.0" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }