import json import time from searcher import Result,SementicSearcher from LLM import openai_llm from prompts import * from utils import extract def get_llm(model = "gpt4o-0513"): return openai_llm(model) def get_llms(): main_llm = get_llm("gpt-4o-2024-08-06") cheap_llm = get_llm("gpt-4o-mini") return main_llm,cheap_llm def judge_idea(i,j,idea0,idea1,topic,llm): prompt = get_judge_idea_all_prompt(idea0,idea1,topic) messages = [{"role":"user","content":prompt}] response = llm.response(messages) novelty = extract(response,"novelty") relevance = extract(response,"relevance") significance = extract(response,"significance") clarity = extract(response,"clarity") feasibility = extract(response,"feasibility") effectiveness = extract(response,"effectiveness") return i,j,novelty,relevance,significance,clarity,feasibility,effectiveness class DeepResearchAgent: def __init__(self,llm = None,cheap_llm=None,publicationData = None,ban_paper = [],**kwargs) -> None: self.reader = SementicSearcher(ban_paper = ban_paper) self.begin_time = time.time() self.llm = llm self.cheap_llm = cheap_llm self.read_papers = set() self.paper_storage = [] self.paper_info_for_refine_experiment = [] self.search_qeuries = [] self.deep_research_chains = [] self.deep_ideas = [] self.check_novel_results = [] self.score_results = [] self.topic =None self.publicationData = publicationData self.improve_cnt = kwargs.get("improve_cnt",1) self.max_chain_length = kwargs.get("max_chain_length",5) self.min_chain_length = kwargs.get("min_chain_length",3) self.max_chain_numbers = kwargs.get("max_chain_numbers",10) def wrap_messages(self,prompt): return [{"role":"user","content":prompt}] def get_openai_response(self,messages): return self.llm.response(messages) def get_cheap_openai_response(self,messages): return self.cheap_llm.response(messages,max_tokens = 16000) def get_search_query(self,topic = None,query=None): prompt = get_deep_search_query_prompt(topic,query) messages = self.wrap_messages(prompt) response = self.get_openai_response(messages) search_query = extract(response,"queries") try: search_query = json.loads(search_query) self.search_qeuries.append({"query":query,"search_query":search_query}) except: search_query = [query] return search_query def generate_idea_with_chain(self,topic): self.topic = topic print(f"begin to generate search query for {topic}") search_query = self.get_search_query(topic=topic) papers = [] for query in search_query: failed_query = [] current_papers = [] cnt = 0 while len(current_papers) == 0 and cnt < 10: paper = self.reader.search(query,1,paper_list=self.read_papers,llm=self.llm,rerank_query=f"{topic}",publicationDate=self.publicationData) if paper and len(paper) > 0 and paper[0]: self.read_papers.add(paper[0].title) current_papers.append(paper[0]) else: failed_query.append(query) prompt = get_deep_rewrite_query_prompt(failed_query,topic) messages = self.wrap_messages(prompt) new_query = self.get_openai_response(messages) new_query = extract(new_query,"query") print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.") query = new_query cnt += 1 papers.extend(current_papers) if len(papers) >= self.max_chain_numbers: break if len(papers) == 0: print(f"failed to generate idea {topic}") return None,None,None,None,None,None,None,None,None idea,idea_chain,experiment,entities,trend,future,human,year = self.deep_research_paper_with_chain(papers[0]) print(f"successfully generated idea") return idea,experiment,entities,idea_chain,idea,trend,future,human,year def get_paper_idea_experiment_references_info(self,paper): article = paper.article if not article: return None paper_content = self.reader.read_paper_content(article) prompt = get_deep_reference_prompt(paper_content,self.topic) messages = self.wrap_messages(prompt) response = self.get_cheap_openai_response(messages) entities = extract(response,"entities") idea = extract(response,"idea") experiment = extract(response,"experiment") references = extract(response,"references") return idea,experiment,entities,references,paper.title def get_article_idea_experiment_references_info(self,article): paper_content = self.reader.read_paper_content_with_ref(article) prompt = get_deep_reference_prompt(paper_content,self.topic) messages = self.wrap_messages(prompt) response = self.get_cheap_openai_response(messages) entities = extract(response,"entities") idea = extract(response,"idea") experiment = extract(response,"experiment") references = extract(response,"references") return idea,experiment,entities,references def deep_research_paper_with_chain(self,paper:Result): print(f"begin to deep research paper {paper.title}") article = paper.article if not article: print(f"failed to deep research paper {paper.title}") return None idea_chain = [] idea_papers = [] experiments = [] total_entities = [] years = [] idea,experiment,entities,references = self.get_article_idea_experiment_references_info(article) try: references = json.loads(references) except: references = [] total_entities.append(entities) idea_chain.append(idea) idea_papers.append(paper.title) experiments.append(experiment) years.append(paper.year) current_title = paper.title current_abstract = paper.abstract # search before while len(idea_chain) 0: search_paper = [] article = None print(f"The references find:{references}") while len(references) > 0 and len(search_paper) == 0: reference = references[0] references.pop(0) if reference in self.read_papers: continue search_paper = self.reader.search(reference,3,llm=self.llm,publicationDate=self.publicationData,paper_list= idea_papers) if len(search_paper) > 0: s_p = search_paper[0] if s_p and s_p.title not in self.read_papers: prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic) messages = self.wrap_messages(prompt) response = self.get_openai_response(messages) relevant = extract(response,"relevant") if relevant != "0" or len(idea_chain) < self.min_chain_length: article = s_p.article if article: cite_paper = s_p break else: print(f"the paper {s_p.title} is not relevant") search_paper = [] if not article: rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}" search_paper = self.reader.search_related_paper(current_title,need_citation=False,rerank_query = rerank_query,llm=self.llm,paper_list=idea_papers) if not search_paper: print(f"failed to find citation paper for {current_title}") continue s_p = search_paper if len(idea_chain) < self.min_chain_length: article = s_p.article if not article: continue else: cite_paper = s_p break else: if s_p and s_p.title not in self.read_papers: prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic) messages = self.wrap_messages(prompt) response = self.get_openai_response(messages) relevant = extract(response,"relevant") if relevant == "1" or len(idea_chain) < self.min_chain_length: article = s_p.article if not article: continue else: cite_paper = s_p break if not article: print(f"failed to find citation paper for {current_title}") continue print("find the citation paper, begin to deep research") paper_content = self.reader.read_paper_content_with_ref(article) prompt = get_deep_reference_prompt(paper_content,self.topic) messages = self.wrap_messages(prompt) response = self.get_cheap_openai_response(messages) idea = extract(response,"idea") references = extract(response,"references") experiment = extract(response,"experiment") entities = extract(response,"entities") try: references = json.loads(references) except: references = [] current_title = cite_paper.title current_abstract = cite_paper.abstract years = [cite_paper.year] + years idea_chain = [idea] + idea_chain idea_papers = [cite_paper.title] + idea_papers experiments = [experiment] + experiments total_entities = [entities] + total_entities if len(idea_chain) >= self.min_chain_length: if cite_paper.citations_conut > 1000: break print("successfully generate idea chain") idea_chains = "" for i,idea,title in zip(range(len(idea_chain)),idea_chain,idea_papers): idea_chains += f"{i}.Paper:{title} idea:{idea}\n \n" prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic) messages = self.wrap_messages(prompt) response = self.get_openai_response(messages) trend = extract(response,"trend") self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years}) prompt = f"""The current research topic is: {self.topic}. Please help me summarize and refine the following entities by merging, simplifying, or deleting them : {total_entities} Please output strictly in the following format: {{cleaned entities}} """ messages = self.wrap_messages(prompt) response = self.get_openai_response(messages) total_entities = extract(response,"entities") bad_case = [] prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities) messages = self.wrap_messages(prompt) response = self.get_openai_response(messages) future = extract(response,"future") human = extract(response,"human") prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case) messages = self.wrap_messages(prompt) response = self.get_openai_response(messages) method = extract(response,"method") novelty = extract(response,"novelty") motivation = extract(response,"motivation") idea = {"motivation":motivation,"novelty":novelty,"method":method} prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic) messages = self.wrap_messages(prompt) response = self.get_openai_response(messages) final_idea = extract(response,"final_idea") idea = final_idea self.deep_ideas.append(idea) print(f"successfully deep research paper {paper.title}") return idea,idea_chains,trend,experiments,total_entities,future,human,years if __name__ == "__main__": reader = SementicSearcher()