Spaces:
Runtime error
Runtime error
import os | |
import pandas as pd | |
import chainlit as cl | |
from chainlit import user_session | |
from chainlit.types import LLMSettings | |
from chainlit.logger import logger | |
from langchain import LLMChain | |
from langchain.prompts import PromptTemplate | |
from langchain.llms import AzureOpenAI | |
from langchain.document_loaders import DataFrameLoader | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.vectorstores import Chroma | |
from langchain.vectorstores.base import VectorStoreRetriever | |
current_agent = "Demo" | |
vectordb = None | |
def load_agent(): | |
df = pd.read_excel(os.environ["AGENT_SHEET"], header=0, keep_default_na=False) | |
df = df[df["Agent"] == current_agent] | |
return df | |
def load_dialogues(): | |
df = pd.read_excel(os.environ["DIALOGUE_SHEET"], header=0, keep_default_na=False) | |
df = df[df["Agent"] == current_agent] | |
return df.astype(str) | |
def load_persona(): | |
df = pd.read_excel(os.environ["PERSONA_SHEET"], header=0, keep_default_na=False) | |
df = df[df["Agent"] == current_agent] | |
return df | |
def load_prompts(): | |
df = pd.read_excel(os.environ["PROMPT_SHEET"], header=0, keep_default_na=False) | |
df = df[df["Agent"] == current_agent] | |
return df | |
def load_documents(df, page_content_column: str): | |
return DataFrameLoader(df, page_content_column).load() | |
def init_embedding_function(): | |
EMBEDDING_MODEL_FOLDER = ".embedding-model" | |
return HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-mpnet-base-v2", | |
encode_kwargs={"normalize_embeddings": True}, | |
#cache_folder=EMBEDDING_MODEL_FOLDER, | |
) | |
def load_vectordb(init: bool = False): | |
global vectordb | |
VECTORDB_FOLDER = ".vectordb" | |
if not init and vectordb is None: | |
vectordb = Chroma( | |
embedding_function=init_embedding_function(), | |
persist_directory=VECTORDB_FOLDER, | |
) | |
if not vectordb.get()["ids"]: | |
init = True | |
else: | |
logger.info(f"Vector DB loaded") | |
if init: | |
vectordb = Chroma.from_documents( | |
documents=load_documents(load_dialogues(), page_content_column="Utterance"), | |
embedding=init_embedding_function(), | |
persist_directory=VECTORDB_FOLDER, | |
) | |
vectordb.persist() | |
logger.info(f"Vector DB initialised") | |
return vectordb | |
def get_retriever(context_state: str, vectordb): | |
return VectorStoreRetriever( | |
vectorstore=vectordb, | |
search_type="similarity", | |
search_kwargs={ | |
"filter": { | |
"$or": [{"Context": {"$eq": ""}}, {"Context": {"$eq": context_state}}] | |
}, | |
"k": 1, | |
}, | |
) | |
def factory(): | |
df_agent = load_agent() | |
load_vectordb() | |
user_session.set("context_state", "") | |
user_session.set("df_prompts", load_prompts()) | |
user_session.set("df_persona", load_persona()) | |
llm_settings = LLMSettings( | |
model_name="text-davinci-003", | |
temperature=0.7, | |
) | |
user_session.set("llm_settings", llm_settings) | |
chat_memory = ConversationBufferWindowMemory( | |
memory_key="History", | |
input_key="Utterance", | |
k=df_agent["History"].values[0], | |
) | |
user_session.set("chat_memory", chat_memory) | |
llm = AzureOpenAI( | |
deployment_name="davinci003", | |
model_name=llm_settings.model_name, | |
temperature=llm_settings.temperature, | |
streaming=True, | |
openai_api_key=cl.user_session.get("env").get("OPENAI_API_KEY") if "OPENAI_API_KEY" not in os.environ else None | |
) | |
default_prompt = """{History} | |
## | |
System: {Persona} | |
## | |
Human: {Utterance} | |
Response: {Response} | |
## | |
AI:""" | |
return LLMChain( | |
prompt=PromptTemplate.from_template(default_prompt), | |
llm=llm, | |
verbose=True, | |
memory=chat_memory, | |
) | |
async def run(agent, input_str): | |
global vectordb | |
if input_str == "/reload": | |
vectordb = load_vectordb(True) | |
return await cl.Message(content="Data loaded").send() | |
df_prompts = user_session.get("df_prompts") | |
df_persona = user_session.get("df_persona") | |
llm_settings = user_session.get("llm_settings") | |
retriever = get_retriever(user_session.get("context_state"), vectordb) | |
document = retriever.get_relevant_documents(query=input_str) | |
prompt = document[0].metadata["Prompt"] | |
if not prompt: | |
await cl.Message( | |
content=document[0].metadata["Response"], | |
author=document[0].metadata["Role"], | |
).send() | |
else: | |
agent.prompt = PromptTemplate.from_template( | |
df_prompts.loc[df_prompts["Prompt"] == prompt]["Template"].values[0] | |
) | |
llm_settings.temperature = df_prompts.loc[df_prompts["Prompt"] == prompt][ | |
"Temperature" | |
].values[0] | |
agent.llm.temperature = llm_settings.temperature | |
response = await agent.acall( | |
{ | |
"Persona": df_persona.loc[ | |
df_persona["Role"] == document[0].metadata["Role"] | |
]["Persona"].values[0], | |
"Utterance": input_str, | |
"Response": document[0].metadata["Response"], | |
}, | |
callbacks=[cl.AsyncLangchainCallbackHandler()], | |
) | |
await cl.Message( | |
content=response["text"], | |
author=document[0].metadata["Role"], | |
llm_settings=llm_settings, | |
).send() | |
user_session.set("context_state", document[0].metadata["Contextualisation"]) | |
continuation = document[0].metadata["Continuation"] | |
while continuation != "": | |
document_continuation = vectordb.get(where={"Intent": continuation}) | |
prompt = document_continuation["metadatas"][0]["Prompt"] | |
if not prompt: | |
await cl.Message( | |
content=document_continuation["metadatas"][0]["Response"], | |
author=document_continuation["metadatas"][0]["Role"], | |
).send() | |
else: | |
agent.prompt = PromptTemplate.from_template( | |
df_prompts.loc[df_prompts["Prompt"] == prompt]["Template"].values[0] | |
) | |
llm_settings.temperature = df_prompts.loc[df_prompts["Prompt"] == prompt][ | |
"Temperature" | |
].values[0] | |
agent.llm.temperature = llm_settings.temperature | |
response = await agent.acall( | |
{ | |
"Persona": df_persona.loc[ | |
df_persona["Role"] | |
== document_continuation["metadatas"][0]["Role"] | |
]["Persona"].values[0], | |
"Utterance": "", | |
"Response": document_continuation["metadatas"][0]["Response"], | |
}, | |
callbacks=[cl.AsyncLangchainCallbackHandler()], | |
) | |
await cl.Message( | |
content=response["text"], | |
author=document_continuation["metadatas"][0]["Role"], | |
llm_settings=llm_settings, | |
).send() | |
user_session.set( | |
"context_state", | |
document_continuation["metadatas"][0]["Contextualisation"], | |
) | |
continuation = document_continuation["metadatas"][0]["Continuation"] | |