MindSearch / mindsearch /agent /mindsearch_agent.py
vansin's picture
feat: update
f3f614f
import json
import logging
import queue
import random
import re
import threading
import uuid
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from dataclasses import asdict
from typing import Dict, List, Optional
from lagent.actions import ActionExecutor
from lagent.agents import BaseAgent, Internlm2Agent
from lagent.agents.internlm2_agent import Internlm2Protocol
from lagent.schema import AgentReturn, AgentStatusCode, ModelStatusCode
from termcolor import colored
# 初始化日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SearcherAgent(Internlm2Agent):
def __init__(self, template='{query}', **kwargs) -> None:
super().__init__(**kwargs)
self.template = template
def stream_chat(self,
question: str,
root_question: str = None,
parent_response: List[dict] = None,
**kwargs) -> AgentReturn:
message = self.template['input'].format(question=question,
topic=root_question)
if parent_response:
if 'context' in self.template:
parent_response = [
self.template['context'].format(**item)
for item in parent_response
]
message = '\n'.join(parent_response + [message])
print(colored(f'current query: {message}', 'green'))
for agent_return in super().stream_chat(message,
session_id=random.randint(
0, 999999),
**kwargs):
agent_return.type = 'searcher'
agent_return.content = question
yield deepcopy(agent_return)
class MindSearchProtocol(Internlm2Protocol):
def __init__(
self,
meta_prompt: str = None,
interpreter_prompt: str = None,
plugin_prompt: str = None,
few_shot: Optional[List] = None,
response_prompt: str = None,
language: Dict = dict(
begin='',
end='',
belong='assistant',
),
tool: Dict = dict(
begin='{start_token}{name}\n',
start_token='<|action_start|>',
name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'),
belong='assistant',
end='<|action_end|>\n',
),
execute: Dict = dict(role='execute',
begin='',
end='',
fallback_role='environment'),
) -> None:
self.response_prompt = response_prompt
super().__init__(meta_prompt=meta_prompt,
interpreter_prompt=interpreter_prompt,
plugin_prompt=plugin_prompt,
few_shot=few_shot,
language=language,
tool=tool,
execute=execute)
def format(self,
inner_step: List[Dict],
plugin_executor: ActionExecutor = None,
**kwargs) -> list:
formatted = []
if self.meta_prompt:
formatted.append(dict(role='system', content=self.meta_prompt))
if self.plugin_prompt:
plugin_prompt = self.plugin_prompt.format(tool_info=json.dumps(
plugin_executor.get_actions_info(), ensure_ascii=False))
formatted.append(
dict(role='system', content=plugin_prompt, name='plugin'))
if self.interpreter_prompt:
formatted.append(
dict(role='system',
content=self.interpreter_prompt,
name='interpreter'))
if self.few_shot:
for few_shot in self.few_shot:
formatted += self.format_sub_role(few_shot)
formatted += self.format_sub_role(inner_step)
return formatted
class WebSearchGraph:
end_signal = 'end'
searcher_cfg = dict()
def __init__(self):
self.nodes = {}
self.adjacency_list = defaultdict(list)
self.executor = ThreadPoolExecutor(max_workers=10)
self.future_to_query = dict()
self.searcher_resp_queue = queue.Queue()
def add_root_node(self, node_content, node_name='root'):
self.nodes[node_name] = dict(content=node_content, type='root')
self.adjacency_list[node_name] = []
self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
def add_node(self, node_name, node_content):
self.nodes[node_name] = dict(content=node_content, type='searcher')
self.adjacency_list[node_name] = []
def model_stream_thread():
agent = SearcherAgent(**self.searcher_cfg)
try:
parent_nodes = []
for start_node, adj in self.adjacency_list.items():
for neighbor in adj:
if node_name == neighbor[
'name'] and start_node in self.nodes and 'response' in self.nodes[
start_node]:
parent_nodes.append(self.nodes[start_node])
parent_response = [
dict(question=node['content'], answer=node['response'])
for node in parent_nodes
]
for answer in agent.stream_chat(
node_content,
self.nodes['root']['content'],
parent_response=parent_response):
self.searcher_resp_queue.put(
deepcopy((node_name,
dict(response=answer.response,
detail=answer), [])))
self.nodes[node_name]['response'] = answer.response
self.nodes[node_name]['detail'] = answer
except Exception as e:
logger.exception(f'Error in model_stream_thread: {e}')
self.future_to_query[self.executor.submit(
model_stream_thread)] = f'{node_name}-{node_content}'
def add_response_node(self, node_name='response'):
self.nodes[node_name] = dict(type='end')
self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
def add_edge(self, start_node, end_node):
self.adjacency_list[start_node].append(
dict(id=str(uuid.uuid4()), name=end_node, state=2))
self.searcher_resp_queue.put((start_node, self.nodes[start_node],
self.adjacency_list[start_node]))
def reset(self):
self.nodes = {}
self.adjacency_list = defaultdict(list)
def node(self, node_name):
return self.nodes[node_name].copy()
class MindSearchAgent(BaseAgent):
def __init__(self,
llm,
searcher_cfg,
protocol=MindSearchProtocol(),
max_turn=10):
self.local_dict = {}
self.ptr = 0
self.llm = llm
self.max_turn = max_turn
WebSearchGraph.searcher_cfg = searcher_cfg
super().__init__(llm=llm, action_executor=None, protocol=protocol)
def stream_chat(self, message, **kwargs):
if isinstance(message, str):
message = [{'role': 'user', 'content': message}]
elif isinstance(message, dict):
message = [message]
as_dict = kwargs.pop('as_dict', False)
return_early = kwargs.pop('return_early', False)
self.local_dict.clear()
self.ptr = 0
inner_history = message[:]
agent_return = AgentReturn()
agent_return.type = 'planner'
agent_return.nodes = {}
agent_return.adjacency_list = {}
agent_return.inner_steps = deepcopy(inner_history)
for _ in range(self.max_turn):
prompt = self._protocol.format(inner_step=inner_history)
for model_state, response, _ in self.llm.stream_chat(
prompt, session_id=random.randint(0, 999999), **kwargs):
if model_state.value < 0:
agent_return.state = getattr(AgentStatusCode,
model_state.name)
yield deepcopy(agent_return)
return
response = response.replace('<|plugin|>', '<|interpreter|>')
_, language, action = self._protocol.parse(response)
if not language and not action:
continue
code = action['parameters']['command'] if action else ''
agent_return.state = self._determine_agent_state(
model_state, code, agent_return)
agent_return.response = language if not code else code
# if agent_return.state == AgentStatusCode.STREAM_ING:
yield deepcopy(agent_return)
inner_history.append({'role': 'language', 'content': language})
print(colored(response, 'blue'))
if code:
yield from self._process_code(agent_return, inner_history,
code, as_dict, return_early)
else:
agent_return.state = AgentStatusCode.END
yield deepcopy(agent_return)
return
agent_return.state = AgentStatusCode.END
yield deepcopy(agent_return)
def _determine_agent_state(self, model_state, code, agent_return):
if code:
return (AgentStatusCode.PLUGIN_START if model_state
== ModelStatusCode.END else AgentStatusCode.PLUGIN_START)
return (AgentStatusCode.ANSWER_ING
if agent_return.nodes and 'response' in agent_return.nodes else
AgentStatusCode.STREAM_ING)
def _process_code(self,
agent_return,
inner_history,
code,
as_dict=False,
return_early=False):
for node_name, node, adj in self.execute_code(
code, return_early=return_early):
if as_dict and 'detail' in node:
node['detail'] = asdict(node['detail'])
if not adj:
agent_return.nodes[node_name] = node
else:
agent_return.adjacency_list[node_name] = adj
# state 1进行中,2未开始,3已结束
for start_node, neighbors in agent_return.adjacency_list.items():
for neighbor in neighbors:
if neighbor['name'] not in agent_return.nodes:
state = 2
elif 'detail' not in agent_return.nodes[neighbor['name']]:
state = 2
elif agent_return.nodes[neighbor['name']][
'detail'].state == AgentStatusCode.END:
state = 3
else:
state = 1
neighbor['state'] = state
if not adj:
yield deepcopy((agent_return, node_name))
reference, references_url = self._generate_reference(
agent_return, code, as_dict)
inner_history.append({
'role': 'tool',
'content': code,
'name': 'plugin'
})
inner_history.append({
'role': 'environment',
'content': reference,
'name': 'plugin'
})
agent_return.inner_steps = deepcopy(inner_history)
agent_return.state = AgentStatusCode.PLUGIN_RETURN
agent_return.references.update(references_url)
yield deepcopy(agent_return)
def _generate_reference(self, agent_return, code, as_dict):
node_list = [
node.strip().strip('\"') for node in re.findall(
r'graph\.node\("((?:[^"\\]|\\.)*?)"\)', code)
]
if 'add_response_node' in code:
return self._protocol.response_prompt, dict()
references = []
references_url = dict()
for node_name in node_list:
if as_dict:
ref_results = agent_return.nodes[node_name]['detail'][
'actions'][0]['result'][0]['content']
else:
ref_results = agent_return.nodes[node_name]['detail'].actions[
0].result[0]['content']
ref_results = json.loads(ref_results)
ref2url = {idx: item['url'] for idx, item in ref_results.items()}
ref = f"## {node_name}\n\n{agent_return.nodes[node_name]['response']}\n"
updated_ref = re.sub(
r'\[\[(\d+)\]\]',
lambda match: f'[[{int(match.group(1)) + self.ptr}]]', ref)
numbers = [int(n) for n in re.findall(r'\[\[(\d+)\]\]', ref)]
if numbers:
assert all(str(elem) in ref2url for elem in numbers)
references_url.update({
str(idx + self.ptr): ref2url[str(idx)]
for idx in set(numbers)
})
self.ptr += max(numbers) + 1
references.append(updated_ref)
return '\n'.join(references), references_url
def execute_code(self, command: str, return_early=False):
def extract_code(text: str) -> str:
text = re.sub(r'from ([\w.]+) import WebSearchGraph', '', text)
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
if triple_match:
return triple_match.group(1)
elif single_match:
return single_match.group(1)
return text
def run_command(cmd):
try:
exec(cmd, globals(), self.local_dict)
plan_graph = self.local_dict.get('graph')
assert plan_graph is not None
for future in as_completed(plan_graph.future_to_query):
future.result()
plan_graph.future_to_query.clear()
plan_graph.searcher_resp_queue.put(plan_graph.end_signal)
except Exception as e:
logger.exception(f'Error executing code: {e}')
command = extract_code(command)
producer_thread = threading.Thread(target=run_command,
args=(command, ))
producer_thread.start()
responses = defaultdict(list)
ordered_nodes = []
active_node = None
while True:
try:
item = self.local_dict.get('graph').searcher_resp_queue.get(
timeout=60)
if item is WebSearchGraph.end_signal:
for node_name in ordered_nodes:
# resp = None
for resp in responses[node_name]:
yield deepcopy(resp)
# if resp:
# assert resp[1][
# 'detail'].state == AgentStatusCode.END
break
node_name, node, adj = item
if node_name in ['root', 'response']:
yield deepcopy((node_name, node, adj))
else:
if node_name not in ordered_nodes:
ordered_nodes.append(node_name)
responses[node_name].append((node_name, node, adj))
if not active_node and ordered_nodes:
active_node = ordered_nodes[0]
while active_node and responses[active_node]:
if return_early:
if 'detail' in responses[active_node][-1][
1] and responses[active_node][-1][1][
'detail'].state == AgentStatusCode.END:
item = responses[active_node][-1]
else:
item = responses[active_node].pop(0)
else:
item = responses[active_node].pop(0)
if 'detail' in item[1] and item[1][
'detail'].state == AgentStatusCode.END:
ordered_nodes.pop(0)
responses[active_node].clear()
active_node = None
yield deepcopy(item)
except queue.Empty:
if not producer_thread.is_alive():
break
producer_thread.join()
return