CoI_Agent / agents.py
jianghuyihei's picture
fix
0f5901d
import json
import time
from searcher import Result,SementicSearcher
from LLM import openai_llm
from prompts import *
from utils import extract
def get_llm(model = "gpt4o-0513"):
return openai_llm(model)
def get_llms():
main_llm = get_llm("gpt4o-0513")
cheap_llm = get_llm("gpt-4o-mini")
return main_llm,cheap_llm
def judge_idea(i,j,idea0,idea1,topic,llm):
prompt = get_judge_idea_all_prompt(idea0,idea1,topic)
messages = [{"role":"user","content":prompt}]
response = llm.response(messages)
novelty = extract(response,"novelty")
relevance = extract(response,"relevance")
significance = extract(response,"significance")
clarity = extract(response,"clarity")
feasibility = extract(response,"feasibility")
effectiveness = extract(response,"effectiveness")
return i,j,novelty,relevance,significance,clarity,feasibility,effectiveness
class DeepResearchAgent:
def __init__(self,llm = None,cheap_llm=None,publicationData = None,ban_paper = [],**kwargs) -> None:
self.reader = SementicSearcher(ban_paper = ban_paper)
self.begin_time = time.time()
self.llm = llm
self.cheap_llm = cheap_llm
self.read_papers = set()
self.paper_storage = []
self.paper_info_for_refine_experiment = []
self.search_qeuries = []
self.deep_research_chains = []
self.deep_ideas = []
self.check_novel_results = []
self.score_results = []
self.topic =None
self.publicationData = publicationData
self.improve_cnt = kwargs.get("improve_cnt",1)
self.max_chain_length = kwargs.get("max_chain_length",5)
self.min_chain_length = kwargs.get("min_chain_length",3)
self.max_chain_numbers = kwargs.get("max_chain_numbers",10)
def wrap_messages(self,prompt):
return [{"role":"user","content":prompt}]
def get_openai_response(self,messages):
return self.llm.response(messages)
def get_cheap_openai_response(self,messages):
return self.cheap_llm.response(messages,max_tokens = 16000)
def get_search_query(self,topic = None,query=None):
prompt = get_deep_search_query_prompt(topic,query)
messages = self.wrap_messages(prompt)
response = self.get_openai_response(messages)
search_query = extract(response,"queries")
try:
search_query = json.loads(search_query)
self.search_qeuries.append({"query":query,"search_query":search_query})
except:
search_query = [query]
return search_query
def generate_idea_with_chain(self,topic):
self.topic = topic
print(f"begin to generate search query for {topic}")
search_query = self.get_search_query(topic=topic)
papers = []
for query in search_query:
failed_query = []
current_papers = []
cnt = 0
while len(current_papers) == 0 and cnt < 10:
paper = self.reader.search(query,1,paper_list=self.read_papers,llm=self.llm,rerank_query=f"{topic}",publicationDate=self.publicationData)
if paper and len(paper) > 0 and paper[0]:
for p in paper:
prompt = get_deep_judge_relevant_prompt(p.title,p.abstract,self.topic)
messages = self.wrap_messages(prompt)
response = self.get_openai_response(messages)
relevant = extract(response,"relevant")
if relevant == "1":
print(f"{p.title} is relevant to the topic {self.topic},have added to the queue")
self.read_papers.add(p.title)
current_papers.append(p)
else:
failed_query.append(query)
prompt = get_deep_rewrite_query_prompt(failed_query,topic)
messages = self.wrap_messages(prompt)
new_query = self.get_openai_response(messages)
new_query = extract(new_query,"query")
print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.")
query = new_query
cnt += 1
papers.extend(current_papers)
if len(papers) >= self.max_chain_numbers:
break
if len(papers) == 0:
print(f"failed to generate idea {topic}")
return None,None,None,None,None,None,None,None,None
idea,idea_chain,experiment,entities,trend,future,human,year = self.deep_research_paper_with_chain(papers[0])
print(f"successfully generated idea")
return idea,experiment,entities,idea_chain,idea,trend,future,human,year
def get_paper_idea_experiment_references_info(self,paper):
article = paper.article
if not article:
return None
paper_content = self.reader.read_paper_content(article)
prompt = get_deep_reference_prompt(paper_content,self.topic)
messages = self.wrap_messages(prompt)
response = self.get_cheap_openai_response(messages)
entities = extract(response,"entities")
idea = extract(response,"idea")
experiment = extract(response,"experiment")
references = extract(response,"references")
return idea,experiment,entities,references,paper.title
def get_article_idea_experiment_references_info(self,article):
paper_content = self.reader.read_paper_content_with_ref(article)
prompt = get_deep_reference_prompt(paper_content,self.topic)
messages = self.wrap_messages(prompt)
response = self.get_cheap_openai_response(messages)
entities = extract(response,"entities")
idea = extract(response,"idea")
experiment = extract(response,"experiment")
references = extract(response,"references")
return idea,experiment,entities,references
def deep_research_paper_with_chain(self,paper:Result):
print(f"begin to deep research paper {paper.title}")
article = paper.article
if not article:
print(f"failed to deep research paper {paper.title}")
return None
idea_chain = []
idea_papers = []
experiments = []
total_entities = []
years = []
idea,experiment,entities,references = self.get_article_idea_experiment_references_info(article)
try:
references = json.loads(references)
except:
references = []
total_entities.append(entities)
idea_chain.append(idea)
idea_papers.append(paper.title)
experiments.append(experiment)
years.append(paper.year)
current_title = paper.title
current_abstract = paper.abstract
# search before
while len(idea_chain)<self.max_chain_length:
rerank_query = f"{self.topic} {current_title} {current_abstract}"
citation_paper = self.reader.search_related_paper(current_title,need_reference=False,rerank_query=rerank_query,llm=self.llm,paper_list=idea_papers)
if not citation_paper:
print(f"failed to find citation paper for {current_title}")
break
title = citation_paper.title
abstract = citation_paper.abstract
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
messages = self.wrap_messages(prompt)
response = self.get_openai_response(messages)
relevant = extract(response,"relevant")
if relevant != "0":
result = self.get_paper_idea_experiment_references_info(citation_paper)
if not result:
break
idea,experiment,entities,_,_ = result
idea_chain.append(idea)
experiments.append(experiment)
total_entities.append(entities)
idea_papers.append(citation_paper.title)
years.append(citation_paper.year)
current_title = citation_paper.title
current_abstract = citation_paper.abstract
else:
print(f"the paper {title} is not relevant")
break
current_title = paper.title
current_abstract = paper.abstract
# search after
while len(idea_chain) < self.max_chain_length and len(references) > 0:
search_paper = []
article = None
print(f"The references find:{references}")
while len(references) > 0 and len(search_paper) == 0:
reference = references[0]
references.pop(0)
if reference in self.read_papers:
continue
search_paper = self.reader.search(reference,3,llm=self.llm,publicationDate=self.publicationData,paper_list= idea_papers)
if len(search_paper) > 0:
s_p = search_paper[0]
if s_p and s_p.title not in self.read_papers:
prompt = get_deep_judge_relevant_prompt(s_p.title,s_p.abstract,self.topic)
messages = self.wrap_messages(prompt)
response = self.get_openai_response(messages)
relevant = extract(response,"relevant")
if relevant != "0" or len(idea_chain) < self.min_chain_length:
article = s_p.article
if article:
cite_paper = s_p
break
else:
print(f"the paper {s_p.title} is not relevant")
search_paper = []
if not article:
rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}"
search_paper = self.reader.search_related_paper(current_title,need_citation=False,rerank_query = rerank_query,llm=self.llm,paper_list=idea_papers)
if not search_paper:
print(f"failed to find citation paper for {current_title}")
continue
s_p = search_paper
if len(idea_chain) < self.min_chain_length:
article = s_p.article
if not article:
continue
else:
cite_paper = s_p
break
else:
if s_p and s_p.title not in self.read_papers:
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
messages = self.wrap_messages(prompt)
response = self.get_openai_response(messages)
relevant = extract(response,"relevant")
if relevant == "1" or len(idea_chain) < self.min_chain_length:
article = s_p.article
if not article:
continue
else:
cite_paper = s_p
break
if not article:
print(f"failed to find citation paper for {current_title}")
continue
print("find the citation paper, begin to deep research")
paper_content = self.reader.read_paper_content_with_ref(article)
prompt = get_deep_reference_prompt(paper_content,self.topic)
messages = self.wrap_messages(prompt)
response = self.get_cheap_openai_response(messages)
idea = extract(response,"idea")
references = extract(response,"references")
experiment = extract(response,"experiment")
entities = extract(response,"entities")
try:
references = json.loads(references)
except:
references = []
current_title = cite_paper.title
current_abstract = cite_paper.abstract
years = [cite_paper.year] + years
idea_chain = [idea] + idea_chain
idea_papers = [cite_paper.title] + idea_papers
experiments = [experiment] + experiments
total_entities = [entities] + total_entities
if len(idea_chain) >= self.min_chain_length:
if cite_paper.citations_conut > 1000:
break
print("successfully generate idea chain")
idea_chains = ""
for i,idea,title in zip(range(len(idea_chain)),idea_chain,idea_papers):
idea_chains += f"{i}.Paper:{title} idea:{idea}\n \n"
prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic)
messages = self.wrap_messages(prompt)
response = self.get_openai_response(messages)
trend = extract(response,"trend")
self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years})
prompt = f"""The current research topic is: {self.topic}. Please help me summarize and refine the following entities by merging, simplifying, or deleting them : {total_entities}
Please output strictly in the following format:
<entities> {{cleaned entities}}</entities>
"""
messages = self.wrap_messages(prompt)
response = self.get_openai_response(messages)
total_entities = extract(response,"entities")
bad_case = []
prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities)
messages = self.wrap_messages(prompt)
response = self.get_openai_response(messages)
future = extract(response,"future")
human = extract(response,"human")
prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case)
messages = self.wrap_messages(prompt)
response = self.get_openai_response(messages)
method = extract(response,"method")
novelty = extract(response,"novelty")
motivation = extract(response,"motivation")
idea = {"motivation":motivation,"novelty":novelty,"method":method}
prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic)
messages = self.wrap_messages(prompt)
response = self.get_openai_response(messages)
final_idea = extract(response,"final_idea")
idea = final_idea
self.deep_ideas.append(idea)
print(f"successfully deep research paper {paper.title}")
return idea,idea_chains,trend,experiments,total_entities,future,human,years
if __name__ == "__main__":
reader = SementicSearcher()