Spaces:
Running
Running
from typing import List | |
import pandas as pd | |
import streamlit as st | |
from langchain.retrievers import SelfQueryRetriever | |
from langchain_core.documents import Document | |
from langchain_core.runnables import RunnableConfig | |
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain | |
from backend.constants.myscale_tables import MYSCALE_TABLES | |
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons | |
from backend.callbacks.self_query_callbacks import ChatDataSelfAskCallBackHandler, CustomSelfQueryRetrieverCallBackHandler | |
from ui.utils import display | |
from logger import logger | |
def process_self_query(selected_table, query_type): | |
place_holder = st.empty() | |
logger.info( | |
f"button-1: {RetrieverButtons.self_query_from_db}, " | |
f"button-2: {RetrieverButtons.self_query_with_llm}, " | |
f"content: {st.session_state.query_self}" | |
) | |
with place_holder.expander('πͺ΅ Chat Log', expanded=True): | |
try: | |
if query_type == RetrieverButtons.self_query_from_db: | |
callback = CustomSelfQueryRetrieverCallBackHandler() | |
retriever: SelfQueryRetriever = \ | |
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["retriever"] | |
config: RunnableConfig = {"callbacks": [callback]} | |
relevant_docs = retriever.invoke( | |
input=st.session_state.query_self, | |
config=config | |
) | |
callback.progress_bar.progress( | |
value=1.0, text="[Question -> LLM -> Query filter -> MyScaleDB -> Results] Done!β ") | |
st.markdown(f"### Self Query Results from `{selected_table}` \n" | |
f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n") | |
display( | |
dataframe=pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in relevant_docs] | |
), | |
columns_=MYSCALE_TABLES[selected_table].must_have_col_names | |
) | |
elif query_type == RetrieverButtons.self_query_with_llm: | |
# callback = CustomSelfQueryRetrieverCallBackHandler() | |
callback = ChatDataSelfAskCallBackHandler() | |
chain: CustomRetrievalQAWithSourcesChain = \ | |
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["chain"] | |
chain_results = chain(st.session_state.query_self, callbacks=[callback]) | |
callback.progress_bar.progress( | |
value=1.0, | |
text="[Question -> LLM -> Query filter -> MyScaleDB -> Related Results -> LLM -> LLM Answer] Done!β " | |
) | |
documents_reference: List[Document] = chain_results["source_documents"] | |
st.markdown(f"### SelfQueryRetriever Results from `{selected_table}` \n" | |
f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n") | |
display( | |
pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in documents_reference] | |
) | |
) | |
st.markdown( | |
f"### Answer from LLM \n" | |
f"> The response of the LLM when given the `SelfQueryRetriever` results. \n\n" | |
) | |
st.write(chain_results['answer']) | |
st.markdown( | |
f"### References from `{selected_table}`\n" | |
f"> Here shows that which documents used by LLM \n\n" | |
) | |
if len(chain_results['sources']) == 0: | |
st.write("No documents is used by LLM.") | |
else: | |
display( | |
dataframe=pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']] | |
), | |
columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names, | |
index='ref_id' | |
) | |
st.markdown(DIVIDER_HTML, unsafe_allow_html=True) | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |