Spaces:
Running
Running
import streamlit as st | |
from langchain.callbacks.streamlit.streamlit_callback_handler import ( | |
StreamlitCallbackHandler, | |
) | |
from langchain.schema.output import LLMResult | |
from sql_formatter.core import format_sql | |
class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler): | |
def __init__(self) -> None: | |
self.progress_bar = st.progress(value=0.0, text="Writing SQL...") | |
self.status_bar = st.empty() | |
self.prog_value = 0 | |
self.prog_interval = 0.2 | |
def on_llm_start(self, serialized, prompts, **kwargs) -> None: | |
pass | |
def on_llm_end( | |
self, | |
response: LLMResult, | |
*args, | |
**kwargs, | |
): | |
text = response.generations[0][0].text | |
if text.replace(" ", "").upper().startswith("SELECT"): | |
st.markdown("### Generated Vector Search SQL Statement \n" | |
"> This sql statement is generated by LLM \n\n") | |
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""") | |
self.prog_value += self.prog_interval | |
self.progress_bar.progress( | |
value=self.prog_value, text="Searching in DB...") | |
def on_chain_start(self, serialized, inputs, **kwargs) -> None: | |
cid = ".".join(serialized["id"]) | |
self.prog_value += self.prog_interval | |
self.progress_bar.progress( | |
value=self.prog_value, text=f"Running Chain `{cid}`..." | |
) | |
def on_chain_end(self, outputs, **kwargs) -> None: | |
pass | |
class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler): | |
def __init__(self, table: str) -> None: | |
self.progress_bar = st.progress(value=0.0, text="Writing SQL...") | |
self.status_bar = st.empty() | |
self.prog_value = 0 | |
self.prog_interval = 0.1 | |
self.table = table | |