File size: 20,213 Bytes
f71c233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import json
import os
import os.path as osp
from scipy.signal import savgol_filter

# LOAD FINAL RESULTS:
datasets = ["x_div_y", "x_minus_y", "x_plus_y", "permutation"]
folders = os.listdir("./")
final_results = {}
results_info = {}
for folder in folders:
    if folder.startswith("run") and osp.isdir(folder):
        with open(osp.join(folder, "final_info.json"), "r") as f:
            final_results[folder] = json.load(f)
        results_dict = np.load(
            osp.join(folder, "all_results.npy"), allow_pickle=True
        ).item()
        print(results_dict.keys())
        run_info = {}
        for dataset in datasets:
            run_info[dataset] = {}
            val_losses = []
            train_losses = []
            val_accs = []
            train_accs = []
            for k in results_dict.keys():
                if dataset in k and "val_info" in k:
                    run_info[dataset]["step"] = [
                        info["step"] for info in results_dict[k]
                    ]
                    val_losses.append([info["val_loss"] for info in results_dict[k]])
                    val_accs.append([info["val_accuracy"] for info in results_dict[k]])
                if dataset in k and "train_info" in k:
                    train_losses.append(
                        [info["train_loss"] for info in results_dict[k]]
                    )
                    train_accs.append(
                        [info["train_accuracy"] for info in results_dict[k]]
                    )
                mean_val_losses = np.mean(val_losses, axis=0)
                mean_train_losses = np.mean(train_losses, axis=0)
                mean_val_accs = np.mean(val_accs, axis=0)
                mean_train_accs = np.mean(train_accs, axis=0)
                if len(val_losses) > 0:
                    sterr_val_losses = np.std(val_losses, axis=0) / np.sqrt(
                        len(val_losses)
                    )
                    stderr_train_losses = np.std(train_losses, axis=0) / np.sqrt(
                        len(train_losses)
                    )
                    sterr_val_accs = np.std(val_accs, axis=0) / np.sqrt(len(val_accs))
                    stderr_train_accs = np.std(train_accs, axis=0) / np.sqrt(
                        len(train_accs)
                    )
                else:
                    sterr_val_losses = np.zeros_like(mean_val_losses)
                    stderr_train_losses = np.zeros_like(mean_train_losses)
                    sterr_val_accs = np.zeros_like(mean_val_accs)
                    stderr_train_accs = np.zeros_like(mean_train_accs)
                run_info[dataset]["val_loss"] = mean_val_losses
                run_info[dataset]["train_loss"] = mean_train_losses
                run_info[dataset]["val_loss_sterr"] = sterr_val_losses
                run_info[dataset]["train_loss_sterr"] = stderr_train_losses
                run_info[dataset]["val_acc"] = mean_val_accs
                run_info[dataset]["train_acc"] = mean_train_accs
                run_info[dataset]["val_acc_sterr"] = sterr_val_accs
                run_info[dataset]["train_acc_sterr"] = stderr_train_accs

            # Add MDL info
            mdl_data = [info for k, info in results_dict.items() if dataset in k and "mdl_info" in k]
            if mdl_data:
                run_info[dataset]["mdl_step"] = [item["step"] for item in mdl_data[0]]
                run_info[dataset]["mdl"] = [item["mdl"] for item in mdl_data[0]]

        results_info[folder] = run_info

# CREATE LEGEND -- ADD RUNS HERE THAT WILL BE PLOTTED
labels = {
    "run_0": "Baseline",
    "run_1": "MDL Tracking",
    "run_2": "MDL Analysis",
    "run_3": "Extended Analysis",
    "run_4": "Comprehensive Analysis",
}


# Create a programmatic color palette
def generate_color_palette(n):
    cmap = plt.get_cmap("tab20")
    return [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, n)]


# Get the list of runs and generate the color palette
runs = list(labels.keys())
colors = generate_color_palette(len(runs))

