pradachan's picture
Upload folder using huggingface_hub
f71c233 verified
raw
history blame
3.85 kB
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import json
import os
import os.path as osp
# LOAD FINAL RESULTS:
datasets = ["shakespeare_char"]
folders = os.listdir("./")
final_results = {}
results_info = {}
for folder in folders:
if folder.startswith("run") and osp.isdir(folder):
with open(osp.join(folder, "final_info.json"), "r") as f:
final_results[folder] = json.load(f)
results_dict = np.load(osp.join(folder, "all_results.npy"), allow_pickle=True).item()
run_info = {}
for dataset in datasets:
run_info[dataset] = {}
val_losses = []
train_losses = []
for k in results_dict.keys():
if dataset in k and "val_info" in k:
run_info[dataset]["iters"] = [info["iter"] for info in results_dict[k]]
val_losses.append([info["val/loss"] for info in results_dict[k]])
train_losses.append([info["train/loss"] for info in results_dict[k]])
mean_val_losses = np.mean(val_losses, axis=0)
mean_train_losses = np.mean(train_losses, axis=0)
if len(val_losses) > 0:
sterr_val_losses = np.std(val_losses, axis=0) / np.sqrt(len(val_losses))
stderr_train_losses = np.std(train_losses, axis=0) / np.sqrt(len(train_losses))
else:
sterr_val_losses = np.zeros_like(mean_val_losses)
stderr_train_losses = np.zeros_like(mean_train_losses)
run_info[dataset]["val_loss"] = mean_val_losses
run_info[dataset]["train_loss"] = mean_train_losses
run_info[dataset]["val_loss_sterr"] = sterr_val_losses
run_info[dataset]["train_loss_sterr"] = stderr_train_losses
results_info[folder] = run_info
# CREATE LEGEND -- ADD RUNS HERE THAT WILL BE PLOTTED
labels = {
"run_0": "Baselines",
}
# Create a programmatic color palette
def generate_color_palette(n):
cmap = plt.get_cmap('tab20')
return [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, n)]
# Get the list of runs and generate the color palette
runs = list(labels.keys())
colors = generate_color_palette(len(runs))
# Plot 1: Line plot of training loss for each dataset across the runs with labels
for dataset in datasets:
plt.figure(figsize=(10, 6))
for i, run in enumerate(runs):
iters = results_info[run][dataset]["iters"]
mean = results_info[run][dataset]["train_loss"]
sterr = results_info[run][dataset]["train_loss_sterr"]
plt.plot(iters, mean, label=labels[run], color=colors[i])
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2)
plt.title(f"Training Loss Across Runs for {dataset} Dataset")
plt.xlabel("Iteration")
plt.ylabel("Training Loss")
plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.tight_layout()
plt.savefig(f"train_loss_{dataset}.png")
plt.close()
# Plot 2: Line plot of validation loss for each dataset across the runs with labels
for dataset in datasets:
plt.figure(figsize=(10, 6))
for i, run in enumerate(runs):
iters = results_info[run][dataset]["iters"]
mean = results_info[run][dataset]["val_loss"]
sterr = results_info[run][dataset]["val_loss_sterr"]
plt.plot(iters, mean, label=labels[run], color=colors[i])
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2)
plt.title(f"Validation Loss Across Runs for {dataset} Dataset")
plt.xlabel("Iteration")
plt.ylabel("Validation Loss")
plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.tight_layout()
plt.savefig(f"val_loss_{dataset}.png")
plt.close()