# Title: Weight Initialization Grokking: Assessing the impact of weight initialization strategies on the grokking phenomenon # Experiment description: Modify the `run` function to include different weight initialization strategies (Xavier, He, orthogonal) for the Transformer model. Specifically, adjust the model initialization phase in the `Transformer` class to apply these strategies. Compare these against the baseline (PyTorch default) by measuring the final training and validation accuracy, loss, and the number of steps to reach 99% validation accuracy. Evaluate the results for each dataset and seed combination. # Plot Descriptions: 1. Training Loss Plots (train_loss_[dataset].png): These plots show the training loss across different initialization methods for each dataset. The x-axis represents the number of update steps, and the y-axis shows the training loss. Each line represents a different initialization method, allowing for easy comparison of how quickly and effectively each method reduces the training loss over time. The shaded areas around each line represent the standard error, giving an indication of the variability across different runs. 2. Validation Loss Plots (val_loss_[dataset].png): Similar to the training loss plots, these graphs display the validation loss for each initialization method across update steps. These plots are crucial for understanding how well the model generalizes to unseen data and for detecting potential overfitting. Lower validation loss generally indicates better generalization. 3. Training Accuracy Plots (train_acc_[dataset].png): These plots illustrate the training accuracy over time for each initialization method. The x-axis shows the number of update steps, while the y-axis represents the training accuracy. These graphs help visualize how quickly and accurately each method learns the training data, with higher accuracy indicating better performance on the training set. 4. Validation Accuracy Plots (val_acc_[dataset].png): These graphs show the validation accuracy over time for each initialization method. They are crucial for understanding how well the model generalizes to unseen data. Higher validation accuracy suggests better performance on new, unseen examples. The comparison between different initialization methods can reveal which approach leads to better generalization. 5. Summary Plot (summary_plot.png): This comprehensive plot compares all initialization methods across datasets for three key metrics: a. Final Training Accuracy Mean: The average final training accuracy for each method across all datasets. b. Final Validation Accuracy Mean: The average final validation accuracy for each method across all datasets. c. Steps to 99% Validation Accuracy Mean: The average number of steps required to reach 99% validation accuracy for each method across all datasets. This plot provides a high-level overview of the performance of each initialization method, allowing for quick comparisons across different datasets and metrics. It's particularly useful for identifying which initialization methods consistently perform well across various tasks. These plots collectively provide a comprehensive view of how different weight initialization strategies affect the learning process, generalization ability, and overall performance of the Transformer model across various arithmetic tasks. They allow for in-depth analysis of the grokking phenomenon and how it's influenced by different initialization approaches. ## Run 0: Baseline Results: {'x_div_y': {'final_train_loss_mean': 0.005800435319542885, 'final_val_loss_mean': 0.006530226518710454, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 4200.0}, 'x_minus_y': {'final_train_loss_mean': 0.014211568981409073, 'final_val_loss_mean': 0.014943961674968401, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 4720.0}, 'x_plus_y': {'final_train_loss_mean': 0.003832749711970488, 'final_val_loss_mean': 0.004045687771091859, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 2363.3333333333335}, 'permutation': {'final_train_loss_mean': 0.08011958096176386, 'final_val_loss_mean': 6.804208914438884, 'final_train_acc_mean': 0.9880208373069763, 'final_val_acc_mean': 0.035888671875, 'step_val_acc_99_mean': 7500.0}} Description: Baseline results. ## Run 1: Xavier (Glorot) Initialization Description: Implemented Xavier uniform initialization for Linear and Embedding layers in the Transformer model. The LayerNorm layers were initialized with weight 1.0 and bias 0.0. Results: x_div_y: - Final train loss: 0.00703450928752621 - Final val loss: 0.008110948217411837 - Final train accuracy: 1.0 - Final val accuracy: 1.0 - Steps to 99% val accuracy: 2536.6666666666665 x_minus_y: - Final train loss: 0.005914364087705811 - Final val loss: 1.4212849920925994 - Final train accuracy: 0.9999348918596903 - Final val accuracy: 0.7403157552083334 - Steps to 99% val accuracy: 2346.6666666666665 x_plus_y: - Final train loss: 0.004743196380635102 - Final val loss: 0.0051032428940137224 - Final train accuracy: 1.0 - Final val accuracy: 1.0 - Steps to 99% val accuracy: 863.3333333333334 permutation: - Final train loss: 0.005184388952329755 - Final val loss: 0.008002187125384808 - Final train accuracy: 1.0 - Final val accuracy: 1.0 - Steps to 99% val accuracy: 5066.666666666667 Analysis: Xavier initialization showed improvements in convergence speed for most tasks compared to the baseline. The x_plus_y task saw a significant reduction in steps to reach 99% validation accuracy. However, the x_minus_y task showed some instability in validation performance, indicating that Xavier initialization might not be optimal for all arithmetic operations. ## Run 2: He Initialization Description: Implemented He initialization for Linear and Embedding layers in the Transformer model. The LayerNorm layers were initialized with weight 1.0 and bias 0.0. Results: x_div_y: - Final train loss: 0.0057101390945414705 - Final val loss: 0.006926700938493013 - Final train accuracy: 1.0 - Final val accuracy: 1.0 - Steps to 99% val accuracy: 3463.3333333333335 x_minus_y: - Final train loss: 0.07778730530602236 - Final val loss: 0.05283881491050124 - Final train accuracy: 0.9914713501930237 - Final val accuracy: 0.9942220052083334 - Steps to 99% val accuracy: 3640.0 x_plus_y: - Final train loss: 0.048976593650877476 - Final val loss: 0.03214737741897503 - Final train accuracy: 0.9975911577542623 - Final val accuracy: 0.9988606770833334 - Steps to 99% val accuracy: 2136.6666666666665 permutation: - Final train loss: 0.054390662194540106 - Final val loss: 2.36757427531605 - Final train accuracy: 0.9977213541666666 - Final val accuracy: 0.6680501302083334 - Steps to 99% val accuracy: 6460.0 Analysis: He initialization showed mixed results compared to Xavier initialization. For x_div_y, it performed slightly better in terms of convergence speed. However, for x_minus_y and x_plus_y, it showed slower convergence and slightly lower final accuracies. The permutation task saw a significant drop in validation accuracy, indicating that He initialization might not be suitable for this particular task. Overall, Xavier initialization still appears to be more robust across different arithmetic operations. ## Run 3: Orthogonal Initialization Description: Implemented Orthogonal initialization for Linear and Embedding layers in the Transformer model. The LayerNorm layers were initialized with weight 1.0 and bias 0.0. Results: x_div_y: - Final train loss: 0.5920290009429058 - Final val loss: 0.30291892828730244 - Final train accuracy: 0.8575520912806193 - Final val accuracy: 0.9386393229166666 - Steps to 99% val accuracy: 1643.3333333333333 x_minus_y: - Final train loss: 0.0039047593406091132 - Final val loss: 0.004387715288127462 - Final train accuracy: 1.0 - Final val accuracy: 1.0 - Steps to 99% val accuracy: 1993.3333333333333 x_plus_y: - Final train loss: 0.008580587338656187 - Final val loss: 0.009516028997798761 - Final train accuracy: 1.0 - Final val accuracy: 0.9998372395833334 - Steps to 99% val accuracy: 836.6666666666666 permutation: - Final train loss: 0.004259653855115175 - Final val loss: 0.08990247027638058 - Final train accuracy: 1.0 - Final val accuracy: 0.9829915364583334 - Steps to 99% val accuracy: 4543.333333333333 Analysis: Orthogonal initialization showed mixed results compared to Xavier and He initializations. It performed well on x_minus_y and x_plus_y tasks, achieving perfect or near-perfect accuracy with relatively fast convergence. However, it struggled with the x_div_y task, showing lower accuracy and higher loss compared to previous initializations. For the permutation task, it achieved high accuracy but took longer to converge compared to Xavier initialization. Overall, Orthogonal initialization seems to be effective for certain arithmetic operations but may not be the best choice for all tasks in this experiment. ## Run 4: Kaiming Normal Initialization Description: Implemented Kaiming Normal initialization for Linear and Embedding layers in the Transformer model. The LayerNorm layers were initialized with weight 1.0 and bias 0.0. Results: x_div_y: - Final train loss: 0.006882136842856805 - Final val loss: 0.008419923096274337 - Final train accuracy: 1.0 - Final val accuracy: 1.0 - Steps to 99% val accuracy: 3070.0 x_minus_y: - Final train loss: 0.40684207854792476 - Final val loss: 0.18088411757101616 - Final train accuracy: 0.9046223958333334 - Final val accuracy: 0.9646809895833334 - Steps to 99% val accuracy: 3546.6666666666665 x_plus_y: - Final train loss: 0.005102624961485465 - Final val loss: 0.005619957422216733 - Final train accuracy: 1.0 - Final val accuracy: 1.0 - Steps to 99% val accuracy: 1966.6666666666667 permutation: - Final train loss: 0.15046238231783113 - Final val loss: 0.07921216955098014 - Final train accuracy: 0.981640636920929 - Final val accuracy: 0.9939778645833334 - Steps to 99% val accuracy: 6296.666666666667 Analysis: Kaiming Normal initialization showed mixed results compared to previous initializations. It performed well on x_div_y and x_plus_y tasks, achieving perfect accuracy with relatively fast convergence. However, it struggled more on the x_minus_y task, showing lower accuracies and higher losses. The permutation task saw good performance, with high accuracy but slower convergence compared to some previous initializations. Overall, Kaiming Normal initialization seems to be effective for certain arithmetic operations but may not be the best choice for all tasks in this experiment. ## Run 5: Uniform Initialization Description: Implementing Uniform initialization for Linear and Embedding layers in the Transformer model. The LayerNorm layers will be initialized with weight 1.0 and bias 0.0.