|
import matplotlib.pyplot as plt |
|
import matplotlib.colors as mcolors |
|
import numpy as np |
|
import json |
|
import os |
|
import os.path as osp |
|
import pickle |
|
import warnings |
|
|
|
|
|
datasets = ["circle", "dino", "line", "moons"] |
|
folders = os.listdir("./") |
|
final_results = {} |
|
train_info = {} |
|
|
|
|
|
def smooth(x, window_len=10, window='hanning'): |
|
s = np.r_[x[window_len - 1:0:-1], x, x[-2:-window_len - 1:-1]] |
|
if window == 'flat': |
|
w = np.ones(window_len, 'd') |
|
else: |
|
w = getattr(np, window)(window_len) |
|
y = np.convolve(w / w.sum(), s, mode='valid') |
|
return y |
|
|
|
|
|
for folder in folders: |
|
if folder.startswith("run") and osp.isdir(folder): |
|
with open(osp.join(folder, "final_info.json"), "r") as f: |
|
final_results[folder] = json.load(f) |
|
all_results = pickle.load(open(osp.join(folder, "all_results.pkl"), "rb")) |
|
train_info[folder] = all_results |
|
|
|
|
|
|
|
labels = { |
|
"run_0": "Baseline", |
|
"run_1": "Dual-Expert", |
|
"run_2": "Enhanced Gating", |
|
"run_3": "Increased Capacity", |
|
"run_4": "Diversity Loss", |
|
"run_5": "Adjusted Diversity", |
|
} |
|
|
|
|
|
runs = list(final_results.keys()) |
|
for run in runs: |
|
if run not in labels: |
|
labels[run] = run |
|
|
|
|
|
|
|
|
|
|
|
def generate_color_palette(n): |
|
cmap = plt.get_cmap('tab20') |
|
return [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, n)] |
|
|
|
|
|
|
|
runs = list(final_results.keys()) |
|
colors = generate_color_palette(len(runs)) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 6)) |
|
|
|
x = np.arange(len(datasets)) |
|
width = 0.15 |
|
multiplier = 0 |
|
|
|
for run, label in labels.items(): |
|
kl_values = [] |
|
for dataset in datasets: |
|
kl_value = final_results[run][dataset].get('means', {}).get('kl_divergence', 0) |
|
if kl_value == 0: |
|
warnings.warn(f"KL divergence value missing for {run} on {dataset} dataset.") |
|
kl_values.append(kl_value) |
|
offset = width * multiplier |
|
rects = ax.bar(x + offset, kl_values, width, label=label) |
|
ax.bar_label(rects, padding=3, rotation=90, fmt='%.3f') |
|
multiplier += 1 |
|
|
|
ax.set_ylabel('KL Divergence') |
|
ax.set_title('KL Divergence Comparison Across Runs') |
|
ax.set_xticks(x + width * (len(labels) - 1) / 2) |
|
ax.set_xticklabels(datasets) |
|
ax.legend(loc='upper left', bbox_to_anchor=(1, 1)) |
|
max_kl = max([max([final_results[run][dataset].get('means', {}).get('kl_divergence', 0) for dataset in datasets]) for run in labels]) |
|
if max_kl > 0: |
|
ax.set_ylim(0, max_kl * 1.2) |
|
else: |
|
ax.set_ylim(0, 1) |
|
|
|
plt.tight_layout() |
|
plt.savefig("kl_divergence_comparison.png") |
|
plt.show() |
|
|
|
|
|
fig, axs = plt.subplots(2, 3, figsize=(15, 10)) |
|
fig.suptitle("Generated Samples for 'dino' Dataset", fontsize=16) |
|
|
|
for i, (run, label) in enumerate(labels.items()): |
|
row = i // 3 |
|
col = i % 3 |
|
images = train_info[run]['dino']["images"] |
|
gating_weights = train_info[run]['dino'].get("gating_weights") |
|
|
|
scatter = axs[row, col].scatter(images[:, 0], images[:, 1], c=gating_weights, cmap='coolwarm', alpha=0.5, vmin=0, vmax=1) |
|
axs[row, col].set_title(label) |
|
fig.colorbar(scatter, ax=axs[row, col], label='Gating Weight') |
|
|
|
plt.tight_layout() |
|
plt.savefig("dino_generated_samples.png") |
|
plt.show() |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 6)) |
|
|
|
for run, label in labels.items(): |
|
mean = train_info[run]['dino']["train_losses"] |
|
mean = smooth(mean, window_len=25) |
|
ax.plot(mean, label=label) |
|
|
|
ax.set_title("Training Loss for 'dino' Dataset") |
|
ax.set_xlabel("Training Step") |
|
ax.set_ylabel("Loss") |
|
ax.legend() |
|
|
|
plt.tight_layout() |
|
plt.savefig("dino_train_loss.png") |
|
plt.show() |
|
|
|
|
|
fig, axs = plt.subplots(2, 3, figsize=(15, 10)) |
|
fig.suptitle("Gating Weights Histogram for 'dino' Dataset", fontsize=16) |
|
|
|
for i, (run, label) in enumerate(labels.items()): |
|
row = i // 3 |
|
col = i % 3 |
|
gating_weights = train_info[run]['dino'].get("gating_weights") |
|
|
|
if gating_weights is not None: |
|
axs[row, col].hist(gating_weights, bins=50, range=(0, 1)) |
|
axs[row, col].set_title(label) |
|
axs[row, col].set_xlabel("Gating Weight") |
|
axs[row, col].set_ylabel("Frequency") |
|
|
|
plt.tight_layout() |
|
plt.savefig("dino_gating_weights_histogram.png") |
|
plt.show() |
|
|