# Plot 1: Line plot of training loss for each dataset across the runs with labels
for dataset in datasets:
    plt.figure(figsize=(10, 6))
    for i, run in enumerate(runs):
        iters = results_info[run][dataset]["step"]
        mean = results_info[run][dataset]["train_loss"]
        sterr = results_info[run][dataset]["train_loss_sterr"]
        plt.plot(iters, mean, label=labels[run], color=colors[i])
        plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2)

    plt.title(f"Training Loss Across Runs for {dataset} Dataset")
    plt.xlabel("Update Steps")
    plt.ylabel("Training Loss")
    plt.legend()
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.tight_layout()
    plt.savefig(f"train_loss_{dataset}.png")
    plt.close()

# Plot 2: Line plot of validation loss for each dataset across the runs with labels
for dataset in datasets:
    plt.figure(figsize=(10, 6))
    for i, run in enumerate(runs):
        iters = results_info[run][dataset]["step"]
        mean = results_info[run][dataset]["val_loss"]
        sterr = results_info[run][dataset]["val_loss_sterr"]
        plt.plot(iters, mean, label=labels[run], color=colors[i])
        plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2)

    plt.title(f"Validation Loss Across Runs for {dataset} Dataset")
    plt.xlabel("Update Steps")
    plt.ylabel("Validation Loss")
    plt.legend()
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.tight_layout()
    plt.savefig(f"val_loss_{dataset}.png")
    plt.close()


# Plot 3: Line plot of training acc for each dataset across the runs with labels
for dataset in datasets:
    plt.figure(figsize=(10, 6))
    for i, run in enumerate(runs):
        iters = results_info[run][dataset]["step"]
        mean = results_info[run][dataset]["train_acc"]
        sterr = results_info[run][dataset]["train_acc_sterr"]
        plt.plot(iters, mean, label=labels[run], color=colors[i])
        plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2)

    plt.title(f"Training Accuracy Across Runs for {dataset} Dataset")
    plt.xlabel("Update Steps")
    plt.ylabel("Training Acc")
    plt.legend()
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.tight_layout()
    plt.savefig(f"train_acc_{dataset}.png")
    plt.close()

# Plot 2: Line plot of validation acc for each dataset across the runs with labels
for dataset in datasets:
    plt.figure(figsize=(10, 6))
    for i, run in enumerate(runs):
        iters = results_info[run][dataset]["step"]
        mean = results_info[run][dataset]["val_acc"]
        sterr = results_info[run][dataset]["val_acc_sterr"]
        plt.plot(iters, mean, label=labels[run], color=colors[i])
        plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2)

    plt.title(f"Validation Loss Across Runs for {dataset} Dataset")
    plt.xlabel("Update Steps")
    plt.ylabel("Validation Acc")
    plt.legend()
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.tight_layout()
    plt.savefig(f"val_acc_{dataset}.png")
    plt.close()

# Plot 5: MDL estimates alongside validation accuracy
for dataset in datasets:
    plt.figure(figsize=(10, 6))
    for i, run in enumerate(runs):
        if run != "run_0":  # Skip baseline run
            iters = results_info[run][dataset]["step"]
            val_acc = results_info[run][dataset]["val_acc"]
            mdl_step = results_info[run][dataset]["mdl_step"]
            mdl = results_info[run][dataset]["mdl"]

            # Normalize MDL values
            mdl_normalized = (mdl - np.min(mdl)) / (np.max(mdl) - np.min(mdl))

            # Apply Savitzky-Golay filter to smooth MDL curve
            mdl_smooth = savgol_filter(mdl_normalized, window_length=5, polyorder=2)

            plt.plot(iters, val_acc, label=f"{labels[run]} - Val Acc", color=colors[i])
            plt.plot(mdl_step, mdl_smooth, label=f"{labels[run]} - MDL", linestyle='--', color=colors[i])

    plt.title(f"Validation Accuracy and MDL for {dataset} Dataset")
    plt.xlabel("Update Steps")
    plt.ylabel("Validation Accuracy / Normalized MDL")
    plt.legend()
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.tight_layout()
    plt.savefig(f"val_acc_mdl_{dataset}.png")
    plt.close()

