CoI_Agent / agents.py
jianghuyihei's picture
first commit
863d8a3
raw
history blame
16.8 kB
import json
import time
import asyncio
import os
from searcher import Result,SementicSearcher
from LLM import openai_llm
from prompts import *
from utils import extract
def get_llm(model = "gpt4o-0513"):
return openai_llm(model)
def get_llms():
main_llm = get_llm("gpt4o-0513")
cheap_llm = get_llm("gpt-4o-mini")
return main_llm,cheap_llm
async def judge_idea(i,j,idea0,idea1,topic,llm):
prompt = get_judge_idea_all_prompt(idea0,idea1,topic)
messages = [{"role":"user","content":prompt}]
response = await llm.response_async(messages)
novelty = extract(response,"novelty")
relevance = extract(response,"relevance")
significance = extract(response,"significance")
clarity = extract(response,"clarity")
feasibility = extract(response,"feasibility")
effectiveness = extract(response,"effectiveness")
return i,j,novelty,relevance,significance,clarity,feasibility,effectiveness
class DeepResearchAgent:
def __init__(self,llm = None,cheap_llm=None,publicationData = None,ban_paper = [],**kwargs) -> None:
self.reader = SementicSearcher(ban_paper = ban_paper)
self.begin_time = time.time()
self.llm = llm
self.cheap_llm = cheap_llm
self.read_papers = set()
self.paper_storage = []
self.paper_info_for_refine_experiment = []
self.search_qeuries = []
self.deep_research_chains = []
self.deep_ideas = []
self.check_novel_results = []
self.score_results = []
self.topic =None
self.publicationData = publicationData
self.improve_cnt = kwargs.get("improve_cnt",1)
self.max_chain_length = kwargs.get("max_chain_length",5)
self.min_chain_length = kwargs.get("min_chain_length",3)
self.max_chain_numbers = kwargs.get("max_chain_numbers",10)
def wrap_messages(self,prompt):
return [{"role":"user","content":prompt}]
async def get_openai_response_async(self,messages):
return await self.llm.response_async(messages)
async def get_cheap_openai_response_async(self,messages):
return await self.cheap_llm.response_async(messages,max_tokens = 16000)
async def get_search_query(self,topic = None,query=None):
prompt = get_deep_search_query_prompt(topic,query)
messages = self.wrap_messages(prompt)
response = await self.get_openai_response_async(messages)
search_query = extract(response,"queries")
try:
search_query = json.loads(search_query)
self.search_qeuries.append({"query":query,"search_query":search_query})
except:
search_query = [query]
return search_query
async def generate_idea_with_chain(self,topic):
self.topic = topic
print(f"begin to generate search query for {topic}")
search_query = await self.get_search_query(topic=topic)
papers = []
for query in search_query:
failed_query = []
current_papers = []
cnt = 0
while len(current_papers) == 0 and cnt < 10:
paper = await self.reader.search_async(query,1,paper_list=self.read_papers,llm=self.llm,rerank_query=f"{topic}",publicationDate=self.publicationData)
if paper and len(paper) > 0 and paper[0]:
self.read_papers.add(paper[0].title)
current_papers.append(paper[0])
else:
failed_query.append(query)
prompt = get_deep_rewrite_query_prompt(failed_query,topic)
messages = self.wrap_messages(prompt)
new_query = await self.get_openai_response_async(messages)
new_query = extract(new_query,"query")
print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.")
query = new_query
cnt += 1
papers.extend(current_papers)
if len(papers) >= self.max_chain_numbers:
break
if len(papers) == 0:
print(f"failed to generate idea {topic}")
return None,None,None,None,None,None,None,None,None
tasks = [self.deep_research_paper_with_chain(paper) for paper in papers]
results = await asyncio.gather(*tasks)
results = [result for result in results if result]
if len(results) ==0:
print(f"failed to generate idea {topic}")
return None,None,None,None,None,None,None,None,None
ideas,idea_chains,experiments,entities,trends,futures,humans,years = [[result[i] for result in results] for i in range(8)]
tasks = []
for i,idea_1 in enumerate(ideas):
for j,idea_2 in enumerate(ideas):
if i != j:
tasks.append(judge_idea(i,j,idea_1,idea_2,topic,self.llm))
results = await asyncio.gather(*tasks)
elo_scores = [0 for _ in range(len(ideas))]
elo_selected = 0
def change_winner_to_score(winner,score_1,score_2):
try:
winner = int(winner)
except:
return score_1+0.5,score_2+0.5
if winner == 0:
return score_1+1,score_2
if winner == 2:
return score_1+0.5,score_2+0.5
return score_1,score_2+1
for result in results:
i,j,novelty,relevance,significance,clarity,feasibility,effectiveness = result
for dimension in [novelty,relevance,significance,clarity,feasibility,effectiveness]:
elo_scores[i],elo_scores[j] = change_winner_to_score(dimension,elo_scores[i],elo_scores[j])
print(f"i:{i},j:{j},novelty:{novelty},relevance:{relevance},significance:{significance},clarity:{clarity},feasibility:{feasibility},effectiveness:{effectiveness}")
print(elo_scores)
try:
elo_selected = elo_scores.index(max(elo_scores))
except:
elo_selected = 0
idea,experiment,entities,idea_chain,trend,future,human,year = ideas[elo_selected],experiments[elo_selected],entities[elo_selected],idea_chains[elo_selected],trends[elo_selected],futures[elo_selected],humans[elo_selected],years[elo_selected]
print(f"successfully generated idea")
return idea,experiment,entities,idea_chain,ideas,trend,future,human,year
async def get_paper_idea_experiment_references_info(self,paper):
article = paper.article
if not article:
return None
paper_content = self.reader.read_paper_content(article)
prompt = get_deep_reference_prompt(paper_content,self.topic)
messages = self.wrap_messages(prompt)
response = await self.get_cheap_openai_response_async(messages)
entities = extract(response,"entities")
idea = extract(response,"idea")
experiment = extract(response,"experiment")
references = extract(response,"references")
return idea,experiment,entities,references,paper.title
async def get_article_idea_experiment_references_info(self,article):
paper_content = self.reader.read_paper_content_with_ref(article)
prompt = get_deep_reference_prompt(paper_content,self.topic)
messages = self.wrap_messages(prompt)
response = await self.get_cheap_openai_response_async(messages)
entities = extract(response,"entities")
idea = extract(response,"idea")
experiment = extract(response,"experiment")
references = extract(response,"references")
return idea,experiment,entities,references
async def deep_research_paper_with_chain(self,paper:Result):
print(f"begin to deep research paper {paper.title}")
article = paper.article
if not article:
print(f"failed to deep research paper {paper.title}")
return None
idea_chain = []
idea_papers = []
experiments = []
total_entities = []
years = []
idea,experiment,entities,references = await self.get_article_idea_experiment_references_info(article)
try:
references = json.loads(references)
except:
references = []
total_entities.append(entities)
idea_chain.append(idea)
idea_papers.append(paper.title)
experiments.append(experiment)
years.append(paper.year)
current_title = paper.title
current_abstract = paper.abstract
# search before
while len(idea_chain)<self.max_chain_length:
rerank_query = f"{self.topic} {current_title} {current_abstract}"
citation_paper = await self.reader.search_related_paper_async(current_title,need_reference=False,rerank_query=rerank_query,llm=self.llm,paper_list=idea_papers)
if not citation_paper:
print(f"failed to find citation paper for {current_title}")
break
title = citation_paper.title
abstract = citation_paper.abstract
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
messages = self.wrap_messages(prompt)
response = await self.get_openai_response_async(messages)
relevant = extract(response,"relevant")
if relevant != "0":
result = await self.get_paper_idea_experiment_references_info(citation_paper)
if not result:
break
idea,experiment,entities,_,_ = result
idea_chain.append(idea)
experiments.append(experiment)
total_entities.append(entities)
idea_papers.append(citation_paper.title)
years.append(citation_paper.year)
current_title = citation_paper.title
current_abstract = citation_paper.abstract
else:
print(f"the paper {title} is not relevant")
break
current_title = paper.title
current_abstract = paper.abstract
# search after
while len(idea_chain) < self.max_chain_length and len(references) > 0:
search_paper = []
article = None
print(f"The references find:{references}")
while len(references) > 0 and len(search_paper) == 0:
reference = references[0]
references.pop(0)
if reference in self.read_papers:
continue
search_paper = await self.reader.search_async(reference,3,llm=self.llm,publicationDate=self.publicationData,paper_list= idea_papers)
if len(search_paper) > 0:
s_p = search_paper[0]
if s_p and s_p.title not in self.read_papers:
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
messages = self.wrap_messages(prompt)
response = await self.get_openai_response_async(messages)
relevant = extract(response,"relevant")
if relevant != "0" or len(idea_chain) < self.min_chain_length:
article = s_p.article
if article:
cite_paper = s_p
break
else:
print(f"the paper {s_p.title} is not relevant")
search_paper = []
if not article:
rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}"
search_paper = await self.reader.search_related_paper_async(current_title,need_citation=False,rerank_query = rerank_query,llm=self.llm,paper_list=idea_papers)
if not search_paper:
print(f"failed to find citation paper for {current_title}")
continue
s_p = search_paper
if len(idea_chain) < self.min_chain_length:
article = s_p.article
if not article:
continue
else:
cite_paper = s_p
break
else:
if s_p and s_p.title not in self.read_papers:
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
messages = self.wrap_messages(prompt)
response = await self.get_openai_response_async(messages)
relevant = extract(response,"relevant")
if relevant == "1" or len(idea_chain) < self.min_chain_length:
article = await s_p.article
if not article:
continue
else:
cite_paper = s_p
break
if not article:
print(f"failed to find citation paper for {current_title}")
continue
print("find the citation paper, begin to deep research")
paper_content = self.reader.read_paper_content_with_ref(article)
prompt = get_deep_reference_prompt(paper_content,self.topic)
messages = self.wrap_messages(prompt)
response = await self.get_cheap_openai_response_async(messages)
idea = extract(response,"idea")
references = extract(response,"references")
experiment = extract(response,"experiment")
entities = extract(response,"entities")
try:
references = json.loads(references)
except:
references = []
current_title = cite_paper.title
current_abstract = cite_paper.abstract
years = [cite_paper.year] + years
idea_chain = [idea] + idea_chain
idea_papers = [cite_paper.title] + idea_papers
experiments = [experiment] + experiments
total_entities = [entities] + total_entities
if len(idea_chain) >= self.min_chain_length:
if cite_paper.citations_conut > 1000:
break
print("successfully generate idea chain")
idea_chains = ""
for i,idea,title in zip(range(len(idea_chain)),idea_chain,idea_papers):
idea_chains += f"{i}.Paper:{title} idea:{idea}\n \n"
prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic)
messages = self.wrap_messages(prompt)
response = await self.get_openai_response_async(messages)
trend = extract(response,"trend")
self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years})
prompt = f"""The current research topic is: {self.topic}. Please help me summarize and refine the following entities by merging, simplifying, or deleting them : {total_entities}
Please output strictly in the following format:
<entities> {{cleaned entities}}</entities>
"""
messages = self.wrap_messages(prompt)
response = await self.get_openai_response_async(messages)
total_entities = extract(response,"entities")
bad_case = []
prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities)
messages = self.wrap_messages(prompt)
response = await self.get_openai_response_async(messages)
future = extract(response,"future")
human = extract(response,"human")
prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case)
messages = self.wrap_messages(prompt)
response = await self.get_openai_response_async(messages)
method = extract(response,"method")
novelty = extract(response,"novelty")
motivation = extract(response,"motivation")
idea = {"motivation":motivation,"novelty":novelty,"method":method}
prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic)
messages = self.wrap_messages(prompt)
response = await self.get_openai_response_async(messages)
final_idea = extract(response,"final_idea")
idea = final_idea
self.deep_ideas.append(idea)
print(f"successfully deep research paper {paper.title}")
return idea,idea_chains,trend,experiments,total_entities,future,human,years
if __name__ == "__main__":
reader = SementicSearcher()