|
import shutil |
|
import os.path as osp |
|
import subprocess |
|
from subprocess import TimeoutExpired |
|
import sys |
|
import json |
|
|
|
MAX_ITERS = 4 |
|
MAX_RUNS = 5 |
|
MAX_STDERR_OUTPUT = 1500 |
|
|
|
coder_prompt = """Your goal is to implement the following idea: {title}. |
|
The proposed experiment is as follows: {idea}. |
|
You are given a total of up to {max_runs} runs to complete the necessary experiments. You do not need to use all {max_runs}. |
|
|
|
First, plan the list of experiments you would like to run. For example, if you are sweeping over a specific hyperparameter, plan each value you would like to test for each run. |
|
|
|
Note that we already provide the vanilla baseline results, so you do not need to re-run it. |
|
|
|
For reference, the baseline results are as follows: |
|
|
|
{baseline_results} |
|
|
|
After you complete each change, we will run the command `python experiment.py --out_dir=run_i' where i is the run number and evaluate the results. |
|
YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS. |
|
You can then implement the next thing on your list.""" |
|
|
|
|
|
|
|
def run_experiment(folder_name, run_num, timeout=7200): |
|
cwd = osp.abspath(folder_name) |
|
|
|
shutil.copy( |
|
osp.join(folder_name, "experiment.py"), |
|
osp.join(folder_name, f"run_{run_num}.py"), |
|
) |
|
|
|
|
|
command = [ |
|
"python", |
|
"experiment.py", |
|
f"--out_dir=run_{run_num}", |
|
] |
|
try: |
|
result = subprocess.run( |
|
command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout |
|
) |
|
|
|
if result.stderr: |
|
print(result.stderr, file=sys.stderr) |
|
|
|
if result.returncode != 0: |
|
print(f"Run {run_num} failed with return code {result.returncode}") |
|
if osp.exists(osp.join(cwd, f"run_{run_num}")): |
|
shutil.rmtree(osp.join(cwd, f"run_{run_num}")) |
|
print(f"Run failed with the following error {result.stderr}") |
|
stderr_output = result.stderr |
|
if len(stderr_output) > MAX_STDERR_OUTPUT: |
|
stderr_output = "..." + stderr_output[-MAX_STDERR_OUTPUT:] |
|
next_prompt = f"Run failed with the following error {stderr_output}" |
|
else: |
|
with open(osp.join(cwd, f"run_{run_num}", "final_info.json"), "r") as f: |
|
results = json.load(f) |
|
results = {k: v["means"] for k, v in results.items()} |
|
|
|
next_prompt = f"""Run {run_num} completed. Here are the results: |
|
{results} |
|
|
|
Decide if you need to re-plan your experiments given the result (you often will not need to). |
|
|
|
Someone else will be using `notes.txt` to perform a writeup on this in the future. |
|
Please include *all* relevant information for the writeup on Run {run_num}, including an experiment description and the run number. Be as verbose as necessary. |
|
|
|
Then, implement the next thing on your list. |
|
We will then run the command `python experiment.py --out_dir=run_{run_num + 1}'. |
|
YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS. |
|
If you are finished with experiments, respond with 'ALL_COMPLETED'.""" |
|
return result.returncode, next_prompt |
|
except TimeoutExpired: |
|
print(f"Run {run_num} timed out after {timeout} seconds") |
|
if osp.exists(osp.join(cwd, f"run_{run_num}")): |
|
shutil.rmtree(osp.join(cwd, f"run_{run_num}")) |
|
next_prompt = f"Run timed out after {timeout} seconds" |
|
return 1, next_prompt |
|
|
|
|
|
|
|
def run_plotting(folder_name, timeout=600): |
|
cwd = osp.abspath(folder_name) |
|
|
|
command = [ |
|
"python", |
|
"plot.py", |
|
] |
|
try: |
|
result = subprocess.run( |
|
command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout |
|
) |
|
|
|
if result.stderr: |
|
print(result.stderr, file=sys.stderr) |
|
|
|
if result.returncode != 0: |
|
print(f"Plotting failed with return code {result.returncode}") |
|
next_prompt = f"Plotting failed with the following error {result.stderr}" |
|
else: |
|
next_prompt = "" |
|
return result.returncode, next_prompt |
|
except TimeoutExpired: |
|
print(f"Plotting timed out after {timeout} seconds") |
|
next_prompt = f"Plotting timed out after {timeout} seconds" |
|
return 1, next_prompt |
|
|
|
|
|
|
|
def perform_experiments(idea, folder_name, coder, baseline_results) -> bool: |
|
|
|
current_iter = 0 |
|
run = 1 |
|
next_prompt = coder_prompt.format( |
|
title=idea["Title"], |
|
idea=idea["Experiment"], |
|
max_runs=MAX_RUNS, |
|
baseline_results=baseline_results, |
|
) |
|
while run < MAX_RUNS + 1: |
|
if current_iter >= MAX_ITERS: |
|
print("Max iterations reached") |
|
break |
|
coder_out = coder.run(next_prompt) |
|
print(coder_out) |
|
if "ALL_COMPLETED" in coder_out: |
|
break |
|
return_code, next_prompt = run_experiment(folder_name, run) |
|
if return_code == 0: |
|
run += 1 |
|
current_iter = 0 |
|
current_iter += 1 |
|
if current_iter >= MAX_ITERS: |
|
print("Not all experiments completed.") |
|
return False |
|
|
|
current_iter = 0 |
|
next_prompt = """ |
|
Great job! Please modify `plot.py` to generate the most relevant plots for the final writeup. |
|
|
|
In particular, be sure to fill in the "labels" dictionary with the correct names for each run that you want to plot. |
|
|
|
Only the runs in the `labels` dictionary will be plotted, so make sure to include all relevant runs. |
|
|
|
We will be running the command `python plot.py` to generate the plots. |
|
""" |
|
while True: |
|
coder_out = coder.run(next_prompt) |
|
return_code, next_prompt = run_plotting(folder_name) |
|
current_iter += 1 |
|
if return_code == 0 or current_iter >= MAX_ITERS: |
|
break |
|
next_prompt = """ |
|
Please modify `notes.txt` with a description of what each plot shows along with the filename of the figure. Please do so in-depth. |
|
|
|
Somebody else will be using `notes.txt` to write a report on this in the future. |
|
""" |
|
coder.run(next_prompt) |
|
|
|
return True |
|
|