ChatData / chat.py
lqhl's picture
Synced repo using 'sync_with_huggingface' Github Action
0e573d0 verified
raw
history blame
15.4 kB
import json
import pandas as pd
from os import environ
from time import sleep
import datetime
import streamlit as st
from lib.sessions import SessionManager
from lib.private_kb import PrivateKnowledgeBase
from langchain.schema import HumanMessage, FunctionMessage
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
from lib.json_conv import CustomJSONDecoder
from lib.helper import (
build_agents,
MYSCALE_HOST,
MYSCALE_PASSWORD,
MYSCALE_PORT,
MYSCALE_USER,
DEFAULT_SYSTEM_PROMPT,
UNSTRUCTURED_API,
)
from login import back_to_main
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
TOOL_NAMES = {
"langchain_retriever_tool": "Self-querying retriever",
"vecsql_retriever_tool": "Vector SQL",
}
def on_chat_submit():
with st.session_state.next_round.container():
with st.chat_message("user"):
st.write(st.session_state.chat_input)
with st.chat_message("assistant"):
container = st.container()
st_callback = ChatDataAgentCallBackHandler(
container, collapse_completed_thoughts=False
)
ret = st.session_state.agent(
{"input": st.session_state.chat_input}, callbacks=[st_callback]
)
print(ret)
def clear_history():
if "agent" in st.session_state:
st.session_state.agent.memory.clear()
def back_to_main():
if "user_info" in st.session_state:
del st.session_state.user_info
if "user_name" in st.session_state:
del st.session_state.user_name
if "jump_query_ask" in st.session_state:
del st.session_state.jump_query_ask
if "sel_sess" in st.session_state:
del st.session_state.sel_sess
if "current_sessions" in st.session_state:
del st.session_state.current_sessions
def on_session_change_submit():
if "session_manager" in st.session_state and "session_editor" in st.session_state:
print(st.session_state.session_editor)
try:
for elem in st.session_state.session_editor["added_rows"]:
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
if elem["session_id"] != "" and "?" not in elem["session_id"]:
st.session_state.session_manager.add_session(
user_id=st.session_state.user_name,
session_id=f"{st.session_state.user_name}?{elem['session_id']}",
system_prompt=elem["system_prompt"],
)
else:
raise KeyError(
"`session_id` should NOT be neither empty nor contain question marks."
)
else:
raise KeyError(
"You should fill both `session_id` and `system_prompt` to add a column!"
)
for elem in st.session_state.session_editor["deleted_rows"]:
st.session_state.session_manager.remove_session(
session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}",
)
refresh_sessions()
except Exception as e:
sleep(2)
st.error(f"{type(e)}: {str(e)}")
finally:
st.session_state.session_editor["added_rows"] = []
st.session_state.session_editor["deleted_rows"] = []
refresh_agent()
def build_session_manager():
return SessionManager(
st.session_state,
host=MYSCALE_HOST,
port=MYSCALE_PORT,
username=MYSCALE_USER,
password=MYSCALE_PASSWORD,
)
def refresh_sessions():
st.session_state[
"current_sessions"
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
if (
type(st.session_state.current_sessions) is not dict
and len(st.session_state.current_sessions) <= 0
):
st.session_state.session_manager.add_session(
st.session_state.user_name,
f"{st.session_state.user_name}?default",
DEFAULT_SYSTEM_PROMPT,
)
st.session_state[
"current_sessions"
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
st.session_state["user_files"] = st.session_state.private_kb.list_files(
st.session_state.user_name
)
st.session_state["user_tools"] = st.session_state.private_kb.list_tools(
st.session_state.user_name
)
st.session_state["tools_with_users"] = {
**st.session_state.tools,
**st.session_state.private_kb.as_tools(st.session_state.user_name),
}
try:
dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index(
"default"
if "" not in st.session_state
else st.session_state.sel_session["session_id"]
)
except ValueError:
dfl_indx = 0
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
def build_kb_as_tool():
if (
"b_tool_name" in st.session_state
and "b_tool_desc" in st.session_state
and "b_tool_files" in st.session_state
and len(st.session_state.b_tool_name) > 0
and len(st.session_state.b_tool_desc) > 0
and len(st.session_state.b_tool_files) > 0
):
st.session_state.private_kb.create_tool(
st.session_state.user_name,
st.session_state.b_tool_name,
st.session_state.b_tool_desc,
[f["file_name"] for f in st.session_state.b_tool_files],
)
refresh_sessions()
else:
st.session_state.tool_status.error(
"You should fill all fields to build up a tool!"
)
sleep(2)
def remove_kb():
if "r_tool_names" in st.session_state and len(st.session_state.r_tool_names) > 0:
st.session_state.private_kb.remove_tools(
st.session_state.user_name,
[f["tool_name"] for f in st.session_state.r_tool_names],
)
refresh_sessions()
else:
st.session_state.tool_status.error(
"You should specify at least one tool to delete!"
)
sleep(2)
def refresh_agent():
with st.spinner("Initializing session..."):
print(
f"??? Changed to ",
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
)
st.session_state["agent"] = build_agents(
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
["LangChain Self Query Retriever For Wikipedia"]
if "selected_tools" not in st.session_state
else st.session_state.selected_tools,
system_prompt=DEFAULT_SYSTEM_PROMPT
if "sel_sess" not in st.session_state
else st.session_state.sel_sess["system_prompt"],
)
def add_file():
if (
"uploaded_files" not in st.session_state
or len(st.session_state.uploaded_files) == 0
):
st.session_state.tool_status.error("Please upload files!", icon="⚠️")
sleep(2)
return
try:
st.session_state.tool_status.info("Uploading...")
st.session_state.private_kb.add_by_file(
st.session_state.user_name, st.session_state.uploaded_files
)
refresh_sessions()
except ValueError as e:
st.session_state.tool_status.error("Failed to upload! " + str(e))
sleep(2)
def clear_files():
st.session_state.private_kb.clear(st.session_state.user_name)
refresh_sessions()
def chat_page():
if "sel_sess" not in st.session_state:
st.session_state["sel_sess"] = {
"session_id": "default",
"system_prompt": DEFAULT_SYSTEM_PROMPT,
}
if "private_kb" not in st.session_state:
st.session_state["private_kb"] = PrivateKnowledgeBase(
host=MYSCALE_HOST,
port=MYSCALE_PORT,
username=MYSCALE_USER,
password=MYSCALE_PASSWORD,
embedding=st.session_state.embeddings["Wikipedia"],
parser_api_key=UNSTRUCTURED_API,
)
if "session_manager" not in st.session_state:
st.session_state["session_manager"] = build_session_manager()
with st.sidebar:
with st.expander("Session Management"):
if "current_sessions" not in st.session_state:
refresh_sessions()
st.info(
"Here you can set up your session! \n\nYou can **change your prompt** here!",
icon="πŸ€–",
)
st.info(
(
"**Add columns by clicking the empty row**.\n"
"And **delete columns by selecting rows with a press on `DEL` Key**"
),
icon="πŸ’‘",
)
st.info(
"Don't forget to **click `Submit Change` to save your change**!",
icon="πŸ“’",
)
st.data_editor(
st.session_state.current_sessions,
num_rows="dynamic",
key="session_editor",
use_container_width=True,
)
st.button("Submit Change!", on_click=on_session_change_submit)
with st.expander("Session Selection", expanded=True):
st.info(
"If no session is attach to your account, then we will add a default session to you!",
icon="❀️",
)
try:
dfl_indx = [
x["session_id"] for x in st.session_state.current_sessions
].index(
"default"
if "" not in st.session_state
else st.session_state.sel_session["session_id"]
)
except Exception as e:
print("*** ", str(e))
dfl_indx = 0
st.selectbox(
"Choose a session to chat:",
options=st.session_state.current_sessions,
index=dfl_indx,
key="sel_sess",
format_func=lambda x: x["session_id"],
on_change=refresh_agent,
)
print(st.session_state.sel_sess)
with st.expander("Tool Settings", expanded=True):
st.info(
"We provides you several knowledge base tools for you. We are building more tools!",
icon="πŸ”§",
)
st.session_state["tool_status"] = st.empty()
tab_kb, tab_file = st.tabs(
[
"Knowledge Bases",
"File Upload",
]
)
with tab_kb:
st.markdown("#### Build You Own Knowledge")
st.multiselect(
"Select Files to Build up",
st.session_state.user_files,
placeholder="You should upload files first",
key="b_tool_files",
format_func=lambda x: x["file_name"],
)
st.text_input(
"Tool Name", "get_relevant_documents", key="b_tool_name")
st.text_input(
"Tool Description",
"Searches among user's private files and returns related documents",
key="b_tool_desc",
)
st.button("Build!", on_click=build_kb_as_tool)
st.markdown("### Knowledge Base Selection")
if (
"user_tools" in st.session_state
and len(st.session_state.user_tools) > 0
):
st.markdown("***User Created Knowledge Bases***")
st.dataframe(st.session_state.user_tools)
st.multiselect(
"Select a Knowledge Base Tool",
st.session_state.tools.keys()
if "tools_with_users" not in st.session_state
else st.session_state.tools_with_users,
default=["Wikipedia + Self Querying"],
key="selected_tools",
on_change=refresh_agent,
)
st.markdown("### Delete Knowledge Base")
st.multiselect(
"Choose Knowledge Base to Remove",
st.session_state.user_tools,
format_func=lambda x: x["tool_name"],
key="r_tool_names",
)
st.button("Delete", on_click=remove_kb)
with tab_file:
st.info(
(
"We adopted [Unstructured API](https://unstructured.io/api-key) "
"here and we only store the processed texts from your documents. "
"For privacy concerns, please refer to "
"[our policy issue](https://myscale.com/privacy/)."
),
icon="πŸ“ƒ",
)
st.file_uploader(
"Upload files", key="uploaded_files", accept_multiple_files=True
)
st.markdown("### Uploaded Files")
st.dataframe(
st.session_state.private_kb.list_files(
st.session_state.user_name),
use_container_width=True,
)
col_1, col_2 = st.columns(2)
with col_1:
st.button("Add Files", on_click=add_file)
with col_2:
st.button("Clear Files and All Tools",
on_click=clear_files)
st.button("Clear Chat History", on_click=clear_history)
st.button("Logout", on_click=back_to_main)
if "agent" not in st.session_state:
refresh_agent()
print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
for msg in st.session_state.agent.memory.chat_memory.messages:
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
if isinstance(msg, FunctionMessage):
with st.chat_message("Knowledge Base", avatar="πŸ“–"):
st.write(
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
)
st.write("Retrieved from knowledge base:")
try:
st.dataframe(
pd.DataFrame.from_records(
json.loads(msg.content, cls=CustomJSONDecoder)
),
use_container_width=True,
)
except:
st.write(msg.content)
else:
if len(msg.content) > 0:
with st.chat_message(speaker):
print(type(msg), msg.dict())
st.write(
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
)
st.write(f"{msg.content}")
st.session_state["next_round"] = st.empty()
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")