Spaces:
Running
Running
import json | |
import time | |
from searcher import Result,SementicSearcher | |
from LLM import openai_llm | |
from prompts import * | |
from utils import extract | |
def get_llm(model = "gpt4o-0513"): | |
return openai_llm(model) | |
def get_llms(): | |
main_llm = get_llm("gpt-4o-2024-08-06") | |
cheap_llm = get_llm("gpt-4o-mini") | |
return main_llm,cheap_llm | |
def judge_idea(i,j,idea0,idea1,topic,llm): | |
prompt = get_judge_idea_all_prompt(idea0,idea1,topic) | |
messages = [{"role":"user","content":prompt}] | |
response = llm.response(messages) | |
novelty = extract(response,"novelty") | |
relevance = extract(response,"relevance") | |
significance = extract(response,"significance") | |
clarity = extract(response,"clarity") | |
feasibility = extract(response,"feasibility") | |
effectiveness = extract(response,"effectiveness") | |
return i,j,novelty,relevance,significance,clarity,feasibility,effectiveness | |
class DeepResearchAgent: | |
def __init__(self,llm = None,cheap_llm=None,publicationData = None,ban_paper = [],**kwargs) -> None: | |
self.reader = SementicSearcher(ban_paper = ban_paper) | |
self.begin_time = time.time() | |
self.llm = llm | |
self.cheap_llm = cheap_llm | |
self.read_papers = set() | |
self.paper_storage = [] | |
self.paper_info_for_refine_experiment = [] | |
self.search_qeuries = [] | |
self.deep_research_chains = [] | |
self.deep_ideas = [] | |
self.check_novel_results = [] | |
self.score_results = [] | |
self.topic =None | |
self.publicationData = publicationData | |
self.improve_cnt = kwargs.get("improve_cnt",1) | |
self.max_chain_length = kwargs.get("max_chain_length",5) | |
self.min_chain_length = kwargs.get("min_chain_length",3) | |
self.max_chain_numbers = kwargs.get("max_chain_numbers",10) | |
def wrap_messages(self,prompt): | |
return [{"role":"user","content":prompt}] | |
def get_openai_response(self,messages): | |
return self.llm.response(messages) | |
def get_cheap_openai_response(self,messages): | |
return self.cheap_llm.response(messages,max_tokens = 16000) | |
def get_search_query(self,topic = None,query=None): | |
prompt = get_deep_search_query_prompt(topic,query) | |
messages = self.wrap_messages(prompt) | |
response = self.get_openai_response(messages) | |
search_query = extract(response,"queries") | |
try: | |
search_query = json.loads(search_query) | |
self.search_qeuries.append({"query":query,"search_query":search_query}) | |
except: | |
search_query = [query] | |
return search_query | |
def generate_idea_with_chain(self,topic): | |
self.topic = topic | |
print(f"begin to generate search query for {topic}") | |
search_query = self.get_search_query(topic=topic) | |
papers = [] | |
for query in search_query: | |
failed_query = [] | |
current_papers = [] | |
cnt = 0 | |
while len(current_papers) == 0 and cnt < 10: | |
paper = self.reader.search(query,1,paper_list=self.read_papers,llm=self.llm,rerank_query=f"{topic}",publicationDate=self.publicationData) | |
if paper and len(paper) > 0 and paper[0]: | |
self.read_papers.add(paper[0].title) | |
current_papers.append(paper[0]) | |
else: | |
failed_query.append(query) | |
prompt = get_deep_rewrite_query_prompt(failed_query,topic) | |
messages = self.wrap_messages(prompt) | |
new_query = self.get_openai_response(messages) | |
new_query = extract(new_query,"query") | |
print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.") | |
query = new_query | |
cnt += 1 | |
papers.extend(current_papers) | |
if len(papers) >= self.max_chain_numbers: | |
break | |
if len(papers) == 0: | |
print(f"failed to generate idea {topic}") | |
return None,None,None,None,None,None,None,None,None | |
idea,idea_chain,experiment,entities,trend,future,human,year = self.deep_research_paper_with_chain(papers[0]) | |
print(f"successfully generated idea") | |
return idea,experiment,entities,idea_chain,idea,trend,future,human,year | |
def get_paper_idea_experiment_references_info(self,paper): | |
article = paper.article | |
if not article: | |
return None | |
paper_content = self.reader.read_paper_content(article) | |
prompt = get_deep_reference_prompt(paper_content,self.topic) | |
messages = self.wrap_messages(prompt) | |
response = self.get_cheap_openai_response(messages) | |
entities = extract(response,"entities") | |
idea = extract(response,"idea") | |
experiment = extract(response,"experiment") | |
references = extract(response,"references") | |
return idea,experiment,entities,references,paper.title | |
def get_article_idea_experiment_references_info(self,article): | |
paper_content = self.reader.read_paper_content_with_ref(article) | |
prompt = get_deep_reference_prompt(paper_content,self.topic) | |
messages = self.wrap_messages(prompt) | |
response = self.get_cheap_openai_response(messages) | |
entities = extract(response,"entities") | |
idea = extract(response,"idea") | |
experiment = extract(response,"experiment") | |
references = extract(response,"references") | |
return idea,experiment,entities,references | |
def deep_research_paper_with_chain(self,paper:Result): | |
print(f"begin to deep research paper {paper.title}") | |
article = paper.article | |
if not article: | |
print(f"failed to deep research paper {paper.title}") | |
return None | |
idea_chain = [] | |
idea_papers = [] | |
experiments = [] | |
total_entities = [] | |
years = [] | |
idea,experiment,entities,references = self.get_article_idea_experiment_references_info(article) | |
try: | |
references = json.loads(references) | |
except: | |
references = [] | |
total_entities.append(entities) | |
idea_chain.append(idea) | |
idea_papers.append(paper.title) | |
experiments.append(experiment) | |
years.append(paper.year) | |
current_title = paper.title | |
current_abstract = paper.abstract | |
# search before | |
while len(idea_chain)<self.max_chain_length: | |
rerank_query = f"{self.topic} {current_title} {current_abstract}" | |
citation_paper = self.reader.search_related_paper(current_title,need_reference=False,rerank_query=rerank_query,llm=self.llm,paper_list=idea_papers) | |
if not citation_paper: | |
print(f"failed to find citation paper for {current_title}") | |
break | |
title = citation_paper.title | |
abstract = citation_paper.abstract | |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic) | |
messages = self.wrap_messages(prompt) | |
response = self.get_openai_response(messages) | |
relevant = extract(response,"relevant") | |
if relevant != "0": | |
result = self.get_paper_idea_experiment_references_info(citation_paper) | |
if not result: | |
break | |
idea,experiment,entities,_,_ = result | |
idea_chain.append(idea) | |
experiments.append(experiment) | |
total_entities.append(entities) | |
idea_papers.append(citation_paper.title) | |
years.append(citation_paper.year) | |
current_title = citation_paper.title | |
current_abstract = citation_paper.abstract | |
else: | |
print(f"the paper {title} is not relevant") | |
break | |
current_title = paper.title | |
current_abstract = paper.abstract | |
# search after | |
while len(idea_chain) < self.max_chain_length and len(references) > 0: | |
search_paper = [] | |
article = None | |
print(f"The references find:{references}") | |
while len(references) > 0 and len(search_paper) == 0: | |
reference = references[0] | |
references.pop(0) | |
if reference in self.read_papers: | |
continue | |
search_paper = self.reader.search(reference,3,llm=self.llm,publicationDate=self.publicationData,paper_list= idea_papers) | |
if len(search_paper) > 0: | |
s_p = search_paper[0] | |
if s_p and s_p.title not in self.read_papers: | |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic) | |
messages = self.wrap_messages(prompt) | |
response = self.get_openai_response(messages) | |
relevant = extract(response,"relevant") | |
if relevant != "0" or len(idea_chain) < self.min_chain_length: | |
article = s_p.article | |
if article: | |
cite_paper = s_p | |
break | |
else: | |
print(f"the paper {s_p.title} is not relevant") | |
search_paper = [] | |
if not article: | |
rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}" | |
search_paper = self.reader.search_related_paper(current_title,need_citation=False,rerank_query = rerank_query,llm=self.llm,paper_list=idea_papers) | |
if not search_paper: | |
print(f"failed to find citation paper for {current_title}") | |
continue | |
s_p = search_paper | |
if len(idea_chain) < self.min_chain_length: | |
article = s_p.article | |
if not article: | |
continue | |
else: | |
cite_paper = s_p | |
break | |
else: | |
if s_p and s_p.title not in self.read_papers: | |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic) | |
messages = self.wrap_messages(prompt) | |
response = self.get_openai_response(messages) | |
relevant = extract(response,"relevant") | |
if relevant == "1" or len(idea_chain) < self.min_chain_length: | |
article = s_p.article | |
if not article: | |
continue | |
else: | |
cite_paper = s_p | |
break | |
if not article: | |
print(f"failed to find citation paper for {current_title}") | |
continue | |
print("find the citation paper, begin to deep research") | |
paper_content = self.reader.read_paper_content_with_ref(article) | |
prompt = get_deep_reference_prompt(paper_content,self.topic) | |
messages = self.wrap_messages(prompt) | |
response = self.get_cheap_openai_response(messages) | |
idea = extract(response,"idea") | |
references = extract(response,"references") | |
experiment = extract(response,"experiment") | |
entities = extract(response,"entities") | |
try: | |
references = json.loads(references) | |
except: | |
references = [] | |
current_title = cite_paper.title | |
current_abstract = cite_paper.abstract | |
years = [cite_paper.year] + years | |
idea_chain = [idea] + idea_chain | |
idea_papers = [cite_paper.title] + idea_papers | |
experiments = [experiment] + experiments | |
total_entities = [entities] + total_entities | |
if len(idea_chain) >= self.min_chain_length: | |
if cite_paper.citations_conut > 1000: | |
break | |
print("successfully generate idea chain") | |
idea_chains = "" | |
for i,idea,title in zip(range(len(idea_chain)),idea_chain,idea_papers): | |
idea_chains += f"{i}.Paper:{title} idea:{idea}\n \n" | |
prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic) | |
messages = self.wrap_messages(prompt) | |
response = self.get_openai_response(messages) | |
trend = extract(response,"trend") | |
self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years}) | |
prompt = f"""The current research topic is: {self.topic}. Please help me summarize and refine the following entities by merging, simplifying, or deleting them : {total_entities} | |
Please output strictly in the following format: | |
<entities> {{cleaned entities}}</entities> | |
""" | |
messages = self.wrap_messages(prompt) | |
response = self.get_openai_response(messages) | |
total_entities = extract(response,"entities") | |
bad_case = [] | |
prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities) | |
messages = self.wrap_messages(prompt) | |
response = self.get_openai_response(messages) | |
future = extract(response,"future") | |
human = extract(response,"human") | |
prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case) | |
messages = self.wrap_messages(prompt) | |
response = self.get_openai_response(messages) | |
method = extract(response,"method") | |
novelty = extract(response,"novelty") | |
motivation = extract(response,"motivation") | |
idea = {"motivation":motivation,"novelty":novelty,"method":method} | |
prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic) | |
messages = self.wrap_messages(prompt) | |
response = self.get_openai_response(messages) | |
final_idea = extract(response,"final_idea") | |
idea = final_idea | |
self.deep_ideas.append(idea) | |
print(f"successfully deep research paper {paper.title}") | |
return idea,idea_chains,trend,experiments,total_entities,future,human,years | |
if __name__ == "__main__": | |
reader = SementicSearcher() | |