AI-Scientist / ai_scientist /.ipynb_checkpoints /perform_experiments-checkpoint.py
pradachan's picture
Upload folder using huggingface_hub
f71c233 verified
raw
history blame
6.15 kB
import shutil
import os.path as osp
import subprocess
from subprocess import TimeoutExpired
import sys
import json
MAX_ITERS = 4
MAX_RUNS = 5
MAX_STDERR_OUTPUT = 1500
coder_prompt = """Your goal is to implement the following idea: {title}.
The proposed experiment is as follows: {idea}.
You are given a total of up to {max_runs} runs to complete the necessary experiments. You do not need to use all {max_runs}.
First, plan the list of experiments you would like to run. For example, if you are sweeping over a specific hyperparameter, plan each value you would like to test for each run.
Note that we already provide the vanilla baseline results, so you do not need to re-run it.
For reference, the baseline results are as follows:
{baseline_results}
After you complete each change, we will run the command `python experiment.py --out_dir=run_i' where i is the run number and evaluate the results.
YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
You can then implement the next thing on your list."""
# RUN EXPERIMENT
def run_experiment(folder_name, run_num, timeout=7200):
cwd = osp.abspath(folder_name)
# COPY CODE SO WE CAN SEE IT.
shutil.copy(
osp.join(folder_name, "experiment.py"),
osp.join(folder_name, f"run_{run_num}.py"),
)
# LAUNCH COMMAND
command = [
"python",
"experiment.py",
f"--out_dir=run_{run_num}",
]
try:
result = subprocess.run(
command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
)
if result.stderr:
print(result.stderr, file=sys.stderr)
if result.returncode != 0:
print(f"Run {run_num} failed with return code {result.returncode}")
if osp.exists(osp.join(cwd, f"run_{run_num}")):
shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
print(f"Run failed with the following error {result.stderr}")
stderr_output = result.stderr
if len(stderr_output) > MAX_STDERR_OUTPUT:
stderr_output = "..." + stderr_output[-MAX_STDERR_OUTPUT:]
next_prompt = f"Run failed with the following error {stderr_output}"
else:
with open(osp.join(cwd, f"run_{run_num}", "final_info.json"), "r") as f:
results = json.load(f)
results = {k: v["means"] for k, v in results.items()}
next_prompt = f"""Run {run_num} completed. Here are the results:
{results}
Decide if you need to re-plan your experiments given the result (you often will not need to).
Someone else will be using `notes.txt` to perform a writeup on this in the future.
Please include *all* relevant information for the writeup on Run {run_num}, including an experiment description and the run number. Be as verbose as necessary.
Then, implement the next thing on your list.
We will then run the command `python experiment.py --out_dir=run_{run_num + 1}'.
YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
If you are finished with experiments, respond with 'ALL_COMPLETED'."""
return result.returncode, next_prompt
except TimeoutExpired:
print(f"Run {run_num} timed out after {timeout} seconds")
if osp.exists(osp.join(cwd, f"run_{run_num}")):
shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
next_prompt = f"Run timed out after {timeout} seconds"
return 1, next_prompt
# RUN PLOTTING
def run_plotting(folder_name, timeout=600):
cwd = osp.abspath(folder_name)
# LAUNCH COMMAND
command = [
"python",
"plot.py",
]
try:
result = subprocess.run(
command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
)
if result.stderr:
print(result.stderr, file=sys.stderr)
if result.returncode != 0:
print(f"Plotting failed with return code {result.returncode}")
next_prompt = f"Plotting failed with the following error {result.stderr}"
else:
next_prompt = ""
return result.returncode, next_prompt
except TimeoutExpired:
print(f"Plotting timed out after {timeout} seconds")
next_prompt = f"Plotting timed out after {timeout} seconds"
return 1, next_prompt
# PERFORM EXPERIMENTS
def perform_experiments(idea, folder_name, coder, baseline_results) -> bool:
## RUN EXPERIMENT
current_iter = 0
run = 1
next_prompt = coder_prompt.format(
title=idea["Title"],
idea=idea["Experiment"],
max_runs=MAX_RUNS,
baseline_results=baseline_results,
)
while run < MAX_RUNS + 1:
if current_iter >= MAX_ITERS:
print("Max iterations reached")
break
coder_out = coder.run(next_prompt)
print(coder_out)
if "ALL_COMPLETED" in coder_out:
break
return_code, next_prompt = run_experiment(folder_name, run)
if return_code == 0:
run += 1
current_iter = 0
current_iter += 1
if current_iter >= MAX_ITERS:
print("Not all experiments completed.")
return False
current_iter = 0
next_prompt = """
Great job! Please modify `plot.py` to generate the most relevant plots for the final writeup.
In particular, be sure to fill in the "labels" dictionary with the correct names for each run that you want to plot.
Only the runs in the `labels` dictionary will be plotted, so make sure to include all relevant runs.
We will be running the command `python plot.py` to generate the plots.
"""
while True:
coder_out = coder.run(next_prompt)
return_code, next_prompt = run_plotting(folder_name)
current_iter += 1
if return_code == 0 or current_iter >= MAX_ITERS:
break
next_prompt = """
Please modify `notes.txt` with a description of what each plot shows along with the filename of the figure. Please do so in-depth.
Somebody else will be using `notes.txt` to write a report on this in the future.
"""
coder.run(next_prompt)
return True