|
import matplotlib.pyplot as plt |
|
import matplotlib.colors as mcolors |
|
import numpy as np |
|
import json |
|
import os |
|
import os.path as osp |
|
|
|
|
|
datasets = ["shakespeare_char"] |
|
folders = os.listdir("./") |
|
final_results = {} |
|
results_info = {} |
|
for folder in folders: |
|
if folder.startswith("run") and osp.isdir(folder): |
|
with open(osp.join(folder, "final_info.json"), "r") as f: |
|
final_results[folder] = json.load(f) |
|
results_dict = np.load(osp.join(folder, "all_results.npy"), allow_pickle=True).item() |
|
run_info = {} |
|
for dataset in datasets: |
|
run_info[dataset] = {} |
|
val_losses = [] |
|
train_losses = [] |
|
for k in results_dict.keys(): |
|
if dataset in k and "val_info" in k: |
|
run_info[dataset]["iters"] = [info["iter"] for info in results_dict[k]] |
|
val_losses.append([info["val/loss"] for info in results_dict[k]]) |
|
train_losses.append([info["train/loss"] for info in results_dict[k]]) |
|
mean_val_losses = np.mean(val_losses, axis=0) |
|
mean_train_losses = np.mean(train_losses, axis=0) |
|
if len(val_losses) > 0: |
|
sterr_val_losses = np.std(val_losses, axis=0) / np.sqrt(len(val_losses)) |
|
stderr_train_losses = np.std(train_losses, axis=0) / np.sqrt(len(train_losses)) |
|
else: |
|
sterr_val_losses = np.zeros_like(mean_val_losses) |
|
stderr_train_losses = np.zeros_like(mean_train_losses) |
|
run_info[dataset]["val_loss"] = mean_val_losses |
|
run_info[dataset]["train_loss"] = mean_train_losses |
|
run_info[dataset]["val_loss_sterr"] = sterr_val_losses |
|
run_info[dataset]["train_loss_sterr"] = stderr_train_losses |
|
results_info[folder] = run_info |
|
|
|
|
|
labels = { |
|
"run_0": "Baselines", |
|
} |
|
|
|
|
|
def generate_color_palette(n): |
|
cmap = plt.get_cmap('tab20') |
|
return [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, n)] |
|
|
|
|
|
runs = list(labels.keys()) |
|
colors = generate_color_palette(len(runs)) |
|
|
|
|
|
for dataset in datasets: |
|
plt.figure(figsize=(10, 6)) |
|
for i, run in enumerate(runs): |
|
iters = results_info[run][dataset]["iters"] |
|
mean = results_info[run][dataset]["train_loss"] |
|
sterr = results_info[run][dataset]["train_loss_sterr"] |
|
plt.plot(iters, mean, label=labels[run], color=colors[i]) |
|
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2) |
|
|
|
plt.title(f"Training Loss Across Runs for {dataset} Dataset") |
|
plt.xlabel("Iteration") |
|
plt.ylabel("Training Loss") |
|
plt.legend() |
|
plt.grid(True, which="both", ls="-", alpha=0.2) |
|
plt.tight_layout() |
|
plt.savefig(f"train_loss_{dataset}.png") |
|
plt.close() |
|
|
|
|
|
for dataset in datasets: |
|
plt.figure(figsize=(10, 6)) |
|
for i, run in enumerate(runs): |
|
iters = results_info[run][dataset]["iters"] |
|
mean = results_info[run][dataset]["val_loss"] |
|
sterr = results_info[run][dataset]["val_loss_sterr"] |
|
plt.plot(iters, mean, label=labels[run], color=colors[i]) |
|
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2) |
|
|
|
plt.title(f"Validation Loss Across Runs for {dataset} Dataset") |
|
plt.xlabel("Iteration") |
|
plt.ylabel("Validation Loss") |
|
plt.legend() |
|
plt.grid(True, which="both", ls="-", alpha=0.2) |
|
plt.tight_layout() |
|
plt.savefig(f"val_loss_{dataset}.png") |
|
plt.close() |
|
|