Spaces:
Runtime error
Runtime error
import glob | |
import os | |
import matplotlib.pyplot as plt | |
COLORS = ["green", "blue", "orange", "black", "purple", "gray", "gold", "red", "gold", "yellow"] | |
color_i = 0 | |
if __name__ == "__main__": | |
data = {} | |
color = {} | |
filenames = glob.glob("./eval_results/*") | |
for filename in filenames: | |
if "ig" in filename or "0713" in filename or "0819" in filename: | |
continue | |
items = filename.split("/")[-1].split("_") | |
if len(items) < 5: | |
continue | |
task = items[0] | |
if task == "ok": | |
task = "okvqa" | |
exp = "_".join(items[2:-3]) | |
else: | |
exp = "_".join(items[1:-3]) | |
if "fix" not in exp: | |
step = int(items[-3]) | |
if "13" in exp: | |
step //= 2 | |
score = float(items[-1]) | |
if task not in data: | |
data[task] = {} | |
if exp not in data[task]: | |
data[task][exp] = [] | |
data[task][exp].append([step, score]) | |
if exp not in color: | |
color[exp] = COLORS[color_i] | |
color_i += 1 | |
for task in data: | |
for exp in data[task]: | |
data[task][exp] = sorted(data[task][exp], key=lambda x: x[0]) | |
for task in data: | |
plt.figure() | |
plt.title(f"{task} evaluation") | |
for exp in data[task]: | |
steps = [x[0] for x in data[task][exp]] | |
scores = [x[1] for x in data[task][exp]] | |
plt.plot(steps, scores, "-o", color=color[exp], label=exp) | |
plt.grid() | |
plt.legend() | |
plt.xlabel("step") | |
plt.xlim(0, 15000) | |
plt.savefig(f"eval_results/{task}.jpg") | |