pradachan's picture
Upload folder using huggingface_hub
f71c233 verified
raw
history blame
3.04 kB
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import json
import os
import os.path as osp
import pickle
# LOAD FINAL RESULTS:
datasets = ["circle", "dino", "line", "moons"]
folders = os.listdir("./")
final_results = {}
train_info = {}
def smooth(x, window_len=10, window='hanning'):
s = np.r_[x[window_len - 1:0:-1], x, x[-2:-window_len - 1:-1]]
if window == 'flat': # moving average
w = np.ones(window_len, 'd')
else:
w = getattr(np, window)(window_len)
y = np.convolve(w / w.sum(), s, mode='valid')
return y
for folder in folders:
if folder.startswith("run") and osp.isdir(folder):
with open(osp.join(folder, "final_info.json"), "r") as f:
final_results[folder] = json.load(f)
all_results = pickle.load(open(osp.join(folder, "all_results.pkl"), "rb"))
train_info[folder] = all_results
# CREATE LEGEND -- PLEASE FILL IN YOUR RUN NAMES HERE
# Keep the names short, as these will be in the legend.
labels = {
"run_0": "Baseline",
}
# Use the run key as the default label if not specified
runs = list(final_results.keys())
for run in runs:
if run not in labels:
labels[run] = run
# CREATE PLOTS
# Create a programmatic color palette
def generate_color_palette(n):
cmap = plt.get_cmap('tab20') # You can change 'tab20' to other colormaps like 'Set1', 'Set2', 'Set3', etc.
return [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, n)]
# Get the list of runs and generate the color palette
runs = list(final_results.keys())
colors = generate_color_palette(len(runs))
# Plot 1: Line plot of training loss for each dataset across the runs with labels
fig, axs = plt.subplots(2, 2, figsize=(14, 8), sharex=True)
for j, dataset in enumerate(datasets):
row = j // 2
col = j % 2
for i, run in enumerate(runs):
mean = train_info[run][dataset]["train_losses"]
mean = smooth(mean, window_len=25)
axs[row, col].plot(mean, label=labels[run], color=colors[i])
axs[row, col].set_title(dataset)
axs[row, col].legend()
axs[row, col].set_xlabel("Training Step")
axs[row, col].set_ylabel("Loss")
plt.tight_layout()
plt.savefig("train_loss.png")
plt.show()
# Plot 2: Visualize generated samples
# If there is more than 1 run, these are added as extra rows.
num_runs = len(runs)
fig, axs = plt.subplots(num_runs, 4, figsize=(14, 3 * num_runs))
for i, run in enumerate(runs):
for j, dataset in enumerate(datasets):
images = train_info[run][dataset]["images"]
if num_runs == 1:
axs[j].scatter(images[:, 0], images[:, 1], alpha=0.2, color=colors[i])
axs[j].set_title(dataset)
else:
axs[i, j].scatter(images[:, 0], images[:, 1], alpha=0.2, color=colors[i])
axs[i, j].set_title(dataset)
if num_runs == 1:
axs[0].set_ylabel(labels[run])
else:
axs[i, 0].set_ylabel(labels[run])
plt.tight_layout()
plt.savefig("generated_images.png")
plt.show()