File size: 14,440 Bytes
f71c233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Title: Impact of Data Augmentation on Grokking Dynamics in Mathematical Operations

# Experiment description: Modify AbstractDataset to include methods for operand reversal (for addition and multiplication) and operand negation (for addition, subtraction, and division) augmentations. Update the training loop in train() to apply these augmentations with varying probabilities. Run experiments with five conditions across all datasets: no augmentation (baseline), reversal augmentation, negation augmentation, combined augmentation (15% probability each), and combined augmentation (30% probability each). Track grokking behavior by measuring: 1) steps to 99% validation accuracy, 2) rate of validation accuracy increase around the grokking point, and 3) final accuracies. Plot learning curves for each condition. Compare how different augmentations affect these metrics across operation types.

# Plot Descriptions:

1. Validation Accuracy Across Runs (val_acc_[dataset].png):
   These plots show the validation accuracy over time for each dataset (x_div_y, x_minus_y, x_plus_y, permutation) across all runs. Each line represents a different augmentation strategy, allowing for direct comparison of how quickly and effectively each strategy leads to grokking.
   - X-axis: Update Steps
   - Y-axis: Validation Accuracy (0 to 1.05)
   - Lines: Baseline, Operand Reversal, Negation, Combined (15%), Combined (30%)
   - Interpretation: Steeper curves indicate faster grokking. Higher final accuracies show better overall performance. The point where curves sharply rise indicates the onset of grokking.

2. Steps to 99% Validation Accuracy (steps_to_99_acc.png):
   This bar plot compares the number of steps required to reach 99% validation accuracy across all datasets and runs. It provides a clear visualization of which augmentation strategies lead to faster grokking for each operation.
   - X-axis: Datasets (x_div_y, x_minus_y, x_plus_y, permutation)
   - Y-axis: Steps to 99% Validation Accuracy
   - Bars: Grouped by run (Baseline, Operand Reversal, Negation, Combined (15%), Combined (30%))
   - Interpretation: Shorter bars indicate faster grokking. Comparing bar heights within each dataset group shows which augmentation strategy is most effective for that operation.

3. Final Validation Accuracy (final_val_acc.png):
   This bar plot shows the final validation accuracy achieved for each dataset and run. It allows for comparison of the ultimate effectiveness of each augmentation strategy across different operations.
   - X-axis: Datasets (x_div_y, x_minus_y, x_plus_y, permutation)
   - Y-axis: Final Validation Accuracy
   - Bars: Grouped by run (Baseline, Operand Reversal, Negation, Combined (15%), Combined (30%))
   - Interpretation: Higher bars indicate better final performance. This plot helps identify which augmentation strategies lead to the best overall learning, even if they might not be the fastest to reach grokking.