# Calculate MDL transition point and correlation
mdl_analysis = {}
for dataset in datasets:
    mdl_analysis[dataset] = {}
    for run in runs:
        if run != "run_0":  # Skip baseline run
            mdl = results_info[run][dataset]["mdl"]
            mdl_step = results_info[run][dataset]["mdl_step"]
            val_acc = results_info[run][dataset]["val_acc"]
            train_acc = results_info[run][dataset]["train_acc"]

            # Calculate MDL transition point (steepest decrease)
            mdl_diff = np.diff(mdl)
            mdl_transition_idx = np.argmin(mdl_diff)
            mdl_transition_point = mdl_step[mdl_transition_idx]

            # Find grokking point (95% validation accuracy)
            grokking_point = next((step for step, acc in zip(results_info[run][dataset]["step"], val_acc) if acc >= 0.95), None)

            # Calculate correlation between MDL reduction and validation accuracy improvement
            mdl_normalized = (mdl - np.min(mdl)) / (np.max(mdl) - np.min(mdl))
            val_acc_interp = np.interp(mdl_step, results_info[run][dataset]["step"], val_acc)
            correlation = np.corrcoef(mdl_normalized, val_acc_interp)[0, 1]

            # Calculate generalization gap
            train_acc_interp = np.interp(mdl_step, results_info[run][dataset]["step"], train_acc)
            gen_gap = train_acc_interp - val_acc_interp

            mdl_analysis[dataset][run] = {
                "mdl_transition_point": mdl_transition_point,
                "grokking_point": grokking_point,
                "correlation": correlation,
                "mdl": mdl,
                "mdl_step": mdl_step,
                "val_acc": val_acc_interp,
                "gen_gap": gen_gap
            }

# Plot MDL transition point vs Grokking point
plt.figure(figsize=(10, 6))
for dataset in datasets:
    for run in runs:
        if run != "run_0":
            mdl_tp = mdl_analysis[dataset][run]["mdl_transition_point"]
            grok_p = mdl_analysis[dataset][run]["grokking_point"]
            plt.scatter(mdl_tp, grok_p, label=f"{dataset} - {run}")

plt.plot([0, max(plt.xlim())], [0, max(plt.xlim())], 'k--', alpha=0.5)
plt.xlabel("MDL Transition Point")
plt.ylabel("Grokking Point")
plt.title("MDL Transition Point vs Grokking Point")
plt.legend()
plt.tight_layout()
plt.savefig("mdl_transition_vs_grokking.png")
plt.close()

# Plot correlation between MDL reduction and val acc improvement
plt.figure(figsize=(10, 6))
for dataset in datasets:
    correlations = [mdl_analysis[dataset][run]["correlation"] for run in runs if run != "run_0"]
    plt.bar(dataset, np.mean(correlations), yerr=np.std(correlations), capsize=5)

plt.xlabel("Dataset")
plt.ylabel("Correlation")
plt.title("Correlation between MDL Reduction and Val Acc Improvement")
plt.tight_layout()
plt.savefig("mdl_val_acc_correlation.png")
plt.close()

# Plot MDL evolution and generalization gap
for dataset in datasets:
    plt.figure(figsize=(12, 8))
    for run in runs:
        if run != "run_0":
            mdl_step = mdl_analysis[dataset][run]["mdl_step"]
            mdl = mdl_analysis[dataset][run]["mdl"]
            gen_gap = mdl_analysis[dataset][run]["gen_gap"]
            
            plt.subplot(2, 1, 1)
            plt.plot(mdl_step, mdl, label=f"{run} - MDL")
            plt.title(f"MDL Evolution and Generalization Gap - {dataset}")
            plt.ylabel("MDL")
            plt.legend()
            
            plt.subplot(2, 1, 2)
            plt.plot(mdl_step, gen_gap, label=f"{run} - Gen Gap")
            plt.xlabel("Steps")
            plt.ylabel("Generalization Gap")
            plt.legend()
    
    plt.tight_layout()
    plt.savefig(f"mdl_gen_gap_{dataset}.png")
    plt.close()

# Calculate and plot MDL transition rate
for dataset in datasets:
    plt.figure(figsize=(10, 6))
    for run in runs:
        if run != "run_0":
            mdl_step = mdl_analysis[dataset][run]["mdl_step"]
            mdl = mdl_analysis[dataset][run]["mdl"]
            mdl_rate = np.gradient(mdl, mdl_step)
            plt.plot(mdl_step, mdl_rate, label=f"{run} - MDL Rate")
    plt.title(f"MDL Transition Rate - {dataset}")
    plt.xlabel("Steps")
    plt.ylabel("MDL Rate of Change")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"mdl_transition_rate_{dataset}.png")
    plt.close()

