Spaces:
Running
Running
import os | |
from typing import Sequence, List | |
import streamlit as st | |
from langchain.agents import AgentExecutor | |
from langchain.schema.language_model import BaseLanguageModel | |
from langchain.tools import BaseTool | |
from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter | |
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT | |
from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS | |
from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS | |
from logger import logger | |
try: | |
from sqlalchemy.orm import declarative_base | |
except ImportError: | |
from sqlalchemy.ext.declarative import declarative_base | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts.chat import MessagesPlaceholder | |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory | |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent | |
from langchain.schema.messages import SystemMessage | |
from langchain.memory import SQLChatMessageHistory | |
def create_agent_executor( | |
agent_name: str, | |
session_id: str, | |
llm: BaseLanguageModel, | |
tools: Sequence[BaseTool], | |
system_prompt: str, | |
**kwargs | |
) -> AgentExecutor: | |
agent_name = agent_name.replace(" ", "_") | |
conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}' | |
chat_memory = SQLChatMessageHistory( | |
session_id, | |
connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https', | |
custom_message_converter=DefaultClickhouseMessageConverter(agent_name)) | |
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory) | |
prompt = OpenAIFunctionsAgent.create_prompt( | |
system_message=SystemMessage(content=system_prompt), | |
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], | |
) | |
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) | |
return AgentExecutor( | |
agent=agent, | |
tools=tools, | |
memory=memory, | |
verbose=True, | |
return_intermediate_steps=True, | |
**kwargs | |
) | |
def build_agents( | |
session_id: str, | |
tool_names: List[str], | |
model: str = "gpt-3.5-turbo-0125", | |
temperature: float = 0.6, | |
system_prompt: str = DEFAULT_SYSTEM_PROMPT | |
): | |
chat_llm = ChatOpenAI( | |
model_name=model, | |
temperature=temperature, | |
base_url=GLOBAL_CONFIG.openai_api_base, | |
api_key=GLOBAL_CONFIG.openai_api_key, | |
streaming=True | |
) | |
tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS)) | |
selected_tools = [tools[k] for k in tool_names] | |
logger.info(f"create agent, use tools: {selected_tools}") | |
agent = create_agent_executor( | |
agent_name="chat_memory", | |
session_id=session_id, | |
llm=chat_llm, | |
tools=selected_tools, | |
system_prompt=system_prompt | |
) | |
return agent | |