These plots collectively provide a comprehensive view of how different augmentation strategies affect grokking dynamics across various mathematical operations. They allow for analysis of both the speed of grokking (from the validation accuracy curves and steps to 99% accuracy) and the quality of final learning (from the final validation accuracy plot). Researchers can use these visualizations to draw conclusions about which augmentation strategies are most effective for different types of operations and to understand the trade-offs between speed of grokking and final performance.
## Run 0: Baseline
Results: {'x_div_y': {'final_train_loss_mean': 0.005800435319542885, 'final_val_loss_mean': 0.006530226518710454, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 4200.0}, 'x_minus_y': {'final_train_loss_mean': 0.014211568981409073, 'final_val_loss_mean': 0.014943961674968401, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 4720.0}, 'x_plus_y': {'final_train_loss_mean': 0.003832749711970488, 'final_val_loss_mean': 0.004045687771091859, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 2363.3333333333335}, 'permutation': {'final_train_loss_mean': 0.08011958096176386, 'final_val_loss_mean': 6.804208914438884, 'final_train_acc_mean': 0.9880208373069763, 'final_val_acc_mean': 0.035888671875, 'step_val_acc_99_mean': 7500.0}}
Description: Baseline results.
## Run 1: Operand Reversal Augmentation
Results: {'x_div_y': {'final_train_loss_mean': 0.3073409056911866, 'final_val_loss_mean': 0.818450391292572, 'final_train_acc_mean': 0.9272135496139526, 'final_val_acc_mean': 0.807373046875, 'step_val_acc_99_mean': 4500.0, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'x_minus_y': {'final_train_loss_mean': 0.0060335383750498295, 'final_val_loss_mean': 0.014887654222548008, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 0.9984537760416666, 'step_val_acc_99_mean': 5160.0, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'x_plus_y': {'final_train_loss_mean': 0.00963700112576286, 'final_val_loss_mean': 0.00942775048315525, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 1993.3333333333333, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'permutation': {'final_train_loss_mean': 0.02911314181983471, 'final_val_loss_mean': 6.984205881754558, 'final_train_acc_mean': 0.998046875, 'final_val_acc_mean': 0.022623697916666668, 'step_val_acc_99_mean': None, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}}
Description: Run 1 implemented operand reversal augmentation for addition and multiplication operations. The results show some interesting changes compared to the baseline:
1. For x_plus_y, the step_val_acc_99_mean decreased from 2363 to 1993, indicating faster grokking.
2. For x_minus_y, the step_val_acc_99_mean increased from 4720 to 5160, suggesting slower grokking.
3. For x_div_y, the final accuracies decreased significantly, which is unexpected and may need further investigation.
4. The permutation task showed little change, as expected since reversal augmentation doesn't apply to it.
These results suggest that operand reversal augmentation has different effects on different operations, potentially improving grokking for addition but hindering it for subtraction and division.

## Run 2: Negation Augmentation
Results: {'x_div_y': {'final_train_loss_mean': 0.009206565717856089, 'final_val_loss_mean': 0.01013705072303613, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 1443.3333333333333, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'x_minus_y': {'final_train_loss_mean': 0.007552354441334804, 'final_val_loss_mean': 0.008941326756030321, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 0.9996744791666666, 'step_val_acc_99_mean': 1343.3333333333333, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'x_plus_y': {'final_train_loss_mean': 0.003964016912505031, 'final_val_loss_mean': 0.004060242945949237, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 1000.0, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'permutation': {'final_train_loss_mean': 0.10073609425065418, 'final_val_loss_mean': 5.342761894067128, 'final_train_acc_mean': 0.9820312658945719, 'final_val_acc_mean': 0.3232421875, 'step_val_acc_99_mean': None, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}}
Description: Run 2 implemented negation augmentation for addition, subtraction, and division operations. The results show significant improvements compared to both the baseline (Run 0) and the operand reversal augmentation (Run 1):
1. For x_div_y, the step_val_acc_99_mean decreased dramatically from 4200 (baseline) and 4500 (reversal) to 1443, indicating much faster grokking. The final accuracies also improved to 100%, resolving the issues seen in Run 1.
2. For x_minus_y, the step_val_acc_99_mean decreased significantly from 4720 (baseline) and 5160 (reversal) to 1343, showing a substantial improvement in grokking speed.
3. For x_plus_y, the step_val_acc_99_mean further decreased from 2363 (baseline) and 1993 (reversal) to 1000, demonstrating even faster grokking.
4. The permutation task showed some improvement in validation accuracy (from 0.036 to 0.323) but still did not achieve grokking, which is expected as negation augmentation doesn't directly apply to it.
These results suggest that negation augmentation has a strong positive effect on grokking speed for all applicable operations (addition, subtraction, and division). The improvement is particularly notable for division and subtraction, which struggled with the reversal augmentation in Run 1.

## Run 3: Combined Operand Reversal and Negation Augmentation
Results: {'x_div_y': {'final_train_loss_mean': 0.007881488806257645, 'final_val_loss_mean': 0.008671872317790985, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 1766.6666666666667, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'x_minus_y': {'final_train_loss_mean': 0.04362078352520863, 'final_val_loss_mean': 0.036223807682593666, 'final_train_acc_mean': 0.998242199420929, 'final_val_acc_mean': 0.9984537760416666, 'step_val_acc_99_mean': 1056.6666666666667, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'x_plus_y': {'final_train_loss_mean': 0.3260944996339579, 'final_val_loss_mean': 0.42411381254593533, 'final_train_acc_mean': 0.929882824420929, 'final_val_acc_mean': 0.9142252604166666, 'step_val_acc_99_mean': 920.0, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'permutation': {'final_train_loss_mean': 0.02254059522723158, 'final_val_loss_mean': 1.8000340942914288, 'final_train_acc_mean': 0.998242199420929, 'final_val_acc_mean': 0.686767578125, 'step_val_acc_99_mean': 6925.0, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}}
Description: Run 3 implemented a combination of operand reversal and negation augmentations for applicable operations (addition, subtraction, and division), each with a 15% probability. The results show interesting changes compared to the previous runs:
1. For x_div_y, the step_val_acc_99_mean increased slightly from 1443 (negation only) to 1767, indicating slightly slower grokking. However, it still maintains 100% final accuracy.
2. For x_minus_y, the step_val_acc_99_mean decreased from 1343 (negation only) to 1057, showing faster grokking. The final accuracies remain high.
3. For x_plus_y, the step_val_acc_99_mean decreased further from 1000 (negation only) to 920, demonstrating even faster grokking. However, the final accuracies decreased slightly.
4. The permutation task showed significant improvement in validation accuracy (from 0.323 to 0.687) and achieved grokking, which is unexpected as these augmentations don't directly apply to it. This might be due to increased regularization from the combined augmentations.
These results suggest that the combination of operand reversal and negation augmentations has mixed effects:
- It slightly slows down grokking for division but maintains high accuracy.
- It speeds up grokking for subtraction and addition, with a slight trade-off in final accuracy for addition.
- Unexpectedly, it significantly improves performance on the permutation task, possibly due to increased regularization.
The combined augmentations seem to provide a good balance between grokking speed and final accuracy across different operations, with the added benefit of improving performance on the permutation task.

## Run 4: Increased Probability of Augmentations
Results: {'x_div_y': {'final_train_loss_mean': 0.007121656866123279, 'final_val_loss_mean': 0.00815950520336628, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 0.9998372395833334, 'step_val_acc_99_mean': 1876.6666666666667, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'x_minus_y': {'final_train_loss_mean': 0.0070951394736766815, 'final_val_loss_mean': 0.00798630698894461, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 1366.6666666666667, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'x_plus_y': {'final_train_loss_mean': 0.01038626270989577, 'final_val_loss_mean': 0.010425164829939604, 'final_train_acc_mean': 1.0, 'final_val_acc_mean': 1.0, 'step_val_acc_99_mean': 793.3333333333334, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}, 'permutation': {'final_train_loss_mean': 0.08731850298742454, 'final_val_loss_mean': 5.65827210744222, 'final_train_acc_mean': 0.9888672033945719, 'final_val_acc_mean': 0.12752278645833334, 'step_val_acc_99_mean': None, 'step_val_acc_95_mean': None, 'max_acc_increase_rate_mean': 0.0}}
Description: Run 4 implemented a combination of operand reversal and negation augmentations for applicable operations (addition, subtraction, and division), each with a 30% probability (increased from 15% in Run 3). The results show some interesting changes compared to Run 3:
1. For x_div_y, the step_val_acc_99_mean increased slightly from 1767 to 1877, indicating slightly slower grokking. However, it still maintains very high final accuracy (99.98%).
2. For x_minus_y, the step_val_acc_99_mean increased from 1057 to 1367, showing slower grokking. The final accuracies remain at 100%.
3. For x_plus_y, the step_val_acc_99_mean decreased from 920 to 793, demonstrating faster grokking. The final accuracies remain at 100%.
4. The permutation task showed a significant decrease in validation accuracy (from 0.687 to 0.128), which is a reversal of the unexpected improvement seen in Run 3.
These results suggest that increasing the probability of augmentations to 30% has mixed effects:
- It slightly slows down grokking for division and subtraction but maintains high accuracy.
- It speeds up grokking for addition while maintaining perfect accuracy.
- The increased augmentation probability negatively impacts the permutation task, reversing the unexpected improvement seen with lower augmentation probabilities.
The higher augmentation rate seems to have different effects on different operations, potentially due to the increased complexity introduced in the training data. This suggests that there might be an optimal augmentation probability that balances improved grokking for some operations without negatively impacting others.