File size: 16,852 Bytes
f71c233 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# Title: Minimal Description Length and Grokking: An Information-Theoretic Perspective on Sudden Generalization # Experiment description: Implement a function estimate_mdl(model) using weight pruning to approximate the model's description length. Prune weights below a threshold and count remaining non-zero weights. Modify the training loop to compute MDL every 500 steps. Run experiments on ModDivisionDataset and PermutationGroup, including a baseline without MDL tracking. Plot MDL estimates alongside validation accuracy. Define the 'MDL transition point' as the step with the steepest decrease in MDL. Compare this point with the grokking point (95% validation accuracy). Analyze the correlation between MDL reduction and improvement in validation accuracy. Compare MDL evolution between grokking and non-grokking (baseline) scenarios. ## Run 0: Baseline Results: {'x_div_y': {'final_train_loss_mean': 0.005800435319542885, 'final_val_loss_mean': 0.006530226518710454, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 4200.0}, 'x_minus_y': {'final_train_loss_mean': 0.014211568981409073, 'final_val_loss_mean': 0.014943961674968401, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 4720.0}, 'x_plus_y': {'final_train_loss_mean': 0.003832749711970488, 'final_val_loss_mean': 0.004045687771091859, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 2363.3333333333335}, 'permutation': {'final_train_loss_mean': 0.08011958096176386, 'final_val_loss_mean': 6.804208914438884, 'final_train_acc_mean': 0.9880208373069763, 'final_val_acc_mean': 0.035888671875, 'step_val_acc_99_mean': 7500.0}} Description: Baseline results. ## Run 1: MDL Tracking Implementation Experiment description: Implement MDL estimation and tracking for ModDivisionDataset and PermutationGroup. Modify the training loop to compute MDL every 500 steps. Plot MDL estimates alongside validation accuracy. Results: {'x_div_y': {'final_train_loss_mean': 0.1435996194680532, 'final_val_loss_mean': 0.08725565796097119, 'final_train_acc_mean': 0.9878255327542623, 'final_val_acc_mean': 0.9969889322916666, 'step_val_acc_99_mean': 4503.333333333333, 'step_val_acc_95_mean': 4306.666666666667}, 'x_minus_y': {'final_train_loss_mean': 0.005183443737526734, 'final_val_loss_mean': 0.007959973920757571, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 4413.333333333333, 'step_val_acc_95_mean': 4140.0}, 'x_plus_y': {'final_train_loss_mean': 0.14411260839551687, 'final_val_loss_mean': 0.0721069195618232, 'final_train_acc_mean': 0.983007808526357, 'final_val_acc_mean': 0.9964192708333334, 'step_val_acc_99_mean': 2386.6666666666665, 'step_val_acc_95_mean': 2203.3333333333335}, 'permutation': {'final_train_loss_mean': 0.05460624893506368, 'final_val_loss_mean': 6.463259855906169, 'final_train_acc_mean': 0.994726558526357, 'final_val_acc_mean': 0.045003255208333336, 'step_val_acc_99_mean': 7500.0, 'step_val_acc_95_mean': 7500.0}} Description: In this run, we implemented MDL estimation and tracking for both ModDivisionDataset and PermutationGroup. The results show that all datasets except permutation achieved high validation accuracy (>99%). The permutation dataset still struggles to generalize, with a final validation accuracy of only 4.5%. The MDL tracking implementation allows us to observe the relationship between MDL and validation accuracy, which will be analyzed in the next run. ## Run 2: MDL Analysis and Correlation with Grokking Experiment description: Analyze the relationship between MDL and grokking by comparing the MDL transition point with the grokking point (95% validation accuracy). Calculate the correlation between MDL reduction and improvement in validation accuracy. Compare MDL evolution between grokking and non-grokking scenarios. Results: {'x_div_y': {'final_train_loss_mean': 0.43493669751721126, 'final_val_loss_mean': 0.5295936224671701, 'final_train_acc_mean': 0.8995442787806193, 'final_val_acc_mean': 0.8900553385416666, 'step_val_acc_99_mean': 4166.666666666667, 'step_val_acc_95_mean': 4010.0}, 'x_minus_y': {'final_train_loss_mean': 0.005072474246844649, 'final_val_loss_mean': 0.006565834938858946, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 4483.333333333333, 'step_val_acc_95_mean': 4333.333333333333}, 'x_plus_y': {'final_train_loss_mean': 0.005436841010426481, 'final_val_loss_mean': 0.005422429492076238, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 2730.0, 'step_val_acc_95_mean': 2596.6666666666665}, 'permutation': {'final_train_loss_mean': 0.06365254408835123, 'final_val_loss_mean': 4.672119283738236, 'final_train_acc_mean': 0.9929036498069763, 'final_val_acc_mean': 0.3465169270833333, 'step_val_acc_99_mean': 7403.333333333333, 'step_val_acc_95_mean': 7363.333333333333}} Description: In this run, we focused on analyzing the relationship between MDL and grokking across all datasets. The results show varying degrees of grokking and MDL evolution: 1. x_div_y: This dataset showed a slight decrease in performance compared to previous runs, with final validation accuracy around 89%. The grokking point (95% validation accuracy) was reached at step 4010 on average. 2. x_minus_y and x_plus_y: Both datasets achieved perfect validation accuracy (100%) and showed quick grokking, with 95% validation accuracy reached at steps 4333 and 2597, respectively. 3. Permutation: This dataset showed significant improvement in validation accuracy (34.65%) compared to previous runs, but still failed to achieve high generalization. The grokking point was not reached within the 7500 steps. The MDL transition points, correlations between MDL reduction and validation accuracy improvement, and MDL evolution plots will provide deeper insights into the relationship between MDL and grokking. These analyses will be crucial for understanding the information-theoretic perspective on sudden generalization. ## Run 3: Extended Analysis of MDL and Grokking Relationship Experiment description: This run focuses on further analyzing the relationship between Minimal Description Length (MDL) and grokking across all datasets. We continue to track MDL estimates, grokking points, and the correlation between MDL reduction and validation accuracy improvement. Additionally, we implement more comprehensive visualization techniques to better understand the MDL transition points and their relationship to grokking. Results: {'x_div_y': {'final_train_loss_mean': 0.007205140932152669, 'final_val_loss_mean': 0.008706816316892704, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 0.9998372395833334, 'step_val_acc_99_mean': 4570.0, 'step_val_acc_95_mean': 4363.333333333333}, 'x_minus_y': {'final_train_loss_mean': 0.004773730955397089, 'final_val_loss_mean': 0.005530588639279206, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 4393.333333333333, 'step_val_acc_95_mean': 4173.333333333333}, 'x_plus_y': {'final_train_loss_mean': 0.06498558493331075, 'final_val_loss_mean': 0.039060630525151886, 'final_train_acc_mean': 0.9936849077542623, 'final_val_acc_mean': 0.997314453125, 'step_val_acc_99_mean': 2096.6666666666665, 'step_val_acc_95_mean': 1960.0}, 'permutation': {'final_train_loss_mean': 0.10378322905550401, 'final_val_loss_mean': 6.62765375773112, 'final_train_acc_mean': 0.9820963740348816, 'final_val_acc_mean': 0.031168619791666668, 'step_val_acc_99_mean': 7500.0, 'step_val_acc_95_mean': 7500.0}} Description: In this run, we observe consistent performance across the datasets, with some notable improvements and insights: 1. x_div_y: This dataset showed significant improvement compared to Run 2, achieving near-perfect validation accuracy (99.98%). The grokking point (95% validation accuracy) was reached at step 4363 on average, which is slightly later than in Run 2. 2. x_minus_y: Performance remained excellent, with perfect validation accuracy (100%). The grokking point was reached at step 4173, which is earlier than in Run 2. 3. x_plus_y: This dataset maintained high performance with a slight decrease in final validation accuracy (99.73%). Notably, the grokking point was reached much earlier at step 1960, showing faster learning compared to previous runs. 4. Permutation: This dataset continues to struggle with generalization, achieving only 3.12% validation accuracy. The grokking point was not reached within the 7500 steps, consistent with previous runs. The extended analysis in this run, including new visualization techniques such as MDL Transition Point vs Grokking Point scatter plots and MDL Transition Rate plots, will provide deeper insights into the relationship between MDL and grokking. These results will help us better understand the information-theoretic perspective on sudden generalization and the differences in learning dynamics across the datasets. ## Run 4: Comprehensive MDL and Grokking Analysis Experiment description: This run focuses on a comprehensive analysis of the relationship between Minimal Description Length (MDL) and grokking across all datasets. We continue to track MDL estimates, grokking points, and the correlation between MDL reduction and validation accuracy improvement. The run also includes the implementation of additional visualization techniques and analysis methods to provide a more in-depth understanding of the MDL transition points, their relationship to grokking, and the overall learning dynamics of each dataset. Results: {'x_div_y': {'final_train_loss_mean': 0.005417962558567524, 'final_val_loss_mean': 0.006445834257950385, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 4173.333333333333, 'step_val_acc_95_mean': 3983.3333333333335}, 'x_minus_y': {'final_train_loss_mean': 0.01458902300025026, 'final_val_loss_mean': 0.015689253496627014, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 0.9998372395833334, 'step_val_acc_99_mean': 4610.0, 'step_val_acc_95_mean': 4403.333333333333}, 'x_plus_y': {'final_train_loss_mean': 0.005431499797850847, 'final_val_loss_mean': 0.005901952274143696, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 2573.3333333333335, 'step_val_acc_95_mean': 2350.0}, 'permutation': {'final_train_loss_mean': 0.0076007782481610775, 'final_val_loss_mean': 5.415488628049691, 'final_train_acc_mean': 0.9999348918596903, 'final_val_acc_mean': 0.3392740885416667, 'step_val_acc_99_mean': 7390.0, 'step_val_acc_95_mean': 7346.666666666667}} Description: In this run, we observe consistent and improved performance across most datasets, with some notable insights: 1. x_div_y: This dataset achieved perfect validation accuracy (100%), improving upon the previous run. The grokking point (95% validation accuracy) was reached earlier at step 3983, showing faster learning compared to Run 3. 2. x_minus_y: Performance remained excellent, with near-perfect validation accuracy (99.98%). The grokking point was reached at step 4403, which is slightly later than in Run 3 but still consistent with previous observations. 3. x_plus_y: This dataset showed perfect validation accuracy (100%), improving upon Run 3. The grokking point was reached at step 2350, which is later than in Run 3 but still significantly earlier than the other datasets. 4. Permutation: This dataset showed a notable improvement in validation accuracy (33.93%) compared to previous runs, although it still struggles with generalization. The grokking point (95% accuracy) was not reached within the 7500 steps, but the improvement suggests some progress in learning. The comprehensive analysis in this run, including the new visualization techniques and analysis methods, provides deeper insights into the relationship between MDL and grokking: 1. MDL Transition Points: The analysis of MDL transition points across datasets reveals the varying speeds at which models compress their representations of the data. 2. Correlation between MDL and Validation Accuracy: The correlation analysis helps quantify the relationship between MDL reduction and improvement in validation accuracy, potentially revealing the strength of the connection between compression and generalization. 3. MDL Evolution: Comparing the MDL evolution between grokking and non-grokking scenarios (especially in the permutation dataset) may provide insights into the differences in learning dynamics. 4. MDL Transition Rate: The analysis of MDL transition rates offers a new perspective on the speed of compression and its potential relationship to grokking speed. 5. Generalization Gap: The examination of the generalization gap (difference between training and validation accuracy) in relation to MDL provides insights into how compression relates to a model's ability to generalize. These results and analyses contribute to our understanding of the information-theoretic perspective on sudden generalization, highlighting the complex relationship between MDL, grokking, and the learning dynamics across different types of datasets. The improved performance on the permutation dataset, while still not achieving high generalization, suggests that the relationship between MDL and grokking may be more nuanced for more complex tasks. Plot Descriptions: 1. Training Loss Plots (train_loss_{dataset}.png): These plots show the training loss over time for each dataset. They help visualize how quickly and effectively the model learns to fit the training data. A decreasing trend indicates improved performance on the training set. 2. Validation Loss Plots (val_loss_{dataset}.png): Similar to the training loss plots, these show the validation loss over time. They help identify overfitting (if validation loss increases while training loss decreases) and assess the model's generalization capabilities. 3. Training Accuracy Plots (train_acc_{dataset}.png): These plots display the training accuracy over time, showing how well the model performs on the training data. An increasing trend indicates improved performance. 4. Validation Accuracy Plots (val_acc_{dataset}.png): These plots show the validation accuracy over time, indicating how well the model generalizes to unseen data. The point where the validation accuracy reaches 95% is considered the "grokking point". 5. Validation Accuracy and MDL Plots (val_acc_mdl_{dataset}.png): These plots combine validation accuracy and normalized MDL estimates over time. They help visualize the relationship between model compression (MDL reduction) and improved generalization (increased validation accuracy). 6. MDL Transition vs Grokking Point Scatter Plot (mdl_transition_vs_grokking_scatter.png): This scatter plot compares the MDL transition point (steepest decrease in MDL) with the grokking point (95% validation accuracy) for all datasets and runs. It helps identify any correlation between these two events, potentially revealing insights into the relationship between compression and generalization. 7. MDL-Validation Accuracy Correlation Plot (mdl_val_acc_correlation.png): This bar plot shows the correlation between MDL reduction and validation accuracy improvement for each dataset. Higher correlation values suggest a stronger link between compression and generalization. 8. MDL Evolution and Generalization Gap Plots (mdl_gen_gap_{dataset}.png): These plots show the MDL evolution and generalization gap (difference between training and validation accuracy) over time for each dataset. They help visualize how compression relates to the model's ability to generalize. 9. MDL Transition Rate Plots (mdl_transition_rate_{dataset}.png): These plots display the rate of change in MDL over time for each dataset. They provide insights into the speed of compression and how it might relate to the learning dynamics. 10. MDL Transition Rate Plots (mdl_transition_rate_{dataset}.png): These plots display the rate of change in MDL over time for each dataset. They provide insights into the speed of compression and how it might relate to the learning dynamics. 11. MDL and Validation Accuracy Combined Plots (val_acc_mdl_{dataset}.png): These plots combine MDL estimates and validation accuracy over time for each dataset. They provide a comprehensive view of the relationship between compression, generalization, and key learning events, helping to visualize how MDL reduction correlates with improved validation accuracy. These plots collectively offer a multi-faceted view of the learning dynamics, compression, and generalization processes across different datasets and experimental runs. They help in understanding the complex interplay between Minimal Description Length (MDL) and grokking, providing valuable insights into the information-theoretic perspective on sudden generalization in neural networks. |