ChatData / backend /retrievers /self_query.py
lqhl's picture
Synced repo using 'sync_with_huggingface' Github Action
e931b70 verified
from typing import List
import pandas as pd
import streamlit as st
from langchain.retrievers import SelfQueryRetriever
from langchain_core.documents import Document
from langchain_core.runnables import RunnableConfig
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons
from backend.callbacks.self_query_callbacks import ChatDataSelfAskCallBackHandler, CustomSelfQueryRetrieverCallBackHandler
from ui.utils import display
from logger import logger
def process_self_query(selected_table, query_type):
place_holder = st.empty()
logger.info(
f"button-1: {RetrieverButtons.self_query_from_db}, "
f"button-2: {RetrieverButtons.self_query_with_llm}, "
f"content: {st.session_state.query_self}"
)
with place_holder.expander('πŸͺ΅ Chat Log', expanded=True):
try:
if query_type == RetrieverButtons.self_query_from_db:
callback = CustomSelfQueryRetrieverCallBackHandler()
retriever: SelfQueryRetriever = \
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["retriever"]
config: RunnableConfig = {"callbacks": [callback]}
relevant_docs = retriever.invoke(
input=st.session_state.query_self,
config=config
)
callback.progress_bar.progress(
value=1.0, text="[Question -> LLM -> Query filter -> MyScaleDB -> Results] Done!βœ…")
st.markdown(f"### Self Query Results from `{selected_table}` \n"
f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n")
display(
dataframe=pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in relevant_docs]
),
columns_=MYSCALE_TABLES[selected_table].must_have_col_names
)
elif query_type == RetrieverButtons.self_query_with_llm:
# callback = CustomSelfQueryRetrieverCallBackHandler()
callback = ChatDataSelfAskCallBackHandler()
chain: CustomRetrievalQAWithSourcesChain = \
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["chain"]
chain_results = chain(st.session_state.query_self, callbacks=[callback])
callback.progress_bar.progress(
value=1.0,
text="[Question -> LLM -> Query filter -> MyScaleDB -> Related Results -> LLM -> LLM Answer] Done!βœ…"
)
documents_reference: List[Document] = chain_results["source_documents"]
st.markdown(f"### SelfQueryRetriever Results from `{selected_table}` \n"
f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n")
display(
pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in documents_reference]
)
)
st.markdown(
f"### Answer from LLM \n"
f"> The response of the LLM when given the `SelfQueryRetriever` results. \n\n"
)
st.write(chain_results['answer'])
st.markdown(
f"### References from `{selected_table}`\n"
f"> Here shows that which documents used by LLM \n\n"
)
if len(chain_results['sources']) == 0:
st.write("No documents is used by LLM.")
else:
display(
dataframe=pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']]
),
columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names,
index='ref_id'
)
st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
except Exception as e:
st.write('Oops 😡 Something bad happened...')
raise e