pradachan's picture
Upload folder using huggingface_hub
f71c233 verified
raw
history blame
5.97 kB
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import json
import os
import os.path as osp
# LOAD FINAL RESULTS:
datasets = ["shakespeare_char", "enwik8", "text8"]
folders = os.listdir("./")
final_results = {}
results_info = {}
for folder in folders:
if folder.startswith("run") and osp.isdir(folder):
with open(osp.join(folder, "final_info.json"), "r") as f:
final_results[folder] = json.load(f)
results_dict = np.load(osp.join(folder, "all_results.npy"), allow_pickle=True).item()
run_info = {}
for dataset in datasets:
run_info[dataset] = {}
val_losses = []
train_losses = []
for k in results_dict.keys():
if dataset in k and "val_info" in k:
run_info[dataset]["iters"] = [info["iter"] for info in results_dict[k]]
val_losses.append([info["val/loss"] for info in results_dict[k]])
train_losses.append([info["train/loss"] for info in results_dict[k]])
mean_val_losses = np.mean(val_losses, axis=0)
mean_train_losses = np.mean(train_losses, axis=0)
if len(val_losses) > 0:
sterr_val_losses = np.std(val_losses, axis=0) / np.sqrt(len(val_losses))
stderr_train_losses = np.std(train_losses, axis=0) / np.sqrt(len(train_losses))
else:
sterr_val_losses = np.zeros_like(mean_val_losses)
stderr_train_losses = np.zeros_like(mean_train_losses)
run_info[dataset]["val_loss"] = mean_val_losses
run_info[dataset]["train_loss"] = mean_train_losses
run_info[dataset]["val_loss_sterr"] = sterr_val_losses
run_info[dataset]["train_loss_sterr"] = stderr_train_losses
results_info[folder] = run_info
# CREATE LEGEND -- ADD RUNS HERE THAT WILL BE PLOTTED
labels = {
"run_0": "Baseline",
"run_1": "Multi-Style Adapter",
"run_2": "Fine-tuned Multi-Style Adapter",
"run_3": "Enhanced Style Consistency",
"run_4": "Style Consistency Analysis",
}
# Create a programmatic color palette
def generate_color_palette(n):
cmap = plt.get_cmap('tab20')
return [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, n)]
# Get the list of runs and generate the color palette
runs = list(labels.keys())
colors = generate_color_palette(len(runs))
# Plot 1: Line plot of training loss for each dataset across the runs with labels
for dataset in datasets:
plt.figure(figsize=(10, 6))
for i, run in enumerate(runs):
iters = results_info[run][dataset]["iters"]
mean = results_info[run][dataset]["train_loss"]
sterr = results_info[run][dataset]["train_loss_sterr"]
plt.plot(iters, mean, label=labels[run], color=colors[i])
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2)
plt.title(f"Training Loss Across Runs for {dataset} Dataset")
plt.xlabel("Iteration")
plt.ylabel("Training Loss")
plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.tight_layout()
plt.savefig(f"train_loss_{dataset}.png")
plt.close()
# Plot 2: Line plot of validation loss for each dataset across the runs with labels
for dataset in datasets:
plt.figure(figsize=(10, 6))
for i, run in enumerate(runs):
iters = results_info[run][dataset]["iters"]
mean = results_info[run][dataset]["val_loss"]
sterr = results_info[run][dataset]["val_loss_sterr"]
plt.plot(iters, mean, label=labels[run], color=colors[i])
plt.fill_between(iters, mean - sterr, mean + sterr, color=colors[i], alpha=0.2)
plt.title(f"Validation Loss Across Runs for {dataset} Dataset")
plt.xlabel("Iteration")
plt.ylabel("Validation Loss")
plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.tight_layout()
plt.savefig(f"val_loss_{dataset}.png")
plt.close()
# Plot 3: Bar plot of style consistency scores for each dataset across the runs
plt.figure(figsize=(12, 6))
x = np.arange(len(datasets))
width = 0.8 / len(runs)
for i, run in enumerate(runs):
means = []
stds = []
for dataset in datasets:
if 'style_consistency_scores' in final_results[run][dataset].get('means', {}):
means.append(final_results[run][dataset]['means']['style_consistency_scores'].get('mean_consistency', 0))
stds.append(final_results[run][dataset].get('stderrs', {}).get('style_consistency_scores', {}).get('mean_consistency', 0))
else:
means.append(0)
stds.append(0)
plt.bar(x + i*width, means, width, label=labels[run], yerr=stds, capsize=5)
plt.xlabel('Dataset')
plt.ylabel('Style Consistency Score')
plt.title('Style Consistency Scores Across Runs and Datasets')
plt.xticks(x + width*(len(runs)-1)/2, datasets)
plt.legend()
plt.tight_layout()
plt.savefig("style_consistency_scores.png")
plt.close()
# Plot 4: Bar plot of inference speed for each dataset across the runs
plt.figure(figsize=(12, 6))
x = np.arange(len(datasets))
width = 0.8 / len(runs)
for i, run in enumerate(runs):
means = []
stds = []
for dataset in datasets:
if 'avg_inference_tokens_per_second_mean' in final_results[run][dataset]['means']:
means.append(final_results[run][dataset]['means']['avg_inference_tokens_per_second_mean'])
stds.append(final_results[run][dataset]['stderrs'].get('avg_inference_tokens_per_second_mean', 0))
else:
means.append(0)
stds.append(0)
plt.bar(x + i*width, means, width, label=labels[run], yerr=stds, capsize=5)
plt.xlabel('Dataset')
plt.ylabel('Tokens per Second')
plt.title('Inference Speed Across Runs and Datasets')
plt.xticks(x + width*(len(runs)-1)/2, datasets)
plt.legend()
plt.tight_layout()
plt.savefig("inference_speed.png")
plt.close()