|
import matplotlib.pyplot as plt |
|
import matplotlib.colors as mcolors |
|
import numpy as np |
|
import json |
|
import os |
|
import os.path as osp |
|
from scipy.signal import savgol_filter |
|
|
|
|
|
datasets = ["x_div_y", "x_minus_y", "x_plus_y", "permutation"] |
|
folders = os.listdir("./") |
|
final_results = {} |
|
results_info = {} |
|
for folder in folders: |
|
if folder.startswith("run") and osp.isdir(folder): |
|
with open(osp.join(folder, "final_info.json"), "r") as f: |
|
final_results[folder] = json.load(f) |
|
results_dict = np.load( |
|
osp.join(folder, "all_results.npy"), allow_pickle=True |
|
).item() |
|
print(results_dict.keys()) |
|
run_info = {} |
|
for dataset in datasets: |
|
run_info[dataset] = {} |
|
val_losses = [] |
|
train_losses = [] |
|
val_accs = [] |
|
train_accs = [] |
|
for k in results_dict.keys(): |
|
if dataset in k and "val_info" in k: |
|
run_info[dataset]["step"] = [ |
|
info["step"] for info in results_dict[k] |
|
] |
|
val_losses.append([info["val_loss"] for info in results_dict[k]]) |
|
val_accs.append([info["val_accuracy"] for info in results_dict[k]]) |
|
if dataset in k and "train_info" in k: |
|
train_losses.append( |
|
[info["train_loss"] for info in results_dict[k]] |
|
) |
|
train_accs.append( |
|
[info["train_accuracy"] for info in results_dict[k]] |
|
) |
|
mean_val_losses = np.mean(val_losses, axis=0) |
|
mean_train_losses = np.mean(train_losses, axis=0) |
|
mean_val_accs = np.mean(val_accs, axis=0) |
|
mean_train_accs = np.mean(train_accs, axis=0) |
|
if len(val_losses) > 0: |
|
sterr_val_losses = np.std(val_losses, axis=0) / np.sqrt( |
|
len(val_losses) |
|
) |
|
stderr_train_losses = np.std(train_losses, axis=0) / np.sqrt( |
|
len(train_losses) |
|
) |
|
sterr_val_accs = np.std(val_accs, axis=0) / np.sqrt(len(val_accs)) |
|
stderr_train_accs = np.std(train_accs, axis=0) / np.sqrt( |
|
len(train_accs) |
|
) |
|
else: |
|
sterr_val_losses = np.zeros_like(mean_val_losses) |
|
stderr_train_losses = np.zeros_like(mean_train_losses) |
|
sterr_val_accs = np.zeros_like(mean_val_accs) |
|
stderr_train_accs = np.zeros_like(mean_train_accs) |
|
run_info[dataset]["val_loss"] = mean_val_losses |
|
run_info[dataset]["train_loss"] = mean_train_losses |
|
run_info[dataset]["val_loss_sterr"] = sterr_val_losses |
|
run_info[dataset]["train_loss_sterr"] = stderr_train_losses |
|
run_info[dataset]["val_acc"] = mean_val_accs |
|
run_info[dataset]["train_acc"] = mean_train_accs |
|
run_info[dataset]["val_acc_sterr"] = sterr_val_accs |
|
run_info[dataset]["train_acc_sterr"] = stderr_train_accs |
|
|
|
|
|
mdl_data = [info for k, info in results_dict.items() if dataset in k and "mdl_info" in k] |
|
if mdl_data: |
|
run_info[dataset]["mdl_step"] = [item["step"] for item in mdl_data[0]] |
|
run_info[dataset]["mdl"] = [item["mdl"] for item in mdl_data[0]] |
|
|
|
results_info[folder] = run_info |
|
|
|
|
|
labels = { |
|
"run_0": "Baseline", |
|
"run_1": "MDL Tracking", |
|
"run_2": "MDL Analysis", |
|
"run_3": "Extended Analysis", |
|
"run_4": "Comprehensive Analysis", |
|
} |
|
|
|
|
|
|
|
def generate_color_palette(n): |
|
cmap = plt.get_cmap("tab20") |
|
return [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, n)] |
|
|
|
|
|
|
|
runs = list(labels.keys()) |
|
colors = generate_color_palette(len(runs)) |
|
|
|
|
|
for dataset in datasets: |
|
plt.figure(figsize=(10, 6)) |
|
for i, run in enumerate(runs): |
|
iters = results_info[run][dataset]["step"] |
|
mean = results_info[run][dataset]["train_loss"] |
|
sterr = results_info[run][dataset]["train_loss_sterr"] |
|
plt.plot(iters, mean, label=labels[run], color=colors[i]) |
|
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2) |
|
|
|
plt.title(f"Training Loss Across Runs for {dataset} Dataset") |
|
plt.xlabel("Update Steps") |
|
plt.ylabel("Training Loss") |
|
plt.legend() |
|
plt.grid(True, which="both", ls="-", alpha=0.2) |
|
plt.tight_layout() |
|
plt.savefig(f"train_loss_{dataset}.png") |
|
plt.close() |
|
|
|
|
|
for dataset in datasets: |
|
plt.figure(figsize=(10, 6)) |
|
for i, run in enumerate(runs): |
|
iters = results_info[run][dataset]["step"] |
|
mean = results_info[run][dataset]["val_loss"] |
|
sterr = results_info[run][dataset]["val_loss_sterr"] |
|
plt.plot(iters, mean, label=labels[run], color=colors[i]) |
|
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2) |
|
|
|
plt.title(f"Validation Loss Across Runs for {dataset} Dataset") |
|
plt.xlabel("Update Steps") |
|
plt.ylabel("Validation Loss") |
|
plt.legend() |
|
plt.grid(True, which="both", ls="-", alpha=0.2) |
|
plt.tight_layout() |
|
plt.savefig(f"val_loss_{dataset}.png") |
|
plt.close() |
|
|
|
|
|
|
|
for dataset in datasets: |
|
plt.figure(figsize=(10, 6)) |
|
for i, run in enumerate(runs): |
|
iters = results_info[run][dataset]["step"] |
|
mean = results_info[run][dataset]["train_acc"] |
|
sterr = results_info[run][dataset]["train_acc_sterr"] |
|
plt.plot(iters, mean, label=labels[run], color=colors[i]) |
|
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2) |
|
|
|
plt.title(f"Training Accuracy Across Runs for {dataset} Dataset") |
|
plt.xlabel("Update Steps") |
|
plt.ylabel("Training Acc") |
|
plt.legend() |
|
plt.grid(True, which="both", ls="-", alpha=0.2) |
|
plt.tight_layout() |
|
plt.savefig(f"train_acc_{dataset}.png") |
|
plt.close() |
|
|
|
|
|
for dataset in datasets: |
|
plt.figure(figsize=(10, 6)) |
|
for i, run in enumerate(runs): |
|
iters = results_info[run][dataset]["step"] |
|
mean = results_info[run][dataset]["val_acc"] |
|
sterr = results_info[run][dataset]["val_acc_sterr"] |
|
plt.plot(iters, mean, label=labels[run], color=colors[i]) |
|
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2) |
|
|
|
plt.title(f"Validation Loss Across Runs for {dataset} Dataset") |
|
plt.xlabel("Update Steps") |
|
plt.ylabel("Validation Acc") |
|
plt.legend() |
|
plt.grid(True, which="both", ls="-", alpha=0.2) |
|
plt.tight_layout() |
|
plt.savefig(f"val_acc_{dataset}.png") |
|
plt.close() |
|
|
|
|
|
for dataset in datasets: |
|
plt.figure(figsize=(10, 6)) |
|
for i, run in enumerate(runs): |
|
if run != "run_0": |
|
iters = results_info[run][dataset]["step"] |
|
val_acc = results_info[run][dataset]["val_acc"] |
|
mdl_step = results_info[run][dataset]["mdl_step"] |
|
mdl = results_info[run][dataset]["mdl"] |
|
|
|
|
|
mdl_normalized = (mdl - np.min(mdl)) / (np.max(mdl) - np.min(mdl)) |
|
|
|
|
|
mdl_smooth = savgol_filter(mdl_normalized, window_length=5, polyorder=2) |
|
|
|
plt.plot(iters, val_acc, label=f"{labels[run]} - Val Acc", color=colors[i]) |
|
plt.plot(mdl_step, mdl_smooth, label=f"{labels[run]} - MDL", linestyle='--', color=colors[i]) |
|
|
|
plt.title(f"Validation Accuracy and MDL for {dataset} Dataset") |
|
plt.xlabel("Update Steps") |
|
plt.ylabel("Validation Accuracy / Normalized MDL") |
|
plt.legend() |
|
plt.grid(True, which="both", ls="-", alpha=0.2) |
|
plt.tight_layout() |
|
plt.savefig(f"val_acc_mdl_{dataset}.png") |
|
plt.close() |
|
|
|
|
|
mdl_analysis = {} |
|
for dataset in datasets: |
|
mdl_analysis[dataset] = {} |
|
for run in runs: |
|
if run != "run_0": |
|
mdl = results_info[run][dataset]["mdl"] |
|
mdl_step = results_info[run][dataset]["mdl_step"] |
|
val_acc = results_info[run][dataset]["val_acc"] |
|
train_acc = results_info[run][dataset]["train_acc"] |
|
|
|
|
|
mdl_diff = np.diff(mdl) |
|
mdl_transition_idx = np.argmin(mdl_diff) |
|
mdl_transition_point = mdl_step[mdl_transition_idx] |
|
|
|
|
|
grokking_point = next((step for step, acc in zip(results_info[run][dataset]["step"], val_acc) if acc >= 0.95), None) |
|
|
|
|
|
mdl_normalized = (mdl - np.min(mdl)) / (np.max(mdl) - np.min(mdl)) |
|
val_acc_interp = np.interp(mdl_step, results_info[run][dataset]["step"], val_acc) |
|
correlation = np.corrcoef(mdl_normalized, val_acc_interp)[0, 1] |
|
|
|
|
|
train_acc_interp = np.interp(mdl_step, results_info[run][dataset]["step"], train_acc) |
|
gen_gap = train_acc_interp - val_acc_interp |
|
|
|
mdl_analysis[dataset][run] = { |
|
"mdl_transition_point": mdl_transition_point, |
|
"grokking_point": grokking_point, |
|
"correlation": correlation, |
|
"mdl": mdl, |
|
"mdl_step": mdl_step, |
|
"val_acc": val_acc_interp, |
|
"gen_gap": gen_gap |
|
} |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
for dataset in datasets: |
|
for run in runs: |
|
if run != "run_0": |
|
mdl_tp = mdl_analysis[dataset][run]["mdl_transition_point"] |
|
grok_p = mdl_analysis[dataset][run]["grokking_point"] |
|
plt.scatter(mdl_tp, grok_p, label=f"{dataset} - {run}") |
|
|
|
plt.plot([0, max(plt.xlim())], [0, max(plt.xlim())], 'k--', alpha=0.5) |
|
plt.xlabel("MDL Transition Point") |
|
plt.ylabel("Grokking Point") |
|
plt.title("MDL Transition Point vs Grokking Point") |
|
plt.legend() |
|
plt.tight_layout() |
|
plt.savefig("mdl_transition_vs_grokking.png") |
|
plt.close() |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
for dataset in datasets: |
|
correlations = [mdl_analysis[dataset][run]["correlation"] for run in runs if run != "run_0"] |
|
plt.bar(dataset, np.mean(correlations), yerr=np.std(correlations), capsize=5) |
|
|
|
plt.xlabel("Dataset") |
|
plt.ylabel("Correlation") |
|
plt.title("Correlation between MDL Reduction and Val Acc Improvement") |
|
plt.tight_layout() |
|
plt.savefig("mdl_val_acc_correlation.png") |
|
plt.close() |
|
|
|
|
|
for dataset in datasets: |
|
plt.figure(figsize=(12, 8)) |
|
for run in runs: |
|
if run != "run_0": |
|
mdl_step = mdl_analysis[dataset][run]["mdl_step"] |
|
mdl = mdl_analysis[dataset][run]["mdl"] |
|
gen_gap = mdl_analysis[dataset][run]["gen_gap"] |
|
|
|
plt.subplot(2, 1, 1) |
|
plt.plot(mdl_step, mdl, label=f"{run} - MDL") |
|
plt.title(f"MDL Evolution and Generalization Gap - {dataset}") |
|
plt.ylabel("MDL") |
|
plt.legend() |
|
|
|
plt.subplot(2, 1, 2) |
|
plt.plot(mdl_step, gen_gap, label=f"{run} - Gen Gap") |
|
plt.xlabel("Steps") |
|
plt.ylabel("Generalization Gap") |
|
plt.legend() |
|
|
|
plt.tight_layout() |
|
plt.savefig(f"mdl_gen_gap_{dataset}.png") |
|
plt.close() |
|
|
|
|
|
for dataset in datasets: |
|
plt.figure(figsize=(10, 6)) |
|
for run in runs: |
|
if run != "run_0": |
|
mdl_step = mdl_analysis[dataset][run]["mdl_step"] |
|
mdl = mdl_analysis[dataset][run]["mdl"] |
|
mdl_rate = np.gradient(mdl, mdl_step) |
|
plt.plot(mdl_step, mdl_rate, label=f"{run} - MDL Rate") |
|
plt.title(f"MDL Transition Rate - {dataset}") |
|
plt.xlabel("Steps") |
|
plt.ylabel("MDL Rate of Change") |
|
plt.legend() |
|
plt.tight_layout() |
|
plt.savefig(f"mdl_transition_rate_{dataset}.png") |
|
plt.close() |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
for dataset in datasets: |
|
for run in runs: |
|
if run != "run_0": |
|
mdl_tp = mdl_analysis[dataset][run]["mdl_transition_point"] |
|
grok_p = mdl_analysis[dataset][run]["grokking_point"] |
|
if mdl_tp is not None and grok_p is not None: |
|
plt.scatter(mdl_tp, grok_p, label=f"{dataset} - {run}") |
|
if plt.gca().get_xlim()[1] > 0 and plt.gca().get_ylim()[1] > 0: |
|
plt.plot([0, max(plt.xlim())], [0, max(plt.ylim())], 'k--', alpha=0.5) |
|
plt.xlabel("MDL Transition Point") |
|
plt.ylabel("Grokking Point") |
|
plt.title("MDL Transition Point vs Grokking Point") |
|
plt.legend() |
|
plt.tight_layout() |
|
plt.savefig("mdl_transition_vs_grokking_scatter.png") |
|
plt.close() |
|
|
|
|
|
for dataset in datasets: |
|
print(f"Dataset: {dataset}") |
|
for run in runs: |
|
if run != "run_0": |
|
analysis = mdl_analysis[dataset][run] |
|
print(f" Run: {run}") |
|
print(f" MDL Transition Point: {analysis['mdl_transition_point']}") |
|
print(f" Grokking Point: {analysis['grokking_point']}") |
|
print(f" Correlation: {analysis['correlation']:.4f}") |
|
print() |
|
|
|
|
|
for dataset in datasets: |
|
mdl_tps = [] |
|
grok_ps = [] |
|
correlations = [] |
|
for run in runs: |
|
if run != "run_0": |
|
mdl_tps.append(mdl_analysis[dataset][run]["mdl_transition_point"]) |
|
grok_ps.append(mdl_analysis[dataset][run]["grokking_point"]) |
|
correlations.append(mdl_analysis[dataset][run]["correlation"]) |
|
avg_mdl_tp = np.mean(mdl_tps) if mdl_tps else None |
|
avg_grok_p = np.mean(grok_ps) if grok_ps else None |
|
avg_correlation = np.mean(correlations) if correlations else None |
|
print(f"Dataset: {dataset}") |
|
print(f" Average MDL Transition Point: {avg_mdl_tp:.2f if avg_mdl_tp is not None else 'N/A'}") |
|
print(f" Average Grokking Point: {avg_grok_p:.2f if avg_grok_p is not None else 'N/A'}") |
|
if avg_mdl_tp is not None and avg_grok_p is not None: |
|
print(f" Difference: {abs(avg_mdl_tp - avg_grok_p):.2f}") |
|
else: |
|
print(" Difference: N/A") |
|
print(f" Average Correlation: {avg_correlation:.4f if avg_correlation is not None else 'N/A'}") |
|
|
|
|
|
print(f" MDL Transition Points: {mdl_tps}") |
|
print(f" Grokking Points: {grok_ps}") |
|
print(f" Correlations: {correlations}") |
|
print() |
|
|
|
|
|
try: |
|
plt.figure(figsize=(12, 8)) |
|
for dataset in datasets: |
|
for run in runs: |
|
if run != "run_0": |
|
analysis = mdl_analysis[dataset][run] |
|
mdl_transition_rate = np.min(np.gradient(analysis['mdl'], analysis['mdl_step'])) |
|
if analysis['grokking_point'] is not None and analysis['mdl_transition_point'] is not None: |
|
if analysis['grokking_point'] != analysis['mdl_transition_point']: |
|
grokking_speed = 1 / (analysis['grokking_point'] - analysis['mdl_transition_point']) |
|
else: |
|
grokking_speed = np.inf |
|
plt.scatter(mdl_transition_rate, grokking_speed, label=f"{dataset} - {labels[run]}", alpha=0.7) |
|
|
|
plt.xlabel("MDL Transition Rate") |
|
plt.ylabel("Grokking Speed") |
|
plt.title("MDL Transition Rate vs Grokking Speed") |
|
plt.legend() |
|
plt.xscale('symlog') |
|
plt.yscale('symlog') |
|
plt.grid(True, which="both", ls="-", alpha=0.2) |
|
plt.tight_layout() |
|
plt.savefig("mdl_transition_rate_vs_grokking_speed.png") |
|
plt.close() |
|
except Exception as e: |
|
print(f"Error plotting MDL Transition Rate vs Grokking Speed: {e}") |
|
|
|
|
|
for dataset in datasets: |
|
plt.figure(figsize=(15, 10)) |
|
for run in runs: |
|
if run != "run_0": |
|
analysis = mdl_analysis[dataset][run] |
|
mdl_step = analysis['mdl_step'] |
|
mdl = analysis['mdl'] |
|
val_acc = analysis['val_acc'] |
|
|
|
plt.plot(mdl_step, mdl, label=f'{labels[run]} - MDL') |
|
plt.plot(mdl_step, val_acc, label=f'{labels[run]} - Val Acc') |
|
plt.axvline(x=analysis['mdl_transition_point'], color='r', linestyle='--', label='MDL Transition') |
|
plt.axvline(x=analysis['grokking_point'], color='g', linestyle='--', label='Grokking Point') |
|
|
|
plt.title(f"MDL Evolution and Validation Accuracy - {dataset}") |
|
plt.xlabel("Steps") |
|
plt.ylabel("MDL / Validation Accuracy") |
|
plt.legend() |
|
plt.grid(True, which="both", ls="-", alpha=0.2) |
|
plt.tight_layout() |
|
plt.savefig(f"mdl_val_acc_evolution_{dataset}.png") |
|
plt.close() |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
for dataset in datasets: |
|
correlations = [] |
|
for run in runs: |
|
if run != "run_0": |
|
correlations.append(mdl_analysis[dataset][run]["correlation"]) |
|
plt.bar(dataset, np.mean(correlations), yerr=np.std(correlations), capsize=5) |
|
|
|
plt.xlabel("Dataset") |
|
plt.ylabel("Correlation") |
|
plt.title("Correlation between MDL Reduction and Validation Accuracy Improvement") |
|
plt.tight_layout() |
|
plt.savefig("mdl_val_acc_correlation.png") |
|
plt.close() |
|
|
|
|
|
print("\nAnalysis Results:") |
|
for dataset in datasets: |
|
print(f"\nDataset: {dataset}") |
|
for run in runs: |
|
if run != "run_0": |
|
analysis = mdl_analysis[dataset][run] |
|
print(f" Run: {labels[run]}") |
|
print(f" MDL Transition Point: {analysis['mdl_transition_point']}") |
|
print(f" Grokking Point: {analysis['grokking_point']}") |
|
print(f" Correlation: {analysis['correlation']:.4f}") |
|
|
|
|
|
print("\nAverage MDL Transition Point and Grokking Point:") |
|
for dataset in datasets: |
|
mdl_tps = [] |
|
grok_ps = [] |
|
correlations = [] |
|
for run in runs: |
|
if run != "run_0": |
|
mdl_tp = mdl_analysis[dataset][run]["mdl_transition_point"] |
|
grok_p = mdl_analysis[dataset][run]["grokking_point"] |
|
correlation = mdl_analysis[dataset][run]["correlation"] |
|
if mdl_tp is not None: |
|
mdl_tps.append(mdl_tp) |
|
if grok_p is not None: |
|
grok_ps.append(grok_p) |
|
if correlation is not None: |
|
correlations.append(correlation) |
|
|
|
avg_mdl_tp = np.mean(mdl_tps) if mdl_tps else None |
|
avg_grok_p = np.mean(grok_ps) if grok_ps else None |
|
avg_correlation = np.mean(correlations) if correlations else None |
|
|
|
print(f"\nDataset: {dataset}") |
|
print(f" Average MDL Transition Point: {avg_mdl_tp:.2f if avg_mdl_tp is not None else 'N/A'}") |
|
print(f" Average Grokking Point: {avg_grok_p:.2f if avg_grok_p is not None else 'N/A'}") |
|
if avg_mdl_tp is not None and avg_grok_p is not None: |
|
print(f" Difference: {abs(avg_mdl_tp - avg_grok_p):.2f}") |
|
else: |
|
print(" Difference: N/A") |
|
print(f" Average Correlation: {avg_correlation:.4f if avg_correlation is not None else 'N/A'}") |
|
|
|
|
|
print(f" MDL Transition Points: {mdl_tps}") |
|
print(f" Grokking Points: {grok_ps}") |
|
print(f" Correlations: {correlations}") |
|
|