|
import json |
|
import os |
|
import os.path as osp |
|
import time |
|
from typing import Dict, List, Union |
|
|
|
import backoff |
|
import requests |
|
from strictjson import strict_json |
|
|
|
from ai_scientist.llm import ( |
|
allchoices, |
|
extract_json_between_markers, |
|
get_response_from_llm, |
|
llm_json_auto_correct, |
|
) |
|
|
|
S2_API_KEY = os.getenv("S2_API_KEY") |
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
idea_first_prompt = """{task_description} |
|
<experiment.py> |
|
{code} |
|
</experiment.py> |
|
|
|
Here are the ideas that you have already generated: |
|
|
|
''' |
|
{prev_ideas_string} |
|
''' |
|
|
|
Come up with the next impactful and creative idea for research experiments and directions you can feasibly investigate with the code provided. |
|
Note that you will not have access to any additional resources or datasets. |
|
Make sure any idea is not overfit the specific training dataset or model, and has wider significance. |
|
|
|
Respond in the following format: |
|
|
|
THOUGHT: |
|
<THOUGHT> |
|
|
|
NEW IDEA JSON: |
|
```json |
|
<JSON> |
|
``` |
|
|
|
In <THOUGHT>, first briefly discuss your intuitions and motivations for the idea. Detail your high-level plan, necessary design choices and ideal outcomes of the experiments. Justify how the idea is different from the existing ones. |
|
|
|
Add '```json' before the <JSON> and '```' after the <JSON> as above. In <JSON>, provide the new idea in JSON format with the following keys and values: |
|
- "Name": A shortened descriptor of the idea. Lowercase, no spaces, underscores allowed. |
|
- "Title": A title for the idea, will be used for the report writing. |
|
- "Experiment": An outline of the implementation. E.g. which functions need to be added or modified, how results will be obtained, ... |
|
- "Interestingness": A rating from 1 to 10 (lowest to highest). |
|
- "Feasibility": A rating from 1 to 10 (lowest to highest). |
|
- "Novelty": A rating from 1 to 10 (lowest to highest). |
|
|
|
Be cautious and realistic on your ratings. |
|
This JSON will be automatically parsed, so ensure the format is precise. |
|
You will have {num_reflections} rounds to iterate on the idea, but do not need to use them all. |
|
""" |
|
|
|
idea_reflection_prompt = """Round {current_round}/{num_reflections}. |
|
In your thoughts, first carefully consider the quality, novelty, and feasibility of the idea you just created. |
|
Include any other factors that you think are important in evaluating the idea. |
|
Ensure the idea is clear and concise, and the JSON is the correct format. |
|
Do not make things overly complicated. |
|
In the next attempt, try and refine and improve your idea. |
|
Stick to the spirit of the original idea unless there are glaring issues. |
|
|
|
Respond in the exactly the same format as before: |
|
THOUGHT: |
|
<THOUGHT> |
|
|
|
NEW IDEA JSON: |
|
```json |
|
<JSON> |
|
``` |
|
|
|
If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON. |
|
ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES. |
|
""" |
|
|
|
|
|
|
|
def format_idea_json(text): |
|
json_start_marker = "```json" |
|
json_end_marker = "```" |
|
start_index = text.find(json_start_marker) |
|
if start_index != -1: |
|
start_index += len(json_start_marker) |
|
end_index = text.find(json_end_marker, start_index) |
|
json_string = text[start_index:end_index].strip() |
|
res = strict_json( |
|
system_prompt="You are a JSON formatter", |
|
user_prompt=json_string, |
|
output_format={ |
|
"Name": "A shortened descriptor of the idea", |
|
"Title": "A title for the idea, will be used for the report writing", |
|
"Experiment": "An outline of the implementation, type: list", |
|
"Interestingness": "A rating from 1 to 10 (lowest to highest), type: int", |
|
"Feasibility": "A rating from 1 to 10 (lowest to highest), type: int", |
|
"Novelty": "A rating from 1 to 10 (lowest to highest), type: int", |
|
}, |
|
llm=llm_json_auto_correct, |
|
) |
|
text = "```json\n" + json.dumps(res) + "```\n" |
|
return text |
|
|
|
|
|
def format_novelty_json(text): |
|
json_start_marker = "```json" |
|
json_end_marker = "```" |
|
start_index = text.find(json_start_marker) |
|
if start_index != -1: |
|
start_index += len(json_start_marker) |
|
end_index = text.find(json_end_marker, start_index) |
|
json_string = text[start_index:end_index].strip() |
|
res = strict_json( |
|
system_prompt="You are a JSON formatter", |
|
user_prompt=json_string, |
|
output_format={ |
|
"Query": "An optional search query to search the literature (e.g. attention is all you need)", |
|
}, |
|
llm=llm_json_auto_correct, |
|
) |
|
text = "```json\n" + json.dumps(res) + "```\n" |
|
return text |
|
|
|
|
|
|
|
def generate_ideas( |
|
base_dir, |
|
client, |
|
model, |
|
skip_generation=False, |
|
max_num_generations=20, |
|
num_reflections=5, |
|
): |
|
if skip_generation: |
|
|
|
try: |
|
with open(osp.join(base_dir, "ideas.json"), "r") as f: |
|
ideas = json.load(f) |
|
print("Loaded existing ideas:") |
|
for idea in ideas: |
|
print(idea) |
|
return ideas |
|
except FileNotFoundError: |
|
print("No existing ideas found. Generating new ideas.") |
|
except json.JSONDecodeError: |
|
print("Error decoding existing ideas. Generating new ideas.") |
|
|
|
idea_str_archive = [] |
|
with open(osp.join(base_dir, "seed_ideas.json"), "r") as f: |
|
seed_ideas = json.load(f) |
|
for seed_idea in seed_ideas: |
|
idea_str_archive.append(json.dumps(seed_idea)) |
|
|
|
with open(osp.join(base_dir, "experiment.py"), "r") as f: |
|
code = f.read() |
|
|
|
with open(osp.join(base_dir, "prompt.json"), "r") as f: |
|
prompt = json.load(f) |
|
|
|
idea_system_prompt = prompt["system"] |
|
|
|
for _ in range(max_num_generations): |
|
print() |
|
print(f"Generating idea {_ + 1}/{max_num_generations}") |
|
import traceback |
|
try: |
|
prev_ideas_string = "\n\n".join(idea_str_archive) |
|
|
|
msg_history = [] |
|
print(f"Iteration 1/{num_reflections}") |
|
text, msg_history = get_response_from_llm( |
|
idea_first_prompt.format( |
|
task_description=prompt["task_description"], |
|
code=code, |
|
prev_ideas_string=prev_ideas_string, |
|
num_reflections=num_reflections, |
|
), |
|
client=client, |
|
model=model, |
|
system_message=idea_system_prompt, |
|
msg_history=msg_history, |
|
) |
|
|
|
text = format_idea_json(text) |
|
|
|
|
|
json_output = extract_json_between_markers(text) |
|
assert json_output is not None, "Failed to extract JSON from LLM output" |
|
|
|
|
|
|
|
if num_reflections > 1: |
|
for j in range(num_reflections - 1): |
|
print(f"Iteration {j + 2}/{num_reflections}") |
|
text, msg_history = get_response_from_llm( |
|
idea_reflection_prompt.format( |
|
current_round=j + 2, num_reflections=num_reflections |
|
), |
|
client=client, |
|
model=model, |
|
system_message=idea_system_prompt, |
|
msg_history=msg_history, |
|
) |
|
|
|
text = format_idea_json(text) |
|
|
|
json_output = extract_json_between_markers(text) |
|
assert ( |
|
json_output is not None |
|
), "Failed to extract JSON from LLM output" |
|
|
|
|
|
if "I am done" in text: |
|
print(f"Idea generation converged after {j + 2} iterations.") |
|
break |
|
|
|
idea_str_archive.append(json.dumps(json_output)) |
|
except Exception as e: |
|
print(f"Failed to generate idea: {e}") |
|
traceback.print_exc() |
|
continue |
|
|
|
|
|
ideas = [] |
|
for idea_str in idea_str_archive: |
|
ideas.append(json.loads(idea_str)) |
|
|
|
with open(osp.join(base_dir, "ideas.json"), "w") as f: |
|
json.dump(ideas, f, indent=4) |
|
|
|
return ideas |
|
|
|
|
|
|
|
def generate_next_idea( |
|
base_dir, |
|
client, |
|
model, |
|
prev_idea_archive=[], |
|
num_reflections=5, |
|
max_attempts=10, |
|
): |
|
idea_archive = prev_idea_archive |
|
original_archive_size = len(idea_archive) |
|
|
|
print(f"Generating idea {original_archive_size + 1}") |
|
|
|
if len(prev_idea_archive) == 0: |
|
print(f"First iteration, taking seed ideas") |
|
|
|
with open(osp.join(base_dir, "seed_ideas.json"), "r") as f: |
|
seed_ideas = json.load(f) |
|
for seed_idea in seed_ideas[:1]: |
|
idea_archive.append(seed_idea) |
|
else: |
|
with open(osp.join(base_dir, "experiment.py"), "r") as f: |
|
code = f.read() |
|
with open(osp.join(base_dir, "prompt.json"), "r") as f: |
|
prompt = json.load(f) |
|
idea_system_prompt = prompt["system"] |
|
|
|
for _ in range(max_attempts): |
|
import traceback |
|
try: |
|
idea_strings = [] |
|
for idea in idea_archive: |
|
idea_strings.append(json.dumps(idea)) |
|
prev_ideas_string = "\n\n".join(idea_strings) |
|
|
|
msg_history = [] |
|
print(f"Iteration 1/{num_reflections}") |
|
text, msg_history = get_response_from_llm( |
|
idea_first_prompt.format( |
|
task_description=prompt["task_description"], |
|
code=code, |
|
prev_ideas_string=prev_ideas_string, |
|
num_reflections=num_reflections, |
|
) |
|
+ """ |
|
Completed ideas have an additional "Score" field which indicates the assessment by an expert ML reviewer. |
|
This is on a standard 1-10 ML conference scale. |
|
Scores of 0 indicate the idea failed either during experimentation, writeup or reviewing. |
|
""", |
|
client=client, |
|
model=model, |
|
system_message=idea_system_prompt, |
|
msg_history=msg_history, |
|
) |
|
|
|
text = format_idea_json(text) |
|
|
|
json_output = extract_json_between_markers(text) |
|
assert json_output is not None, "Failed to extract JSON from LLM output" |
|
|
|
|
|
|
|
if num_reflections > 1: |
|
for j in range(num_reflections - 1): |
|
print(f"Iteration {j + 2}/{num_reflections}") |
|
text, msg_history = get_response_from_llm( |
|
idea_reflection_prompt.format( |
|
current_round=j + 2, num_reflections=num_reflections |
|
), |
|
client=client, |
|
model=model, |
|
system_message=idea_system_prompt, |
|
msg_history=msg_history, |
|
) |
|
|
|
text = format_idea_json(text) |
|
|
|
json_output = extract_json_between_markers(text) |
|
assert ( |
|
json_output is not None |
|
), "Failed to extract JSON from LLM output" |
|
|
|
|
|
if "I am done" in text: |
|
print( |
|
f"Idea generation converged after {j + 2} iterations." |
|
) |
|
break |
|
|
|
idea_archive.append(json_output) |
|
break |
|
except Exception as e: |
|
print(f"Failed to generate idea: {e}") |
|
traceback.print_exc() |
|
continue |
|
|
|
|
|
with open(osp.join(base_dir, "ideas.json"), "w") as f: |
|
json.dump(idea_archive, f, indent=4) |
|
|
|
return idea_archive |
|
|
|
|
|
def on_backoff(details): |
|
print( |
|
f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries " |
|
f"calling function {details['target'].__name__} at {time.strftime('%X')}" |
|
) |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff |
|
) |
|
def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]: |
|
if not query: |
|
return None |
|
rsp = requests.get( |
|
"https://api.semanticscholar.org/graph/v1/paper/search", |
|
headers={"X-API-KEY": S2_API_KEY}, |
|
params={ |
|
"query": query, |
|
"limit": result_limit, |
|
"fields": "title,authors,venue,year,abstract,citationStyles,citationCount", |
|
}, |
|
) |
|
print(f"Response Status Code: {rsp.status_code}") |
|
print( |
|
f"Response Content: {rsp.text[:500]}" |
|
) |
|
rsp.raise_for_status() |
|
results = rsp.json() |
|
total = results["total"] |
|
if not total: |
|
return None |
|
time.sleep(2) |
|
papers = results["data"] |
|
return papers |
|
|
|
|
|
novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field. |
|
You have an idea and you want to check if it is novel or not. I.e., not overlapping significantly with existing literature or already well explored. |
|
Be a harsh critic for novelty, ensure there is a sufficient contribution in the idea for a new conference or workshop paper. |
|
You will be given access to the Semantic Scholar API, which you may use to survey the literature and find relevant papers to help you make your decision. |
|
The top 10 results for any search query will be presented to you with the abstracts. |
|
|
|
You will be given {num_rounds} to decide on the paper, but you do not need to use them all. |
|
At any round, you may exit early and decide on the novelty of the idea. |
|
Decide a paper idea is novel if after sufficient searching, you have not found a paper that significantly overlaps with your idea. |
|
Decide a paper idea is not novel, if you have found a paper that significantly overlaps with your idea. |
|
|
|
{task_description} |
|
<experiment.py> |
|
{code} |
|
</experiment.py> |
|
""" |
|
|
|
novelty_prompt = '''Round {current_round}/{num_rounds}. |
|
You have this idea: |
|
|
|
""" |
|
{idea} |
|
""" |
|
|
|
The results of the last query are (empty on first round): |
|
""" |
|
{last_query_results} |
|
""" |
|
|
|
Respond in the following format: |
|
|
|
THOUGHT: |
|
<THOUGHT> |
|
|
|
RESPONSE: |
|
```json |
|
<JSON> |
|
``` |
|
|
|
In <THOUGHT>, first briefly reason over the idea and identify any query that could help you make your decision. |
|
If you have made your decision, add "Decision made: novel." or "Decision made: not novel." to your thoughts. |
|
|
|
In <JSON>, respond in JSON format with ONLY the following field: |
|
- "Query": An optional search query to search the literature (e.g. attention is all you need). You must make a query if you have not decided this round. |
|
|
|
A query will work best if you are able to recall the exact name of the paper you are looking for, or the authors. |
|
This JSON will be automatically parsed, so ensure the format is precise. |
|
''' |
|
|
|
|
|
def check_idea_novelty( |
|
ideas, |
|
base_dir, |
|
client, |
|
model, |
|
max_num_iterations=10, |
|
): |
|
with open(osp.join(base_dir, "experiment.py"), "r") as f: |
|
code = f.read() |
|
with open(osp.join(base_dir, "prompt.json"), "r") as f: |
|
prompt = json.load(f) |
|
task_description = prompt["task_description"] |
|
|
|
for idx, idea in enumerate(ideas): |
|
if "novel" in idea: |
|
print(f"Skipping idea {idx}, already checked.") |
|
continue |
|
|
|
print(f"\nChecking novelty of idea {idx}: {idea['Name']}") |
|
|
|
novel = False |
|
msg_history = [] |
|
papers_str = "" |
|
|
|
for j in range(max_num_iterations): |
|
try: |
|
text, msg_history = get_response_from_llm( |
|
novelty_prompt.format( |
|
current_round=j + 1, |
|
num_rounds=max_num_iterations, |
|
idea=idea, |
|
last_query_results=papers_str, |
|
), |
|
client=client, |
|
model=model, |
|
system_message=novelty_system_msg.format( |
|
num_rounds=max_num_iterations, |
|
task_description=task_description, |
|
code=code, |
|
), |
|
msg_history=msg_history, |
|
) |
|
if "decision made: novel" in text.lower(): |
|
print("Decision made: novel after round", j) |
|
novel = True |
|
break |
|
if "decision made: not novel" in text.lower(): |
|
print("Decision made: not novel after round", j) |
|
break |
|
|
|
|
|
text = format_novelty_json(text) |
|
print("text after formating\n", text) |
|
|
|
json_output = extract_json_between_markers(text) |
|
assert json_output is not None, "Failed to extract JSON from LLM output" |
|
|
|
|
|
query = json_output["Query"] |
|
papers = search_for_papers(query, result_limit=10) |
|
if papers is None: |
|
papers_str = "No papers found." |
|
|
|
paper_strings = [] |
|
for i, paper in enumerate(papers): |
|
paper_strings.append( |
|
"""{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format( |
|
i=i, |
|
title=paper["title"], |
|
authors=paper["authors"], |
|
venue=paper["venue"], |
|
year=paper["year"], |
|
cites=paper["citationCount"], |
|
abstract=paper["abstract"], |
|
) |
|
) |
|
papers_str = "\n\n".join(paper_strings) |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
continue |
|
|
|
idea["novel"] = novel |
|
|
|
|
|
results_file = osp.join(base_dir, "ideas.json") |
|
with open(results_file, "w") as f: |
|
json.dump(ideas, f, indent=4) |
|
|
|
return ideas |
|
|
|
|
|
if __name__ == "__main__": |
|
MAX_NUM_GENERATIONS = 32 |
|
NUM_REFLECTIONS = 5 |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Generate AI scientist ideas") |
|
|
|
parser.add_argument( |
|
"--experiment", |
|
type=str, |
|
default="nanoGPT", |
|
help="Experiment to run AI Scientist on.", |
|
) |
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
default="deepseek-ai/DeepSeek-V2.5", |
|
choices=allchoices, |
|
help="Model to use for AI Scientist.", |
|
) |
|
parser.add_argument( |
|
"--skip-idea-generation", |
|
action="store_true", |
|
help="Skip idea generation and use existing ideas.", |
|
) |
|
parser.add_argument( |
|
"--check-novelty", |
|
action="store_true", |
|
help="Check novelty of ideas.", |
|
) |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
if args.model == "Qwen/Qwen2.5-72B-Instruct": |
|
|
|
print(f"Welcome to the PARADISE of debug <generate_scientist.py> {args.model}.") |
|
|
|
import openai |
|
import os |
|
|
|
client_model = args.model |
|
client = openai.OpenAI( |
|
api_key=os.environ["OPENAI_API_KEY"], base_url="https://api.hyperbolic.xyz/v1" |
|
) |
|
|
|
|
|
|
|
|
|
elif args.model == "claude-3-5-sonnet-20240620": |
|
import anthropic |
|
|
|
print(f"Using Anthropic API with model {args.model}.") |
|
client_model = "claude-3-5-sonnet-20240620" |
|
client = anthropic.Anthropic() |
|
elif args.model.startswith("bedrock") and "claude" in args.model: |
|
import anthropic |
|
|
|
|
|
client_model = args.model.split("/")[-1] |
|
|
|
print(f"Using Amazon Bedrock with model {client_model}.") |
|
client = anthropic.AnthropicBedrock() |
|
elif args.model == "gpt-4o-2024-05-13" or args.model == "hybrid": |
|
import openai |
|
|
|
print(f"Using OpenAI API with model {args.model}.") |
|
client_model = "gpt-4o-2024-05-13" |
|
client = openai.OpenAI() |
|
elif args.model == "deepseek-coder-v2-0724": |
|
import openai |
|
|
|
print(f"Using OpenAI API with {args.model}.") |
|
client_model = "deepseek-coder-v2-0724" |
|
client = openai.OpenAI( |
|
api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.hyperbolic.xyz/v1" |
|
) |
|
elif args.model == "llama3.1-405b": |
|
import openai |
|
|
|
print(f"Using OpenAI API with {args.model}.") |
|
client_model = "meta-llama/llama-3.1-405b-instruct" |
|
client = openai.OpenAI( |
|
api_key=os.environ["OPENROUTER_API_KEY"], |
|
base_url="https://openrouter.ai/api/v1", |
|
) |
|
elif args.model.startswith("ollama"): |
|
import openai |
|
|
|
print(f"Using Ollama with {args.model}.") |
|
client_model = args.model.split("/")[-1] |
|
|
|
client = openai.OpenAI(api_key="ollama", base_url="http://localhost:11434/v1") |
|
|
|
else: |
|
raise ValueError(f"Model {args.model} not supported.") |
|
|
|
base_dir = osp.join("templates", args.experiment) |
|
results_dir = osp.join("results", args.experiment) |
|
print("going into line 623...") |
|
ideas = generate_ideas( |
|
base_dir, |
|
client=client, |
|
model=client_model, |
|
skip_generation=args.skip_idea_generation, |
|
max_num_generations=MAX_NUM_GENERATIONS, |
|
num_reflections=NUM_REFLECTIONS, |
|
) |
|
if args.check_novelty: |
|
ideas = check_idea_novelty( |
|
ideas, |
|
base_dir=base_dir, |
|
client=client, |
|
model=client_model, |
|
) |