Muhammad Rizqi Nur
commited on
Commit
•
82572a0
1
Parent(s):
88d6559
This view is limited to 50 files because it contains too many changes.
See raw diff
- contraceptive/lct_gan/eval.csv +2 -0
- contraceptive/lct_gan/history.csv +20 -0
- contraceptive/lct_gan/mlu-eval.ipynb +0 -0
- contraceptive/lct_gan/model.pt +3 -0
- contraceptive/lct_gan/params.json +1 -0
- contraceptive/realtabformer/eval.csv +2 -0
- contraceptive/realtabformer/history.csv +20 -0
- contraceptive/realtabformer/mlu-eval.ipynb +0 -0
- contraceptive/realtabformer/model.pt +3 -0
- contraceptive/realtabformer/params.json +1 -0
- contraceptive/tab_ddpm_concat/eval.csv +2 -0
- contraceptive/tab_ddpm_concat/history.csv +17 -0
- contraceptive/tab_ddpm_concat/mlu-eval.ipynb +0 -0
- contraceptive/tab_ddpm_concat/model.pt +3 -0
- contraceptive/tab_ddpm_concat/params.json +1 -0
- contraceptive/tvae/eval.csv +2 -0
- contraceptive/tvae/history.csv +20 -0
- contraceptive/tvae/mlu-eval.ipynb +0 -0
- contraceptive/tvae/model.pt +3 -0
- contraceptive/tvae/params.json +1 -0
- insurance/lct_gan/eval.csv +2 -0
- insurance/lct_gan/history.csv +27 -0
- insurance/lct_gan/mlu-eval.ipynb +0 -0
- insurance/lct_gan/model.pt +3 -0
- insurance/lct_gan/params.json +1 -0
- insurance/realtabformer/eval.csv +2 -0
- insurance/realtabformer/history.csv +14 -0
- insurance/realtabformer/mlu-eval.ipynb +0 -0
- insurance/realtabformer/model.pt +3 -0
- insurance/realtabformer/params.json +1 -0
- insurance/tab_ddpm_concat/eval.csv +2 -0
- insurance/tab_ddpm_concat/history.csv +19 -0
- insurance/tab_ddpm_concat/mlu-eval.ipynb +0 -0
- insurance/tab_ddpm_concat/model.pt +3 -0
- insurance/tab_ddpm_concat/params.json +1 -0
- insurance/tvae/eval.csv +2 -0
- insurance/tvae/history.csv +23 -0
- insurance/tvae/mlu-eval.ipynb +0 -0
- insurance/tvae/model.pt +3 -0
- insurance/tvae/params.json +1 -0
- treatment/lct_gan/eval.csv +2 -0
- treatment/lct_gan/history.csv +29 -0
- treatment/lct_gan/mlu-eval.ipynb +0 -0
- treatment/lct_gan/model.pt +3 -0
- treatment/lct_gan/params.json +1 -0
- treatment/realtabformer/eval.csv +2 -0
- treatment/realtabformer/history.csv +17 -0
- treatment/realtabformer/mlu-eval.ipynb +0 -0
- treatment/realtabformer/model.pt +3 -0
- treatment/realtabformer/params.json +1 -0
contraceptive/lct_gan/eval.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
|
2 |
+
lct_gan,0.009699148273355183,,0.0012717798803132062,2.5979268550872803,0.031053613871335983,0.5863479971885681,0.0409364253282547,1.6693826410119073e-06,3.1333327293395996,0.028224041685461998,0.06742087006568909,0.03566202521324158,0.05481972172856331,0.022213930264115334,5.73125958442688
|
contraceptive/lct_gan/history.csv
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
|
2 |
+
0,0.01786272754015954,3.636569730948348,0.001167070859791755,0.0,0.0,0.0,0.0,0.0,0.01786272754015954,320,160,162.09109449386597,1.0130693405866622,0.5065346702933311,0.07982906211551608,0.0036570764575117208,8.543643573567897,1.3811252661610762e-05,0.0,0.0,0.0,0.0,0.0,0.0036570764575117208,80,40,38.24689245223999,0.9561723113059998,0.4780861556529999,0.013961730610299128
|
3 |
+
1,0.005696987385198327,3.198073918344037,6.659749155254468e-05,0.0,0.0,0.0,0.0,0.0,0.005696987385198327,320,160,161.77688694000244,1.0111055433750153,0.5055527716875077,0.06436048086907249,0.003786977470736019,3.9373093709834395,1.9986069321897836e-05,0.0,0.0,0.0,0.0,0.0,0.003786977470736019,80,40,38.12403869628906,0.9531009674072266,0.4765504837036133,0.019040833081817254
|
4 |
+
2,0.0032535013802828415,2.8198290220143503,1.6548445633081822e-05,0.0,0.0,0.0,0.0,0.0,0.0032535013802828415,320,160,158.56593680381775,0.9910371050238609,0.49551855251193044,0.05899094843493913,0.002633672622323502,4.367062080342106,6.056017390343449e-06,0.0,0.0,0.0,0.0,0.0,0.002633672622323502,80,40,36.21775007247925,0.9054437518119812,0.4527218759059906,0.029313163098959195
|
5 |
+
3,0.0023438148911395728,2.2918122927125295,7.134696727795209e-06,0.0,0.0,0.0,0.0,0.0,0.0023438148911395728,320,160,152.8709909915924,0.9554436936974525,0.47772184684872626,0.06780749239678699,0.002467950962409304,4.88692410795129,4.36867752655612e-06,0.0,0.0,0.0,0.0,0.0,0.002467950962409304,80,40,35.697052240371704,0.8924263060092926,0.4462131530046463,0.018721673299660326
|
6 |
+
4,0.002259014635501444,1.812988119399502,5.5367047818208335e-06,0.0,0.0,0.0,0.0,0.0,0.002259014635501444,320,160,151.4134497642517,0.9463340610265731,0.47316703051328657,0.06716944240579323,0.0023299435670196544,7.427999823173169,6.818181761379086e-06,0.0,0.0,0.0,0.0,0.0,0.0023299435670196544,80,40,36.091819047927856,0.9022954761981964,0.4511477380990982,0.025550102235138185
|
7 |
+
5,0.002213509789987711,2.1981050524016665,6.457744553583708e-06,0.0,0.0,0.0,0.0,0.0,0.002213509789987711,320,160,155.00266909599304,0.9687666818499565,0.48438334092497826,0.06307913716664189,0.0025578231319741463,5.699148117736866,7.032264796913435e-06,0.0,0.0,0.0,0.0,0.0,0.0025578231319741463,80,40,35.88962531089783,0.8972406327724457,0.44862031638622285,0.017301402381235675
|
8 |
+
6,0.002113006258792893,2.1890728188584476,6.114715563168734e-06,0.0,0.0,0.0,0.0,0.0,0.002113006258792893,320,160,153.502343416214,0.9593896463513374,0.4796948231756687,0.0661177773316524,0.0022539663767020103,8.082782875575992,6.678637441320801e-06,0.0,0.0,0.0,0.0,0.0,0.0022539663767020103,80,40,35.9239776134491,0.8980994403362275,0.44904972016811373,0.025397659230247883
|
9 |
+
7,0.0018896224307241027,1.3406870565765918,4.11630951652614e-06,0.0,0.0,0.0,0.0,0.0,0.0018896224307241027,320,160,151.49836015701294,0.9468647509813308,0.4734323754906654,0.07347480687003553,0.0025110580976615894,5.483385282401798,4.007785178927748e-06,0.0,0.0,0.0,0.0,0.0,0.0025110580976615894,80,40,33.77740168571472,0.844435042142868,0.422217521071434,0.01343663605657639
|
10 |
+
8,0.0019506988204057052,1.1689463564448361,4.559629974733976e-06,0.0,0.0,0.0,0.0,0.0,0.0019506988204057052,320,160,144.0004370212555,0.9000027313828468,0.4500013656914234,0.0659007933063549,0.0023398275739964446,4.280078241747558,9.40580381603908e-06,0.0,0.0,0.0,0.0,0.0,0.0023398275739964446,80,40,34.14411234855652,0.8536028087139129,0.42680140435695646,0.029161113313784882
|
11 |
+
9,0.0016659495725662055,1.259914455004442,3.6357496224758216e-06,0.0,0.0,0.0,0.0,0.0,0.0016659495725662055,320,160,146.88511276245117,0.9180319547653198,0.4590159773826599,0.06927311308209028,0.0020320220184657954,1.9301895227664638,5.797358790626817e-06,0.0,0.0,0.0,0.0,0.0,0.0020320220184657954,80,40,33.82334542274475,0.8455836355686188,0.4227918177843094,0.031894830953388006
|
12 |
+
10,0.0016075492954826132,1.0092837510064634,2.735296125088143e-06,0.0,0.0,0.0,0.0,0.0,0.0016075492954826132,320,160,143.3337082862854,0.8958356767892838,0.4479178383946419,0.07756731488425431,0.0026302734650016646,3.5820208163912866,1.3108698285635434e-05,0.0,0.0,0.0,0.0,0.0,0.0026302734650016646,80,40,33.34847378730774,0.8337118446826934,0.4168559223413467,0.026611345619312488
|
13 |
+
11,0.0013905711543884536,1.4293391600536105,1.9943211839244628e-06,0.0,0.0,0.0,0.0,0.0,0.0013905711543884536,320,160,145.80800819396973,0.9113000512123108,0.4556500256061554,0.07521452093462813,0.0022007888529515184,4.483247433313909,7.529547302986828e-06,0.0,0.0,0.0,0.0,0.0,0.0022007888529515184,80,40,34.17486619949341,0.8543716549873352,0.4271858274936676,0.025888706676232685
|
14 |
+
12,0.0012333092558833413,1.0602984449560355,1.4978945140242672e-06,0.0,0.0,0.0,0.0,0.0,0.0012333092558833413,320,160,144.26001048088074,0.9016250655055046,0.4508125327527523,0.06890693796813138,0.002605481748287275,5.8101766740395275,1.2162480100338935e-05,0.0,0.0,0.0,0.0,0.0,0.002605481748287275,80,40,34.0310595035553,0.8507764875888825,0.42538824379444123,0.030725552017582914
|
15 |
+
13,0.0011501418213924809,1.3386667872641667,1.3185595447717802e-06,0.0,0.0,0.0,0.0,0.0,0.0011501418213924809,320,160,144.17591977119446,0.9010994985699654,0.4505497492849827,0.07558018262188852,0.0022788071105424024,2.7987204812981075,7.67833690556996e-06,0.0,0.0,0.0,0.0,0.0,0.0022788071105424024,80,40,34.57782983779907,0.8644457459449768,0.4322228729724884,0.028933984229661293
|
16 |
+
14,0.0010513833095672,0.6387672975999288,1.4141241849948554e-06,0.0,0.0,0.0,0.0,0.0,0.0010513833095672,320,160,145.27083349227905,0.907942709326744,0.453971354663372,0.07942418371021631,0.0023162082296039445,3.349859969510388,8.629114276192951e-06,0.0,0.0,0.0,0.0,0.0,0.0023162082296039445,80,40,33.332454442977905,0.8333113610744476,0.4166556805372238,0.03227444588264916
|
17 |
+
15,0.0009490251221507151,0.6831669182588915,1.0006828890851693e-06,0.0,0.0,0.0,0.0,0.0,0.0009490251221507151,320,160,145.07566261291504,0.906722891330719,0.4533614456653595,0.08901288353954442,0.002519568522711779,3.413597872712374,1.3267886331692902e-05,0.0,0.0,0.0,0.0,0.0,0.002519568522711779,80,40,35.61937355995178,0.8904843389987945,0.44524216949939727,0.03244233850882665
|
18 |
+
16,0.0008506465366181715,0.6938899818249922,7.591354541367547e-07,0.0,0.0,0.0,0.0,0.0,0.0008506465366181715,320,160,141.91219687461853,0.8869512304663658,0.4434756152331829,0.07522430039280152,0.0022380605670150543,2.3211301649584923,8.706937260289163e-06,0.0,0.0,0.0,0.0,0.0,0.0022380605670150543,80,40,33.153270959854126,0.8288317739963531,0.41441588699817655,0.033504872786579654
|
19 |
+
17,0.0007984611616791426,0.8550257516943558,6.954888952128042e-07,0.0,0.0,0.0,0.0,0.0,0.0007984611616791426,320,160,141.78189277648926,0.8861368298530579,0.44306841492652893,0.07633930989904911,0.0023599293230745387,2.454739641254777,1.1268823242671644e-05,0.0,0.0,0.0,0.0,0.0,0.0023599293230745387,80,40,33.33975076675415,0.8334937691688538,0.4167468845844269,0.03295552866329672
|
20 |
+
18,0.0007362112768333872,1.1732242802433412,4.6208515920671473e-07,0.0,0.0,0.0,0.0,0.0,0.0007362112768333872,320,160,142.94105648994446,0.8933816030621529,0.44669080153107643,0.07887016502173765,0.0022324465072415477,2.106721719149411,9.551615076732606e-06,0.0,0.0,0.0,0.0,0.0,0.0022324465072415477,80,40,33.48200488090515,0.8370501220226287,0.41852506101131437,0.031034307027584872
|
contraceptive/lct_gan/mlu-eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
contraceptive/lct_gan/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:12c803933371fdedc1397d36f37b854713f585ec63c4a4887467f850e2b255cf
|
3 |
+
size 41106197
|
contraceptive/lct_gan/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 8, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.775, "gradient_penalty_mode": "NONE", "synth_data": 2, "dataset_size": 2048, "batch_size": 2, "epochs": 100, "lr_mul": 0.075, "n_warmup_steps": 100, "Optim": "amsgradw", "loss_balancer_beta": 0.675, "loss_balancer_r": 0.95, "fixed_role_model": "lct_gan", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": false, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation": "leakyhardtanh", "tf_activation_final": "leakyhardtanh", "tf_num_inds": 64, "ada_d_hid": 1024, "ada_n_layers": 8, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation": "relu6", "head_activation_final": "leakyhardsigmoid", "models": ["lct_gan"], "max_seconds": 3600}
|
contraceptive/realtabformer/eval.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
|
2 |
+
realtabformer,0.01544937376803275,,0.0014923845270062802,2.446577548980713,0.11075861752033234,1.6372919082641602,0.24003435671329498,1.4633540104114218e-06,4.698302984237671,0.03089020401239395,0.07129628211259842,0.03863139450550079,0.05503246188163757,0.02127235010266304,7.144880533218384
|
contraceptive/realtabformer/history.csv
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
|
2 |
+
0,0.015214953777865503,3.156637710683768,0.0015362251463382382,0.0,0.0,0.0,0.0,0.0,0.015214953777865503,320,160,150.5653796195984,0.9410336226224899,0.47051681131124495,0.06246557859708446,0.0029828296956111444,3.72726428431983,7.1439977092975506e-06,0.0,0.0,0.0,0.0,0.0,0.0029828296956111444,80,40,35.29301643371582,0.8823254108428955,0.44116270542144775,0.017451438040006907
|
3 |
+
1,0.005820814266519392,3.788629152714401,0.00012637018700479842,0.0,0.0,0.0,0.0,0.0,0.005820814266519392,320,160,151.99716687202454,0.9499822929501534,0.4749911464750767,0.054377025530902755,0.003795744390345135,4.469785842269539,2.3026626298489063e-05,0.0,0.0,0.0,0.0,0.0,0.003795744390345135,80,40,35.12023663520813,0.8780059158802033,0.43900295794010163,0.025538454634443042
|
4 |
+
2,0.003438323737486826,2.1381948816152487,2.1238802406123302e-05,0.0,0.0,0.0,0.0,0.0,0.003438323737486826,320,160,151.25162959098816,0.945322684943676,0.472661342471838,0.059417175995986324,0.0027740994402847717,3.3004495511856193,8.8964574990494e-06,0.0,0.0,0.0,0.0,0.0,0.0027740994402847717,80,40,35.27286076545715,0.8818215191364288,0.4409107595682144,0.02031386639282573
|
5 |
+
3,0.003000541084111319,1.2077899141461088,1.181568532097677e-05,0.0,0.0,0.0,0.0,0.0,0.003000541084111319,320,160,151.62242531776428,0.9476401582360268,0.4738200791180134,0.0704218547190976,0.003368616230500265,3.7329810831716825,2.0277994416362245e-05,0.0,0.0,0.0,0.0,0.0,0.003368616230500265,80,40,35.269885778427124,0.8817471444606781,0.4408735722303391,0.022317990543524503
|
6 |
+
4,0.0021543572441999003,2.36580728989997,3.905246718544086e-06,0.0,0.0,0.0,0.0,0.0,0.0021543572441999003,320,160,151.42908096313477,0.9464317560195923,0.47321587800979614,0.06814513673386387,0.0023375894830166997,3.1858106834027735,5.140880977472922e-06,0.0,0.0,0.0,0.0,0.0,0.0023375894830166997,80,40,35.21497082710266,0.8803742706775666,0.4401871353387833,0.022702036050031894
|
7 |
+
5,0.0020639134972753937,1.8771122450190234,4.7617895551466114e-06,0.0,0.0,0.0,0.0,0.0,0.0020639134972753937,320,160,150.65398001670837,0.9415873751044274,0.4707936875522137,0.06839547897311604,0.002713540389231639,2.751661246550566,8.476993637529517e-06,0.0,0.0,0.0,0.0,0.0,0.002713540389231639,80,40,35.15806221961975,0.8789515554904938,0.4394757777452469,0.0216164964978816
|
8 |
+
6,0.001974041026574014,2.6116740645130543,4.903630427295745e-06,0.0,0.0,0.0,0.0,0.0,0.001974041026574014,320,160,150.98297429084778,0.9436435893177986,0.4718217946588993,0.06690819428837073,0.0026584940679640567,2.928696505277807,1.0981135633864048e-05,0.0,0.0,0.0,0.0,0.0,0.0026584940679640567,80,40,34.98962068557739,0.8747405171394348,0.4373702585697174,0.028631291013152805
|
9 |
+
7,0.0017634188792953864,1.5246528687675955,4.08627352063845e-06,0.0,0.0,0.0,0.0,0.0,0.0017634188792953864,320,160,151.10478925704956,0.9444049328565598,0.4722024664282799,0.06988477217641957,0.0025102962197934174,3.828912246527557,9.336235920053004e-06,0.0,0.0,0.0,0.0,0.0,0.0025102962197934174,80,40,35.14225649833679,0.8785564124584198,0.4392782062292099,0.025317877356610553
|
10 |
+
8,0.0017970734881543216,1.9675439566956825,4.507927208244523e-06,0.0,0.0,0.0,0.0,0.0,0.0017970734881543216,320,160,151.0646207332611,0.9441538795828819,0.47207693979144094,0.06440023747350096,0.0029548812297832683,2.1843356087258696,1.6957902914016554e-05,0.0,0.0,0.0,0.0,0.0,0.0029548812297832683,80,40,35.090479135513306,0.8772619783878326,0.4386309891939163,0.028445404235390014
|
11 |
+
9,0.0016425128245685983,2.024512654822502,3.285046532060373e-06,0.0,0.0,0.0,0.0,0.0,0.0016425128245685983,320,160,150.84772562980652,0.9427982851862907,0.47139914259314536,0.0747756733842209,0.002334522669548278,2.7116857997084027,6.992434950987836e-06,0.0,0.0,0.0,0.0,0.0,0.002334522669548278,80,40,35.245197772979736,0.8811299443244934,0.4405649721622467,0.026337641538702883
|
12 |
+
10,0.0016268517200273892,1.4290221301512553,2.7465491623539574e-06,0.0,0.0,0.0,0.0,0.0,0.0016268517200273892,320,160,150.6063461303711,0.9412896633148193,0.47064483165740967,0.07696705762027704,0.0022422952166834876,3.8237469222483185,4.734226487249082e-06,0.0,0.0,0.0,0.0,0.0,0.0022422952166834876,80,40,35.0414342880249,0.8760358572006226,0.4380179286003113,0.02215462920921709
|
13 |
+
11,0.001684735846095009,1.8539249592111176,3.617836458871815e-06,0.0,0.0,0.0,0.0,0.0,0.001684735846095009,320,160,151.0113205909729,0.9438207536935806,0.4719103768467903,0.06602817026396224,0.0023927180209284415,1.9305211880035131,8.37876484468536e-06,0.0,0.0,0.0,0.0,0.0,0.0023927180209284415,80,40,35.019232988357544,0.8754808247089386,0.4377404123544693,0.02876366543350741
|
14 |
+
12,0.001648270361597781,1.466464621800828,3.6850406459681182e-06,0.0,0.0,0.0,0.0,0.0,0.001648270361597781,320,160,150.67579007148743,0.9417236879467964,0.4708618439733982,0.07282405201085566,0.0024106145921905407,1.8117562619359986,8.840941997867446e-06,0.0,0.0,0.0,0.0,0.0,0.0024106145921905407,80,40,35.25731110572815,0.8814327776432037,0.44071638882160186,0.028296888258773835
|
15 |
+
13,0.0014949767563791738,2.0110324508543216,2.594263938809541e-06,0.0,0.0,0.0,0.0,0.0,0.0014949767563791738,320,160,150.59481382369995,0.9412175863981247,0.47060879319906235,0.07127468679950652,0.0034006003257673,2.3089766705088124,2.330103268749495e-05,0.0,0.0,0.0,0.0,0.0,0.0034006003257673,80,40,35.14684081077576,0.8786710202693939,0.43933551013469696,0.03264807362284046
|
16 |
+
14,0.0014807669692402214,1.486915388038304,2.468103663008994e-06,0.0,0.0,0.0,0.0,0.0,0.0014807669692402214,320,160,150.8731460571289,0.9429571628570557,0.47147858142852783,0.06568786399493547,0.002275905742408213,2.4593560876942546,5.3657121410810585e-06,0.0,0.0,0.0,0.0,0.0,0.002275905742408213,80,40,35.17850923538208,0.879462730884552,0.439731365442276,0.028419802509597504
|
17 |
+
15,0.0013959772029608075,1.5552767487895864,2.032669835160851e-06,0.0,0.0,0.0,0.0,0.0,0.0013959772029608075,320,160,145.97118186950684,0.9123198866844178,0.4561599433422089,0.0743303620764891,0.0021730058373577777,3.265071701107611,5.249506042540042e-06,0.0,0.0,0.0,0.0,0.0,0.0021730058373577777,80,40,32.64319920539856,0.816079980134964,0.408039990067482,0.02472462045188877
|
18 |
+
16,0.0013986327389147845,1.0404113516639566,2.472861648920811e-06,0.0,0.0,0.0,0.0,0.0,0.0013986327389147845,320,160,142.02302479743958,0.8876439049839974,0.4438219524919987,0.0766021506049583,0.0021833765216797475,1.9410757024995746,5.971337235900851e-06,0.0,0.0,0.0,0.0,0.0,0.0021833765216797475,80,40,33.03069186210632,0.8257672965526581,0.41288364827632906,0.029009606450563295
|
19 |
+
17,0.0013606195244022957,1.2219324406304888,2.558043809543567e-06,0.0,0.0,0.0,0.0,0.0,0.0013606195244022957,320,160,142.88847541809082,0.8930529713630676,0.4465264856815338,0.07105417825841868,0.0025992857859819195,4.460525457162658,1.0201318598324071e-05,0.0,0.0,0.0,0.0,0.0,0.0025992857859819195,80,40,32.683080196380615,0.8170770049095154,0.4085385024547577,0.02676110131196765
|
20 |
+
18,0.0013010777075521673,0.9639209467314742,2.0914453604690185e-06,0.0,0.0,0.0,0.0,0.0,0.0013010777075521673,320,160,141.19299387931824,0.882456211745739,0.4412281058728695,0.07996181348535174,0.0021221282840997446,2.53250820556988,6.020639418136131e-06,0.0,0.0,0.0,0.0,0.0,0.0021221282840997446,80,40,32.62534475326538,0.8156336188316345,0.40781680941581727,0.028227102017262952
|
contraceptive/realtabformer/mlu-eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
contraceptive/realtabformer/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4cc609d6c10be1c4d1dc21442aa2a3a961cbd654795f076579f799dd433c742
|
3 |
+
size 43889419
|
contraceptive/realtabformer/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 8, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.775, "gradient_penalty_mode": "NONE", "synth_data": 2, "dataset_size": 2048, "batch_size": 2, "epochs": 100, "lr_mul": 0.075, "n_warmup_steps": 100, "Optim": "amsgradw", "loss_balancer_beta": 0.675, "loss_balancer_r": 0.95, "fixed_role_model": "realtabformer", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": false, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation": "leakyhardtanh", "tf_activation_final": "leakyhardtanh", "tf_num_inds": 64, "ada_d_hid": 1024, "ada_n_layers": 8, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation": "relu6", "head_activation_final": "leakyhardsigmoid", "models": ["realtabformer"], "max_seconds": 3600}
|
contraceptive/tab_ddpm_concat/eval.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
|
2 |
+
tab_ddpm_concat,0.004686037855951365,0.016378744196746925,0.0026090041155702806,3.8580918312072754,0.06953004002571106,0.8769555687904358,0.09042102098464966,1.2404520020936616e-05,1.3648459911346436,0.03967232629656792,0.0928136557340622,0.05107840895652771,0.06657693535089493,7.981087151165411e-07,5.222937822341919
|
contraceptive/tab_ddpm_concat/history.csv
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
|
2 |
+
0,0.017509687443816802,0.25492126335856824,0.000845979718699752,0.0,0.0,0.0,0.0,0.0,0.017509687443816802,320,80,74.85561537742615,0.9356951922178268,0.2339237980544567,0.12247967834118753,0.03225579813006334,0.3893812867692759,0.0018670448790572892,0.0,0.0,0.0,0.0,0.0,0.03225579813006334,80,20,16.974793434143066,0.8487396717071534,0.21218491792678834,0.11351076629944146
|
3 |
+
1,0.015060588270716834,0.5294836329319879,0.0005396637374993886,0.0,0.0,0.0,0.0,0.0,0.015060588270716834,320,80,74.75039911270142,0.9343799889087677,0.23359499722719193,0.10789964701980352,0.017869547638110817,3.109420410258463,0.0007849385737095816,0.0,0.0,0.0,0.0,0.0,0.017869547638110817,80,20,17.005717754364014,0.8502858877182007,0.21257147192955017,0.032646807050332426
|
4 |
+
2,0.007901813013450009,0.43976500204076957,0.00010274562480983643,0.0,0.0,0.0,0.0,0.0,0.007901813013450009,320,80,74.73881554603577,0.9342351943254471,0.23355879858136178,0.09000834664329886,0.006841135048307479,1.7945492254511919,8.046349178982836e-05,0.0,0.0,0.0,0.0,0.0,0.006841135048307479,80,20,16.90992760658264,0.8454963803291321,0.21137409508228303,0.052934233518317345
|
5 |
+
3,0.005526901292250841,0.4796540130246029,5.3269670587949184e-05,0.0,0.0,0.0,0.0,0.0,0.005526901292250841,320,80,74.69570064544678,0.9336962580680848,0.2334240645170212,0.09062463160371408,0.004396481180447154,1.441257982449315,1.9934287330158895e-05,0.0,0.0,0.0,0.0,0.0,0.004396481180447154,80,20,16.831034421920776,0.8415517210960388,0.2103879302740097,0.034367192443460225
|
6 |
+
4,0.003952335390204098,0.6800493642964284,3.501202619586863e-05,0.0,0.0,0.0,0.0,0.0,0.003952335390204098,320,80,74.82494044303894,0.9353117555379867,0.23382793888449668,0.08449668972752988,0.0030278531834483148,1.419872753619893,1.2108854497228094e-05,0.0,0.0,0.0,0.0,0.0,0.0030278531834483148,80,20,16.83168315887451,0.8415841579437255,0.21039603948593139,0.04602950892876834
|
7 |
+
5,0.003957326662930427,0.3222652507973578,1.749390277871613e-05,0.0,0.0,0.0,0.0,0.0,0.003957326662930427,320,80,74.4447557926178,0.9305594474077225,0.2326398618519306,0.09507375009125099,0.003036417685507331,1.8372398112704105,1.3633386806094494e-05,0.0,0.0,0.0,0.0,0.0,0.003036417685507331,80,20,16.73884344100952,0.836942172050476,0.209235543012619,0.03600916846189648
|
8 |
+
6,0.0028476251969550503,0.2714852012659293,1.3130850977921548e-05,0.0,0.0,0.0,0.0,0.0,0.0028476251969550503,320,80,75.16001582145691,0.9395001977682114,0.23487504944205284,0.09719131344463676,0.0032441405899589883,2.57110884013091,1.2244734485200582e-05,0.0,0.0,0.0,0.0,0.0,0.0032441405899589883,80,20,16.80543065071106,0.840271532535553,0.21006788313388824,0.034793011099100116
|
9 |
+
7,0.002179265605263936,0.2912421387520652,5.355932938649715e-06,0.0,0.0,0.0,0.0,0.0,0.002179265605263936,320,80,75.13848423957825,0.9392310529947281,0.23480776324868202,0.09003764551598578,0.002960549862473272,1.5272669666737784,1.1626163023858993e-05,0.0,0.0,0.0,0.0,0.0,0.002960549862473272,80,20,16.840531826019287,0.8420265913009644,0.2105066478252411,0.048068627482280135
|
10 |
+
8,0.0019942367394833126,0.8764173788223844,5.071989225379896e-06,0.0,0.0,0.0,0.0,0.0,0.0019942367394833126,320,80,74.68324661254883,0.9335405826568604,0.2333851456642151,0.0842734721081797,0.003437347624276299,1.6128702243404405,2.132158708016774e-05,0.0,0.0,0.0,0.0,0.0,0.003437347624276299,80,20,16.703343152999878,0.8351671576499939,0.20879178941249849,0.04150933439377695
|
11 |
+
9,0.001910439515268081,0.5296208621499737,3.807974610924676e-06,0.0,0.0,0.0,0.0,0.0,0.001910439515268081,320,80,74.98403477668762,0.9373004347085953,0.2343251086771488,0.0939617162453942,0.0029005830438109115,1.4297594713909347,9.182002216068242e-06,0.0,0.0,0.0,0.0,0.0,0.0029005830438109115,80,20,16.704230070114136,0.8352115035057068,0.2088028758764267,0.040789688983932135
|
12 |
+
10,0.002364715466683265,0.23049106563653615,1.08880691394031e-05,0.0,0.0,0.0,0.0,0.0,0.002364715466683265,320,80,74.73340845108032,0.934167605638504,0.233541901409626,0.09482922677416354,0.0025556790380505843,1.4470476474137044,9.480406802708785e-06,0.0,0.0,0.0,0.0,0.0,0.0025556790380505843,80,20,16.752872228622437,0.8376436114311219,0.20941090285778047,0.04689050167798996
|
13 |
+
11,0.001990120611480961,0.20637534558109127,5.3637341571725695e-06,0.0,0.0,0.0,0.0,0.0,0.001990120611480961,320,80,74.96153736114502,0.9370192170143128,0.2342548042535782,0.09388396987924352,0.0026564171042991803,1.8061061197324306,1.1139397728709977e-05,0.0,0.0,0.0,0.0,0.0,0.0026564171042991803,80,20,16.754515647888184,0.8377257823944092,0.2094314455986023,0.046416288684122266
|
14 |
+
12,0.0018798561781295576,0.3383319207922398,4.4709591399128e-06,0.0,0.0,0.0,0.0,0.0,0.0018798561781295576,320,80,74.77418828010559,0.9346773535013199,0.23366933837532997,0.0905981837247964,0.0026210575761069776,2.1850536189552257,5.391822381461964e-06,0.0,0.0,0.0,0.0,0.0,0.0026210575761069776,80,20,16.748401641845703,0.8374200820922851,0.20935502052307128,0.03947516868356615
|
15 |
+
13,0.0018263132704305462,0.4754561664466223,3.583063472956116e-06,0.0,0.0,0.0,0.0,0.0,0.0018263132704305462,320,80,74.7354383468628,0.9341929793357849,0.23354824483394623,0.0871307724271901,0.002742944849887863,2.296998679006356,8.635369825379935e-06,0.0,0.0,0.0,0.0,0.0,0.002742944849887863,80,20,16.92655324935913,0.8463276624679565,0.21158191561698914,0.04202657011337578
|
16 |
+
14,0.0017013913855407736,0.2523655897014123,2.6748898654643803e-06,0.0,0.0,0.0,0.0,0.0,0.0017013913855407736,320,80,75.46495079994202,0.9433118849992752,0.2358279712498188,0.09202192013035529,0.00283771293470636,1.8566916088265089,1.0541326899016213e-05,0.0,0.0,0.0,0.0,0.0,0.00283771293470636,80,20,17.10892963409424,0.8554464817047119,0.21386162042617798,0.05005494304932654
|
17 |
+
15,0.0015473646866666969,0.26786694799376515,3.7729344502605496e-06,0.0,0.0,0.0,0.0,0.0,0.0015473646866666969,320,80,74.63104557991028,0.9328880697488785,0.23322201743721963,0.09204029910615645,0.00325308749161195,1.8810293299167824,1.5295879933319156e-05,0.0,0.0,0.0,0.0,0.0,0.00325308749161195,80,20,16.859638690948486,0.8429819345474243,0.21074548363685608,0.04019828836899251
|
contraceptive/tab_ddpm_concat/mlu-eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
contraceptive/tab_ddpm_concat/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ec4624dc56aa9865a93acbdcdeae70f85f9456a946fb2e5ed9cd8b5dc9f4c19
|
3 |
+
size 45181003
|
contraceptive/tab_ddpm_concat/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 8, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mse", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "none", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.74, "gradient_penalty_mode": "NONE", "synth_data": 2, "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.075, "n_warmup_steps": 100, "Optim": "amsgradw", "loss_balancer_beta": 0.675, "loss_balancer_r": 0.95, "fixed_role_model": "tab_ddpm_concat", "mse_mag": false, "mse_mag_target": 1.0, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation": "tanh", "tf_activation_final": "leakyhardtanh", "tf_num_inds": 64, "ada_d_hid": 1024, "ada_n_layers": 9, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation": "softsign", "head_activation_final": "leakyhardsigmoid", "models": ["tab_ddpm_concat"], "max_seconds": 3600}
|
contraceptive/tvae/eval.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
|
2 |
+
tvae,0.014026033173837974,,0.0012284577070317652,2.706435203552246,0.03164428099989891,0.6164292693138123,0.0400746688246727,9.105955882660055e-07,3.2376320362091064,0.02793470025062561,0.06429528445005417,0.03504936024546623,0.057808149605989456,0.011085247620940208,5.9440672397613525
|
contraceptive/tvae/history.csv
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
|
2 |
+
0,0.02241888580356317,1.4905488636076916,0.0030782962097319457,0.0,0.0,0.0,0.0,0.0,0.02241888580356317,320,160,142.862961769104,0.8928935110569001,0.44644675552845003,0.0961261961127093,0.007483271204910125,7.415935450342414,9.700103195409149e-05,0.0,0.0,0.0,0.0,0.0,0.007483271204910125,80,40,32.79719305038452,0.8199298262596131,0.40996491312980654,0.030276008496821306
|
3 |
+
1,0.004170822119840522,2.427984166694133,5.6218089488430103e-05,0.0,0.0,0.0,0.0,0.0,0.004170822119840522,320,160,140.515380859375,0.8782211303710937,0.43911056518554686,0.06761386157022571,0.0027781161175880697,5.97971810359972,8.342038965802879e-06,0.0,0.0,0.0,0.0,0.0,0.0027781161175880697,80,40,32.894118309020996,0.8223529577255249,0.41117647886276243,0.024398993137219806
|
4 |
+
2,0.0032142372105653295,3.1775326946756253,9.600223262965901e-06,0.0,0.0,0.0,0.0,0.0,0.0032142372105653295,320,160,144.2568175792694,0.9016051098704339,0.4508025549352169,0.06473975269825587,0.002925431027142622,6.108939102519116,8.98504544019768e-06,0.0,0.0,0.0,0.0,0.0,0.002925431027142622,80,40,35.61764717102051,0.8904411792755127,0.44522058963775635,0.028732989538184484
|
5 |
+
3,0.003577234379551553,2.9195899648459376,4.348199776086897e-05,0.0,0.0,0.0,0.0,0.0,0.003577234379551553,320,160,150.63309359550476,0.9414568349719048,0.4707284174859524,0.06363038770923594,0.0032030581480285035,5.47701075857707,1.4309088606778708e-05,0.0,0.0,0.0,0.0,0.0,0.0032030581480285035,80,40,32.49244570732117,0.8123111426830292,0.4061555713415146,0.021739204511686695
|
6 |
+
4,0.002611448067193578,1.8328958396290929,8.810737564255572e-06,0.0,0.0,0.0,0.0,0.0,0.002611448067193578,320,160,143.45851230621338,0.8966157019138337,0.44830785095691683,0.07633217830557441,0.0030096135813437288,5.497269465320635,8.17212617150176e-06,0.0,0.0,0.0,0.0,0.0,0.0030096135813437288,80,40,34.19865131378174,0.8549662828445435,0.42748314142227173,0.01728544359702937
|
7 |
+
5,0.002066187719401569,1.418562725057735,4.818185051591941e-06,0.0,0.0,0.0,0.0,0.0,0.002066187719401569,320,160,141.1322615146637,0.8820766344666481,0.44103831723332404,0.07070515162549781,0.002357064618456661,3.038762490750969,4.724545272427605e-06,0.0,0.0,0.0,0.0,0.0,0.002357064618456661,80,40,32.73992657661438,0.8184981644153595,0.40924908220767975,0.019966062564344612
|
8 |
+
6,0.0018150892569863686,1.9370867185192977,4.895915542963466e-06,0.0,0.0,0.0,0.0,0.0,0.0018150892569863686,320,160,142.5343050956726,0.8908394068479538,0.4454197034239769,0.06771241171363726,0.002098434802974225,2.5296149099483842,6.119135131865683e-06,0.0,0.0,0.0,0.0,0.0,0.002098434802974225,80,40,33.130537033081055,0.8282634258270264,0.4141317129135132,0.036213114765996576
|
9 |
+
7,0.0017754018189464205,1.0608564720709155,4.110462002784865e-06,0.0,0.0,0.0,0.0,0.0,0.0017754018189464205,320,160,150.93800163269043,0.9433625102043152,0.4716812551021576,0.07812782935689029,0.002651070246429299,5.481469099184153,8.274616622792885e-06,0.0,0.0,0.0,0.0,0.0,0.002651070246429299,80,40,36.68884253501892,0.917221063375473,0.4586105316877365,0.0201021930330171
|
10 |
+
8,0.0016320147636861293,1.572569604070008,3.280935549836465e-06,0.0,0.0,0.0,0.0,0.0,0.0016320147636861293,320,160,152.10332083702087,0.9506457552313805,0.47532287761569025,0.07706006977591642,0.0021084082123252303,4.821968620683037,7.140439676618648e-06,0.0,0.0,0.0,0.0,0.0,0.0021084082123252303,80,40,35.41457486152649,0.8853643715381623,0.44268218576908114,0.032409553838078864
|
11 |
+
9,0.0014390503529739362,1.1130779689192227,1.9856122689985296e-06,0.0,0.0,0.0,0.0,0.0,0.0014390503529739362,320,160,148.04402089118958,0.9252751305699348,0.4626375652849674,0.07865846673303167,0.002113264991157848,2.781704166371675,4.972169583404757e-06,0.0,0.0,0.0,0.0,0.0,0.002113264991157848,80,40,33.858819007873535,0.8464704751968384,0.4232352375984192,0.02809684935346013
|
12 |
+
10,0.001374225791067829,1.163778030275184,1.7583196497888975e-06,0.0,0.0,0.0,0.0,0.0,0.001374225791067829,320,160,145.61796760559082,0.9101122975349426,0.4550561487674713,0.072609595393169,0.0023332797987222877,2.608034198338737,1.0012352827615257e-05,0.0,0.0,0.0,0.0,0.0,0.0023332797987222877,80,40,33.24501919746399,0.8311254799365997,0.41556273996829984,0.03622158533107722
|
13 |
+
11,0.0013136249404567478,1.105370874132261,2.0836581030414523e-06,0.0,0.0,0.0,0.0,0.0,0.0013136249404567478,320,160,144.06322360038757,0.9003951475024223,0.45019757375121117,0.07791482849102067,0.0020939761153385915,7.381703513306002,3.89069975454473e-06,0.0,0.0,0.0,0.0,0.0,0.0020939761153385915,80,40,33.347615242004395,0.8336903810501098,0.4168451905250549,0.01896329457867978
|
14 |
+
12,0.0013007374736943688,0.8016777972321465,1.6661171600203944e-06,0.0,0.0,0.0,0.0,0.0,0.0013007374736943688,320,160,143.7109453678131,0.8981934085488319,0.44909670427441595,0.07403813572964282,0.0021091806715048734,2.195618169948898,6.716770201746968e-06,0.0,0.0,0.0,0.0,0.0,0.0021091806715048734,80,40,33.3885440826416,0.8347136020660401,0.41735680103302003,0.03174531738768564
|
15 |
+
13,0.0011258274745216568,0.9933245053406304,1.2559000061217402e-06,0.0,0.0,0.0,0.0,0.0,0.0011258274745216568,320,160,143.82260847091675,0.8988913029432297,0.44944565147161486,0.07367398725546082,0.002973305231353152,2.332661612354639,1.9470425126388857e-05,0.0,0.0,0.0,0.0,0.0,0.002973305231353152,80,40,33.963603019714355,0.8490900754928589,0.42454503774642943,0.035911593766650186
|
16 |
+
14,0.0010081856245165,1.8124559902420032,1.1379342504010126e-06,0.0,0.0,0.0,0.0,0.0,0.0010081856245165,320,160,144.61974716186523,0.9038734197616577,0.45193670988082885,0.07099440268893886,0.002199153335527626,2.3161073656544886,7.577638564465472e-06,0.0,0.0,0.0,0.0,0.0,0.002199153335527626,80,40,33.57773303985596,0.8394433259963989,0.41972166299819946,0.02861409220568021
|
17 |
+
15,0.0010586415690795547,1.0111776571605908,1.3155730025755604e-06,0.0,0.0,0.0,0.0,0.0,0.0010586415690795547,320,160,143.99448657035828,0.8999655410647392,0.4499827705323696,0.07437738951684877,0.0024275374956232556,2.65694752669535,1.2085388890881177e-05,0.0,0.0,0.0,0.0,0.0,0.0024275374956232556,80,40,33.66980719566345,0.8417451798915863,0.42087258994579313,0.034823847954976374
|
18 |
+
16,0.0008538674877796026,1.5765032050085541,5.374222879224455e-07,0.0,0.0,0.0,0.0,0.0,0.0008538674877796026,320,160,143.34916639328003,0.8959322899580002,0.4479661449790001,0.08046830528701321,0.0022838078687641428,2.0230109165978774,9.47888851559331e-06,0.0,0.0,0.0,0.0,0.0,0.0022838078687641428,80,40,33.02995800971985,0.8257489502429962,0.4128744751214981,0.030621698120376094
|
19 |
+
17,0.0008248503026644372,0.35676954403032646,7.142933498651121e-07,0.0,0.0,0.0,0.0,0.0,0.0008248503026644372,320,160,143.77559542655945,0.8985974714159966,0.4492987357079983,0.08080137882643612,0.0023548384517198427,4.805163549015765,1.1883367263940125e-05,0.0,0.0,0.0,0.0,0.0,0.0023548384517198427,80,40,33.76482057571411,0.8441205143928527,0.42206025719642637,0.03005029430896684
|
20 |
+
18,0.0007936206464748352,1.0760348675862972,7.879423535514518e-07,0.0,0.0,0.0,0.0,0.0,0.0007936206464748352,320,160,144.43712854385376,0.902732053399086,0.451366026699543,0.0707601236276787,0.0022624223027378322,4.505487352633622,1.0153568444593031e-05,0.0,0.0,0.0,0.0,0.0,0.0022624223027378322,80,40,33.66985249519348,0.841746312379837,0.4208731561899185,0.03378975939194788
|
contraceptive/tvae/mlu-eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
contraceptive/tvae/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:36d7ca6e137da75d51ef8ea5a79b9d92615b195432807d283d4d140daf9b0271
|
3 |
+
size 41130645
|
contraceptive/tvae/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 8, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.775, "gradient_penalty_mode": "NONE", "synth_data": 2, "dataset_size": 2048, "batch_size": 2, "epochs": 100, "lr_mul": 0.075, "n_warmup_steps": 100, "Optim": "amsgradw", "loss_balancer_beta": 0.675, "loss_balancer_r": 0.95, "fixed_role_model": "tvae", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": false, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation": "leakyhardtanh", "tf_activation_final": "leakyhardtanh", "tf_num_inds": 64, "ada_d_hid": 1024, "ada_n_layers": 8, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation": "relu6", "head_activation_final": "leakyhardsigmoid", "models": ["tvae"], "max_seconds": 3600}
|
insurance/lct_gan/eval.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
|
2 |
+
lct_gan,0.07935211913926261,0.1306050524190255,0.0007405445374794104,0.5596821308135986,0.03549230098724365,0.7786027193069458,0.054690830409526825,9.762088666320778e-07,0.8873205184936523,0.020644349977374077,0.3238646686077118,0.027212947607040405,0.15517421066761017,1.0350788215873763e-05,1.447002649307251
|
insurance/lct_gan/history.csv
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
|
2 |
+
0,0.07503844532329822,3.76196769104788,0.039352887886764186,0.0,0.0,0.0,0.0,0.0,0.07503844532329822,320,40,42.94257736206055,1.0735644340515136,0.1341955542564392,0.13469835626892745,0.008877724732155912,3.33818835544007,2.9330407841143823e-06,0.0,0.0,0.0,0.0,0.0,0.008877724732155912,80,10,9.165157794952393,0.9165157794952392,0.1145644724369049,0.037592777702957395
|
3 |
+
1,0.02608004305366194,2.0130314562969813,0.005282479949307373,0.0,0.0,0.0,0.0,0.0,0.02608004305366194,320,40,42.331037521362305,1.0582759380340576,0.1322844922542572,0.10117223353590817,0.005078719957964495,0.2115427433644072,3.133896269957859e-05,0.0,0.0,0.0,0.0,0.0,0.005078719957964495,80,10,9.217864990234375,0.9217864990234375,0.11522331237792968,0.11006196644157171
|
4 |
+
2,0.007246867158391979,3.0597032676599154,9.03313408463391e-05,0.0,0.0,0.0,0.0,0.0,0.007246867158391979,320,40,42.494264125823975,1.0623566031455993,0.13279457539319992,0.07669011817779392,0.0074968072643969205,2.55860044410183,0.00017504165297967944,0.0,0.0,0.0,0.0,0.0,0.0074968072643969205,80,10,9.166083335876465,0.9166083335876465,0.11457604169845581,0.09575922545045615
|
5 |
+
3,0.006846483054687269,2.2157006146561513,0.0007793176604440372,0.0,0.0,0.0,0.0,0.0,0.006846483054687269,320,40,42.40406584739685,1.0601016461849213,0.13251270577311516,0.08761264618951828,0.0020610511739505453,0.43892243231239264,7.913704455120296e-06,0.0,0.0,0.0,0.0,0.0,0.0020610511739505453,80,10,9.511794328689575,0.9511794328689576,0.1188974291086197,0.08304880987852811
|
6 |
+
4,0.002302334751948365,0.6305191363795544,1.4314678949647008e-05,0.0,0.0,0.0,0.0,0.0,0.002302334751948365,320,40,43.27270483970642,1.0818176209926604,0.13522720262408255,0.08970329142175615,0.004081522431806661,0.026551660033874214,2.7209868290256623e-05,0.0,0.0,0.0,0.0,0.0,0.004081522431806661,80,10,9.23831558227539,0.923831558227539,0.11547894477844238,0.11656681830063462
|
7 |
+
5,0.0013704985673030023,0.32384693517769847,7.155363357548236e-06,0.0,0.0,0.0,0.0,0.0,0.0013704985673030023,320,40,42.453025341033936,1.0613256335258483,0.13266570419073104,0.09745097612030804,0.0033582281379494817,0.619899958840142,2.357043052825247e-06,0.0,0.0,0.0,0.0,0.0,0.0033582281379494817,80,10,9.264163732528687,0.9264163732528686,0.11580204665660858,0.06247350247576833
|
8 |
+
6,0.0025572317323167225,1.3281221274127346,4.255108958927875e-06,0.0,0.0,0.0,0.0,0.0,0.0025572317323167225,320,40,42.2554829120636,1.05638707280159,0.13204838410019876,0.08552498414646834,0.001253202352381777,0.32618888739889373,2.2838767717248133e-06,0.0,0.0,0.0,0.0,0.0,0.001253202352381777,80,10,9.205610513687134,0.9205610513687134,0.11507013142108917,0.07330623050220311
|
9 |
+
7,0.0017415513455489417,0.39128143024250334,1.98570894781383e-06,0.0,0.0,0.0,0.0,0.0,0.0017415513455489417,320,40,42.568729639053345,1.0642182409763337,0.1330272801220417,0.08582097220933065,0.002552773474599235,0.38444976235623474,9.21505393498695e-06,0.0,0.0,0.0,0.0,0.0,0.002552773474599235,80,10,9.204639673233032,0.9204639673233033,0.11505799591541291,0.09078566757962107
|
10 |
+
8,0.001197526408395788,0.5185240543789404,6.638705742609621e-06,0.0,0.0,0.0,0.0,0.0,0.001197526408395788,320,40,42.46436643600464,1.061609160900116,0.1327011451125145,0.0921793600777164,0.0011125051009003074,0.1445327332803572,3.7028677351003125e-06,0.0,0.0,0.0,0.0,0.0,0.0011125051009003074,80,10,9.251813411712646,0.9251813411712646,0.11564766764640808,0.08213724349625409
|
11 |
+
9,0.0011423767211454106,0.24242009912380066,5.1029623472642616e-06,0.0,0.0,0.0,0.0,0.0,0.0011423767211454106,320,40,42.30474233627319,1.0576185584068298,0.13220231980085373,0.086794763058424,0.0021649388814694247,1.9614427807347965,1.2043508045372908e-05,0.0,0.0,0.0,0.0,0.0,0.0021649388814694247,80,10,9.148065567016602,0.9148065567016601,0.11435081958770751,0.0991100890096277
|
12 |
+
10,0.0008830112970827031,0.4022787155074184,5.028790277500362e-07,0.0,0.0,0.0,0.0,0.0,0.0008830112970827031,320,40,42.54185461997986,1.0635463654994965,0.13294329568743707,0.09175913570215925,0.0017155635854578578,1.2685116354904722,2.10015661070706e-06,0.0,0.0,0.0,0.0,0.0,0.0017155635854578578,80,10,9.253859281539917,0.9253859281539917,0.11567324101924896,0.0978053328813985
|
13 |
+
11,0.001937569323945354,0.6675240401481404,3.123376906712105e-06,0.0,0.0,0.0,0.0,0.0,0.001937569323945354,320,40,42.27454662322998,1.0568636655807495,0.1321079581975937,0.09221202009357513,0.0027879032801138236,0.09599128968548029,1.215036173931594e-05,0.0,0.0,0.0,0.0,0.0,0.0027879032801138236,80,10,9.17811369895935,0.917811369895935,0.11472642123699188,0.11065587596967816
|
14 |
+
12,0.0013319606783625203,0.16566794978843974,6.807524444540914e-07,0.0,0.0,0.0,0.0,0.0,0.0013319606783625203,320,40,42.627525091171265,1.0656881272792815,0.1332110159099102,0.10023473438341171,0.0013272355950903147,0.5792492911004956,6.966086060211652e-08,0.0,0.0,0.0,0.0,0.0,0.0013272355950903147,80,10,9.216788053512573,0.9216788053512573,0.11520985066890717,0.07118493653833866
|
15 |
+
13,0.0007021169698418816,0.11813758928258485,2.759127278142287e-06,0.0,0.0,0.0,0.0,0.0,0.0007021169698418816,320,40,42.255826234817505,1.0563956558704377,0.1320494569838047,0.09224181645549834,0.0008028386626392602,0.06161341504857774,6.375153506842091e-07,0.0,0.0,0.0,0.0,0.0,0.0008028386626392602,80,10,9.26439380645752,0.9264393806457519,0.11580492258071899,0.0892744664568454
|
16 |
+
14,0.0006613305880819098,0.18104081245551243,6.271647721827747e-07,0.0,0.0,0.0,0.0,0.0,0.0006613305880819098,320,40,42.603567361831665,1.0650891840457917,0.13313614800572396,0.09589193011634052,0.0005620345647912473,0.01514620759198806,4.4228354818542924e-07,0.0,0.0,0.0,0.0,0.0,0.0005620345647912473,80,10,9.197651863098145,0.9197651863098144,0.1149706482887268,0.09222434270195663
|
17 |
+
15,0.0007842243476261501,0.3221512252403457,1.0985442627384213e-06,0.0,0.0,0.0,0.0,0.0,0.0007842243476261501,320,40,42.46329689025879,1.0615824222564698,0.13269780278205873,0.09709920620080084,0.001301504473667592,2.9538385085063057,4.959147403504893e-07,0.0,0.0,0.0,0.0,0.0,0.001301504473667592,80,10,9.24899411201477,0.924899411201477,0.11561242640018463,0.1033931726939045
|
18 |
+
16,0.013570401900506113,0.36384937064023576,0.0018209778673778428,0.0,0.0,0.0,0.0,0.0,0.013570401900506113,320,40,42.59553337097168,1.064888334274292,0.1331110417842865,0.12845125668682159,0.16798840463161469,0.3928926819935441,0.06907801991328597,0.0,0.0,0.0,0.0,0.0,0.16798840463161469,80,10,9.341678619384766,0.9341678619384766,0.11677098274230957,0.28858067095279694
|
19 |
+
17,0.22693241573870182,0.459286569285905,0.10300876491237432,0.0,0.0,0.0,0.0,0.0,0.22693241573870182,320,40,42.3380069732666,1.058450174331665,0.13230627179145812,0.32175534069538114,0.4309734970331192,0.6810152728110552,0.33519675582647324,0.0,0.0,0.0,0.0,0.0,0.4309734970331192,80,10,9.211932897567749,0.921193289756775,0.11514916121959687,0.35243880599737165
|
20 |
+
18,0.2022466917289421,0.7070020012586611,0.11807100524520138,0.0,0.0,0.0,0.0,0.0,0.2022466917289421,320,40,42.41661763191223,1.0604154407978057,0.1325519300997257,0.34771558828651905,0.023606129095423967,0.21133080043364316,0.0015780384962681636,0.0,0.0,0.0,0.0,0.0,0.023606129095423967,80,10,9.170459747314453,0.9170459747314453,0.11463074684143067,0.16341671436093747
|
21 |
+
19,0.005479642776481342,0.5328093900557633,8.112759967222604e-05,0.0,0.0,0.0,0.0,0.0,0.005479642776481342,320,40,42.17777228355408,1.0544443070888518,0.13180553838610648,0.1161046927794814,0.0007801399522577412,0.05935813882520051,5.190229413365443e-07,0.0,0.0,0.0,0.0,0.0,0.0007801399522577412,80,10,9.320383310317993,0.9320383310317993,0.11650479137897492,0.0829056172631681
|
22 |
+
20,0.0016211183414270637,0.5920635455369594,2.3483921812016834e-06,0.0,0.0,0.0,0.0,0.0,0.0016211183414270637,320,40,42.64041256904602,1.0660103142261506,0.13325128927826882,0.08777563656913116,0.0008631729724584147,0.09347999086021445,1.679538957422011e-06,0.0,0.0,0.0,0.0,0.0,0.0008631729724584147,80,10,9.24825143814087,0.9248251438140869,0.11560314297676086,0.07839330667629837
|
23 |
+
21,0.0005253879406154737,0.10636353976760801,2.6855408757735233e-07,0.0,0.0,0.0,0.0,0.0,0.0005253879406154737,320,40,42.83327293395996,1.070831823348999,0.13385397791862488,0.09573199808364734,0.0005467957133078016,0.23533951704739592,2.0549961012861218e-07,0.0,0.0,0.0,0.0,0.0,0.0005467957133078016,80,10,9.286863803863525,0.9286863803863525,0.11608579754829407,0.08401567195542156
|
24 |
+
22,0.0006425162649065896,0.10582681066216537,6.702316121443182e-07,0.0,0.0,0.0,0.0,0.0,0.0006425162649065896,320,40,42.88295650482178,1.0720739126205445,0.13400923907756807,0.09285907801240682,0.003129586172872223,0.14476592368632737,2.8466294405604663e-05,0.0,0.0,0.0,0.0,0.0,0.003129586172872223,80,10,9.417258024215698,0.9417258024215698,0.11771572530269622,0.10630810875445604
|
25 |
+
23,0.0012267698623873002,0.3422558290479515,5.092052370217958e-06,0.0,0.0,0.0,0.0,0.0,0.0012267698623873002,320,40,42.59943246841431,1.0649858117103577,0.1331232264637947,0.0916012367233634,0.001950129849865334,1.6070946810563327,1.2394382595015685e-05,0.0,0.0,0.0,0.0,0.0,0.001950129849865334,80,10,9.218074083328247,0.9218074083328247,0.11522592604160309,0.06952399904839694
|
26 |
+
24,0.0015580077400954907,0.5951609104130398,1.101507371748138e-05,0.0,0.0,0.0,0.0,0.0,0.0015580077400954907,320,40,42.20673155784607,1.0551682889461518,0.13189603611826897,0.08924343169201165,0.0014282745076343417,1.1416928244575502,5.442704249958296e-07,0.0,0.0,0.0,0.0,0.0,0.0014282745076343417,80,10,9.161205768585205,0.9161205768585206,0.11451507210731507,0.08108144407160581
|
27 |
+
25,0.001158999443759967,0.673237547548581,1.6942943679773905e-06,0.0,0.0,0.0,0.0,0.0,0.001158999443759967,320,40,42.81976580619812,1.070494145154953,0.13381176814436913,0.09180414919974282,0.003335691889151349,0.1234515317961268,9.410348545632954e-05,0.0,0.0,0.0,0.0,0.0,0.003335691889151349,80,10,9.541585206985474,0.9541585206985473,0.11926981508731842,0.10400733416900039
|
insurance/lct_gan/mlu-eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
insurance/lct_gan/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:707ced24eb7135366f3aa72529a23efe2f8ab8103b2b6ab6bb57b1624efbb5ff
|
3 |
+
size 38580983
|
insurance/lct_gan/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.77, "head_final_mul": "identity", "gradient_penalty_mode": "NONE", "synth_data": 2, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "loss_balancer_beta": 0.75, "loss_balancer_r": 0.95, "fixed_role_model": "lct_gan", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": false, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "relu6", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "rrelu", "head_activation_final": "softsign", "models": ["lct_gan"], "max_seconds": 3600}
|
insurance/realtabformer/eval.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
|
2 |
+
realtabformer,0.014167995679946173,0.02218004682639329,0.0011680017544638207,1.6954426765441895,0.22907589375972748,3.489088535308838,0.4298049509525299,1.9553140191419516e-06,2.115530014038086,0.024390142410993576,0.45029416680336,0.034176040440797806,0.16980427503585815,0.00019665222498588264,3.8109726905822754
|
insurance/realtabformer/history.csv
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
|
2 |
+
0,0.06551592258038,3.8791051520893234,0.03416722755192225,0.0,0.0,0.0,0.0,0.0,0.06551592258038,320,40,41.27148151397705,1.0317870378494263,0.12897337973117828,0.12673317359294742,0.0045693422085605565,0.18476937365267077,6.355656341838767e-05,0.0,0.0,0.0,0.0,0.0,0.0045693422085605565,80,10,8.436395406723022,0.8436395406723023,0.10545494258403779,0.08324137404561043
|
3 |
+
1,0.0037559630061878126,1.925213829313543,3.748016030745149e-05,0.0,0.0,0.0,0.0,0.0,0.0037559630061878126,320,40,40.829859256744385,1.0207464814186096,0.1275933101773262,0.09227103746379725,0.0010763755642983596,0.15519529518205671,2.772719910471011e-06,0.0,0.0,0.0,0.0,0.0,0.0010763755642983596,80,10,8.419125080108643,0.8419125080108643,0.10523906350135803,0.09969702027738095
|
4 |
+
2,0.0023816580101993167,3.9052163640380853,1.0410949786183142e-05,0.0,0.0,0.0,0.0,0.0,0.0023816580101993167,320,40,41.17148947715759,1.0292872369289399,0.12866090461611748,0.08443290112772957,0.010231703845784068,24.354415035247804,2.082776591478819e-05,0.0,0.0,0.0,0.0,0.0,0.010231703845784068,80,10,8.513915061950684,0.8513915061950683,0.10642393827438354,0.01369248509290628
|
5 |
+
3,0.005280107025464531,12.423834150836047,1.5615434550020346e-05,0.0,0.0,0.0,0.0,0.0,0.005280107025464531,320,40,41.08780074119568,1.027195018529892,0.1283993773162365,0.06296185727987905,0.0026845919943298212,1.0221173237751728,4.954857529959611e-06,0.0,0.0,0.0,0.0,0.0,0.0026845919943298212,80,10,8.557486295700073,0.8557486295700073,0.10696857869625091,0.06754178307019174
|
6 |
+
4,0.001628522769169649,0.9669525724625203,2.4910537018897618e-06,0.0,0.0,0.0,0.0,0.0,0.001628522769169649,320,40,41.08520555496216,1.027130138874054,0.12839126735925674,0.08523798966780305,0.0011392909364076331,0.8447293579599318,8.642555209048553e-07,0.0,0.0,0.0,0.0,0.0,0.0011392909364076331,80,10,8.54101276397705,0.8541012763977051,0.10676265954971313,0.07724661021493376
|
7 |
+
5,0.0006698742679873248,0.5631417240535356,1.3532459000285823e-07,0.0,0.0,0.0,0.0,0.0,0.0006698742679873248,320,40,41.17172598838806,1.0292931497097015,0.1286616437137127,0.09377104000886902,0.00028418309084372596,0.9654686861199593,6.203523792436271e-08,0.0,0.0,0.0,0.0,0.0,0.00028418309084372596,80,10,8.475411891937256,0.8475411891937256,0.1059426486492157,0.08174715298227966
|
8 |
+
6,0.00026526139699853955,0.0404354411696886,5.110972021091231e-08,0.0,0.0,0.0,0.0,0.0,0.00026526139699853955,320,40,41.210866928100586,1.0302716732025146,0.12878395915031432,0.1002270121127367,0.0003043986107513774,0.7276606579284817,4.158757311856221e-08,0.0,0.0,0.0,0.0,0.0,0.0003043986107513774,80,10,8.365734815597534,0.8365734815597534,0.10457168519496918,0.08551097614690661
|
9 |
+
7,0.00033921911108336644,0.04215667733975863,2.0074933999580934e-07,0.0,0.0,0.0,0.0,0.0,0.00033921911108336644,320,40,41.27673935890198,1.0319184839725495,0.12898981049656869,0.09141667010262608,0.0003641421761130914,2.5711033316561953,8.640791372971357e-08,0.0,0.0,0.0,0.0,0.0,0.0003641421761130914,80,10,8.418156147003174,0.8418156147003174,0.10522695183753968,0.07711024282034487
|
10 |
+
8,0.00027859737192557075,0.6936592234017018,1.7161798243723202e-08,0.0,0.0,0.0,0.0,0.0,0.00027859737192557075,320,40,40.9585645198822,1.0239641129970551,0.1279955141246319,0.09465919948415831,0.00026416685177537146,2.1159255215665325,2.42559791252539e-08,0.0,0.0,0.0,0.0,0.0,0.00026416685177537146,80,10,8.478416442871094,0.8478416442871094,0.10598020553588867,0.07725258702412248
|
11 |
+
9,0.00025029900834852017,0.03145681113393835,3.047866507614738e-08,0.0,0.0,0.0,0.0,0.0,0.00025029900834852017,320,40,40.887590169906616,1.0221897542476654,0.12777371928095818,0.0979282318148762,0.00020534966315608472,2.1954285900741297,7.391519909029712e-09,0.0,0.0,0.0,0.0,0.0,0.00020534966315608472,80,10,8.422897815704346,0.8422897815704345,0.10528622269630432,0.08292618948034942
|
12 |
+
10,0.00018911582246801119,0.10153866822858788,1.738624569006149e-08,0.0,0.0,0.0,0.0,0.0,0.00018911582246801119,320,40,41.04085445404053,1.0260213613510132,0.12825267016887665,0.09834078068379312,0.00036368721775943414,2.7509145542862825,1.876272959222547e-08,0.0,0.0,0.0,0.0,0.0,0.00036368721775943414,80,10,8.390719890594482,0.8390719890594482,0.10488399863243103,0.07442375177051871
|
13 |
+
11,0.00028841259018008714,0.5210410886732475,1.396120100700081e-07,0.0,0.0,0.0,0.0,0.0,0.00028841259018008714,320,40,40.944926261901855,1.0236231565475464,0.1279528945684433,0.09693868652684615,0.0009965902085241397,2.239691164344549,6.967871311047701e-08,0.0,0.0,0.0,0.0,0.0,0.0009965902085241397,80,10,8.430182218551636,0.8430182218551636,0.10537727773189545,0.06215766463428736
|
14 |
+
12,0.0009654337161919102,0.6124627925846198,8.020637164636666e-07,0.0,0.0,0.0,0.0,0.0,0.0009654337161919102,320,40,41.00464582443237,1.0251161456108093,0.12813951820135117,0.09093999108299614,0.0008928512179409154,1.408968701583035,4.1629629343731264e-08,0.0,0.0,0.0,0.0,0.0,0.0008928512179409154,80,10,8.480279445648193,0.8480279445648193,0.10600349307060242,0.06699345875531434
|
insurance/realtabformer/mlu-eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
insurance/realtabformer/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d22b595a7793ec36a90bb1106eecd973db22fe0139551b119854e40360fd7e7
|
3 |
+
size 43505805
|
insurance/realtabformer/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.7, "head_final_mul": "identity", "gradient_penalty_mode": "NONE", "synth_data": 2, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "loss_balancer_beta": 0.79, "loss_balancer_r": 0.95, "fixed_role_model": "realtabformer", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardsigmoid", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "prelu", "head_activation_final": "softsign", "models": ["realtabformer"], "max_seconds": 3600}
|
insurance/tab_ddpm_concat/eval.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
|
2 |
+
tab_ddpm_concat,5.952381614060002e-08,0.609262997868067,0.01993643540660279,0.559147834777832,0.19419053196907043,0.9970712065696716,0.2823074758052826,1.8548176740296185e-05,0.8766729831695557,0.0972040519118309,0.7692358493804932,0.14119644463062286,0.053181055933237076,0.7849618196487427,1.4358208179473877
|
insurance/tab_ddpm_concat/history.csv
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
|
2 |
+
0,0.0265493107464863,9.705848431753656,0.0019637997826472465,0.0,0.0,0.0,0.0,0.0,0.0265493107464863,320,40,39.08715486526489,0.9771788716316223,0.12214735895395279,0.04609664692543447,0.012864274116873275,8.93672634124523,3.463389237516879e-05,0.0,0.0,0.0,0.0,0.0,0.012864274116873275,80,10,8.234524965286255,0.8234524965286255,0.10293156206607819,0.023089123656973243
|
3 |
+
1,0.013430703204357996,10.238072396071818,0.0001760885078965657,0.0,0.0,0.0,0.0,0.0,0.013430703204357996,320,40,38.923088788986206,0.9730772197246551,0.12163465246558189,0.027457697270438074,0.01386686596670188,9.424022936335371,5.71949209714262e-05,0.0,0.0,0.0,0.0,0.0,0.01386686596670188,80,10,8.236119270324707,0.8236119270324707,0.10295149087905883,0.019944945629686118
|
4 |
+
2,0.013098158335196786,6.953670260656827,7.627181049958409e-05,0.0,0.0,0.0,0.0,0.0,0.013098158335196786,320,40,38.896809816360474,0.9724202454090118,0.12155253067612648,0.03701225146651268,0.011231413613131735,4.642900250397725,1.232088975626766e-05,0.0,0.0,0.0,0.0,0.0,0.011231413613131735,80,10,8.272239923477173,0.8272239923477173,0.10340299904346466,0.031016640178859235
|
5 |
+
3,0.013012661421089432,6.77741541211999,0.00014677781123761946,0.0,0.0,0.0,0.0,0.0,0.013012661421089432,320,40,39.03108096122742,0.9757770240306854,0.12197212800383568,0.040795679786242545,0.010680149483960122,5.439762359634369,8.51207419643174e-06,0.0,0.0,0.0,0.0,0.0,0.010680149483960122,80,10,8.236795425415039,0.8236795425415039,0.10295994281768799,0.02782872337847948
|
6 |
+
4,0.012592662169481628,6.8064604322151805,0.00012719917820476213,0.0,0.0,0.0,0.0,0.0,0.012592662169481628,320,40,38.966336727142334,0.9741584181785583,0.12176980227231979,0.03671876427251845,0.012881963208201341,16.157494982505522,0.00010115250418607502,0.0,0.0,0.0,0.0,0.0,0.012881963208201341,80,10,8.331452369689941,0.8331452369689941,0.10414315462112426,0.012491705431602895
|
7 |
+
5,0.013670370759791694,10.748200260194086,0.0001568969438597634,0.0,0.0,0.0,0.0,0.0,0.013670370759791694,320,40,38.94208788871765,0.9735521972179413,0.12169402465224266,0.029897483938839287,0.014085652580251917,22.363185199221174,0.00020219407759825003,0.0,0.0,0.0,0.0,0.0,0.014085652580251917,80,10,8.2787184715271,0.82787184715271,0.10348398089408875,0.009641142934560776
|
8 |
+
6,0.014017040852922946,10.649183725507465,0.00013577813718335108,0.0,0.0,0.0,0.0,0.0,0.014017040852922946,320,40,38.94879508018494,0.9737198770046234,0.12171498462557792,0.028363983915187418,0.01068424858385697,3.8434145080467714,1.0552424407705985e-05,0.0,0.0,0.0,0.0,0.0,0.01068424858385697,80,10,8.305310726165771,0.8305310726165771,0.10381638407707214,0.03533868733793497
|
9 |
+
7,0.011766438081394881,8.660977102358947,8.090821406305792e-05,0.0,0.0,0.0,0.0,0.0,0.011766438081394881,320,40,38.78416681289673,0.9696041703224182,0.12120052129030227,0.04158601735252887,0.012133054883452132,20.211999930033198,2.2262640635517526e-05,0.0,0.0,0.0,0.0,0.0,0.012133054883452132,80,10,8.369733810424805,0.8369733810424804,0.10462167263031005,0.010681234044022858
|
10 |
+
8,0.012191647826693953,7.005204355998285,9.821474643096905e-05,0.0,0.0,0.0,0.0,0.0,0.012191647826693953,320,40,38.88823890686035,0.9722059726715088,0.1215257465839386,0.03872000898700208,0.014966235030442476,9.767283525761012,0.0001517352883070089,0.0,0.0,0.0,0.0,0.0,0.014966235030442476,80,10,8.22826075553894,0.8228260755538941,0.10285325944423676,0.01799462023191154
|
11 |
+
9,0.012526353562134319,6.590273188782885,7.7691583878714e-05,0.0,0.0,0.0,0.0,0.0,0.012526353562134319,320,40,38.93899869918823,0.9734749674797059,0.12168437093496323,0.03674360387958586,0.012331876624375581,18.443907407086634,6.805989072731223e-05,0.0,0.0,0.0,0.0,0.0,0.012331876624375581,80,10,8.421772003173828,0.8421772003173829,0.10527215003967286,0.01039172657765448
|
12 |
+
10,0.012064280622871593,9.317451603279006,3.690295125249321e-05,0.0,0.0,0.0,0.0,0.0,0.012064280622871593,320,40,39.005112171173096,0.9751278042793274,0.12189097553491593,0.0359303968725726,0.01261272220290266,10.194672084533522,5.446935015456234e-05,0.0,0.0,0.0,0.0,0.0,0.01261272220290266,80,10,8.333169937133789,0.8333169937133789,0.10416462421417236,0.01722581619396806
|
13 |
+
11,0.012482693148194812,8.178162423045615,9.780007754767173e-05,0.0,0.0,0.0,0.0,0.0,0.012482693148194812,320,40,38.96896147727966,0.9742240369319916,0.12177800461649894,0.03824995262548327,0.012514100689440966,19.314230701327325,7.543949816977147e-05,0.0,0.0,0.0,0.0,0.0,0.012514100689440966,80,10,8.239241361618042,0.8239241361618042,0.10299051702022552,0.009454242698848248
|
14 |
+
12,0.01332451379566919,10.310542043212262,0.0003665929893701819,0.0,0.0,0.0,0.0,0.0,0.01332451379566919,320,40,39.00809144973755,0.9752022862434387,0.12190028578042984,0.027350465022027492,0.010987071882118471,4.729085849918556,8.189743033426566e-06,0.0,0.0,0.0,0.0,0.0,0.010987071882118471,80,10,8.261511325836182,0.8261511325836182,0.10326889157295227,0.03069485481828451
|
15 |
+
13,0.013592794616124592,7.457387926033698,0.000220215535729551,0.0,0.0,0.0,0.0,0.0,0.013592794616124592,320,40,38.93746519088745,0.9734366297721863,0.12167957872152328,0.03546805907972157,0.011548876191955059,6.165951245542237,1.5450504935188292e-05,0.0,0.0,0.0,0.0,0.0,0.011548876191955059,80,10,8.323935985565186,0.8323935985565185,0.10404919981956481,0.026838560402393342
|
16 |
+
14,0.013447031378746033,8.19890535405798,0.00016373289685844838,0.0,0.0,0.0,0.0,0.0,0.013447031378746033,320,40,38.8496150970459,0.9712403774261474,0.12140504717826843,0.029747568373568355,0.011828925088047981,5.351523938098455,2.9337766557091526e-05,0.0,0.0,0.0,0.0,0.0,0.011828925088047981,80,10,8.288572311401367,0.8288572311401368,0.1036071538925171,0.0307698548771441
|
17 |
+
15,0.01384369531297125,8.665561918970889,0.00016956335028766033,0.0,0.0,0.0,0.0,0.0,0.01384369531297125,320,40,38.970547676086426,0.9742636919021607,0.12178296148777008,0.03315324831055477,0.01150583740673028,6.560287872780464,1.5260961676233363e-05,0.0,0.0,0.0,0.0,0.0,0.01150583740673028,80,10,8.322679042816162,0.8322679042816162,0.10403348803520203,0.0259027692489326
|
18 |
+
16,0.012172109389211982,7.0008499470219245,7.735751830111326e-05,0.0,0.0,0.0,0.0,0.0,0.012172109389211982,320,40,39.07010316848755,0.9767525792121887,0.12209407240152359,0.03525363316293806,0.012191956081369425,7.897130101547532,2.3637969795231585e-05,0.0,0.0,0.0,0.0,0.0,0.012191956081369425,80,10,8.288463592529297,0.8288463592529297,0.10360579490661621,0.022672764584422113
|
19 |
+
17,0.012383807837613859,4.774037581340053,0.00011500121783720729,0.0,0.0,0.0,0.0,0.0,0.012383807837613859,320,40,38.94200682640076,0.9735501706600189,0.12169377133250237,0.04339534998871386,0.012275835702894256,9.627914267603774,4.899254280417153e-05,0.0,0.0,0.0,0.0,0.0,0.012275835702894256,80,10,8.323761701583862,0.8323761701583863,0.10404702126979828,0.018597377510741354
|
insurance/tab_ddpm_concat/mlu-eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
insurance/tab_ddpm_concat/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11257d861ce29a41abfcdf4cdcbf17964078480c9fc41b54d90eec20c5e9e4e8
|
3 |
+
size 38511671
|
insurance/tab_ddpm_concat/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.77, "head_final_mul": "identity", "gradient_penalty_mode": "NONE", "synth_data": 2, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "loss_balancer_beta": 0.75, "loss_balancer_r": 0.95, "fixed_role_model": "tab_ddpm_concat", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": false, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "relu6", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "rrelu", "head_activation_final": "softsign", "models": ["tab_ddpm_concat"], "max_seconds": 3600}
|
insurance/tvae/eval.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
|
2 |
+
tvae,0.13478357570810726,0.03636767101219685,0.000275009540043874,0.5683860778808594,0.018824299797415733,0.6825659275054932,0.0344335213303566,1.3921320984877639e-08,0.8837041854858398,0.01289679016917944,0.1385168433189392,0.016583411023020744,0.15040378272533417,0.0008385563851334155,1.4520902633666992
|
insurance/tvae/history.csv
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
|
2 |
+
0,0.05458167113538366,4.561985811768864,0.023550471702759836,0.0,0.0,0.0,0.0,0.0,0.05458167113538366,320,40,39.33485436439514,0.9833713591098785,0.12292141988873481,0.12334999229060487,0.01106511988909915,7.3775769050087545,0.00039846873109325995,0.0,0.0,0.0,0.0,0.0,0.01106511988909915,80,10,8.30074167251587,0.8300741672515869,0.10375927090644836,0.04176213040482253
|
3 |
+
1,0.010921533098735382,3.7708608118317897,0.0006090835865870864,0.0,0.0,0.0,0.0,0.0,0.010921533098735382,320,40,38.92951965332031,0.9732379913330078,0.12165474891662598,0.07502402040408924,0.002461973318713717,0.2656627141033823,8.382165523856955e-06,0.0,0.0,0.0,0.0,0.0,0.002461973318713717,80,10,8.352109670639038,0.8352109670639039,0.10440137088298798,0.07963283583521844
|
4 |
+
2,0.004752650485897902,4.5005246672456,7.41932757047259e-05,0.0,0.0,0.0,0.0,0.0,0.004752650485897902,320,40,39.09458088874817,0.9773645222187042,0.12217056527733802,0.0816895533236675,0.0009612412060960196,0.23112409779214432,2.9920761611550857e-06,0.0,0.0,0.0,0.0,0.0,0.0009612412060960196,80,10,8.301867723464966,0.8301867723464966,0.10377334654331208,0.08093988439068198
|
5 |
+
3,0.0029934452861198222,1.4091149369219238,3.706777377407988e-05,0.0,0.0,0.0,0.0,0.0,0.0029934452861198222,320,40,39.057528257369995,0.9764382064342498,0.12205477580428123,0.08644149880856275,0.0017080451536457986,0.5054739748910834,2.0698145459556274e-06,0.0,0.0,0.0,0.0,0.0,0.0017080451536457986,80,10,8.39680528640747,0.839680528640747,0.10496006608009338,0.0637943553738296
|
6 |
+
4,0.0022114409464847997,1.4571088086362807,9.073904502000795e-06,0.0,0.0,0.0,0.0,0.0,0.0022114409464847997,320,40,38.94040822982788,0.973510205745697,0.12168877571821213,0.08093992052599788,0.0034676186623983085,0.354912094264597,1.1135635656955855e-05,0.0,0.0,0.0,0.0,0.0,0.0034676186623983085,80,10,8.30032467842102,0.830032467842102,0.10375405848026276,0.10819654231891036
|
7 |
+
5,0.0016322427756676916,0.8344269889868698,2.8054938205387956e-06,0.0,0.0,0.0,0.0,0.0,0.0016322427756676916,320,40,39.14137244224548,0.9785343110561371,0.12231678888201714,0.09135764897800983,0.0034494245337555185,2.7931900787574704,6.050434956339501e-06,0.0,0.0,0.0,0.0,0.0,0.0034494245337555185,80,10,8.366892337799072,0.8366892337799072,0.1045861542224884,0.055338869569823146
|
8 |
+
6,0.002849590677578817,0.8129531741204119,4.8211906484207924e-05,0.0,0.0,0.0,0.0,0.0,0.002849590677578817,320,40,38.968292236328125,0.9742073059082031,0.1217759132385254,0.0901852805633098,0.0025212633569026365,0.6178526908131061,1.4414640320481454e-06,0.0,0.0,0.0,0.0,0.0,0.0025212633569026365,80,10,8.266654014587402,0.8266654014587402,0.10333317518234253,0.05773084256798029
|
9 |
+
7,0.0034268524424987843,1.5629836895840525,1.5161425290398687e-05,0.0,0.0,0.0,0.0,0.0,0.0034268524424987843,320,40,39.01256036758423,0.9753140091896058,0.12191425114870072,0.0829970414401032,0.0014418774226214737,0.05386366389284376,2.4331560492640845e-06,0.0,0.0,0.0,0.0,0.0,0.0014418774226214737,80,10,8.32570481300354,0.832570481300354,0.10407131016254426,0.08880755109712482
|
10 |
+
8,0.0016761758448410546,0.571136603817564,7.011435811053887e-06,0.0,0.0,0.0,0.0,0.0,0.0016761758448410546,320,40,38.852670669555664,0.9713167667388916,0.12141459584236144,0.09045831263065338,0.0006263804327318212,0.24181758030463243,6.295153740953907e-07,0.0,0.0,0.0,0.0,0.0,0.0006263804327318212,80,10,8.345160722732544,0.8345160722732544,0.1043145090341568,0.08191414531320333
|
11 |
+
9,0.0008744017197386711,0.17949836104246067,4.63043962907906e-07,0.0,0.0,0.0,0.0,0.0,0.0008744017197386711,320,40,39.11712980270386,0.9779282450675965,0.12224103063344956,0.09466907754540443,0.0011390350133297033,0.004834387120854444,3.0137808032293377e-06,0.0,0.0,0.0,0.0,0.0,0.0011390350133297033,80,10,8.316069841384888,0.8316069841384888,0.1039508730173111,0.0990539627149701
|
12 |
+
10,0.0004748740824652486,0.1777749692730623,2.3089836414527056e-08,0.0,0.0,0.0,0.0,0.0,0.0004748740824652486,320,40,39.06293201446533,0.9765733003616333,0.12207166254520416,0.09201494687004015,0.00032443252712255344,0.0010629200933180982,3.426863805611191e-07,0.0,0.0,0.0,0.0,0.0,0.00032443252712255344,80,10,8.351998329162598,0.8351998329162598,0.10439997911453247,0.0884638118557632
|
13 |
+
11,0.00030916042924218347,0.04881817966124018,2.0088672352989394e-08,0.0,0.0,0.0,0.0,0.0,0.00030916042924218347,320,40,38.86811137199402,0.9717027842998505,0.12146284803748131,0.10129309091717005,0.00028257269877940416,1.0737754437432159,2.8357685949442768e-08,0.0,0.0,0.0,0.0,0.0,0.00028257269877940416,80,10,8.252684354782104,0.8252684354782105,0.10315855443477631,0.08038602282758803
|
14 |
+
12,0.0013487103491570452,0.43372808683234754,1.0047366970687786e-06,0.0,0.0,0.0,0.0,0.0,0.0013487103491570452,320,40,39.09407997131348,0.9773519992828369,0.12216899991035461,0.0899976636399515,0.003439919964876026,0.015614798056776635,2.0251521429592856e-05,0.0,0.0,0.0,0.0,0.0,0.003439919964876026,80,10,8.342942476272583,0.8342942476272583,0.1042867809534073,0.11240037991665304
|
15 |
+
13,0.0008618889094577753,0.1384840221481113,4.446653539750059e-07,0.0,0.0,0.0,0.0,0.0,0.0008618889094577753,320,40,38.886531829833984,0.9721632957458496,0.1215204119682312,0.09283134532161057,0.000532695987567422,0.6308531300281175,1.360833180625437e-06,0.0,0.0,0.0,0.0,0.0,0.000532695987567422,80,10,8.28925633430481,0.828925633430481,0.10361570417881012,0.0891546759288758
|
16 |
+
14,0.00030356911156559363,0.3619111133062688,5.198837278813596e-08,0.0,0.0,0.0,0.0,0.0,0.00030356911156559363,320,40,38.988784074783325,0.9747196018695832,0.1218399502336979,0.09746413570828735,0.0005432431978988461,0.001004549844947178,2.781989758560144e-07,0.0,0.0,0.0,0.0,0.0,0.0005432431978988461,80,10,8.42578673362732,0.842578673362732,0.1053223341703415,0.09348368076607586
|
17 |
+
15,0.00029625174347529536,0.0572095896306493,6.03426183574306e-08,0.0,0.0,0.0,0.0,0.0,0.00029625174347529536,320,40,39.27268958091736,0.981817239522934,0.12272715494036675,0.09890737304231152,0.00036838351952610536,0.7212186768025276,2.6941624464704718e-08,0.0,0.0,0.0,0.0,0.0,0.00036838351952610536,80,10,8.407346963882446,0.8407346963882446,0.10509183704853058,0.08277125156018883
|
18 |
+
16,0.0005824315209792986,0.32841089839253074,9.46431564320671e-08,0.0,0.0,0.0,0.0,0.0,0.0005824315209792986,320,40,39.24979209899902,0.9812448024749756,0.12265560030937195,0.09515800991794095,0.0007761759276036173,1.2551941490234082,2.9720144345546373e-08,0.0,0.0,0.0,0.0,0.0,0.0007761759276036173,80,10,8.318324089050293,0.8318324089050293,0.10397905111312866,0.09093907248461619
|
19 |
+
17,0.0012332158104982228,0.6590279597393532,8.36102873709775e-06,0.0,0.0,0.0,0.0,0.0,0.0012332158104982228,320,40,39.119892835617065,0.9779973208904267,0.12224966511130334,0.0890660552540794,0.001825959722918924,1.9564546512207017,1.1242688799484313e-06,0.0,0.0,0.0,0.0,0.0,0.001825959722918924,80,10,8.359159708023071,0.8359159708023072,0.1044894963502884,0.06533113070763648
|
20 |
+
18,0.001502491006613127,0.4666562590015076,3.179587947280127e-06,0.0,0.0,0.0,0.0,0.0,0.001502491006613127,320,40,39.09075927734375,0.9772689819335938,0.12215862274169922,0.09002331190858967,0.0008729565364774316,0.20973070683976403,2.998759428507469e-06,0.0,0.0,0.0,0.0,0.0,0.0008729565364774316,80,10,8.317886352539062,0.8317886352539062,0.10397357940673828,0.08415974881500006
|
21 |
+
19,0.001246437881127349,0.6120949116166994,2.0787087329172948e-06,0.0,0.0,0.0,0.0,0.0,0.001246437881127349,320,40,39.05946326255798,0.9764865815639496,0.1220608226954937,0.09171894917380996,0.002248370127927046,4.686978222953622,2.895269359082242e-06,0.0,0.0,0.0,0.0,0.0,0.002248370127927046,80,10,8.319932222366333,0.8319932222366333,0.10399915277957916,0.0715100662317127
|
22 |
+
20,0.0029011325517785736,0.9372176351432528,1.1204305509332328e-05,0.0,0.0,0.0,0.0,0.0,0.0029011325517785736,320,40,38.93023109436035,0.9732557773590088,0.1216569721698761,0.08712862803367898,0.0013807336257741555,2.265311992234274,1.8232192309453056e-06,0.0,0.0,0.0,0.0,0.0,0.0013807336257741555,80,10,8.43259859085083,0.843259859085083,0.10540748238563538,0.07126395150553436
|
23 |
+
21,0.0005831055974340416,0.7094658932399625,5.484142581780628e-08,0.0,0.0,0.0,0.0,0.0,0.0005831055974340416,320,40,38.932344913482666,0.9733086228370667,0.12166357785463333,0.09576541467686184,0.00040745751502981877,0.01960964320030456,8.632079813164495e-08,0.0,0.0,0.0,0.0,0.0,0.00040745751502981877,80,10,8.343457698822021,0.8343457698822021,0.10429322123527526,0.0816122055053711
|
insurance/tvae/mlu-eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
insurance/tvae/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5eada8e849eb7c6638d89cc73f312038c150117726abd4743d18461204c5e8d3
|
3 |
+
size 38609591
|
insurance/tvae/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.77, "head_final_mul": "identity", "gradient_penalty_mode": "NONE", "synth_data": 2, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "loss_balancer_beta": 0.75, "loss_balancer_r": 0.95, "fixed_role_model": "tvae", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": false, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "relu6", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "rrelu", "head_activation_final": "softsign", "models": ["tvae"], "max_seconds": 3600}
|
treatment/lct_gan/eval.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
|
2 |
+
lct_gan,0.0,4.3308387261260425e-08,0.0023591548839308565,4.458594799041748,0.011746696196496487,0.16063837707042694,0.01545888464897871,4.918354079563869e-06,2.3799874782562256,0.03671034052968025,0.06946055591106415,0.048571132123470306,0.07008553296327591,0.010716128163039684,6.838582277297974
|
treatment/lct_gan/history.csv
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
|
2 |
+
0,0.22352998348069378,80.13195664776967,0.09520609854183135,0.0,0.0,0.0,0.0,0.0,0.22352998348069378,320,80,103.12237620353699,1.2890297025442123,0.3222574256360531,0.0988283173581749,0.01245800144970417,1.6089594815347597,0.00019736949793127678,0.0,0.0,0.0,0.0,0.0,0.01245800144970417,80,20,20.080093383789062,1.004004669189453,0.25100116729736327,0.07783476307522505
|
3 |
+
1,0.008753520530444803,0.34713386677235575,0.00011480308697381734,0.0,0.0,0.0,0.0,0.0,0.008753520530444803,320,80,103.45079112052917,1.2931348890066148,0.3232837222516537,0.19368809863808564,0.007529928936855867,3.185711348353652,0.00010870498384889516,0.0,0.0,0.0,0.0,0.0,0.007529928936855867,80,20,20.23315191268921,1.0116575956344604,0.2529143989086151,0.043344746553339066
|
4 |
+
2,0.006802400949891307,0.3522556849198281,0.0001498469549909341,0.0,0.0,0.0,0.0,0.0,0.006802400949891307,320,80,103.24245595932007,1.290530699491501,0.3226326748728752,0.18977850895607845,0.007473361069423845,4.433246247235365,0.00010783077031044641,0.0,0.0,0.0,0.0,0.0,0.007473361069423845,80,20,20.317150354385376,1.0158575177192688,0.2539643794298172,0.043154743919149044
|
5 |
+
3,0.008915021323991823,0.718201418556464,8.08643664615523e-05,0.0,0.0,0.0,0.0,0.0,0.008915021323991823,320,80,103.76086711883545,1.2970108389854431,0.3242527097463608,0.18378724994836376,0.015541119105182587,4.1963203363062345,0.0004786531295351892,0.0,0.0,0.0,0.0,0.0,0.015541119105182587,80,20,20.26967740058899,1.0134838700294495,0.2533709675073624,0.03725434660445899
|
6 |
+
4,0.006931817024451448,0.35298403866354533,7.142922729861737e-05,0.0,0.0,0.0,0.0,0.0,0.006931817024451448,320,80,103.68316864967346,1.2960396081209182,0.32400990203022956,0.1754610677191522,0.006783264396653976,2.084030251805075,0.00010564910719947917,0.0,0.0,0.0,0.0,0.0,0.006783264396653976,80,20,20.532756090164185,1.0266378045082092,0.2566594511270523,0.046784830396063626
|
7 |
+
5,0.005099834372958867,0.0875181610394841,9.274947589943961e-05,0.0,0.0,0.0,0.0,0.0,0.005099834372958867,320,80,103.51723384857178,1.2939654231071471,0.3234913557767868,0.19479809664189815,0.006807428642059676,6.448283100366529,8.50121301514406e-05,0.0,0.0,0.0,0.0,0.0,0.006807428642059676,80,20,20.114439249038696,1.005721962451935,0.2514304906129837,0.04373221881105564
|
8 |
+
6,0.00480109396030457,0.9528829011607514,5.0811379982423e-05,0.0,0.0,0.0,0.0,0.0,0.00480109396030457,320,80,103.59416389465332,1.2949270486831665,0.32373176217079164,0.178886199297267,0.006987621737061999,4.279015872643868,9.801611965674085e-05,0.0,0.0,0.0,0.0,0.0,0.006987621737061999,80,20,20.27844500541687,1.0139222502708436,0.2534805625677109,0.04849170843372121
|
9 |
+
7,0.00464072308905088,0.1356045210827208,6.834771834407436e-05,0.0,0.0,0.0,0.0,0.0,0.00464072308905088,320,80,103.63158965110779,1.2953948706388474,0.32384871765971185,0.19017753867083229,0.006893738606595434,3.8557679186087626,8.831684561805275e-05,0.0,0.0,0.0,0.0,0.0,0.006893738606595434,80,20,20.594661712646484,1.0297330856323241,0.25743327140808103,0.04350378216477111
|
10 |
+
8,0.004308880444841634,0.06846157661166216,5.0194533013741347e-05,0.0,0.0,0.0,0.0,0.0,0.004308880444841634,320,80,103.90516233444214,1.2988145291805266,0.32470363229513166,0.18733386998064816,0.006929469533497467,3.2670128452096834,7.968755999527843e-05,0.0,0.0,0.0,0.0,0.0,0.006929469533497467,80,20,20.433101177215576,1.021655058860779,0.2554137647151947,0.04181941950228065
|
11 |
+
9,0.004508837422326906,0.08966287379656705,8.609927907704219e-05,0.0,0.0,0.0,0.0,0.0,0.004508837422326906,320,80,103.64076137542725,1.2955095171928406,0.32387737929821014,0.18831826079403982,0.0070835084756254215,4.297422426286471,0.00011362928610676448,0.0,0.0,0.0,0.0,0.0,0.0070835084756254215,80,20,20.304687023162842,1.015234351158142,0.2538085877895355,0.0519675396499224
|
12 |
+
10,0.004514601970731747,0.1381884859496653,8.4598984369378e-05,0.0,0.0,0.0,0.0,0.0,0.004514601970731747,320,80,103.90474796295166,1.2988093495368958,0.32470233738422394,0.17621430779545336,0.005505984234332573,5.328735202133521,3.370500902892815e-05,0.0,0.0,0.0,0.0,0.0,0.005505984234332573,80,20,20.54746174812317,1.0273730874061584,0.2568432718515396,0.051376725709997115
|
13 |
+
11,0.004419548849091371,0.07128540304256603,6.623427232033962e-05,0.0,0.0,0.0,0.0,0.0,0.004419548849091371,320,80,103.64773535728455,1.2955966919660569,0.3238991729915142,0.18890725779347123,0.005045573477400467,4.834901636225572,2.2054046640818116e-05,0.0,0.0,0.0,0.0,0.0,0.005045573477400467,80,20,20.591118812561035,1.0295559406280517,0.2573889851570129,0.051232723612338306
|
14 |
+
12,0.004236863098185495,0.04387387494761443,8.839865267179564e-05,0.0,0.0,0.0,0.0,0.0,0.004236863098185495,320,80,103.94386696815491,1.2992983371019364,0.3248245842754841,0.1935716205276549,0.008129652022034861,5.03432033594964,0.00015811533519221043,0.0,0.0,0.0,0.0,0.0,0.008129652022034861,80,20,20.881214380264282,1.044060719013214,0.2610151797533035,0.04673133364703972
|
15 |
+
13,0.004276435652536747,0.0654336524629164,4.581930034836763e-05,0.0,0.0,0.0,0.0,0.0,0.004276435652536747,320,80,103.00667309761047,1.287583413720131,0.32189585343003274,0.18084844152908772,0.007877110847039149,2.788165700972968,0.0001813338503336759,0.0,0.0,0.0,0.0,0.0,0.007877110847039149,80,20,20.325989723205566,1.0162994861602783,0.2540748715400696,0.04980120111722499
|
16 |
+
14,0.0038373295942619734,0.0791764844770995,4.599695211216274e-05,0.0,0.0,0.0,0.0,0.0,0.0038373295942619734,320,80,103.24337792396545,1.290542224049568,0.322635556012392,0.18313483651727439,0.0050060375331668185,3.8634611237153877,2.9480706338880223e-05,0.0,0.0,0.0,0.0,0.0,0.0050060375331668185,80,20,20.198811054229736,1.0099405527114869,0.2524851381778717,0.058232192660216245
|
17 |
+
15,0.003900589924239739,0.10312648875277333,2.1983545730452913e-05,0.0,0.0,0.0,0.0,0.0,0.003900589924239739,320,80,103.35287404060364,1.2919109255075454,0.32297773137688635,0.19720614301040768,0.009117326361592858,1.760603382944828,0.00027292398581290066,0.0,0.0,0.0,0.0,0.0,0.009117326361592858,80,20,20.001904249191284,1.0000952124595641,0.25002380311489103,0.048848784435540436
|
18 |
+
16,0.0018584069126518442,0.04251637269783757,1.0266101569923053e-05,0.0,0.0,0.0,0.0,0.0,0.0018584069126518442,320,80,102.94779467582703,1.286847433447838,0.3217118583619595,0.18552915730979294,0.007703973473689984,1.3023191384279245,0.00017489787961899594,0.0,0.0,0.0,0.0,0.0,0.007703973473689984,80,20,20.37928342819214,1.018964171409607,0.25474104285240173,0.05299922423437238
|
19 |
+
17,0.0008170823710770492,0.018880133842297652,5.023515140430146e-07,0.0,0.0,0.0,0.0,0.0,0.0008170823710770492,320,80,102.69966387748718,1.2837457984685898,0.32093644961714746,0.1925133554963395,0.00931795308351866,2.071352872970965,0.0002587945756321958,0.0,0.0,0.0,0.0,0.0,0.00931795308351866,80,20,19.96341824531555,0.9981709122657776,0.2495427280664444,0.053129641944542526
|
20 |
+
18,0.0003576292982074847,0.012536979420885785,4.498594921749366e-08,0.0,0.0,0.0,0.0,0.0,0.0003576292982074847,320,80,100.82452273368835,1.2603065341711044,0.3150766335427761,0.19631691183894873,0.00876514861229225,1.4645986258908124,0.0002133372895583463,0.0,0.0,0.0,0.0,0.0,0.00876514861229225,80,20,19.451266765594482,0.9725633382797241,0.24314083456993102,0.05004617176018655
|
21 |
+
19,0.0002510753680212474,0.008272775144363465,3.3253448800756014e-08,0.0,0.0,0.0,0.0,0.0,0.0002510753680212474,320,80,102.30520558357239,1.2788150697946548,0.3197037674486637,0.19034422542899848,0.00618170693560387,1.5110018466951716,7.781338554512241e-05,0.0,0.0,0.0,0.0,0.0,0.00618170693560387,80,20,19.740620136260986,0.9870310068130493,0.24675775170326233,0.055138330021873114
|
22 |
+
20,0.00028350398017664704,0.019687367545015277,2.482109546082703e-07,0.0,0.0,0.0,0.0,0.0,0.00028350398017664704,320,80,99.86743569374084,1.2483429461717606,0.31208573654294014,0.1864254915737547,0.007587061779486248,1.3006179411045196,0.00014909935981792798,0.0,0.0,0.0,0.0,0.0,0.007587061779486248,80,20,20.21121335029602,1.0105606675148011,0.2526401668787003,0.05435547353699803
|
23 |
+
21,0.00025932476952164054,0.007903821782640907,9.45318477345975e-10,0.0,0.0,0.0,0.0,0.0,0.00025932476952164054,320,80,103.72837948799133,1.2966047435998918,0.32415118589997294,0.18518321572337298,0.006430168935912662,1.9212478918598208,9.073310130256474e-05,0.0,0.0,0.0,0.0,0.0,0.006430168935912662,80,20,19.81023097038269,0.9905115485191345,0.24762788712978362,0.05623150994069874
|
24 |
+
22,0.00033484522286357786,0.012371280277519502,1.6720436467560026e-08,0.0,0.0,0.0,0.0,0.0,0.00033484522286357786,320,80,102.39797282218933,1.2799746602773667,0.3199936650693417,0.1927722441148944,0.0066237156010174655,1.626585018528567,9.447487505163111e-05,0.0,0.0,0.0,0.0,0.0,0.0066237156010174655,80,20,20.115633487701416,1.0057816743850707,0.2514454185962677,0.054200840881094337
|
25 |
+
23,0.00015932412852635026,0.025226200853674642,3.639718875007858e-08,0.0,0.0,0.0,0.0,0.0,0.00015932412852635026,320,80,103.57392120361328,1.294674015045166,0.3236685037612915,0.1952298643416725,0.0065823450138850605,1.5178716897570212,9.93379512048212e-05,0.0,0.0,0.0,0.0,0.0,0.0065823450138850605,80,20,21.15859341621399,1.0579296708106996,0.2644824177026749,0.0543066727463156
|
26 |
+
24,0.00012041164002258853,0.006688666018078895,1.577604157902094e-08,0.0,0.0,0.0,0.0,0.0,0.00012041164002258853,320,80,107.04027843475342,1.3380034804344176,0.3345008701086044,0.18686838666908442,0.007150729962449987,1.6130354540884582,0.00011762036998774761,0.0,0.0,0.0,0.0,0.0,0.007150729962449987,80,20,21.25710916519165,1.0628554582595826,0.26571386456489565,0.050873439060524106
|
27 |
+
25,8.769563716697349e-05,0.005887569199180521,3.494414894220867e-09,0.0,0.0,0.0,0.0,0.0,8.769563716697349e-05,320,80,106.57941937446594,1.3322427421808243,0.33306068554520607,0.18599217470618895,0.00832268671510974,1.4815574481464182,0.00019259734704257792,0.0,0.0,0.0,0.0,0.0,0.00832268671510974,80,20,20.612091302871704,1.0306045651435851,0.2576511412858963,0.05257055321708322
|
28 |
+
26,8.882110219019524e-05,0.002085926350794054,2.3958830939162234e-09,0.0,0.0,0.0,0.0,0.0,8.882110219019524e-05,320,80,106.64562821388245,1.3330703526735306,0.33326758816838264,0.20409600271377712,0.006672654673457146,1.4586342717834213,9.671619951117094e-05,0.0,0.0,0.0,0.0,0.0,0.006672654673457146,80,20,20.21643877029419,1.0108219385147095,0.25270548462867737,0.0529700854793191
|
29 |
+
27,7.043356762892472e-05,0.005430679229713497,4.15760022837944e-10,0.0,0.0,0.0,0.0,0.0,7.043356762892472e-05,320,80,102.27982902526855,1.278497862815857,0.31962446570396424,0.2028546938439831,0.006019308035320137,1.3263217964900833,7.28448786667002e-05,0.0,0.0,0.0,0.0,0.0,0.006019308035320137,80,20,20.296292066574097,1.0148146033287049,0.2537036508321762,0.05590685121715069
|
treatment/lct_gan/mlu-eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
treatment/lct_gan/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3fa44e8b1c64400dcb7cbae969a7e28bf2478f52136ab2aacaa8dd8cf8014335
|
3 |
+
size 74778241
|
treatment/lct_gan/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "torch", "grad_clip": 0.8, "gradient_penalty_mode": "NONE", "synth_data": 2, "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.04, "n_warmup_steps": 220, "Optim": "diffgrad", "loss_balancer_beta": 0.73, "loss_balancer_r": 0.94, "fixed_role_model": "lct_gan", "mse_mag": false, "mse_mag_target": 0.2, "mse_mag_multiply": false, "d_model": 512, "attn_activation": "leakyhardsigmoid", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardtanh", "tf_activation_final": "leakyhardtanh", "tf_num_inds": 64, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "selu", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 128, "head_n_layers": 8, "head_n_head": 64, "head_activation": "leakyhardsigmoid", "head_activation_final": "leakyhardsigmoid", "models": ["lct_gan"], "max_seconds": 3600}
|
treatment/realtabformer/eval.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
|
2 |
+
realtabformer,0.0,,0.0015596782029655273,2.3078653812408447,0.30178436636924744,5.190309524536133,0.41850993037223816,5.543264705920592e-06,10.77073621749878,0.02881322056055069,0.054603252559900284,0.039492759853601456,0.07855503261089325,8.974096999736503e-05,13.078601598739624
|
treatment/realtabformer/history.csv
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
|
2 |
+
0,0.03644084440955699,6.612403465312778,0.0045574880142667225,0.0,0.0,0.0,0.0,0.0,0.03644084440955699,320,160,166.3736925125122,1.0398355782032014,0.5199177891016007,0.07979056920851804,0.007586553805595031,3.274879661123663,0.00010986064107032512,0.0,0.0,0.0,0.0,0.0,0.007586553805595031,80,40,34.049020767211914,0.8512255191802979,0.42561275959014894,0.04524085290104267
|
3 |
+
1,0.010942080118709896,3.6911100070230938,0.00037560693868070696,0.0,0.0,0.0,0.0,0.0,0.010942080118709896,320,160,166.92550325393677,1.0432843953371047,0.5216421976685524,0.14872968647239873,0.026652724126324755,7.903173122670358,0.0026581332727844825,0.0,0.0,0.0,0.0,0.0,0.026652724126324755,80,40,33.64800572395325,0.8412001430988312,0.4206000715494156,0.02281103633413295
|
4 |
+
2,0.010455712701009312,3.9791094253597508,0.0003505937737199728,0.0,0.0,0.0,0.0,0.0,0.010455712701009312,320,160,166.77985906600952,1.0423741191625595,0.5211870595812798,0.15552815834055309,0.012862359535210999,5.169579194596042,0.000530584767799982,0.0,0.0,0.0,0.0,0.0,0.012862359535210999,80,40,33.91816258430481,0.8479540646076202,0.4239770323038101,0.026617402605188543
|
5 |
+
3,0.007001378159810656,2.796910241492641,0.00011505681204335635,0.0,0.0,0.0,0.0,0.0,0.007001378159810656,320,160,166.83584594726562,1.04272403717041,0.521362018585205,0.13855754688775052,0.004590427323637414,6.018633032932655,3.031255504705524e-05,0.0,0.0,0.0,0.0,0.0,0.004590427323637414,80,40,33.972458839416504,0.8493114709854126,0.4246557354927063,0.04482766297978742
|
6 |
+
4,0.0072793409375947245,2.296859505424894,0.0001573822781525827,0.0,0.0,0.0,0.0,0.0,0.0072793409375947245,320,160,167.01165390014648,1.0438228368759155,0.5219114184379577,0.16224303948229135,0.004423623792263243,6.739300958566085,6.933249757632432e-05,0.0,0.0,0.0,0.0,0.0,0.004423623792263243,80,40,33.651365756988525,0.8412841439247132,0.4206420719623566,0.026270963538991055
|
7 |
+
5,0.006381618711196779,2.8839405857504543,0.00012163765120135503,0.0,0.0,0.0,0.0,0.0,0.006381618711196779,320,160,170.09552717208862,1.063097044825554,0.531548522412777,0.14938008590495427,0.0050863039046817,7.422684189675602,4.899889018377124e-05,0.0,0.0,0.0,0.0,0.0,0.0050863039046817,80,40,37.50921392440796,0.937730348110199,0.4688651740550995,0.029683142538488028
|
8 |
+
6,0.006045459082537263,2.934065179698594,8.634548304524475e-05,0.0,0.0,0.0,0.0,0.0,0.006045459082537263,320,160,173.5575668811798,1.0847347930073739,0.5423673965036869,0.128542572739525,0.004006205599580426,6.11909287295653,2.419293156137453e-05,0.0,0.0,0.0,0.0,0.0,0.004006205599580426,80,40,36.28923416137695,0.9072308540344238,0.4536154270172119,0.02928218668603222
|
9 |
+
7,0.005807553204306259,3.233931762049849,6.728753389705139e-05,0.0,0.0,0.0,0.0,0.0,0.005807553204306259,320,160,173.0330581665039,1.0814566135406494,0.5407283067703247,0.13020459838949136,0.004725078083447442,5.915451907179363,4.036447823969336e-05,0.0,0.0,0.0,0.0,0.0,0.004725078083447442,80,40,35.27379822731018,0.8818449556827546,0.4409224778413773,0.03375284938047116
|
10 |
+
8,0.005597162095671138,2.4687882317225887,9.451023661321612e-05,0.0,0.0,0.0,0.0,0.0,0.005597162095671138,320,160,171.33222126960754,1.070826382935047,0.5354131914675235,0.16350435888944048,0.005341960320765793,4.580611580479113,5.5816291564272924e-05,0.0,0.0,0.0,0.0,0.0,0.005341960320765793,80,40,35.73819422721863,0.8934548556804657,0.44672742784023284,0.028984847072570118
|
11 |
+
9,0.005475803721810735,2.9333888558279755,5.944834627596224e-05,0.0,0.0,0.0,0.0,0.0,0.005475803721810735,320,160,181.32838702201843,1.1333024188876153,0.5666512094438076,0.14286351032374114,0.00505017776886234,5.1855854878382335,4.6662756986093344e-05,0.0,0.0,0.0,0.0,0.0,0.00505017776886234,80,40,39.708540201187134,0.9927135050296784,0.4963567525148392,0.028717920677536313
|
12 |
+
10,0.005619597199086002,2.6044400273167847,0.0001110976432221262,0.0,0.0,0.0,0.0,0.0,0.005619597199086002,320,160,191.31377625465393,1.1957111015915871,0.5978555507957936,0.1449692299744129,0.003595894821683032,4.273874850746305,3.0524735783286906e-05,0.0,0.0,0.0,0.0,0.0,0.003595894821683032,80,40,40.169761657714844,1.004244041442871,0.5021220207214355,0.038896191439198445
|
13 |
+
11,0.005651603202302624,2.1689565299521996,8.184791047085491e-05,0.0,0.0,0.0,0.0,0.0,0.005651603202302624,320,160,191.00746726989746,1.1937966704368592,0.5968983352184296,0.13332524879906488,0.003731331395374582,4.425117476265046,2.9901788849429067e-05,0.0,0.0,0.0,0.0,0.0,0.003731331395374582,80,40,40.135857343673706,1.0033964335918426,0.5016982167959213,0.04076760089155869
|
14 |
+
12,0.005675629577149266,1.5336641537275533,8.70407526315881e-05,0.0,0.0,0.0,0.0,0.0,0.005675629577149266,320,160,191.66054034233093,1.1978783771395682,0.5989391885697841,0.16240290804776122,0.0038316375943395543,4.709503752670197,2.2613220733549987e-05,0.0,0.0,0.0,0.0,0.0,0.0038316375943395543,80,40,40.268686056137085,1.0067171514034272,0.5033585757017136,0.028456049150554462
|
15 |
+
13,0.00614657213177452,2.04597273792412,7.693676020577578e-05,0.0,0.0,0.0,0.0,0.0,0.00614657213177452,320,160,189.26595377922058,1.1829122111201287,0.5914561055600643,0.15553884600512902,0.0042953793235938065,4.08081223766667,3.2513099143405276e-05,0.0,0.0,0.0,0.0,0.0,0.0042953793235938065,80,40,40.096407890319824,1.0024101972579955,0.5012050986289978,0.02533484929444967
|
16 |
+
14,0.005231396525572052,1.73261989282967,7.520394053998295e-05,0.0,0.0,0.0,0.0,0.0,0.005231396525572052,320,160,191.00612378120422,1.1937882736325265,0.5968941368162632,0.1496530410122432,0.003655807935683697,4.01203602381832,2.1951282996873765e-05,0.0,0.0,0.0,0.0,0.0,0.003655807935683697,80,40,40.18678307533264,1.004669576883316,0.502334788441658,0.03820230678429652
|
17 |
+
15,0.005476373057177852,2.6803855865461252,8.069490082747689e-05,0.0,0.0,0.0,0.0,0.0,0.005476373057177852,320,160,191.20353507995605,1.1950220942497254,0.5975110471248627,0.12890588956015564,0.0038699784756772715,5.125400327792415,2.3996304280442248e-05,0.0,0.0,0.0,0.0,0.0,0.0038699784756772715,80,40,39.97018790245056,0.9992546975612641,0.49962734878063203,0.03100545472552767
|
treatment/realtabformer/mlu-eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
treatment/realtabformer/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2575e1d5e72b45904d2984f2ebd0d340d2dcad8b7d40a0cab678e9c61c484c17
|
3 |
+
size 78481207
|
treatment/realtabformer/params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "torch", "grad_clip": 0.8, "gradient_penalty_mode": "NONE", "synth_data": 2, "dataset_size": 2048, "batch_size": 2, "epochs": 100, "lr_mul": 0.04, "n_warmup_steps": 220, "Optim": "diffgrad", "loss_balancer_beta": 0.73, "loss_balancer_r": 0.94, "fixed_role_model": "realtabformer", "mse_mag": false, "mse_mag_target": 0.2, "mse_mag_multiply": false, "d_model": 512, "attn_activation": "leakyhardsigmoid", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardtanh", "tf_activation_final": "leakyhardtanh", "tf_num_inds": 64, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "selu", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 128, "head_n_layers": 8, "head_n_head": 64, "head_activation": "leakyhardsigmoid", "head_activation_final": "leakyhardsigmoid", "models": ["realtabformer"], "max_seconds": 3600}
|