Spaces:
Runtime error
Runtime error
import json | |
import logging | |
import queue | |
import random | |
import re | |
import threading | |
import uuid | |
from collections import defaultdict | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from copy import deepcopy | |
from dataclasses import asdict | |
from typing import Dict, List, Optional | |
from lagent.actions import ActionExecutor | |
from lagent.agents import BaseAgent, Internlm2Agent | |
from lagent.agents.internlm2_agent import Internlm2Protocol | |
from lagent.schema import AgentReturn, AgentStatusCode, ModelStatusCode | |
from termcolor import colored | |
# 初始化日志记录 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class SearcherAgent(Internlm2Agent): | |
def __init__(self, template='{query}', **kwargs) -> None: | |
super().__init__(**kwargs) | |
self.template = template | |
def stream_chat(self, | |
question: str, | |
root_question: str = None, | |
parent_response: List[dict] = None, | |
**kwargs) -> AgentReturn: | |
message = self.template['input'].format(question=question, | |
topic=root_question) | |
if parent_response: | |
if 'context' in self.template: | |
parent_response = [ | |
self.template['context'].format(**item) | |
for item in parent_response | |
] | |
message = '\n'.join(parent_response + [message]) | |
print(colored(f'current query: {message}', 'green')) | |
for agent_return in super().stream_chat(message, | |
session_id=random.randint( | |
0, 999999), | |
**kwargs): | |
agent_return.type = 'searcher' | |
agent_return.content = question | |
yield deepcopy(agent_return) | |
class MindSearchProtocol(Internlm2Protocol): | |
def __init__( | |
self, | |
meta_prompt: str = None, | |
interpreter_prompt: str = None, | |
plugin_prompt: str = None, | |
few_shot: Optional[List] = None, | |
response_prompt: str = None, | |
language: Dict = dict( | |
begin='', | |
end='', | |
belong='assistant', | |
), | |
tool: Dict = dict( | |
begin='{start_token}{name}\n', | |
start_token='<|action_start|>', | |
name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'), | |
belong='assistant', | |
end='<|action_end|>\n', | |
), | |
execute: Dict = dict(role='execute', | |
begin='', | |
end='', | |
fallback_role='environment'), | |
) -> None: | |
self.response_prompt = response_prompt | |
super().__init__(meta_prompt=meta_prompt, | |
interpreter_prompt=interpreter_prompt, | |
plugin_prompt=plugin_prompt, | |
few_shot=few_shot, | |
language=language, | |
tool=tool, | |
execute=execute) | |
def format(self, | |
inner_step: List[Dict], | |
plugin_executor: ActionExecutor = None, | |
**kwargs) -> list: | |
formatted = [] | |
if self.meta_prompt: | |
formatted.append(dict(role='system', content=self.meta_prompt)) | |
if self.plugin_prompt: | |
plugin_prompt = self.plugin_prompt.format(tool_info=json.dumps( | |
plugin_executor.get_actions_info(), ensure_ascii=False)) | |
formatted.append( | |
dict(role='system', content=plugin_prompt, name='plugin')) | |
if self.interpreter_prompt: | |
formatted.append( | |
dict(role='system', | |
content=self.interpreter_prompt, | |
name='interpreter')) | |
if self.few_shot: | |
for few_shot in self.few_shot: | |
formatted += self.format_sub_role(few_shot) | |
formatted += self.format_sub_role(inner_step) | |
return formatted | |
class WebSearchGraph: | |
end_signal = 'end' | |
searcher_cfg = dict() | |
def __init__(self): | |
self.nodes = {} | |
self.adjacency_list = defaultdict(list) | |
self.executor = ThreadPoolExecutor(max_workers=10) | |
self.future_to_query = dict() | |
self.searcher_resp_queue = queue.Queue() | |
def add_root_node(self, node_content, node_name='root'): | |
self.nodes[node_name] = dict(content=node_content, type='root') | |
self.adjacency_list[node_name] = [] | |
self.searcher_resp_queue.put((node_name, self.nodes[node_name], [])) | |
def add_node(self, node_name, node_content): | |
self.nodes[node_name] = dict(content=node_content, type='searcher') | |
self.adjacency_list[node_name] = [] | |
def model_stream_thread(): | |
agent = SearcherAgent(**self.searcher_cfg) | |
try: | |
parent_nodes = [] | |
for start_node, adj in self.adjacency_list.items(): | |
for neighbor in adj: | |
if node_name == neighbor[ | |
'name'] and start_node in self.nodes and 'response' in self.nodes[ | |
start_node]: | |
parent_nodes.append(self.nodes[start_node]) | |
parent_response = [ | |
dict(question=node['content'], answer=node['response']) | |
for node in parent_nodes | |
] | |
for answer in agent.stream_chat( | |
node_content, | |
self.nodes['root']['content'], | |
parent_response=parent_response): | |
self.searcher_resp_queue.put( | |
deepcopy((node_name, | |
dict(response=answer.response, | |
detail=answer), []))) | |
self.nodes[node_name]['response'] = answer.response | |
self.nodes[node_name]['detail'] = answer | |
except Exception as e: | |
logger.exception(f'Error in model_stream_thread: {e}') | |
self.future_to_query[self.executor.submit( | |
model_stream_thread)] = f'{node_name}-{node_content}' | |
def add_response_node(self, node_name='response'): | |
self.nodes[node_name] = dict(type='end') | |
self.searcher_resp_queue.put((node_name, self.nodes[node_name], [])) | |
def add_edge(self, start_node, end_node): | |
self.adjacency_list[start_node].append( | |
dict(id=str(uuid.uuid4()), name=end_node, state=2)) | |
self.searcher_resp_queue.put((start_node, self.nodes[start_node], | |
self.adjacency_list[start_node])) | |
def reset(self): | |
self.nodes = {} | |
self.adjacency_list = defaultdict(list) | |
def node(self, node_name): | |
return self.nodes[node_name].copy() | |
class MindSearchAgent(BaseAgent): | |
def __init__(self, | |
llm, | |
searcher_cfg, | |
protocol=MindSearchProtocol(), | |
max_turn=10): | |
self.local_dict = {} | |
self.ptr = 0 | |
self.llm = llm | |
self.max_turn = max_turn | |
WebSearchGraph.searcher_cfg = searcher_cfg | |
super().__init__(llm=llm, action_executor=None, protocol=protocol) | |
def stream_chat(self, message, **kwargs): | |
if isinstance(message, str): | |
message = [{'role': 'user', 'content': message}] | |
elif isinstance(message, dict): | |
message = [message] | |
as_dict = kwargs.pop('as_dict', False) | |
return_early = kwargs.pop('return_early', False) | |
self.local_dict.clear() | |
self.ptr = 0 | |
inner_history = message[:] | |
agent_return = AgentReturn() | |
agent_return.type = 'planner' | |
agent_return.nodes = {} | |
agent_return.adjacency_list = {} | |
agent_return.inner_steps = deepcopy(inner_history) | |
for _ in range(self.max_turn): | |
prompt = self._protocol.format(inner_step=inner_history) | |
for model_state, response, _ in self.llm.stream_chat( | |
prompt, session_id=random.randint(0, 999999), **kwargs): | |
if model_state.value < 0: | |
agent_return.state = getattr(AgentStatusCode, | |
model_state.name) | |
yield deepcopy(agent_return) | |
return | |
response = response.replace('<|plugin|>', '<|interpreter|>') | |
_, language, action = self._protocol.parse(response) | |
if not language and not action: | |
continue | |
code = action['parameters']['command'] if action else '' | |
agent_return.state = self._determine_agent_state( | |
model_state, code, agent_return) | |
agent_return.response = language if not code else code | |
# if agent_return.state == AgentStatusCode.STREAM_ING: | |
yield deepcopy(agent_return) | |
inner_history.append({'role': 'language', 'content': language}) | |
print(colored(response, 'blue')) | |
if code: | |
yield from self._process_code(agent_return, inner_history, | |
code, as_dict, return_early) | |
else: | |
agent_return.state = AgentStatusCode.END | |
yield deepcopy(agent_return) | |
return | |
agent_return.state = AgentStatusCode.END | |
yield deepcopy(agent_return) | |
def _determine_agent_state(self, model_state, code, agent_return): | |
if code: | |
return (AgentStatusCode.PLUGIN_START if model_state | |
== ModelStatusCode.END else AgentStatusCode.PLUGIN_START) | |
return (AgentStatusCode.ANSWER_ING | |
if agent_return.nodes and 'response' in agent_return.nodes else | |
AgentStatusCode.STREAM_ING) | |
def _process_code(self, | |
agent_return, | |
inner_history, | |
code, | |
as_dict=False, | |
return_early=False): | |
for node_name, node, adj in self.execute_code( | |
code, return_early=return_early): | |
if as_dict and 'detail' in node: | |
node['detail'] = asdict(node['detail']) | |
if not adj: | |
agent_return.nodes[node_name] = node | |
else: | |
agent_return.adjacency_list[node_name] = adj | |
# state 1进行中,2未开始,3已结束 | |
for start_node, neighbors in agent_return.adjacency_list.items(): | |
for neighbor in neighbors: | |
if neighbor['name'] not in agent_return.nodes: | |
state = 2 | |
elif 'detail' not in agent_return.nodes[neighbor['name']]: | |
state = 2 | |
elif agent_return.nodes[neighbor['name']][ | |
'detail'].state == AgentStatusCode.END: | |
state = 3 | |
else: | |
state = 1 | |
neighbor['state'] = state | |
if not adj: | |
yield deepcopy((agent_return, node_name)) | |
reference, references_url = self._generate_reference( | |
agent_return, code, as_dict) | |
inner_history.append({ | |
'role': 'tool', | |
'content': code, | |
'name': 'plugin' | |
}) | |
inner_history.append({ | |
'role': 'environment', | |
'content': reference, | |
'name': 'plugin' | |
}) | |
agent_return.inner_steps = deepcopy(inner_history) | |
agent_return.state = AgentStatusCode.PLUGIN_RETURN | |
agent_return.references.update(references_url) | |
yield deepcopy(agent_return) | |
def _generate_reference(self, agent_return, code, as_dict): | |
node_list = [ | |
node.strip().strip('\"') for node in re.findall( | |
r'graph\.node\("((?:[^"\\]|\\.)*?)"\)', code) | |
] | |
if 'add_response_node' in code: | |
return self._protocol.response_prompt, dict() | |
references = [] | |
references_url = dict() | |
for node_name in node_list: | |
if as_dict: | |
ref_results = agent_return.nodes[node_name]['detail'][ | |
'actions'][0]['result'][0]['content'] | |
else: | |
ref_results = agent_return.nodes[node_name]['detail'].actions[ | |
0].result[0]['content'] | |
ref_results = json.loads(ref_results) | |
ref2url = {idx: item['url'] for idx, item in ref_results.items()} | |
ref = f"## {node_name}\n\n{agent_return.nodes[node_name]['response']}\n" | |
updated_ref = re.sub( | |
r'\[\[(\d+)\]\]', | |
lambda match: f'[[{int(match.group(1)) + self.ptr}]]', ref) | |
numbers = [int(n) for n in re.findall(r'\[\[(\d+)\]\]', ref)] | |
if numbers: | |
assert all(str(elem) in ref2url for elem in numbers) | |
references_url.update({ | |
str(idx + self.ptr): ref2url[str(idx)] | |
for idx in set(numbers) | |
}) | |
self.ptr += max(numbers) + 1 | |
references.append(updated_ref) | |
return '\n'.join(references), references_url | |
def execute_code(self, command: str, return_early=False): | |
def extract_code(text: str) -> str: | |
text = re.sub(r'from ([\w.]+) import WebSearchGraph', '', text) | |
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) | |
single_match = re.search(r'`([^`]*)`', text, re.DOTALL) | |
if triple_match: | |
return triple_match.group(1) | |
elif single_match: | |
return single_match.group(1) | |
return text | |
def run_command(cmd): | |
try: | |
exec(cmd, globals(), self.local_dict) | |
plan_graph = self.local_dict.get('graph') | |
assert plan_graph is not None | |
for future in as_completed(plan_graph.future_to_query): | |
future.result() | |
plan_graph.future_to_query.clear() | |
plan_graph.searcher_resp_queue.put(plan_graph.end_signal) | |
except Exception as e: | |
logger.exception(f'Error executing code: {e}') | |
command = extract_code(command) | |
producer_thread = threading.Thread(target=run_command, | |
args=(command, )) | |
producer_thread.start() | |
responses = defaultdict(list) | |
ordered_nodes = [] | |
active_node = None | |
while True: | |
try: | |
item = self.local_dict.get('graph').searcher_resp_queue.get( | |
timeout=60) | |
if item is WebSearchGraph.end_signal: | |
for node_name in ordered_nodes: | |
# resp = None | |
for resp in responses[node_name]: | |
yield deepcopy(resp) | |
# if resp: | |
# assert resp[1][ | |
# 'detail'].state == AgentStatusCode.END | |
break | |
node_name, node, adj = item | |
if node_name in ['root', 'response']: | |
yield deepcopy((node_name, node, adj)) | |
else: | |
if node_name not in ordered_nodes: | |
ordered_nodes.append(node_name) | |
responses[node_name].append((node_name, node, adj)) | |
if not active_node and ordered_nodes: | |
active_node = ordered_nodes[0] | |
while active_node and responses[active_node]: | |
if return_early: | |
if 'detail' in responses[active_node][-1][ | |
1] and responses[active_node][-1][1][ | |
'detail'].state == AgentStatusCode.END: | |
item = responses[active_node][-1] | |
else: | |
item = responses[active_node].pop(0) | |
else: | |
item = responses[active_node].pop(0) | |
if 'detail' in item[1] and item[1][ | |
'detail'].state == AgentStatusCode.END: | |
ordered_nodes.pop(0) | |
responses[active_node].clear() | |
active_node = None | |
yield deepcopy(item) | |
except queue.Empty: | |
if not producer_thread.is_alive(): | |
break | |
producer_thread.join() | |
return | |