# Scatter plot of MDL transition points vs grokking points
plt.figure(figsize=(10, 6))
for dataset in datasets:
    for run in runs:
        if run != "run_0":
            mdl_tp = mdl_analysis[dataset][run]["mdl_transition_point"]
            grok_p = mdl_analysis[dataset][run]["grokking_point"]
            if mdl_tp is not None and grok_p is not None:
                plt.scatter(mdl_tp, grok_p, label=f"{dataset} - {run}")
if plt.gca().get_xlim()[1] > 0 and plt.gca().get_ylim()[1] > 0:
    plt.plot([0, max(plt.xlim())], [0, max(plt.ylim())], 'k--', alpha=0.5)
plt.xlabel("MDL Transition Point")
plt.ylabel("Grokking Point")
plt.title("MDL Transition Point vs Grokking Point")
plt.legend()
plt.tight_layout()
plt.savefig("mdl_transition_vs_grokking_scatter.png")
plt.close()

# Print analysis results
for dataset in datasets:
    print(f"Dataset: {dataset}")
    for run in runs:
        if run != "run_0":
            analysis = mdl_analysis[dataset][run]
            print(f"  Run: {run}")
            print(f"    MDL Transition Point: {analysis['mdl_transition_point']}")
            print(f"    Grokking Point: {analysis['grokking_point']}")
            print(f"    Correlation: {analysis['correlation']:.4f}")
    print()

# Calculate and print average MDL transition point and grokking point for each dataset
for dataset in datasets:
    mdl_tps = []
    grok_ps = []
    correlations = []
    for run in runs:
        if run != "run_0":
            mdl_tps.append(mdl_analysis[dataset][run]["mdl_transition_point"])
            grok_ps.append(mdl_analysis[dataset][run]["grokking_point"])
            correlations.append(mdl_analysis[dataset][run]["correlation"])
    avg_mdl_tp = np.mean(mdl_tps) if mdl_tps else None
    avg_grok_p = np.mean(grok_ps) if grok_ps else None
    avg_correlation = np.mean(correlations) if correlations else None
    print(f"Dataset: {dataset}")
    print(f"  Average MDL Transition Point: {avg_mdl_tp:.2f if avg_mdl_tp is not None else 'N/A'}")
    print(f"  Average Grokking Point: {avg_grok_p:.2f if avg_grok_p is not None else 'N/A'}")
    if avg_mdl_tp is not None and avg_grok_p is not None:
        print(f"  Difference: {abs(avg_mdl_tp - avg_grok_p):.2f}")
    else:
        print("  Difference: N/A")
    print(f"  Average Correlation: {avg_correlation:.4f if avg_correlation is not None else 'N/A'}")

    # Add these lines for debugging
    print(f"  MDL Transition Points: {mdl_tps}")
    print(f"  Grokking Points: {grok_ps}")
    print(f"  Correlations: {correlations}")
    print()

# Plot MDL Transition Rate vs Grokking Speed
try:
    plt.figure(figsize=(12, 8))
    for dataset in datasets:
        for run in runs:
            if run != "run_0":
                analysis = mdl_analysis[dataset][run]
                mdl_transition_rate = np.min(np.gradient(analysis['mdl'], analysis['mdl_step']))
                if analysis['grokking_point'] is not None and analysis['mdl_transition_point'] is not None:
                    if analysis['grokking_point'] != analysis['mdl_transition_point']:
                        grokking_speed = 1 / (analysis['grokking_point'] - analysis['mdl_transition_point'])
                    else:
                        grokking_speed = np.inf
                    plt.scatter(mdl_transition_rate, grokking_speed, label=f"{dataset} - {labels[run]}", alpha=0.7)

    plt.xlabel("MDL Transition Rate")
    plt.ylabel("Grokking Speed")
    plt.title("MDL Transition Rate vs Grokking Speed")
    plt.legend()
    plt.xscale('symlog')
    plt.yscale('symlog')
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.tight_layout()
    plt.savefig("mdl_transition_rate_vs_grokking_speed.png")
    plt.close()
