import streamlit as st from langchain.chat_models import ChatOpenAI from langchain.prompts.prompt import PromptTemplate from langchain.retrievers.self_query.base import SelfQueryRetriever from langchain.retrievers.self_query.myscale import MyScaleTranslator from langchain.utilities.sql_database import SQLDatabase from langchain.vectorstores import MyScaleSettings from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain from sqlalchemy import create_engine, MetaData from backend.constants.myscale_tables import MYSCALE_TABLES from backend.constants.prompts import MYSCALE_PROMPT from backend.constants.variables import TABLE_EMBEDDINGS_MAPPING, GLOBAL_CONFIG from backend.retrievers.vector_sql_output_parser import VectorSQLRetrieveOutputParser from backend.vector_store.myscale_without_metadata import MyScaleWithoutMetadataJson from logger import logger @st.cache_resource def build_self_query_retriever(table_name: str) -> SelfQueryRetriever: with st.spinner(f"Building VectorStore for MyScaleDB/{table_name} ..."): myscale_connection = { "host": GLOBAL_CONFIG.myscale_host, "port": GLOBAL_CONFIG.myscale_port, "username": GLOBAL_CONFIG.myscale_user, "password": GLOBAL_CONFIG.myscale_password, } myscale_settings = MyScaleSettings( **myscale_connection, database=MYSCALE_TABLES[table_name].database, table=MYSCALE_TABLES[table_name].table, column_map={ "id": "id", "text": MYSCALE_TABLES[table_name].text_col_name, "vector": MYSCALE_TABLES[table_name].vector_col_name, # TODO refine MyScaleDB metadata in langchain. "metadata": MYSCALE_TABLES[table_name].metadata_col_name } ) myscale_vector_store = MyScaleWithoutMetadataJson( embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name], config=myscale_settings, must_have_cols=MYSCALE_TABLES[table_name].must_have_col_names ) with st.spinner(f"Building SelfQueryRetriever for MyScaleDB/{table_name} ..."): retriever: SelfQueryRetriever = SelfQueryRetriever.from_llm( llm=ChatOpenAI( model_name=GLOBAL_CONFIG.query_model, base_url=GLOBAL_CONFIG.openai_api_base, api_key=GLOBAL_CONFIG.openai_api_key, temperature=0 ), vectorstore=myscale_vector_store, document_contents=MYSCALE_TABLES[table_name].table_contents, metadata_field_info=MYSCALE_TABLES[table_name].metadata_col_attributes, use_original_query=False, structured_query_translator=MyScaleTranslator() ) return retriever @st.cache_resource def build_vector_sql_db_chain_retriever(table_name: str) -> VectorSQLDatabaseChainRetriever: """Get a group of relative docs from MyScaleDB""" with st.spinner(f'Building Vector SQL Database Retriever for MyScaleDB/{table_name}...'): if GLOBAL_CONFIG.myscale_enable_https == False: engine = create_engine( f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@' f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}' f'/{MYSCALE_TABLES[table_name].database}?protocol=http' ) else: engine = create_engine( f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@' f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}' f'/{MYSCALE_TABLES[table_name].database}?protocol=https' ) metadata = MetaData(bind=engine) logger.info(f"{table_name} metadata is : {metadata}") prompt = PromptTemplate( input_variables=["input", "table_info", "top_k"], template=MYSCALE_PROMPT, ) # Custom `out_put_parser` rewrite search SQL, make it's possible to query custom column. output_parser = VectorSQLRetrieveOutputParser.from_embeddings( model=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name], # rewrite columns needs be searched. must_have_columns=MYSCALE_TABLES[table_name].must_have_col_names ) # `db_chain` will generate a SQL vector_sql_db_chain: VectorSQLDatabaseChain = VectorSQLDatabaseChain.from_llm( llm=ChatOpenAI( model_name=GLOBAL_CONFIG.query_model, base_url=GLOBAL_CONFIG.openai_api_base, api_key=GLOBAL_CONFIG.openai_api_key, temperature=0 ), prompt=prompt, top_k=10, return_direct=True, db=SQLDatabase( engine, None, metadata, include_tables=[MYSCALE_TABLES[table_name].table], max_string_length=1024 ), sql_cmd_parser=output_parser, # TODO needs update `langchain`, fix return type. native_format=True ) # `retriever` can search a group of documents with `db_chain` vector_sql_db_chain_retriever = VectorSQLDatabaseChainRetriever( sql_db_chain=vector_sql_db_chain, page_content_key=MYSCALE_TABLES[table_name].text_col_name ) return vector_sql_db_chain_retriever