Spaces:
Running
Running
import pandas as pd | |
from os import environ | |
import streamlit as st | |
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \ | |
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \ | |
ChatDataSQLAskCallBackHandler | |
from chat import chat_page | |
from login import login, back_to_main | |
from lib.helper import build_tools, build_all, sel_map, display | |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE'] | |
st.set_page_config(page_title="ChatData", | |
page_icon="https://myscale.com/favicon.ico") | |
st.markdown( | |
f""" | |
<style> | |
.st-e4 {{ | |
max-width: 500px | |
}} | |
</style>""", | |
unsafe_allow_html=True, | |
) | |
st.header("ChatData") | |
if 'sel_map_obj' not in st.session_state or 'embeddings' not in st.session_state: | |
st.session_state["sel_map_obj"], st.session_state["embeddings"] = build_all() | |
st.session_state["tools"] = build_tools() | |
if login(): | |
if "user_name" in st.session_state: | |
chat_page() | |
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask: | |
sel = st.selectbox('Choose the knowledge base you want to ask with:', | |
options=['ArXiv Papers', 'Wikipedia']) | |
sel_map[sel]['hint']() | |
tab_sql, tab_self_query = st.tabs( | |
['Vector SQL', 'Self-Query Retrievers']) | |
with tab_sql: | |
sel_map[sel]['hint_sql']() | |
st.text_input("Ask a question:", key='query_sql') | |
cols = st.columns([1, 1, 1, 4]) | |
cols[0].button("Query", key='search_sql') | |
cols[1].button("Ask", key='ask_sql') | |
cols[2].button("Back", key='back_sql', on_click=back_to_main) | |
plc_hldr = st.empty() | |
if st.session_state.search_sql: | |
plc_hldr = st.empty() | |
print(st.session_state.query_sql) | |
with plc_hldr.expander('Query Log', expanded=True): | |
callback = ChatDataSQLSearchCallBackHandler() | |
try: | |
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents( | |
st.session_state.query_sql, callbacks=[callback]) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display(docs) | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |
if st.session_state.ask_sql: | |
plc_hldr = st.empty() | |
print(st.session_state.query_sql) | |
with plc_hldr.expander('Chat Log', expanded=True): | |
callback = ChatDataSQLAskCallBackHandler() | |
try: | |
ret = st.session_state.sel_map_obj[sel]["sql_chain"]( | |
st.session_state.query_sql, callbacks=[callback]) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
st.markdown( | |
f"### Answer from LLM\n{ret['answer']}\n### References") | |
docs = ret['sources'] | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display( | |
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |
with tab_self_query: | |
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='π‘') | |
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"]) | |
st.text_input("Ask a question:", key='query_self') | |
cols = st.columns([1, 1, 1, 4]) | |
cols[0].button("Query", key='search_self') | |
cols[1].button("Ask", key='ask_self') | |
cols[2].button("Back", key='back_self', on_click=back_to_main) | |
plc_hldr = st.empty() | |
if st.session_state.search_self: | |
plc_hldr = st.empty() | |
print(st.session_state.query_self) | |
with plc_hldr.expander('Query Log', expanded=True): | |
call_back = None | |
callback = ChatDataSelfSearchCallBackHandler() | |
try: | |
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents( | |
st.session_state.query_self, callbacks=[callback]) | |
print(docs) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display(docs, sel_map[sel]["must_have_cols"]) | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |
if st.session_state.ask_self: | |
plc_hldr = st.empty() | |
print(st.session_state.query_self) | |
with plc_hldr.expander('Chat Log', expanded=True): | |
call_back = None | |
callback = ChatDataSelfAskCallBackHandler() | |
try: | |
ret = st.session_state.sel_map_obj[sel]["chain"]( | |
st.session_state.query_self, callbacks=[callback]) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
st.markdown( | |
f"### Answer from LLM\n{ret['answer']}\n### References") | |
docs = ret['sources'] | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display( | |
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |