Spaces:
Running
Running
import json | |
import pandas as pd | |
from os import environ | |
from time import sleep | |
import datetime | |
import streamlit as st | |
from lib.sessions import SessionManager | |
from lib.private_kb import PrivateKnowledgeBase | |
from langchain.schema import HumanMessage, FunctionMessage | |
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler | |
from lib.json_conv import CustomJSONDecoder | |
from lib.helper import ( | |
build_agents, | |
MYSCALE_HOST, | |
MYSCALE_PASSWORD, | |
MYSCALE_PORT, | |
MYSCALE_USER, | |
DEFAULT_SYSTEM_PROMPT, | |
UNSTRUCTURED_API, | |
) | |
from login import back_to_main | |
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"] | |
TOOL_NAMES = { | |
"langchain_retriever_tool": "Self-querying retriever", | |
"vecsql_retriever_tool": "Vector SQL", | |
} | |
def on_chat_submit(): | |
with st.session_state.next_round.container(): | |
with st.chat_message("user"): | |
st.write(st.session_state.chat_input) | |
with st.chat_message("assistant"): | |
container = st.container() | |
st_callback = ChatDataAgentCallBackHandler( | |
container, collapse_completed_thoughts=False | |
) | |
ret = st.session_state.agent( | |
{"input": st.session_state.chat_input}, callbacks=[st_callback] | |
) | |
print(ret) | |
def clear_history(): | |
if "agent" in st.session_state: | |
st.session_state.agent.memory.clear() | |
def back_to_main(): | |
if "user_info" in st.session_state: | |
del st.session_state.user_info | |
if "user_name" in st.session_state: | |
del st.session_state.user_name | |
if "jump_query_ask" in st.session_state: | |
del st.session_state.jump_query_ask | |
if "sel_sess" in st.session_state: | |
del st.session_state.sel_sess | |
if "current_sessions" in st.session_state: | |
del st.session_state.current_sessions | |
def on_session_change_submit(): | |
if "session_manager" in st.session_state and "session_editor" in st.session_state: | |
print(st.session_state.session_editor) | |
try: | |
for elem in st.session_state.session_editor["added_rows"]: | |
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem: | |
if elem["session_id"] != "" and "?" not in elem["session_id"]: | |
st.session_state.session_manager.add_session( | |
user_id=st.session_state.user_name, | |
session_id=f"{st.session_state.user_name}?{elem['session_id']}", | |
system_prompt=elem["system_prompt"], | |
) | |
else: | |
raise KeyError( | |
"`session_id` should NOT be neither empty nor contain question marks." | |
) | |
else: | |
raise KeyError( | |
"You should fill both `session_id` and `system_prompt` to add a column!" | |
) | |
for elem in st.session_state.session_editor["deleted_rows"]: | |
st.session_state.session_manager.remove_session( | |
session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}", | |
) | |
refresh_sessions() | |
except Exception as e: | |
sleep(2) | |
st.error(f"{type(e)}: {str(e)}") | |
finally: | |
st.session_state.session_editor["added_rows"] = [] | |
st.session_state.session_editor["deleted_rows"] = [] | |
refresh_agent() | |
def build_session_manager(): | |
return SessionManager( | |
st.session_state, | |
host=MYSCALE_HOST, | |
port=MYSCALE_PORT, | |
username=MYSCALE_USER, | |
password=MYSCALE_PASSWORD, | |
) | |
def refresh_sessions(): | |
st.session_state[ | |
"current_sessions" | |
] = st.session_state.session_manager.list_sessions(st.session_state.user_name) | |
if ( | |
type(st.session_state.current_sessions) is not dict | |
and len(st.session_state.current_sessions) <= 0 | |
): | |
st.session_state.session_manager.add_session( | |
st.session_state.user_name, | |
f"{st.session_state.user_name}?default", | |
DEFAULT_SYSTEM_PROMPT, | |
) | |
st.session_state[ | |
"current_sessions" | |
] = st.session_state.session_manager.list_sessions(st.session_state.user_name) | |
st.session_state["user_files"] = st.session_state.private_kb.list_files( | |
st.session_state.user_name | |
) | |
st.session_state["user_tools"] = st.session_state.private_kb.list_tools( | |
st.session_state.user_name | |
) | |
st.session_state["tools_with_users"] = { | |
**st.session_state.tools, | |
**st.session_state.private_kb.as_tools(st.session_state.user_name), | |
} | |
try: | |
dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index( | |
"default" | |
if "" not in st.session_state | |
else st.session_state.sel_session["session_id"] | |
) | |
except ValueError: | |
dfl_indx = 0 | |
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx] | |
def build_kb_as_tool(): | |
if ( | |
"b_tool_name" in st.session_state | |
and "b_tool_desc" in st.session_state | |
and "b_tool_files" in st.session_state | |
and len(st.session_state.b_tool_name) > 0 | |
and len(st.session_state.b_tool_desc) > 0 | |
and len(st.session_state.b_tool_files) > 0 | |
): | |
st.session_state.private_kb.create_tool( | |
st.session_state.user_name, | |
st.session_state.b_tool_name, | |
st.session_state.b_tool_desc, | |
[f["file_name"] for f in st.session_state.b_tool_files], | |
) | |
refresh_sessions() | |
else: | |
st.session_state.tool_status.error( | |
"You should fill all fields to build up a tool!" | |
) | |
sleep(2) | |
def remove_kb(): | |
if "r_tool_names" in st.session_state and len(st.session_state.r_tool_names) > 0: | |
st.session_state.private_kb.remove_tools( | |
st.session_state.user_name, | |
[f["tool_name"] for f in st.session_state.r_tool_names], | |
) | |
refresh_sessions() | |
else: | |
st.session_state.tool_status.error( | |
"You should specify at least one tool to delete!" | |
) | |
sleep(2) | |
def refresh_agent(): | |
with st.spinner("Initializing session..."): | |
print( | |
f"??? Changed to ", | |
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}", | |
) | |
st.session_state["agent"] = build_agents( | |
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}", | |
["LangChain Self Query Retriever For Wikipedia"] | |
if "selected_tools" not in st.session_state | |
else st.session_state.selected_tools, | |
system_prompt=DEFAULT_SYSTEM_PROMPT | |
if "sel_sess" not in st.session_state | |
else st.session_state.sel_sess["system_prompt"], | |
) | |
def add_file(): | |
if ( | |
"uploaded_files" not in st.session_state | |
or len(st.session_state.uploaded_files) == 0 | |
): | |
st.session_state.tool_status.error("Please upload files!", icon="β οΈ") | |
sleep(2) | |
return | |
try: | |
st.session_state.tool_status.info("Uploading...") | |
st.session_state.private_kb.add_by_file( | |
st.session_state.user_name, st.session_state.uploaded_files | |
) | |
refresh_sessions() | |
except ValueError as e: | |
st.session_state.tool_status.error("Failed to upload! " + str(e)) | |
sleep(2) | |
def clear_files(): | |
st.session_state.private_kb.clear(st.session_state.user_name) | |
refresh_sessions() | |
def chat_page(): | |
if "sel_sess" not in st.session_state: | |
st.session_state["sel_sess"] = { | |
"session_id": "default", | |
"system_prompt": DEFAULT_SYSTEM_PROMPT, | |
} | |
if "private_kb" not in st.session_state: | |
st.session_state["private_kb"] = PrivateKnowledgeBase( | |
host=MYSCALE_HOST, | |
port=MYSCALE_PORT, | |
username=MYSCALE_USER, | |
password=MYSCALE_PASSWORD, | |
embedding=st.session_state.embeddings["Wikipedia"], | |
parser_api_key=UNSTRUCTURED_API, | |
) | |
if "session_manager" not in st.session_state: | |
st.session_state["session_manager"] = build_session_manager() | |
with st.sidebar: | |
with st.expander("Session Management"): | |
if "current_sessions" not in st.session_state: | |
refresh_sessions() | |
st.info( | |
"Here you can set up your session! \n\nYou can **change your prompt** here!", | |
icon="π€", | |
) | |
st.info( | |
( | |
"**Add columns by clicking the empty row**.\n" | |
"And **delete columns by selecting rows with a press on `DEL` Key**" | |
), | |
icon="π‘", | |
) | |
st.info( | |
"Don't forget to **click `Submit Change` to save your change**!", | |
icon="π", | |
) | |
st.data_editor( | |
st.session_state.current_sessions, | |
num_rows="dynamic", | |
key="session_editor", | |
use_container_width=True, | |
) | |
st.button("Submit Change!", on_click=on_session_change_submit) | |
with st.expander("Session Selection", expanded=True): | |
st.info( | |
"If no session is attach to your account, then we will add a default session to you!", | |
icon="β€οΈ", | |
) | |
try: | |
dfl_indx = [ | |
x["session_id"] for x in st.session_state.current_sessions | |
].index( | |
"default" | |
if "" not in st.session_state | |
else st.session_state.sel_session["session_id"] | |
) | |
except Exception as e: | |
print("*** ", str(e)) | |
dfl_indx = 0 | |
st.selectbox( | |
"Choose a session to chat:", | |
options=st.session_state.current_sessions, | |
index=dfl_indx, | |
key="sel_sess", | |
format_func=lambda x: x["session_id"], | |
on_change=refresh_agent, | |
) | |
print(st.session_state.sel_sess) | |
with st.expander("Tool Settings", expanded=True): | |
st.info( | |
"We provides you several knowledge base tools for you. We are building more tools!", | |
icon="π§", | |
) | |
st.session_state["tool_status"] = st.empty() | |
tab_kb, tab_file = st.tabs( | |
[ | |
"Knowledge Bases", | |
"File Upload", | |
] | |
) | |
with tab_kb: | |
st.markdown("#### Build You Own Knowledge") | |
st.multiselect( | |
"Select Files to Build up", | |
st.session_state.user_files, | |
placeholder="You should upload files first", | |
key="b_tool_files", | |
format_func=lambda x: x["file_name"], | |
) | |
st.text_input( | |
"Tool Name", "get_relevant_documents", key="b_tool_name") | |
st.text_input( | |
"Tool Description", | |
"Searches among user's private files and returns related documents", | |
key="b_tool_desc", | |
) | |
st.button("Build!", on_click=build_kb_as_tool) | |
st.markdown("### Knowledge Base Selection") | |
if ( | |
"user_tools" in st.session_state | |
and len(st.session_state.user_tools) > 0 | |
): | |
st.markdown("***User Created Knowledge Bases***") | |
st.dataframe(st.session_state.user_tools) | |
st.multiselect( | |
"Select a Knowledge Base Tool", | |
st.session_state.tools.keys() | |
if "tools_with_users" not in st.session_state | |
else st.session_state.tools_with_users, | |
default=["Wikipedia + Self Querying"], | |
key="selected_tools", | |
on_change=refresh_agent, | |
) | |
st.markdown("### Delete Knowledge Base") | |
st.multiselect( | |
"Choose Knowledge Base to Remove", | |
st.session_state.user_tools, | |
format_func=lambda x: x["tool_name"], | |
key="r_tool_names", | |
) | |
st.button("Delete", on_click=remove_kb) | |
with tab_file: | |
st.info( | |
( | |
"We adopted [Unstructured API](https://unstructured.io/api-key) " | |
"here and we only store the processed texts from your documents. " | |
"For privacy concerns, please refer to " | |
"[our policy issue](https://myscale.com/privacy/)." | |
), | |
icon="π", | |
) | |
st.file_uploader( | |
"Upload files", key="uploaded_files", accept_multiple_files=True | |
) | |
st.markdown("### Uploaded Files") | |
st.dataframe( | |
st.session_state.private_kb.list_files( | |
st.session_state.user_name), | |
use_container_width=True, | |
) | |
col_1, col_2 = st.columns(2) | |
with col_1: | |
st.button("Add Files", on_click=add_file) | |
with col_2: | |
st.button("Clear Files and All Tools", | |
on_click=clear_files) | |
st.button("Clear Chat History", on_click=clear_history) | |
st.button("Logout", on_click=back_to_main) | |
if "agent" not in st.session_state: | |
refresh_agent() | |
print("!!! ", st.session_state.agent.memory.chat_memory.session_id) | |
for msg in st.session_state.agent.memory.chat_memory.messages: | |
speaker = "user" if isinstance(msg, HumanMessage) else "assistant" | |
if isinstance(msg, FunctionMessage): | |
with st.chat_message("Knowledge Base", avatar="π"): | |
st.write( | |
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" | |
) | |
st.write("Retrieved from knowledge base:") | |
try: | |
st.dataframe( | |
pd.DataFrame.from_records( | |
json.loads(msg.content, cls=CustomJSONDecoder) | |
), | |
use_container_width=True, | |
) | |
except: | |
st.write(msg.content) | |
else: | |
if len(msg.content) > 0: | |
with st.chat_message(speaker): | |
print(type(msg), msg.dict()) | |
st.write( | |
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" | |
) | |
st.write(f"{msg.content}") | |
st.session_state["next_round"] = st.empty() | |
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input") | |