except Exception as e:
    print(f"Error plotting MDL Transition Rate vs Grokking Speed: {e}")

# Plot MDL evolution and validation accuracy for all datasets
for dataset in datasets:
    plt.figure(figsize=(15, 10))
    for run in runs:
        if run != "run_0":
            analysis = mdl_analysis[dataset][run]
            mdl_step = analysis['mdl_step']
            mdl = analysis['mdl']
            val_acc = analysis['val_acc']
            
            plt.plot(mdl_step, mdl, label=f'{labels[run]} - MDL')
            plt.plot(mdl_step, val_acc, label=f'{labels[run]} - Val Acc')
            plt.axvline(x=analysis['mdl_transition_point'], color='r', linestyle='--', label='MDL Transition')
            plt.axvline(x=analysis['grokking_point'], color='g', linestyle='--', label='Grokking Point')
    
    plt.title(f"MDL Evolution and Validation Accuracy - {dataset}")
    plt.xlabel("Steps")
    plt.ylabel("MDL / Validation Accuracy")
    plt.legend()
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.tight_layout()
    plt.savefig(f"mdl_val_acc_evolution_{dataset}.png")
    plt.close()

# Plot correlation between MDL reduction and validation accuracy improvement
plt.figure(figsize=(10, 6))
for dataset in datasets:
    correlations = []
    for run in runs:
        if run != "run_0":
            correlations.append(mdl_analysis[dataset][run]["correlation"])
    plt.bar(dataset, np.mean(correlations), yerr=np.std(correlations), capsize=5)

plt.xlabel("Dataset")
plt.ylabel("Correlation")
plt.title("Correlation between MDL Reduction and Validation Accuracy Improvement")
plt.tight_layout()
plt.savefig("mdl_val_acc_correlation.png")
plt.close()

# Print analysis results
print("\nAnalysis Results:")
for dataset in datasets:
    print(f"\nDataset: {dataset}")
    for run in runs:
        if run != "run_0":
            analysis = mdl_analysis[dataset][run]
            print(f"  Run: {labels[run]}")
            print(f"    MDL Transition Point: {analysis['mdl_transition_point']}")
            print(f"    Grokking Point: {analysis['grokking_point']}")
            print(f"    Correlation: {analysis['correlation']:.4f}")

# Calculate and print average MDL transition point and grokking point for each dataset
print("\nAverage MDL Transition Point and Grokking Point:")
for dataset in datasets:
    mdl_tps = []
    grok_ps = []
    correlations = []
    for run in runs:
        if run != "run_0":
            mdl_tp = mdl_analysis[dataset][run]["mdl_transition_point"]
            grok_p = mdl_analysis[dataset][run]["grokking_point"]
            correlation = mdl_analysis[dataset][run]["correlation"]
            if mdl_tp is not None:
                mdl_tps.append(mdl_tp)
            if grok_p is not None:
                grok_ps.append(grok_p)
            if correlation is not None:
                correlations.append(correlation)
    
    avg_mdl_tp = np.mean(mdl_tps) if mdl_tps else None
    avg_grok_p = np.mean(grok_ps) if grok_ps else None
    avg_correlation = np.mean(correlations) if correlations else None
    
    print(f"\nDataset: {dataset}")
    print(f"  Average MDL Transition Point: {avg_mdl_tp:.2f if avg_mdl_tp is not None else 'N/A'}")
    print(f"  Average Grokking Point: {avg_grok_p:.2f if avg_grok_p is not None else 'N/A'}")
    if avg_mdl_tp is not None and avg_grok_p is not None:
        print(f"  Difference: {abs(avg_mdl_tp - avg_grok_p):.2f}")
    else:
        print("  Difference: N/A")
    print(f"  Average Correlation: {avg_correlation:.4f if avg_correlation is not None else 'N/A'}")

    # Add these lines for debugging
    print(f"  MDL Transition Points: {mdl_tps}")
    print(f"  Grokking Points: {grok_ps}")
    print(f"  Correlations: {correlations}")