pradachan's picture
Upload folder using huggingface_hub
f71c233 verified
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import json
import os
import os.path as osp
import pickle
import warnings
# LOAD FINAL RESULTS:
datasets = ["circle", "dino", "line", "moons"]
folders = os.listdir("./")
final_results = {}
train_info = {}
def smooth(x, window_len=10, window='hanning'):
s = np.r_[x[window_len - 1:0:-1], x, x[-2:-window_len - 1:-1]]
if window == 'flat': # moving average
w = np.ones(window_len, 'd')
else:
w = getattr(np, window)(window_len)
y = np.convolve(w / w.sum(), s, mode='valid')
return y
for folder in folders:
if folder.startswith("run") and osp.isdir(folder):
with open(osp.join(folder, "final_info.json"), "r") as f:
final_results[folder] = json.load(f)
all_results = pickle.load(open(osp.join(folder, "all_results.pkl"), "rb"))
train_info[folder] = all_results
# CREATE LEGEND -- PLEASE FILL IN YOUR RUN NAMES HERE
# Keep the names short, as these will be in the legend.
labels = {
"run_0": "Baseline",
"run_1": "Dual-Expert",
"run_2": "Enhanced Gating",
"run_3": "Increased Capacity",
"run_4": "Diversity Loss",
"run_5": "Adjusted Diversity",
}
# Use the run key as the default label if not specified
runs = list(final_results.keys())
for run in runs:
if run not in labels:
labels[run] = run
# CREATE PLOTS
# Create a programmatic color palette
def generate_color_palette(n):
cmap = plt.get_cmap('tab20') # You can change 'tab20' to other colormaps like 'Set1', 'Set2', 'Set3', etc.
return [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, n)]
# Get the list of runs and generate the color palette
runs = list(final_results.keys())
colors = generate_color_palette(len(runs))
# Plot 1: KL Divergence comparison across runs
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(datasets))
width = 0.15
multiplier = 0
for run, label in labels.items():
kl_values = []
for dataset in datasets:
kl_value = final_results[run][dataset].get('means', {}).get('kl_divergence', 0)
if kl_value == 0:
warnings.warn(f"KL divergence value missing for {run} on {dataset} dataset.")
kl_values.append(kl_value)
offset = width * multiplier
rects = ax.bar(x + offset, kl_values, width, label=label)
ax.bar_label(rects, padding=3, rotation=90, fmt='%.3f')
multiplier += 1
ax.set_ylabel('KL Divergence')
ax.set_title('KL Divergence Comparison Across Runs')
ax.set_xticks(x + width * (len(labels) - 1) / 2)
ax.set_xticklabels(datasets)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
max_kl = max([max([final_results[run][dataset].get('means', {}).get('kl_divergence', 0) for dataset in datasets]) for run in labels])
if max_kl > 0:
ax.set_ylim(0, max_kl * 1.2)
else:
ax.set_ylim(0, 1) # Set a default y-axis limit if all KL divergence values are 0 or missing
plt.tight_layout()
plt.savefig("kl_divergence_comparison.png")
plt.show()
# Plot 2: Generated samples comparison (focus on 'dino' dataset)
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle("Generated Samples for 'dino' Dataset", fontsize=16)
for i, (run, label) in enumerate(labels.items()):
row = i // 3
col = i % 3
images = train_info[run]['dino']["images"]
gating_weights = train_info[run]['dino'].get("gating_weights")
scatter = axs[row, col].scatter(images[:, 0], images[:, 1], c=gating_weights, cmap='coolwarm', alpha=0.5, vmin=0, vmax=1)
axs[row, col].set_title(label)
fig.colorbar(scatter, ax=axs[row, col], label='Gating Weight')
plt.tight_layout()
plt.savefig("dino_generated_samples.png")
plt.show()
# Plot 3: Training loss comparison (focus on 'dino' dataset)
fig, ax = plt.subplots(figsize=(12, 6))
for run, label in labels.items():
mean = train_info[run]['dino']["train_losses"]
mean = smooth(mean, window_len=25)
ax.plot(mean, label=label)
ax.set_title("Training Loss for 'dino' Dataset")
ax.set_xlabel("Training Step")
ax.set_ylabel("Loss")
ax.legend()
plt.tight_layout()
plt.savefig("dino_train_loss.png")
plt.show()
# Plot 4: Gating weights histogram comparison (focus on 'dino' dataset)
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle("Gating Weights Histogram for 'dino' Dataset", fontsize=16)
for i, (run, label) in enumerate(labels.items()):
row = i // 3
col = i % 3
gating_weights = train_info[run]['dino'].get("gating_weights")
if gating_weights is not None:
axs[row, col].hist(gating_weights, bins=50, range=(0, 1))
axs[row, col].set_title(label)
axs[row, col].set_xlabel("Gating Weight")
axs[row, col].set_ylabel("Frequency")
plt.tight_layout()
plt.savefig("dino_gating_weights_histogram.png")
plt.show()