lqhl commited on
Commit
e931b70
1 Parent(s): f9dc2a4

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (48) hide show
  1. .streamlit/config.toml +1 -5
  2. .streamlit/secrets.example.toml +2 -1
  3. app.py +74 -121
  4. backend/__init__.py +0 -0
  5. backend/callbacks/__init__.py +0 -0
  6. backend/callbacks/arxiv_callbacks.py +46 -0
  7. backend/callbacks/llm_thought_with_table.py +36 -0
  8. backend/callbacks/self_query_callbacks.py +57 -0
  9. backend/callbacks/vector_sql_callbacks.py +53 -0
  10. backend/chains/__init__.py +0 -0
  11. backend/chains/retrieval_qa_with_sources.py +70 -0
  12. backend/chains/stuff_documents.py +65 -0
  13. backend/chat_bot/__init__.py +0 -0
  14. backend/chat_bot/chat.py +225 -0
  15. backend/chat_bot/json_decoder.py +24 -0
  16. backend/chat_bot/message_converter.py +67 -0
  17. backend/chat_bot/private_knowledge_base.py +167 -0
  18. backend/chat_bot/session_manager.py +96 -0
  19. backend/chat_bot/tools.py +100 -0
  20. backend/constants/__init__.py +0 -0
  21. backend/constants/myscale_tables.py +128 -0
  22. backend/constants/prompts.py +128 -0
  23. backend/constants/streamlit_keys.py +35 -0
  24. backend/constants/variables.py +58 -0
  25. backend/construct/__init__.py +0 -0
  26. backend/construct/build_agents.py +82 -0
  27. backend/construct/build_all.py +95 -0
  28. backend/construct/build_chains.py +39 -0
  29. backend/construct/build_chat_bot.py +36 -0
  30. backend/construct/build_retriever_tool.py +45 -0
  31. backend/construct/build_retrievers.py +120 -0
  32. backend/retrievers/__init__.py +0 -0
  33. backend/retrievers/self_query.py +89 -0
  34. backend/retrievers/vector_sql_output_parser.py +23 -0
  35. backend/retrievers/vector_sql_query.py +95 -0
  36. backend/types/__init__.py +0 -0
  37. backend/types/chains_and_retrievers.py +34 -0
  38. backend/types/global_config.py +22 -0
  39. backend/types/table_config.py +25 -0
  40. backend/vector_store/__init__.py +0 -0
  41. backend/vector_store/myscale_without_metadata.py +52 -0
  42. logger.py +18 -0
  43. requirements.txt +9 -7
  44. ui/__init__.py +0 -0
  45. ui/chat_page.py +196 -0
  46. ui/home.py +156 -0
  47. ui/retrievers.py +97 -0
  48. ui/utils.py +18 -0
.streamlit/config.toml CHANGED
@@ -1,6 +1,2 @@
1
  [theme]
2
- primaryColor="#523EFD"
3
- backgroundColor="#FFFFFF"
4
- secondaryBackgroundColor="#D4CEFF"
5
- textColor="#262730"
6
- font="sans serif"
 
1
  [theme]
2
+ base="dark"
 
 
 
 
.streamlit/secrets.example.toml CHANGED
@@ -1,7 +1,8 @@
1
- MYSCALE_HOST = "msc-4a9e710a.us-east-1.aws.staging.myscale.cloud" # read-only database provided by MyScale
2
  MYSCALE_PORT = 443
3
  MYSCALE_USER = "chatdata"
4
  MYSCALE_PASSWORD = "myscale_rocks"
 
5
  OPENAI_API_BASE = "https://api.openai.com/v1"
6
  OPENAI_API_KEY = "<your-openai-key>"
7
  UNSTRUCTURED_API = "<your-unstructured-io-api>" # optional if you don't upload documents
 
1
+ MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com" # read-only database provided by MyScale
2
  MYSCALE_PORT = 443
3
  MYSCALE_USER = "chatdata"
4
  MYSCALE_PASSWORD = "myscale_rocks"
5
+ MYSCALE_ENABLE_HTTPS = true
6
  OPENAI_API_BASE = "https://api.openai.com/v1"
7
  OPENAI_API_KEY = "<your-openai-key>"
8
  UNSTRUCTURED_API = "<your-unstructured-io-api>" # optional if you don't upload documents
app.py CHANGED
@@ -1,133 +1,86 @@
1
- import pandas as pd
2
- from os import environ
 
3
  import streamlit as st
4
 
5
- from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
6
- ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
7
- ChatDataSQLAskCallBackHandler
 
 
 
 
 
 
 
8
 
9
- from chat import chat_page
10
- from login import login, back_to_main
11
- from lib.helper import build_tools, build_all, sel_map, display
12
 
 
13
 
14
- environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- st.set_page_config(page_title="ChatData",
17
- page_icon="https://myscale.com/favicon.ico")
18
- st.markdown(
19
- f"""
20
- <style>
21
- .st-e4 {{
22
- max-width: 500px
23
- }}
24
- </style>""",
25
- unsafe_allow_html=True,
26
- )
27
- st.header("ChatData")
28
 
29
- if 'sel_map_obj' not in st.session_state or 'embeddings' not in st.session_state:
30
- st.session_state["sel_map_obj"], st.session_state["embeddings"] = build_all()
31
- st.session_state["tools"] = build_tools()
 
 
 
 
 
32
 
33
- if login():
34
- if "user_name" in st.session_state:
35
- chat_page()
36
- elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
37
 
38
- sel = st.selectbox('Choose the knowledge base you want to ask with:',
39
- options=['ArXiv Papers', 'Wikipedia'])
40
- sel_map[sel]['hint']()
41
- tab_sql, tab_self_query = st.tabs(
42
- ['Vector SQL', 'Self-Query Retrievers'])
43
- with tab_sql:
44
- sel_map[sel]['hint_sql']()
45
- st.text_input("Ask a question:", key='query_sql')
46
- cols = st.columns([1, 1, 1, 4])
47
- cols[0].button("Query", key='search_sql')
48
- cols[1].button("Ask", key='ask_sql')
49
- cols[2].button("Back", key='back_sql', on_click=back_to_main)
50
- plc_hldr = st.empty()
51
- if st.session_state.search_sql:
52
- plc_hldr = st.empty()
53
- print(st.session_state.query_sql)
54
- with plc_hldr.expander('Query Log', expanded=True):
55
- callback = ChatDataSQLSearchCallBackHandler()
56
- try:
57
- docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
58
- st.session_state.query_sql, callbacks=[callback])
59
- callback.progress_bar.progress(value=1.0, text="Done!")
60
- docs = pd.DataFrame(
61
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
62
- display(docs)
63
- except Exception as e:
64
- st.write('Oops 😵 Something bad happened...')
65
- raise e
66
 
67
- if st.session_state.ask_sql:
68
- plc_hldr = st.empty()
69
- print(st.session_state.query_sql)
70
- with plc_hldr.expander('Chat Log', expanded=True):
71
- callback = ChatDataSQLAskCallBackHandler()
72
- try:
73
- ret = st.session_state.sel_map_obj[sel]["sql_chain"](
74
- st.session_state.query_sql, callbacks=[callback])
75
- callback.progress_bar.progress(value=1.0, text="Done!")
76
- st.markdown(
77
- f"### Answer from LLM\n{ret['answer']}\n### References")
78
- docs = ret['sources']
79
- docs = pd.DataFrame(
80
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
81
- display(
82
- docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
83
- except Exception as e:
84
- st.write('Oops 😵 Something bad happened...')
85
- raise e
86
 
87
- with tab_self_query:
88
- st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
89
- st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
90
- st.text_input("Ask a question:", key='query_self')
91
- cols = st.columns([1, 1, 1, 4])
92
- cols[0].button("Query", key='search_self')
93
- cols[1].button("Ask", key='ask_self')
94
- cols[2].button("Back", key='back_self', on_click=back_to_main)
95
- plc_hldr = st.empty()
96
- if st.session_state.search_self:
97
- plc_hldr = st.empty()
98
- print(st.session_state.query_self)
99
- with plc_hldr.expander('Query Log', expanded=True):
100
- call_back = None
101
- callback = ChatDataSelfSearchCallBackHandler()
102
- try:
103
- docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
104
- st.session_state.query_self, callbacks=[callback])
105
- print(docs)
106
- callback.progress_bar.progress(value=1.0, text="Done!")
107
- docs = pd.DataFrame(
108
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
109
- display(docs, sel_map[sel]["must_have_cols"])
110
- except Exception as e:
111
- st.write('Oops 😵 Something bad happened...')
112
- raise e
113
 
114
- if st.session_state.ask_self:
115
- plc_hldr = st.empty()
116
- print(st.session_state.query_self)
117
- with plc_hldr.expander('Chat Log', expanded=True):
118
- call_back = None
119
- callback = ChatDataSelfAskCallBackHandler()
120
- try:
121
- ret = st.session_state.sel_map_obj[sel]["chain"](
122
- st.session_state.query_self, callbacks=[callback])
123
- callback.progress_bar.progress(value=1.0, text="Done!")
124
- st.markdown(
125
- f"### Answer from LLM\n{ret['answer']}\n### References")
126
- docs = ret['sources']
127
- docs = pd.DataFrame(
128
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
129
- display(
130
- docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
131
- except Exception as e:
132
- st.write('Oops 😵 Something bad happened...')
133
- raise e
 
1
+ import os
2
+ import time
3
+
4
  import streamlit as st
5
 
6
+ from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \
7
+ DATA_INITIALIZE_STARTED
8
+ from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \
9
+ TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config
10
+ from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools
11
+ from backend.types.global_config import GlobalConfig
12
+ from logger import logger
13
+ from ui.chat_page import chat_page
14
+ from ui.home import render_home
15
+ from ui.retrievers import render_retrievers
16
 
 
 
 
17
 
18
+ # warnings.filterwarnings("ignore", category=UserWarning)
19
 
20
+ def prepare_environment():
21
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
22
+ os.environ["LANGCHAIN_TRACING_V2"] = "false"
23
+ # os.environ["LANGCHAIN_API_KEY"] = ""
24
+ os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE']
25
+ os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY']
26
+ os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID']
27
+ os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN']
28
+
29
+ update_global_config(GlobalConfig(
30
+ openai_api_base=st.secrets['OPENAI_API_BASE'],
31
+ openai_api_key=st.secrets['OPENAI_API_KEY'],
32
+ auth0_client_id=st.secrets['AUTH0_CLIENT_ID'],
33
+ auth0_domain=st.secrets['AUTH0_DOMAIN'],
34
+ myscale_user=st.secrets['MYSCALE_USER'],
35
+ myscale_password=st.secrets['MYSCALE_PASSWORD'],
36
+ myscale_host=st.secrets['MYSCALE_HOST'],
37
+ myscale_port=st.secrets['MYSCALE_PORT'],
38
+ query_model="gpt-3.5-turbo-0125",
39
+ chat_model="gpt-3.5-turbo-0125",
40
+ untrusted_api=st.secrets['UNSTRUCTURED_API'],
41
+ myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True),
42
+ ))
43
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # when refresh browser, all session keys will be cleaned.
46
+ def initialize_session_state():
47
+ if DATA_INITIALIZE_STATUS not in st.session_state:
48
+ st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED
49
+ logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}")
50
+ if JUMP_QUERY_ASK not in st.session_state:
51
+ st.session_state[JUMP_QUERY_ASK] = False
52
+ logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}")
53
 
 
 
 
 
54
 
55
+ def initialize_chat_data():
56
+ if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED:
57
+ start_time = time.time()
58
+ st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED
59
+ st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models()
60
+ st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers()
61
+ st.session_state[RETRIEVER_TOOLS] = update_retriever_tools()
62
+ # mark data initialization finished.
63
+ st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED
64
+ end_time = time.time()
65
+ logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, "
66
+ f"session state keys: {list(st.session_state.keys())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ st.set_page_config(
70
+ page_title="ChatData",
71
+ page_icon="https://myscale.com/favicon.ico",
72
+ initial_sidebar_state="expanded",
73
+ layout="wide",
74
+ )
75
+
76
+ prepare_environment()
77
+ initialize_session_state()
78
+ initialize_chat_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ if USER_NAME in st.session_state:
81
+ chat_page()
82
+ else:
83
+ if st.session_state[JUMP_QUERY_ASK]:
84
+ render_retrievers()
85
+ else:
86
+ render_home()
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/__init__.py ADDED
File without changes
backend/callbacks/__init__.py ADDED
File without changes
backend/callbacks/arxiv_callbacks.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import textwrap
3
+ from typing import Dict, Any, List
4
+
5
+ from langchain.callbacks.streamlit.streamlit_callback_handler import (
6
+ LLMThought,
7
+ StreamlitCallbackHandler,
8
+ )
9
+
10
+
11
+ class LLMThoughtWithKnowledgeBase(LLMThought):
12
+ def on_tool_end(
13
+ self,
14
+ output: str,
15
+ color=None,
16
+ observation_prefix=None,
17
+ llm_prefix=None,
18
+ **kwargs: Any,
19
+ ) -> None:
20
+ try:
21
+ self._container.markdown(
22
+ "\n\n".join(
23
+ ["### Retrieved Documents:"]
24
+ + [
25
+ f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
26
+ for i, r in enumerate(json.loads(output))
27
+ ]
28
+ )
29
+ )
30
+ except Exception as e:
31
+ super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
32
+
33
+
34
+ class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
35
+ def on_llm_start(
36
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
37
+ ) -> None:
38
+ if self._current_thought is None:
39
+ self._current_thought = LLMThoughtWithKnowledgeBase(
40
+ parent_container=self._parent_container,
41
+ expanded=self._expand_new_thoughts,
42
+ collapse_on_complete=self._collapse_completed_thoughts,
43
+ labeler=self._thought_labeler,
44
+ )
45
+
46
+ self._current_thought.on_llm_start(serialized, prompts)
backend/callbacks/llm_thought_with_table.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import streamlit as st
4
+ from langchain_core.outputs import LLMResult
5
+ from streamlit.external.langchain import StreamlitCallbackHandler
6
+
7
+
8
+ class ChatDataSelfQueryCallBack(StreamlitCallbackHandler):
9
+ def __init__(self):
10
+ super().__init__(st.container())
11
+ self._current_thought = None
12
+ self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery CallBack...")
13
+
14
+ def on_llm_start(
15
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
16
+ ) -> None:
17
+ self.progress_bar.progress(value=0.35, text="Communicate with LLM...")
18
+ pass
19
+
20
+ def on_chain_end(self, outputs, **kwargs) -> None:
21
+ if len(kwargs['tags']) == 0:
22
+ self.progress_bar.progress(value=0.75, text="Searching in DB...")
23
+
24
+ def on_chain_start(self, serialized, inputs, **kwargs) -> None:
25
+
26
+ pass
27
+
28
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
29
+ st.markdown("### Generate filter by LLM \n"
30
+ "> Here we get `query_constructor` results \n\n")
31
+
32
+ self.progress_bar.progress(value=0.5, text="Generate filter by LLM...")
33
+ for item in response.generations:
34
+ st.markdown(f"{item[0].text}")
35
+
36
+ pass
backend/callbacks/self_query_callbacks.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+
3
+ import streamlit as st
4
+ from langchain.callbacks.streamlit.streamlit_callback_handler import (
5
+ StreamlitCallbackHandler,
6
+ )
7
+ from langchain.schema.output import LLMResult
8
+
9
+
10
+ class CustomSelfQueryRetrieverCallBackHandler(StreamlitCallbackHandler):
11
+ def __init__(self):
12
+ super().__init__(st.container())
13
+ self._current_thought = None
14
+ self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery...")
15
+
16
+ def on_llm_start(
17
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
18
+ ) -> None:
19
+ self.progress_bar.progress(value=0.35, text="Communicate with LLM...")
20
+ pass
21
+
22
+ def on_chain_end(self, outputs, **kwargs) -> None:
23
+ if len(kwargs['tags']) == 0:
24
+ self.progress_bar.progress(value=0.75, text="Searching in DB...")
25
+ pass
26
+
27
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
28
+ st.markdown("### Generate filter by LLM \n"
29
+ "> Here we get `query_constructor` results \n\n")
30
+ self.progress_bar.progress(value=0.5, text="Generate filter by LLM...")
31
+ for item in response.generations:
32
+ st.markdown(f"{item[0].text}")
33
+ pass
34
+
35
+
36
+ class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
37
+ def __init__(self) -> None:
38
+ super().__init__(st.container())
39
+ self.progress_bar = st.progress(value=0.2, text="Executing ChatData SelfQuery Chain...")
40
+
41
+ def on_llm_start(self, serialized, prompts, **kwargs) -> None:
42
+ pass
43
+
44
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
45
+
46
+ if len(kwargs['tags']) != 0:
47
+ self.progress_bar.progress(value=0.5, text="We got filter info from LLM...")
48
+ st.markdown("### Generate filter by LLM \n"
49
+ "> Here we get `query_constructor` results \n\n")
50
+ for item in response.generations:
51
+ st.markdown(f"{item[0].text}")
52
+ pass
53
+
54
+ def on_chain_start(self, serialized, inputs, **kwargs) -> None:
55
+ cid = ".".join(serialized["id"])
56
+ if cid.endswith(".CustomStuffDocumentChain"):
57
+ self.progress_bar.progress(value=0.7, text="Asking LLM with related documents...")
backend/callbacks/vector_sql_callbacks.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.callbacks.streamlit.streamlit_callback_handler import (
3
+ StreamlitCallbackHandler,
4
+ )
5
+ from langchain.schema.output import LLMResult
6
+ from sql_formatter.core import format_sql
7
+
8
+
9
+ class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler):
10
+ def __init__(self) -> None:
11
+ self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
12
+ self.status_bar = st.empty()
13
+ self.prog_value = 0
14
+ self.prog_interval = 0.2
15
+
16
+ def on_llm_start(self, serialized, prompts, **kwargs) -> None:
17
+ pass
18
+
19
+ def on_llm_end(
20
+ self,
21
+ response: LLMResult,
22
+ *args,
23
+ **kwargs,
24
+ ):
25
+ text = response.generations[0][0].text
26
+ if text.replace(" ", "").upper().startswith("SELECT"):
27
+ st.markdown("### Generated Vector Search SQL Statement \n"
28
+ "> This sql statement is generated by LLM \n\n")
29
+ st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
30
+ self.prog_value += self.prog_interval
31
+ self.progress_bar.progress(
32
+ value=self.prog_value, text="Searching in DB...")
33
+
34
+ def on_chain_start(self, serialized, inputs, **kwargs) -> None:
35
+ cid = ".".join(serialized["id"])
36
+ self.prog_value += self.prog_interval
37
+ self.progress_bar.progress(
38
+ value=self.prog_value, text=f"Running Chain `{cid}`..."
39
+ )
40
+
41
+ def on_chain_end(self, outputs, **kwargs) -> None:
42
+ pass
43
+
44
+
45
+ class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler):
46
+ def __init__(self, table: str) -> None:
47
+ self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
48
+ self.status_bar = st.empty()
49
+ self.prog_value = 0
50
+ self.prog_interval = 0.1
51
+ self.table = table
52
+
53
+
backend/chains/__init__.py ADDED
File without changes
backend/chains/retrieval_qa_with_sources.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Dict, Any, Optional, List
3
+
4
+ from langchain.callbacks.manager import (
5
+ AsyncCallbackManagerForChainRun,
6
+ CallbackManagerForChainRun,
7
+ )
8
+ from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
9
+ from langchain.docstore.document import Document
10
+
11
+ from logger import logger
12
+
13
+
14
+ class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
15
+ """QA with source chain for Chat ArXiv app with references
16
+
17
+ This chain will automatically assign reference number to the article,
18
+ Then parse it back to titles or anything else.
19
+ """
20
+
21
+ def _call(
22
+ self,
23
+ inputs: Dict[str, Any],
24
+ run_manager: Optional[CallbackManagerForChainRun] = None,
25
+ ) -> Dict[str, str]:
26
+ logger.info(f"\033[91m\033[1m{self._chain_type}\033[0m")
27
+ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
28
+ accepts_run_manager = (
29
+ "run_manager" in inspect.signature(self._get_docs).parameters
30
+ )
31
+ if accepts_run_manager:
32
+ docs: List[Document] = self._get_docs(inputs, run_manager=_run_manager)
33
+ else:
34
+ docs: List[Document] = self._get_docs(inputs) # type: ignore[call-arg]
35
+
36
+ answer = self.combine_documents_chain.run(
37
+ input_documents=docs, callbacks=_run_manager.get_child(), **inputs
38
+ )
39
+ # parse source with ref_id
40
+ sources = []
41
+ ref_cnt = 1
42
+ for d in docs:
43
+ ref_id = d.metadata['ref_id']
44
+ if f"Doc #{ref_id}" in answer:
45
+ answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}")
46
+ if f"#{ref_id}" in answer:
47
+ title = d.metadata['title'].replace('\n', '')
48
+ d.metadata['ref_id'] = ref_cnt
49
+ answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
50
+ sources.append(d)
51
+ ref_cnt += 1
52
+
53
+ result: Dict[str, Any] = {
54
+ self.answer_key: answer,
55
+ self.sources_answer_key: sources,
56
+ }
57
+ if self.return_source_documents:
58
+ result["source_documents"] = docs
59
+ return result
60
+
61
+ async def _acall(
62
+ self,
63
+ inputs: Dict[str, Any],
64
+ run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
65
+ ) -> Dict[str, Any]:
66
+ raise NotImplementedError
67
+
68
+ @property
69
+ def _chain_type(self) -> str:
70
+ return "custom_retrieval_qa_with_sources_chain"
backend/chains/stuff_documents.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Tuple
2
+
3
+ from langchain.callbacks.manager import Callbacks
4
+ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
5
+ from langchain.docstore.document import Document
6
+ from langchain.schema.prompt_template import format_document
7
+
8
+
9
+ class CustomStuffDocumentChain(StuffDocumentsChain):
10
+ """Combine arxiv documents with PDF reference number"""
11
+
12
+ def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
13
+ """Construct inputs from kwargs and docs.
14
+
15
+ Format and the join all the documents together into one input with name
16
+ `self.document_variable_name`. The pluck any additional variables
17
+ from **kwargs.
18
+
19
+ Args:
20
+ docs: List of documents to format and then join into single input
21
+ **kwargs: additional inputs to chain, will pluck any other required
22
+ arguments from here.
23
+
24
+ Returns:
25
+ dictionary of inputs to LLMChain
26
+ """
27
+ # Format each document according to the prompt
28
+ doc_strings = []
29
+ for doc_id, doc in enumerate(docs):
30
+ # add temp reference number in metadata
31
+ doc.metadata.update({'ref_id': doc_id})
32
+ doc.page_content = doc.page_content.replace('\n', ' ')
33
+ doc_strings.append(format_document(doc, self.document_prompt))
34
+ # Join the documents together to put them in the prompt.
35
+ inputs = {
36
+ k: v
37
+ for k, v in kwargs.items()
38
+ if k in self.llm_chain.prompt.input_variables
39
+ }
40
+ inputs[self.document_variable_name] = self.document_separator.join(
41
+ doc_strings)
42
+ return inputs
43
+
44
+ def combine_docs(
45
+ self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
46
+ ) -> Tuple[str, dict]:
47
+ """Stuff all documents into one prompt and pass to LLM.
48
+
49
+ Args:
50
+ docs: List of documents to join together into one variable
51
+ callbacks: Optional callbacks to pass along
52
+ **kwargs: additional parameters to use to get inputs to LLMChain.
53
+
54
+ Returns:
55
+ The first element returned is the single string output. The second
56
+ element returned is a dictionary of other keys to return.
57
+ """
58
+ inputs = self._get_inputs(docs, **kwargs)
59
+ # Call predict on the LLM.
60
+ output = self.llm_chain.predict(callbacks=callbacks, **inputs)
61
+ return output, {}
62
+
63
+ @property
64
+ def _chain_type(self) -> str:
65
+ return "custom_stuff_document_chain"
backend/chat_bot/__init__.py ADDED
File without changes
backend/chat_bot/chat.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from os import environ
4
+ from time import sleep
5
+ import streamlit as st
6
+
7
+ from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
8
+ from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION_MANAGER, \
9
+ CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, USER_PRIVATE_FILES, \
10
+ EL_BUILD_KB_WITH_FILES, \
11
+ EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \
12
+ USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \
13
+ EL_UPLOAD_FILES_STATUS, EL_SELECTED_KBS, EL_UPLOAD_FILES
14
+ from backend.constants.variables import USER_INFO, USER_NAME, JUMP_QUERY_ASK, RETRIEVER_TOOLS
15
+ from backend.construct.build_agents import build_agents
16
+ from backend.chat_bot.session_manager import SessionManager
17
+ from backend.callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
18
+
19
+ from logger import logger
20
+
21
+ environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
22
+
23
+ TOOL_NAMES = {
24
+ "langchain_retriever_tool": "Self-querying retriever",
25
+ "vecsql_retriever_tool": "Vector SQL",
26
+ }
27
+
28
+
29
+ def on_chat_submit():
30
+ with st.session_state.next_round.container():
31
+ with st.chat_message("user"):
32
+ st.write(st.session_state.chat_input)
33
+ with st.chat_message("assistant"):
34
+ container = st.container()
35
+ st_callback = ChatDataAgentCallBackHandler(
36
+ container, collapse_completed_thoughts=False
37
+ )
38
+ ret = st.session_state.agent(
39
+ {"input": st.session_state.chat_input}, callbacks=[st_callback]
40
+ )
41
+ logger.info(f"ret:{ret}")
42
+
43
+
44
+ def clear_history():
45
+ if "agent" in st.session_state:
46
+ st.session_state.agent.memory.clear()
47
+
48
+
49
+ def back_to_main():
50
+ if USER_INFO in st.session_state:
51
+ del st.session_state[USER_INFO]
52
+ if USER_NAME in st.session_state:
53
+ del st.session_state[USER_NAME]
54
+ if JUMP_QUERY_ASK in st.session_state:
55
+ del st.session_state[JUMP_QUERY_ASK]
56
+ if EL_SESSION_SELECTOR in st.session_state:
57
+ del st.session_state[EL_SESSION_SELECTOR]
58
+ if CHAT_CURRENT_USER_SESSIONS in st.session_state:
59
+ del st.session_state[CHAT_CURRENT_USER_SESSIONS]
60
+
61
+
62
+ def refresh_sessions():
63
+ chat_session_manager: SessionManager = st.session_state[CHAT_SESSION_MANAGER]
64
+ current_user_name = st.session_state[USER_NAME]
65
+ current_user_sessions = chat_session_manager.list_sessions(current_user_name)
66
+
67
+ if not isinstance(current_user_sessions, dict) or not current_user_sessions:
68
+ # generate a default session for current user.
69
+ chat_session_manager.add_session(
70
+ user_id=current_user_name,
71
+ session_id=f"{current_user_name}?default",
72
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
73
+ )
74
+ st.session_state[CHAT_CURRENT_USER_SESSIONS] = chat_session_manager.list_sessions(current_user_name)
75
+ current_user_sessions = st.session_state[CHAT_CURRENT_USER_SESSIONS]
76
+ else:
77
+ st.session_state[CHAT_CURRENT_USER_SESSIONS] = current_user_sessions
78
+
79
+ # load current user files.
80
+ st.session_state[USER_PRIVATE_FILES] = st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(
81
+ current_user_name
82
+ )
83
+ # load current user private knowledge bases.
84
+ st.session_state[USER_PERSONAL_KNOWLEDGE_BASES] = \
85
+ st.session_state[CHAT_KNOWLEDGE_TABLE].list_private_knowledge_bases(current_user_name)
86
+ logger.info(f"current user name: {current_user_name}, "
87
+ f"user private knowledge bases: {st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]}, "
88
+ f"user private files: {st.session_state[USER_PRIVATE_FILES]}")
89
+ st.session_state[AVAILABLE_RETRIEVAL_TOOLS] = {
90
+ # public retrieval tools
91
+ **st.session_state[RETRIEVER_TOOLS],
92
+ # private retrieval tools
93
+ **st.session_state[CHAT_KNOWLEDGE_TABLE].as_retrieval_tools(current_user_name),
94
+ }
95
+ # print(f"sel_session is {st.session_state.sel_session}, current_user_sessions is {current_user_sessions}")
96
+ print(f"current_user_sessions is {current_user_sessions}")
97
+ st.session_state[EL_SESSION_SELECTOR] = current_user_sessions[0]
98
+
99
+
100
+ # process for session add and delete.
101
+ def on_session_change_submit():
102
+ if "session_manager" in st.session_state and "session_editor" in st.session_state:
103
+ try:
104
+ for elem in st.session_state.session_editor["added_rows"]:
105
+ if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
106
+ if elem["session_id"] != "" and "?" not in elem["session_id"]:
107
+ st.session_state.session_manager.add_session(
108
+ user_id=st.session_state.user_name,
109
+ session_id=f"{st.session_state.user_name}?{elem['session_id']}",
110
+ system_prompt=elem["system_prompt"],
111
+ )
112
+ else:
113
+ st.toast("`session_id` shouldn't be neither empty nor contain char `?`.", icon="❌")
114
+ raise KeyError(
115
+ "`session_id` shouldn't be neither empty nor contain char `?`."
116
+ )
117
+ else:
118
+ st.toast("`You should fill both `session_id` and `system_prompt` to add a column!", icon="❌")
119
+ raise KeyError(
120
+ "You should fill both `session_id` and `system_prompt` to add a column!"
121
+ )
122
+ for elem in st.session_state.session_editor["deleted_rows"]:
123
+ user_name = st.session_state[USER_NAME]
124
+ session_id = st.session_state[CHAT_CURRENT_USER_SESSIONS][elem]['session_id']
125
+ user_with_session_id = f"{user_name}?{session_id}"
126
+ st.session_state.session_manager.remove_session(session_id=user_with_session_id)
127
+ st.toast(f"session `{user_with_session_id}` removed.", icon="✅")
128
+
129
+ refresh_sessions()
130
+ except Exception as e:
131
+ sleep(2)
132
+ st.error(f"{type(e)}: {str(e)}")
133
+ finally:
134
+ st.session_state.session_editor["added_rows"] = []
135
+ st.session_state.session_editor["deleted_rows"] = []
136
+ refresh_agent()
137
+
138
+
139
+ def create_private_knowledge_base_as_tool():
140
+ current_user_name = st.session_state[USER_NAME]
141
+
142
+ if (
143
+ EL_PERSONAL_KB_NAME in st.session_state
144
+ and EL_PERSONAL_KB_DESCRIPTION in st.session_state
145
+ and EL_BUILD_KB_WITH_FILES in st.session_state
146
+ and len(st.session_state[EL_PERSONAL_KB_NAME]) > 0
147
+ and len(st.session_state[EL_PERSONAL_KB_DESCRIPTION]) > 0
148
+ and len(st.session_state[EL_BUILD_KB_WITH_FILES]) > 0
149
+ ):
150
+ st.session_state[CHAT_KNOWLEDGE_TABLE].create_private_knowledge_base(
151
+ user_id=current_user_name,
152
+ tool_name=st.session_state[EL_PERSONAL_KB_NAME],
153
+ tool_description=st.session_state[EL_PERSONAL_KB_DESCRIPTION],
154
+ files=[f["file_name"] for f in st.session_state[EL_BUILD_KB_WITH_FILES]],
155
+ )
156
+ refresh_sessions()
157
+ else:
158
+ st.session_state[EL_UPLOAD_FILES_STATUS].error(
159
+ "You should fill all fields to build up a tool!"
160
+ )
161
+ sleep(2)
162
+
163
+
164
+ def remove_private_knowledge_bases():
165
+ if EL_PERSONAL_KB_NEEDS_REMOVE in st.session_state and st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]:
166
+ private_knowledge_bases_needs_remove = st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]
167
+ private_knowledge_base_names = [item["tool_name"] for item in private_knowledge_bases_needs_remove]
168
+ # remove these private knowledge bases.
169
+ st.session_state[CHAT_KNOWLEDGE_TABLE].remove_private_knowledge_bases(
170
+ user_id=st.session_state[USER_NAME],
171
+ private_knowledge_bases=private_knowledge_base_names
172
+ )
173
+ refresh_sessions()
174
+ else:
175
+ st.session_state[EL_UPLOAD_FILES_STATUS].error(
176
+ "You should specify at least one private knowledge base to delete!"
177
+ )
178
+ time.sleep(2)
179
+
180
+
181
+ def refresh_agent():
182
+ with st.spinner("Initializing session..."):
183
+ user_name = st.session_state[USER_NAME]
184
+ session_id = st.session_state[EL_SESSION_SELECTOR]['session_id']
185
+ user_with_session_id = f"{user_name}?{session_id}"
186
+
187
+ if EL_SELECTED_KBS in st.session_state:
188
+ selected_knowledge_bases = st.session_state[EL_SELECTED_KBS]
189
+ else:
190
+ selected_knowledge_bases = ["Wikipedia + Vector SQL"]
191
+
192
+ logger.info(f"selected_knowledge_bases: {selected_knowledge_bases}")
193
+ if EL_SESSION_SELECTOR in st.session_state:
194
+ system_prompt = st.session_state[EL_SESSION_SELECTOR]["system_prompt"]
195
+ else:
196
+ system_prompt = DEFAULT_SYSTEM_PROMPT
197
+
198
+ st.session_state["agent"] = build_agents(
199
+ session_id=user_with_session_id,
200
+ tool_names=selected_knowledge_bases,
201
+ system_prompt=system_prompt
202
+ )
203
+
204
+
205
+ def add_file():
206
+ user_name = st.session_state[USER_NAME]
207
+ if EL_UPLOAD_FILES not in st.session_state or len(st.session_state[EL_UPLOAD_FILES]) == 0:
208
+ st.session_state[EL_UPLOAD_FILES_STATUS].error("Please upload files!", icon="⚠️")
209
+ sleep(2)
210
+ return
211
+ try:
212
+ st.session_state[EL_UPLOAD_FILES_STATUS].info("Uploading...")
213
+ st.session_state[CHAT_KNOWLEDGE_TABLE].add_by_file(
214
+ user_id=user_name,
215
+ files=st.session_state[EL_UPLOAD_FILES]
216
+ )
217
+ refresh_sessions()
218
+ except ValueError as e:
219
+ st.session_state[EL_UPLOAD_FILES_STATUS].error("Failed to upload! " + str(e))
220
+ sleep(2)
221
+
222
+
223
+ def clear_files():
224
+ st.session_state[CHAT_KNOWLEDGE_TABLE].clear(user_id=st.session_state[USER_NAME])
225
+ refresh_sessions()
backend/chat_bot/json_decoder.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import datetime
3
+
4
+
5
+ class CustomJSONEncoder(json.JSONEncoder):
6
+ def default(self, obj):
7
+ if isinstance(obj, datetime.datetime):
8
+ return datetime.datetime.isoformat(obj)
9
+ return json.JSONEncoder.default(self, obj)
10
+
11
+
12
+ class CustomJSONDecoder(json.JSONDecoder):
13
+ def __init__(self, *args, **kwargs):
14
+ json.JSONDecoder.__init__(
15
+ self, object_hook=self.object_hook, *args, **kwargs)
16
+
17
+ def object_hook(self, source):
18
+ for k, v in source.items():
19
+ if isinstance(v, str):
20
+ try:
21
+ source[k] = datetime.datetime.fromisoformat(str(v))
22
+ except:
23
+ pass
24
+ return source
backend/chat_bot/message_converter.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import time
4
+ from typing import Any
5
+
6
+ from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
7
+ from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage, ChatMessage, FunctionMessage
8
+ from langchain.schema.messages import ToolMessage
9
+ from sqlalchemy.orm import declarative_base
10
+
11
+ from backend.chat_bot.tools import create_message_history_table
12
+
13
+
14
+ def _message_from_dict(message: dict) -> BaseMessage:
15
+ _type = message["type"]
16
+ if _type == "human":
17
+ return HumanMessage(**message["data"])
18
+ elif _type == "ai":
19
+ return AIMessage(**message["data"])
20
+ elif _type == "system":
21
+ return SystemMessage(**message["data"])
22
+ elif _type == "chat":
23
+ return ChatMessage(**message["data"])
24
+ elif _type == "function":
25
+ return FunctionMessage(**message["data"])
26
+ elif _type == "tool":
27
+ return ToolMessage(**message["data"])
28
+ elif _type == "AIMessageChunk":
29
+ message["data"]["type"] = "ai"
30
+ return AIMessage(**message["data"])
31
+ else:
32
+ raise ValueError(f"Got unexpected message type: {_type}")
33
+
34
+
35
+ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
36
+ """The default message converter for SQLChatMessageHistory."""
37
+
38
+ def __init__(self, table_name: str):
39
+ super().__init__(table_name)
40
+ self.model_class = create_message_history_table(table_name, declarative_base())
41
+
42
+ def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
43
+ time_stamp = time.time()
44
+ msg_id = hashlib.sha256(
45
+ f"{session_id}_{message}_{time_stamp}".encode('utf-8')).hexdigest()
46
+ user_id, _ = session_id.split("?")
47
+ return self.model_class(
48
+ id=time_stamp,
49
+ msg_id=msg_id,
50
+ user_id=user_id,
51
+ session_id=session_id,
52
+ type=message.type,
53
+ addtionals=json.dumps(message.additional_kwargs),
54
+ message=json.dumps({
55
+ "type": message.type,
56
+ "additional_kwargs": {"timestamp": time_stamp},
57
+ "data": message.dict()})
58
+ )
59
+
60
+ def from_sql_model(self, sql_message: Any) -> BaseMessage:
61
+ msg_dump = json.loads(sql_message.message)
62
+ msg = _message_from_dict(msg_dump)
63
+ msg.additional_kwargs = msg_dump["additional_kwargs"]
64
+ return msg
65
+
66
+ def get_sql_model_class(self) -> Any:
67
+ return self.model_class
backend/chat_bot/private_knowledge_base.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from datetime import datetime
3
+ from typing import List, Optional
4
+
5
+ import pandas as pd
6
+ from clickhouse_connect import get_client
7
+ from langchain.schema.embeddings import Embeddings
8
+ from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
9
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
10
+
11
+ from backend.chat_bot.tools import parse_files, extract_embedding
12
+ from backend.construct.build_retriever_tool import create_retriever_tool
13
+ from logger import logger
14
+
15
+
16
+ class ChatBotKnowledgeTable:
17
+ def __init__(self, host, port, username, password,
18
+ embedding: Embeddings, parser_api_key: str, db="chat",
19
+ kb_table="private_kb", tool_table="private_tool") -> None:
20
+ super().__init__()
21
+ personal_files_schema_ = f"""
22
+ CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
23
+ entity_id String,
24
+ file_name String,
25
+ text String,
26
+ user_id String,
27
+ created_by DateTime,
28
+ vector Array(Float32),
29
+ CONSTRAINT cons_vec_len CHECK length(vector) = 768,
30
+ VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
31
+ ) ENGINE = ReplacingMergeTree ORDER BY entity_id
32
+ """
33
+
34
+ # `tool_name` represent private knowledge database name.
35
+ private_knowledge_base_schema_ = f"""
36
+ CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
37
+ tool_id String,
38
+ tool_name String,
39
+ file_names Array(String),
40
+ user_id String,
41
+ created_by DateTime,
42
+ tool_description String
43
+ ) ENGINE = ReplacingMergeTree ORDER BY tool_id
44
+ """
45
+ self.personal_files_table = kb_table
46
+ self.private_knowledge_base_table = tool_table
47
+ config = MyScaleSettings(
48
+ host=host,
49
+ port=port,
50
+ username=username,
51
+ password=password,
52
+ database=db,
53
+ table=kb_table,
54
+ )
55
+ self.client = get_client(
56
+ host=config.host,
57
+ port=config.port,
58
+ username=config.username,
59
+ password=config.password,
60
+ )
61
+ self.client.command("SET allow_experimental_object_type=1")
62
+ self.client.command(personal_files_schema_)
63
+ self.client.command(private_knowledge_base_schema_)
64
+ self.parser_api_key = parser_api_key
65
+ self.vector_store = MyScaleWithoutJSON(
66
+ embedding=embedding,
67
+ config=config,
68
+ must_have_cols=["file_name", "text", "created_by"],
69
+ )
70
+
71
+ # List all files with given `user_id`
72
+ def list_files(self, user_id: str):
73
+ query = f"""
74
+ SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph,
75
+ arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
76
+ FROM {self.vector_store.config.database}.{self.personal_files_table}
77
+ WHERE user_id = '{user_id}' GROUP BY file_name
78
+ """
79
+ return [r for r in self.vector_store.client.query(query).named_results()]
80
+
81
+ # Parse and embedding files
82
+ def add_by_file(self, user_id, files: List[UploadedFile]):
83
+ data = parse_files(self.parser_api_key, user_id, files)
84
+ data = extract_embedding(self.vector_store.embeddings, data)
85
+ self.vector_store.client.insert_df(
86
+ table=self.personal_files_table,
87
+ df=pd.DataFrame(data),
88
+ database=self.vector_store.config.database,
89
+ )
90
+
91
+ # Remove all files and private_knowledge_bases with given `user_id`
92
+ def clear(self, user_id: str):
93
+ self.vector_store.client.command(
94
+ f"DELETE FROM {self.vector_store.config.database}.{self.personal_files_table} "
95
+ f"WHERE user_id='{user_id}'"
96
+ )
97
+ query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
98
+ WHERE user_id = '{user_id}'"""
99
+ self.vector_store.client.command(query)
100
+
101
+ def create_private_knowledge_base(
102
+ self, user_id: str, tool_name: str, tool_description: str, files: Optional[List[str]] = None
103
+ ):
104
+ self.vector_store.client.insert_df(
105
+ self.private_knowledge_base_table,
106
+ pd.DataFrame(
107
+ [
108
+ {
109
+ "tool_id": hashlib.sha256(
110
+ (user_id + tool_name).encode("utf-8")
111
+ ).hexdigest(),
112
+ "tool_name": tool_name, # tool_name represent user's private knowledge base.
113
+ "file_names": files,
114
+ "user_id": user_id,
115
+ "created_by": datetime.now(),
116
+ "tool_description": tool_description,
117
+ }
118
+ ]
119
+ ),
120
+ database=self.vector_store.config.database,
121
+ )
122
+
123
+ # Show all private knowledge bases with given `user_id`
124
+ def list_private_knowledge_bases(self, user_id: str, private_knowledge_base=None):
125
+ extended_where = f"AND tool_name = '{private_knowledge_base}'" if private_knowledge_base else ""
126
+ query = f"""
127
+ SELECT tool_name, tool_description, length(file_names)
128
+ FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
129
+ WHERE user_id = '{user_id}' {extended_where}
130
+ """
131
+ return [r for r in self.vector_store.client.query(query).named_results()]
132
+
133
+ def remove_private_knowledge_bases(self, user_id: str, private_knowledge_bases: List[str]):
134
+ unique_list = list(set(private_knowledge_bases))
135
+ unique_list = ",".join([f"'{t}'" for t in unique_list])
136
+ query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
137
+ WHERE user_id = '{user_id}' AND tool_name IN [{unique_list}]"""
138
+ self.vector_store.client.command(query)
139
+
140
+ def as_retrieval_tools(self, user_id, tool_name=None):
141
+ logger.info(f"")
142
+ private_knowledge_bases = self.list_private_knowledge_bases(user_id=user_id, private_knowledge_base=tool_name)
143
+ retrievers = {}
144
+ for private_kb in private_knowledge_bases:
145
+ file_names_sql = f"""
146
+ SELECT arrayJoin(file_names) FROM (
147
+ SELECT file_names
148
+ FROM chat.private_tool
149
+ WHERE user_id = '{user_id}' AND tool_name = '{private_kb["tool_name"]}'
150
+ )
151
+ """
152
+ logger.info(f"user_id is {user_id}, file_names_sql is {file_names_sql}")
153
+ res = self.client.query(file_names_sql)
154
+ file_names = []
155
+ for line in res.result_rows:
156
+ file_names.append(line[0])
157
+ file_names = ', '.join(f"'{item}'" for item in file_names)
158
+ logger.info(f"user_id is {user_id}, file_names is {file_names}")
159
+ retrievers[private_kb["tool_name"]] = create_retriever_tool(
160
+ self.vector_store.as_retriever(
161
+ search_kwargs={"where_str": f"user_id='{user_id}' AND file_name IN ({file_names})"},
162
+ ),
163
+ tool_name=private_kb["tool_name"],
164
+ description=private_kb["tool_description"],
165
+ )
166
+ return retrievers
167
+
backend/chat_bot/session_manager.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from backend.chat_bot.tools import create_session_table, create_message_history_table
4
+ from backend.constants.variables import GLOBAL_CONFIG
5
+
6
+ try:
7
+ from sqlalchemy.orm import declarative_base
8
+ except ImportError:
9
+ from sqlalchemy.ext.declarative import declarative_base
10
+ from datetime import datetime
11
+ from sqlalchemy import orm, create_engine
12
+ from logger import logger
13
+
14
+
15
+ def get_sessions(engine, model_class, user_id):
16
+ with orm.sessionmaker(engine)() as session:
17
+ result = (
18
+ session.query(model_class)
19
+ .where(
20
+ model_class.session_id == user_id
21
+ )
22
+ .order_by(model_class.create_by.desc())
23
+ )
24
+ return json.loads(result)
25
+
26
+
27
+ class SessionManager:
28
+ def __init__(
29
+ self,
30
+ session_state,
31
+ host,
32
+ port,
33
+ username,
34
+ password,
35
+ db='chat',
36
+ session_table='sessions',
37
+ msg_table='chat_memory'
38
+ ) -> None:
39
+ if GLOBAL_CONFIG.myscale_enable_https == False:
40
+ conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=http'
41
+ else:
42
+ conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
43
+ self.engine = create_engine(conn_str, echo=False)
44
+ self.session_model_class = create_session_table(
45
+ session_table, declarative_base())
46
+ self.session_model_class.metadata.create_all(self.engine)
47
+ self.msg_model_class = create_message_history_table(msg_table, declarative_base())
48
+ self.msg_model_class.metadata.create_all(self.engine)
49
+ self.session_orm = orm.sessionmaker(self.engine)
50
+ self.session_state = session_state
51
+
52
+ def list_sessions(self, user_id: str):
53
+ with self.session_orm() as session:
54
+ result = (
55
+ session.query(self.session_model_class)
56
+ .where(
57
+ self.session_model_class.user_id == user_id
58
+ )
59
+ .order_by(self.session_model_class.create_by.desc())
60
+ )
61
+ sessions = []
62
+ for r in result:
63
+ sessions.append({
64
+ "session_id": r.session_id.split("?")[-1],
65
+ "system_prompt": r.system_prompt,
66
+ })
67
+ return sessions
68
+
69
+ # Update sys_prompt with given session_id
70
+ def modify_system_prompt(self, session_id, sys_prompt):
71
+ with self.session_orm() as session:
72
+ obj = session.query(self.session_model_class).where(
73
+ self.session_model_class.session_id == session_id).first()
74
+ if obj:
75
+ obj.system_prompt = sys_prompt
76
+ session.commit()
77
+ else:
78
+ logger.warning(f"Session {session_id} not found")
79
+
80
+ # Add a session(session_id, sys_prompt)
81
+ def add_session(self, user_id: str, session_id: str, system_prompt: str, **kwargs):
82
+ with self.session_orm() as session:
83
+ elem = self.session_model_class(
84
+ user_id=user_id, session_id=session_id, system_prompt=system_prompt,
85
+ create_by=datetime.now(), additionals=json.dumps(kwargs)
86
+ )
87
+ session.add(elem)
88
+ session.commit()
89
+
90
+ # Remove a session and related chat history.
91
+ def remove_session(self, session_id: str):
92
+ with self.session_orm() as session:
93
+ # remove session
94
+ session.query(self.session_model_class).where(self.session_model_class.session_id == session_id).delete()
95
+ # remove related chat history.
96
+ session.query(self.msg_model_class).where(self.msg_model_class.session_id == session_id).delete()
backend/chat_bot/tools.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from datetime import datetime
3
+ from multiprocessing.pool import ThreadPool
4
+ from typing import List
5
+
6
+ import requests
7
+ from clickhouse_sqlalchemy import types, engines
8
+ from langchain.schema.embeddings import Embeddings
9
+ from sqlalchemy import Column, Text
10
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
11
+
12
+
13
+ def parse_files(api_key, user_id, files: List[UploadedFile]):
14
+ def parse_file(file: UploadedFile):
15
+ headers = {
16
+ "accept": "application/json",
17
+ "unstructured-api-key": api_key,
18
+ }
19
+ data = {"strategy": "auto", "ocr_languages": ["eng"]}
20
+ file_hash = hashlib.sha256(file.read()).hexdigest()
21
+ file_data = {"files": (file.name, file.getvalue(), file.type)}
22
+ response = requests.post(
23
+ url="https://api.unstructured.io/general/v0/general",
24
+ headers=headers,
25
+ data=data,
26
+ files=file_data
27
+ )
28
+ json_response = response.json()
29
+ if response.status_code != 200:
30
+ raise ValueError(str(json_response))
31
+ texts = [
32
+ {
33
+ "text": t["text"],
34
+ "file_name": t["metadata"]["filename"],
35
+ "entity_id": hashlib.sha256(
36
+ (file_hash + t["text"]).encode()
37
+ ).hexdigest(),
38
+ "user_id": user_id,
39
+ "created_by": datetime.now(),
40
+ }
41
+ for t in json_response
42
+ if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10
43
+ ]
44
+ return texts
45
+
46
+ with ThreadPool(8) as p:
47
+ rows = []
48
+ for r in p.imap_unordered(parse_file, files):
49
+ rows.extend(r)
50
+ return rows
51
+
52
+
53
+ def extract_embedding(embeddings: Embeddings, texts):
54
+ if len(texts) > 0:
55
+ embeddings = embeddings.embed_documents(
56
+ [t["text"] for _, t in enumerate(texts)])
57
+ for i, _ in enumerate(texts):
58
+ texts[i]["vector"] = embeddings[i]
59
+ return texts
60
+ raise ValueError("No texts extracted!")
61
+
62
+
63
+ def create_message_history_table(table_name: str, base_class):
64
+ class Message(base_class):
65
+ __tablename__ = table_name
66
+ id = Column(types.Float64)
67
+ session_id = Column(Text)
68
+ user_id = Column(Text)
69
+ msg_id = Column(Text, primary_key=True)
70
+ type = Column(Text)
71
+ # should be additions, formal developer mistake spell it.
72
+ addtionals = Column(Text)
73
+ message = Column(Text)
74
+ __table_args__ = (
75
+ engines.MergeTree(
76
+ partition_by='session_id',
77
+ order_by=('id', 'msg_id')
78
+ ),
79
+ {'comment': 'Store Chat History'}
80
+ )
81
+
82
+ return Message
83
+
84
+
85
+ def create_session_table(table_name: str, DynamicBase):
86
+ class Session(DynamicBase):
87
+ __tablename__ = table_name
88
+ user_id = Column(Text)
89
+ session_id = Column(Text, primary_key=True)
90
+ system_prompt = Column(Text)
91
+ # represent create time.
92
+ create_by = Column(types.DateTime)
93
+ # should be additions, formal developer mistake spell it.
94
+ additionals = Column(Text)
95
+ __table_args__ = (
96
+ engines.MergeTree(order_by=session_id),
97
+ {'comment': 'Store Session and Prompts'}
98
+ )
99
+
100
+ return Session
backend/constants/__init__.py ADDED
File without changes
backend/constants/myscale_tables.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import streamlit as st
3
+ from langchain.chains.query_constructor.schema import AttributeInfo
4
+ from langchain_community.embeddings import SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings
5
+ from langchain.prompts import PromptTemplate
6
+
7
+ from backend.types.table_config import TableConfig
8
+
9
+
10
+ def hint_arxiv():
11
+ st.markdown("Here we provide some query samples.")
12
+ st.markdown("- If you want to search papers with filters")
13
+ st.markdown("1. ```What is a Bayesian network? Please use articles published later than Feb 2018 and with more "
14
+ "than 2 categories and whose title like `computer` and must have `cs.CV` in its category. ```")
15
+ st.markdown("2. ```What is a Bayesian network? Please use articles published later than Feb 2018```")
16
+ st.markdown("- If you want to ask questions based on arxiv papers stored in MyScaleDB")
17
+ st.markdown("1. ```Did Geoffrey Hinton wrote paper about Capsule Neural Networks?```")
18
+ st.markdown("2. ```Introduce some applications of GANs published around 2019.```")
19
+ st.markdown("3. ```请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些```")
20
+
21
+
22
+ def hint_sql_arxiv():
23
+ st.markdown('''```sql
24
+ CREATE TABLE default.ChatArXiv (
25
+ `abstract` String,
26
+ `id` String,
27
+ `vector` Array(Float32),
28
+ `metadata` Object('JSON'),
29
+ `pubdate` DateTime,
30
+ `title` String,
31
+ `categories` Array(String),
32
+ `authors` Array(String),
33
+ `comment` String,
34
+ `primary_category` String,
35
+ VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
36
+ CONSTRAINT vec_len CHECK length(vector) = 768)
37
+ ENGINE = ReplacingMergeTree ORDER BY id
38
+ ```''')
39
+
40
+
41
+ def hint_wiki():
42
+ st.markdown("Here we provide some query samples.")
43
+ st.markdown("1. ```Which company did Elon Musk found?```")
44
+ st.markdown("2. ```What is Iron Gwazi?```")
45
+ st.markdown("3. ```苹果的发源地是哪里?```")
46
+ st.markdown("4. ```What is a Ring in mathematics?```")
47
+ st.markdown("5. ```The producer of Rick and Morty.```")
48
+ st.markdown("6. ```How low is the temperature on Pluto?```")
49
+
50
+
51
+ def hint_sql_wiki():
52
+ st.markdown('''```sql
53
+ CREATE TABLE wiki.Wikipedia (
54
+ `id` String,
55
+ `title` String,
56
+ `text` String,
57
+ `url` String,
58
+ `wiki_id` UInt64,
59
+ `views` Float32,
60
+ `paragraph_id` UInt64,
61
+ `langs` UInt32,
62
+ `emb` Array(Float32),
63
+ VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
64
+ CONSTRAINT emb_len CHECK length(emb) = 768)
65
+ ENGINE = ReplacingMergeTree ORDER BY id
66
+ ```''')
67
+
68
+
69
+ MYSCALE_TABLES: Dict[str, TableConfig] = {
70
+ 'Wikipedia': TableConfig(
71
+ database="wiki",
72
+ table="Wikipedia",
73
+ table_contents="Snapshort from Wikipedia for 2022. All in English.",
74
+ hint=hint_wiki,
75
+ hint_sql=hint_sql_wiki,
76
+ # doc_prompt 对 qa source chain 有用
77
+ doc_prompt=PromptTemplate(
78
+ input_variables=["page_content", "url", "title", "ref_id", "views"],
79
+ template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"
80
+ ),
81
+ metadata_col_attributes=[
82
+ AttributeInfo(name="title", description="title of the wikipedia page", type="string"),
83
+ AttributeInfo(name="text", description="paragraph from this wiki page", type="string"),
84
+ AttributeInfo(name="views", description="number of views", type="float")
85
+ ],
86
+ must_have_col_names=['id', 'title', 'url', 'text', 'views'],
87
+ vector_col_name="emb",
88
+ text_col_name="text",
89
+ metadata_col_name="metadata",
90
+ emb_model=lambda: SentenceTransformerEmbeddings(
91
+ model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
92
+ ),
93
+ tool_desc=("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages")
94
+ ),
95
+ 'ArXiv Papers': TableConfig(
96
+ database="default",
97
+ table="ChatArXiv",
98
+ table_contents="Snapshort from Wikipedia for 2022. All in English.",
99
+ hint=hint_arxiv,
100
+ hint_sql=hint_sql_arxiv,
101
+ doc_prompt=PromptTemplate(
102
+ input_variables=["page_content", "id", "title", "ref_id", "authors", "pubdate", "categories"],
103
+ template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\t"
104
+ "Date of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"
105
+ ),
106
+ metadata_col_attributes=[
107
+ AttributeInfo(name="pubdate", description="The year the paper is published", type="timestamp"),
108
+ AttributeInfo(name="authors", description="List of author names", type="list[string]"),
109
+ AttributeInfo(name="title", description="Title of the paper", type="string"),
110
+ AttributeInfo(name="categories", description="arxiv categories to this paper", type="list[string]"),
111
+ AttributeInfo(name="length(categories)", description="length of arxiv categories to this paper", type="int")
112
+ ],
113
+ must_have_col_names=['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
114
+ vector_col_name="vector",
115
+ text_col_name="abstract",
116
+ metadata_col_name="metadata",
117
+ emb_model=lambda: HuggingFaceInstructEmbeddings(
118
+ model_name='hkunlp/instructor-xl',
119
+ embed_instruction="Represent the question for retrieving supporting scientific papers: "
120
+ ),
121
+ tool_desc=(
122
+ "search_among_scientific_papers",
123
+ "Searches among scientific papers from ArXiv and returns research papers"
124
+ )
125
+ )
126
+ }
127
+
128
+ ALL_TABLE_NAME: List[str] = [config.table for config in MYSCALE_TABLES.values()]
backend/constants/prompts.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import ChatPromptTemplate, \
2
+ SystemMessagePromptTemplate, HumanMessagePromptTemplate
3
+
4
+ DEFAULT_SYSTEM_PROMPT = (
5
+ "Do your best to answer the questions. "
6
+ "Feel free to use any tools available to look up "
7
+ "relevant information. Please keep all details in query "
8
+ "when calling search functions."
9
+ )
10
+
11
+ COMBINE_PROMPT_TEMPLATE = (
12
+ "You are a helpful document assistant. "
13
+ "Your task is to provide information and answer any questions related to documents given below. "
14
+ "You should use the sections, title and abstract of the selected documents as your source of information "
15
+ "and try to provide concise and accurate answers to any questions asked by the user. "
16
+ "If you are unable to find relevant information in the given sections, "
17
+ "you will need to let the user know that the source does not contain relevant information but still try to "
18
+ "provide an answer based on your general knowledge. You must refer to the corresponding section name and page "
19
+ "that you refer to when answering. "
20
+ "The following is the related information about the document that will help you answer users' questions, "
21
+ "you MUST answer it using question's language:\n\n {summaries} "
22
+ "Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n"
23
+ )
24
+
25
+ COMBINE_PROMPT = ChatPromptTemplate.from_strings(
26
+ string_messages=[(SystemMessagePromptTemplate, COMBINE_PROMPT_TEMPLATE),
27
+ (HumanMessagePromptTemplate, '{question}')])
28
+
29
+ MYSCALE_PROMPT = """
30
+ You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
31
+ MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
32
+ When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
33
+
34
+ *NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array.
35
+
36
+ Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MyScale. You should only order according to the distance function.
37
+ Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
38
+ Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
39
+ Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema.
40
+ Pay attention to the data type when using functions. Always use `AND` to connect conditions in `WHERE` and never use comma.
41
+ Make sure you never write an isolated `WHERE` keyword and never use undesired condition to conrtain the query.
42
+
43
+ Use the following format:
44
+
45
+ ======== table info ========
46
+ <some table infos>
47
+
48
+ Question: "Question here"
49
+ SQLQuery: "SQL Query to run"
50
+
51
+
52
+ Here are some examples:
53
+
54
+ ======== table info ========
55
+ CREATE TABLE "ChatPaper" (
56
+ abstract String,
57
+ id String,
58
+ vector Array(Float32),
59
+ ) ENGINE = ReplicatedReplacingMergeTree()
60
+ ORDER BY id
61
+ PRIMARY KEY id
62
+
63
+ Question: What is Feartue Pyramid Network?
64
+ SQLQuery: SELECT ChatPaper.abstract, ChatPaper.id FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
65
+
66
+
67
+ ======== table info ========
68
+ CREATE TABLE "ChatPaper" (
69
+ abstract String,
70
+ id String,
71
+ vector Array(Float32),
72
+ categories Array(String),
73
+ pubdate DateTime,
74
+ title String,
75
+ authors Array(String),
76
+ primary_category String
77
+ ) ENGINE = ReplicatedReplacingMergeTree()
78
+ ORDER BY id
79
+ PRIMARY KEY id
80
+
81
+ Question: What is PaperRank? What is the contribution of those works? Use paper with more than 2 categories.
82
+ SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper WHERE length(categories) > 2 ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
83
+
84
+
85
+ ======== table info ========
86
+ CREATE TABLE "ChatArXiv" (
87
+ primary_category String
88
+ categories Array(String),
89
+ pubdate DateTime,
90
+ abstract String,
91
+ title String,
92
+ paper_id String,
93
+ vector Array(Float32),
94
+ authors Array(String),
95
+ ) ENGINE = MergeTree()
96
+ ORDER BY paper_id
97
+ PRIMARY KEY paper_id
98
+
99
+ Question: Did Geoffrey Hinton wrote about Capsule Neural Networks? Please use articles published later than 2021.
100
+ SQLQuery: SELECT ChatArXiv.title, ChatArXiv.paper_id, ChatArXiv.authors FROM ChatArXiv WHERE has(authors, 'Geoffrey Hinton') AND pubdate > parseDateTimeBestEffort('2021-01-01') ORDER BY DISTANCE(vector, NeuralArray(Capsule Neural Networks)) LIMIT {top_k}
101
+
102
+
103
+ ======== table info ========
104
+ CREATE TABLE "PaperDatabase" (
105
+ abstract String,
106
+ categories Array(String),
107
+ vector Array(Float32),
108
+ pubdate DateTime,
109
+ id String,
110
+ comments String,
111
+ title String,
112
+ authors Array(String),
113
+ primary_category String
114
+ ) ENGINE = MergeTree()
115
+ ORDER BY id
116
+ PRIMARY KEY id
117
+
118
+ Question: Find papers whose abstract has Mutual Information in it.
119
+ SQLQuery: SELECT PaperDatabase.title, PaperDatabase.id FROM PaperDatabase WHERE abstract ILIKE '%Mutual Information%' ORDER BY DISTANCE(vector, NeuralArray(Mutual Information)) LIMIT {top_k}
120
+
121
+
122
+ Let's begin:
123
+
124
+ ======== table info ========
125
+ {table_info}
126
+
127
+ Question: {input}
128
+ SQLQuery: """
backend/constants/streamlit_keys.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA_INITIALIZE_NOT_STATED = "data_initialize_not_started"
2
+ DATA_INITIALIZE_STARTED = "data_initialize_started"
3
+ DATA_INITIALIZE_COMPLETED = "data_initialize_completed"
4
+
5
+
6
+ CHAT_SESSION = "sel_sess"
7
+ CHAT_KNOWLEDGE_TABLE = "private_kb"
8
+
9
+ CHAT_SESSION_MANAGER = "session_manager"
10
+ CHAT_CURRENT_USER_SESSIONS = "current_sessions"
11
+
12
+ EL_SESSION_SELECTOR = "el_session_selector"
13
+
14
+ # all personal knowledge bases under a specific user.
15
+ USER_PERSONAL_KNOWLEDGE_BASES = "user_tools"
16
+ # all personal files under a specific user.
17
+ USER_PRIVATE_FILES = "user_files"
18
+ # public and personal knowledge bases.
19
+ AVAILABLE_RETRIEVAL_TOOLS = "tools_with_users"
20
+
21
+ EL_PERSONAL_KB_NEEDS_REMOVE = "el_personal_kb_needs_remove"
22
+
23
+ # files needs upload
24
+ EL_UPLOAD_FILES = "el_upload_files"
25
+ EL_UPLOAD_FILES_STATUS = "el_upload_files_status"
26
+
27
+ # use these files to build private knowledge base
28
+ EL_BUILD_KB_WITH_FILES = "el_build_kb_with_files"
29
+ # build a personal kb, given name.
30
+ EL_PERSONAL_KB_NAME = "el_personal_kb_name"
31
+ # build a personal kb, given description.
32
+ EL_PERSONAL_KB_DESCRIPTION = "el_personal_kb_description"
33
+
34
+ # knowledge bases selected by user.
35
+ EL_SELECTED_KBS = "el_selected_kbs"
backend/constants/variables.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.types.global_config import GlobalConfig
2
+
3
+ # ***** str variables ***** #
4
+ EMBEDDING_MODEL_PREFIX = "embedding_model"
5
+ CHAINS_RETRIEVERS_MAPPING = "sel_map_obj"
6
+ LANGCHAIN_RETRIEVER = "langchain_retriever"
7
+ VECTOR_SQL_RETRIEVER = "vecsql_retriever"
8
+ TABLE_EMBEDDINGS_MAPPING = "embeddings"
9
+ RETRIEVER_TOOLS = "tools"
10
+ DATA_INITIALIZE_STATUS = "data_initialized"
11
+ UI_INITIALIZED = "ui_initialized"
12
+ JUMP_QUERY_ASK = "jump_query_ask"
13
+ USER_NAME = "user_name"
14
+ USER_INFO = "user_info"
15
+
16
+ DIVIDER_HTML = """
17
+ <div style="
18
+ height: 4px;
19
+ background: linear-gradient(to right, red, orange, yellow, green, blue, indigo, violet);
20
+ margin-top: 20px;
21
+ margin-bottom: 20px;
22
+ "></div>
23
+ """
24
+
25
+ DIVIDER_THIN_HTML = """
26
+ <div style="
27
+ height: 2px;
28
+ background: linear-gradient(to right, blue, darkslateblue, indigo, violet);
29
+ margin-top: 20px;
30
+ margin-bottom: 20px;
31
+ "></div>
32
+ """
33
+
34
+
35
+ class RetrieverButtons:
36
+ vector_sql_query_from_db = "vector_sql_query_from_db"
37
+ vector_sql_query_with_llm = "vector_sql_query_with_llm"
38
+ self_query_from_db = "self_query_from_db"
39
+ self_query_with_llm = "self_query_with_llm"
40
+
41
+
42
+ GLOBAL_CONFIG = GlobalConfig()
43
+
44
+
45
+ def update_global_config(new_config: GlobalConfig):
46
+ global GLOBAL_CONFIG
47
+ GLOBAL_CONFIG.openai_api_base = new_config.openai_api_base
48
+ GLOBAL_CONFIG.openai_api_key = new_config.openai_api_key
49
+ GLOBAL_CONFIG.auth0_client_id = new_config.auth0_client_id
50
+ GLOBAL_CONFIG.auth0_domain = new_config.auth0_domain
51
+ GLOBAL_CONFIG.myscale_user = new_config.myscale_user
52
+ GLOBAL_CONFIG.myscale_password = new_config.myscale_password
53
+ GLOBAL_CONFIG.myscale_host = new_config.myscale_host
54
+ GLOBAL_CONFIG.myscale_port = new_config.myscale_port
55
+ GLOBAL_CONFIG.query_model = new_config.query_model
56
+ GLOBAL_CONFIG.chat_model = new_config.chat_model
57
+ GLOBAL_CONFIG.untrusted_api = new_config.untrusted_api
58
+ GLOBAL_CONFIG.myscale_enable_https = new_config.myscale_enable_https
backend/construct/__init__.py ADDED
File without changes
backend/construct/build_agents.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Sequence, List
3
+
4
+ import streamlit as st
5
+ from langchain.agents import AgentExecutor
6
+ from langchain.schema.language_model import BaseLanguageModel
7
+ from langchain.tools import BaseTool
8
+
9
+ from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter
10
+ from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
11
+ from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS
12
+ from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS
13
+ from logger import logger
14
+
15
+ try:
16
+ from sqlalchemy.orm import declarative_base
17
+ except ImportError:
18
+ from sqlalchemy.ext.declarative import declarative_base
19
+ from langchain.chat_models import ChatOpenAI
20
+ from langchain.prompts.chat import MessagesPlaceholder
21
+ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
22
+ from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
23
+ from langchain.schema.messages import SystemMessage
24
+ from langchain.memory import SQLChatMessageHistory
25
+
26
+
27
+ def create_agent_executor(
28
+ agent_name: str,
29
+ session_id: str,
30
+ llm: BaseLanguageModel,
31
+ tools: Sequence[BaseTool],
32
+ system_prompt: str,
33
+ **kwargs
34
+ ) -> AgentExecutor:
35
+ agent_name = agent_name.replace(" ", "_")
36
+ conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}'
37
+ chat_memory = SQLChatMessageHistory(
38
+ session_id,
39
+ connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https',
40
+ custom_message_converter=DefaultClickhouseMessageConverter(agent_name))
41
+ memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
42
+
43
+ prompt = OpenAIFunctionsAgent.create_prompt(
44
+ system_message=SystemMessage(content=system_prompt),
45
+ extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
46
+ )
47
+ agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
48
+ return AgentExecutor(
49
+ agent=agent,
50
+ tools=tools,
51
+ memory=memory,
52
+ verbose=True,
53
+ return_intermediate_steps=True,
54
+ **kwargs
55
+ )
56
+
57
+
58
+ def build_agents(
59
+ session_id: str,
60
+ tool_names: List[str],
61
+ model: str = "gpt-3.5-turbo-0125",
62
+ temperature: float = 0.6,
63
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT
64
+ ):
65
+ chat_llm = ChatOpenAI(
66
+ model_name=model,
67
+ temperature=temperature,
68
+ base_url=GLOBAL_CONFIG.openai_api_base,
69
+ api_key=GLOBAL_CONFIG.openai_api_key,
70
+ streaming=True
71
+ )
72
+ tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS))
73
+ selected_tools = [tools[k] for k in tool_names]
74
+ logger.info(f"create agent, use tools: {selected_tools}")
75
+ agent = create_agent_executor(
76
+ agent_name="chat_memory",
77
+ session_id=session_id,
78
+ llm=chat_llm,
79
+ tools=selected_tools,
80
+ system_prompt=system_prompt
81
+ )
82
+ return agent
backend/construct/build_all.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logger import logger
2
+ from typing import Dict, Any, Union
3
+
4
+ import streamlit as st
5
+
6
+ from backend.constants.myscale_tables import MYSCALE_TABLES
7
+ from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING
8
+ from backend.construct.build_chains import build_retrieval_qa_with_sources_chain
9
+ from backend.construct.build_retriever_tool import create_retriever_tool
10
+ from backend.construct.build_retrievers import build_self_query_retriever, build_vector_sql_db_chain_retriever
11
+ from backend.types.chains_and_retrievers import ChainsAndRetrievers, MetadataColumn
12
+
13
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, \
14
+ SentenceTransformerEmbeddings
15
+
16
+
17
+ @st.cache_resource
18
+ def load_embedding_model_for_table(table_name: str) -> \
19
+ Union[SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings]:
20
+ with st.spinner(f"Loading embedding models for [{table_name}] ..."):
21
+ embeddings = MYSCALE_TABLES[table_name].emb_model()
22
+ return embeddings
23
+
24
+
25
+ @st.cache_resource
26
+ def load_embedding_models() -> Dict[str, Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings]]:
27
+ embedding_models = {}
28
+ for table in MYSCALE_TABLES:
29
+ embedding_models[table] = load_embedding_model_for_table(table)
30
+ return embedding_models
31
+
32
+
33
+ @st.cache_resource
34
+ def update_retriever_tools():
35
+ retrievers_tools = {}
36
+ for table in MYSCALE_TABLES:
37
+ logger.info(f"Updating retriever tools [<retriever>, <sql_retriever>] for table {table}")
38
+ retrievers_tools.update(
39
+ {
40
+ f"{table} + Self Querying": create_retriever_tool(
41
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["retriever"],
42
+ *MYSCALE_TABLES[table].tool_desc
43
+ ),
44
+ f"{table} + Vector SQL": create_retriever_tool(
45
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["sql_retriever"],
46
+ *MYSCALE_TABLES[table].tool_desc
47
+ ),
48
+ })
49
+ return retrievers_tools
50
+
51
+
52
+ @st.cache_resource
53
+ def build_chains_retriever_for_table(table_name: str) -> ChainsAndRetrievers:
54
+ metadata_col_attributes = MYSCALE_TABLES[table_name].metadata_col_attributes
55
+
56
+ self_query_retriever = build_self_query_retriever(table_name)
57
+ self_query_chain = build_retrieval_qa_with_sources_chain(
58
+ table_name=table_name,
59
+ retriever=self_query_retriever,
60
+ chain_name="Self Query Retriever"
61
+ )
62
+
63
+ vector_sql_retriever = build_vector_sql_db_chain_retriever(table_name)
64
+ vector_sql_chain = build_retrieval_qa_with_sources_chain(
65
+ table_name=table_name,
66
+ retriever=vector_sql_retriever,
67
+ chain_name="Vector SQL DB Retriever"
68
+ )
69
+
70
+ metadata_columns = [
71
+ MetadataColumn(
72
+ name=attribute.name,
73
+ desc=attribute.description,
74
+ type=attribute.type
75
+ )
76
+ for attribute in metadata_col_attributes
77
+ ]
78
+ return ChainsAndRetrievers(
79
+ metadata_columns=metadata_columns,
80
+ # for self query
81
+ retriever=self_query_retriever,
82
+ chain=self_query_chain,
83
+ # for vector sql
84
+ sql_retriever=vector_sql_retriever,
85
+ sql_chain=vector_sql_chain
86
+ )
87
+
88
+
89
+ @st.cache_resource
90
+ def build_chains_and_retrievers() -> Dict[str, Dict[str, Any]]:
91
+ chains_and_retrievers = {}
92
+ for table in MYSCALE_TABLES:
93
+ logger.info(f"Building chains, retrievers for table {table}")
94
+ chains_and_retrievers[table] = build_chains_retriever_for_table(table).to_dict()
95
+ return chains_and_retrievers
backend/construct/build_chains.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import LLMChain
2
+ from langchain.chat_models import ChatOpenAI
3
+ from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
4
+ from langchain.schema import BaseRetriever
5
+ import streamlit as st
6
+
7
+ from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
8
+ from backend.chains.stuff_documents import CustomStuffDocumentChain
9
+ from backend.constants.myscale_tables import MYSCALE_TABLES
10
+ from backend.constants.prompts import COMBINE_PROMPT
11
+ from backend.constants.variables import GLOBAL_CONFIG
12
+
13
+
14
+ def build_retrieval_qa_with_sources_chain(
15
+ table_name: str,
16
+ retriever: BaseRetriever,
17
+ chain_name: str = "<chain_name>"
18
+ ) -> CustomRetrievalQAWithSourcesChain:
19
+ with st.spinner(f'Building QA source chain named `{chain_name}` for MyScaleDB/{table_name} ...'):
20
+ # Assign ref_id for documents
21
+ custom_stuff_document_chain = CustomStuffDocumentChain(
22
+ llm_chain=LLMChain(
23
+ prompt=COMBINE_PROMPT,
24
+ llm=ChatOpenAI(
25
+ model_name=GLOBAL_CONFIG.chat_model,
26
+ openai_api_key=GLOBAL_CONFIG.openai_api_key,
27
+ temperature=0.6
28
+ ),
29
+ ),
30
+ document_prompt=MYSCALE_TABLES[table_name].doc_prompt,
31
+ document_variable_name="summaries",
32
+ )
33
+ chain = CustomRetrievalQAWithSourcesChain(
34
+ retriever=retriever,
35
+ combine_documents_chain=custom_stuff_document_chain,
36
+ return_source_documents=True,
37
+ max_tokens_limit=12000,
38
+ )
39
+ return chain
backend/construct/build_chat_bot.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.chat_bot.private_knowledge_base import ChatBotKnowledgeTable
2
+ from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION, CHAT_SESSION_MANAGER
3
+ import streamlit as st
4
+
5
+ from backend.constants.variables import GLOBAL_CONFIG, TABLE_EMBEDDINGS_MAPPING
6
+ from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
7
+ from backend.chat_bot.session_manager import SessionManager
8
+
9
+
10
+ def build_chat_knowledge_table():
11
+ if CHAT_KNOWLEDGE_TABLE not in st.session_state:
12
+ st.session_state[CHAT_KNOWLEDGE_TABLE] = ChatBotKnowledgeTable(
13
+ host=GLOBAL_CONFIG.myscale_host,
14
+ port=GLOBAL_CONFIG.myscale_port,
15
+ username=GLOBAL_CONFIG.myscale_user,
16
+ password=GLOBAL_CONFIG.myscale_password,
17
+ # embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["Wikipedia"],
18
+ embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["ArXiv Papers"],
19
+ parser_api_key=GLOBAL_CONFIG.untrusted_api,
20
+ )
21
+
22
+
23
+ def initialize_session_manager():
24
+ if CHAT_SESSION not in st.session_state:
25
+ st.session_state[CHAT_SESSION] = {
26
+ "session_id": "default",
27
+ "system_prompt": DEFAULT_SYSTEM_PROMPT,
28
+ }
29
+ if CHAT_SESSION_MANAGER not in st.session_state:
30
+ st.session_state[CHAT_SESSION_MANAGER] = SessionManager(
31
+ st.session_state,
32
+ host=GLOBAL_CONFIG.myscale_host,
33
+ port=GLOBAL_CONFIG.myscale_port,
34
+ username=GLOBAL_CONFIG.myscale_user,
35
+ password=GLOBAL_CONFIG.myscale_password,
36
+ )
backend/construct/build_retriever_tool.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List
3
+
4
+ from langchain.pydantic_v1 import BaseModel, Field
5
+ from langchain.schema import BaseRetriever, Document
6
+ from langchain.tools import Tool
7
+
8
+ from backend.chat_bot.json_decoder import CustomJSONEncoder
9
+
10
+
11
+ class RetrieverInput(BaseModel):
12
+ query: str = Field(description="query to look up in retriever")
13
+
14
+
15
+ def create_retriever_tool(
16
+ retriever: BaseRetriever,
17
+ tool_name: str,
18
+ description: str
19
+ ) -> Tool:
20
+ """Create a tool to do retrieval of documents.
21
+
22
+ Args:
23
+ retriever: The retriever to use for the retrieval
24
+ tool_name: The name for the tool. This will be passed to the language model,
25
+ so should be unique and somewhat descriptive.
26
+ description: The description for the tool. This will be passed to the language
27
+ model, so should be descriptive.
28
+
29
+ Returns:
30
+ Tool class to pass to an agent
31
+ """
32
+ def wrap(func):
33
+ def wrapped_retrieve(*args, **kwargs):
34
+ docs: List[Document] = func(*args, **kwargs)
35
+ return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
36
+
37
+ return wrapped_retrieve
38
+
39
+ return Tool(
40
+ name=tool_name,
41
+ description=description,
42
+ func=wrap(retriever.get_relevant_documents),
43
+ coroutine=retriever.aget_relevant_documents,
44
+ args_schema=RetrieverInput,
45
+ )
backend/construct/build_retrievers.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.chat_models import ChatOpenAI
3
+ from langchain.prompts.prompt import PromptTemplate
4
+ from langchain.retrievers.self_query.base import SelfQueryRetriever
5
+ from langchain.retrievers.self_query.myscale import MyScaleTranslator
6
+ from langchain.utilities.sql_database import SQLDatabase
7
+ from langchain.vectorstores import MyScaleSettings
8
+ from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
9
+ from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
10
+ from sqlalchemy import create_engine, MetaData
11
+
12
+ from backend.constants.myscale_tables import MYSCALE_TABLES
13
+ from backend.constants.prompts import MYSCALE_PROMPT
14
+ from backend.constants.variables import TABLE_EMBEDDINGS_MAPPING, GLOBAL_CONFIG
15
+ from backend.retrievers.vector_sql_output_parser import VectorSQLRetrieveOutputParser
16
+ from backend.vector_store.myscale_without_metadata import MyScaleWithoutMetadataJson
17
+ from logger import logger
18
+
19
+
20
+ @st.cache_resource
21
+ def build_self_query_retriever(table_name: str) -> SelfQueryRetriever:
22
+ with st.spinner(f"Building VectorStore for MyScaleDB/{table_name} ..."):
23
+ myscale_connection = {
24
+ "host": GLOBAL_CONFIG.myscale_host,
25
+ "port": GLOBAL_CONFIG.myscale_port,
26
+ "username": GLOBAL_CONFIG.myscale_user,
27
+ "password": GLOBAL_CONFIG.myscale_password,
28
+ }
29
+ myscale_settings = MyScaleSettings(
30
+ **myscale_connection,
31
+ database=MYSCALE_TABLES[table_name].database,
32
+ table=MYSCALE_TABLES[table_name].table,
33
+ column_map={
34
+ "id": "id",
35
+ "text": MYSCALE_TABLES[table_name].text_col_name,
36
+ "vector": MYSCALE_TABLES[table_name].vector_col_name,
37
+ # TODO refine MyScaleDB metadata in langchain.
38
+ "metadata": MYSCALE_TABLES[table_name].metadata_col_name
39
+ }
40
+ )
41
+ myscale_vector_store = MyScaleWithoutMetadataJson(
42
+ embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name],
43
+ config=myscale_settings,
44
+ must_have_cols=MYSCALE_TABLES[table_name].must_have_col_names
45
+ )
46
+
47
+ with st.spinner(f"Building SelfQueryRetriever for MyScaleDB/{table_name} ..."):
48
+ retriever: SelfQueryRetriever = SelfQueryRetriever.from_llm(
49
+ llm=ChatOpenAI(
50
+ model_name=GLOBAL_CONFIG.query_model,
51
+ base_url=GLOBAL_CONFIG.openai_api_base,
52
+ api_key=GLOBAL_CONFIG.openai_api_key,
53
+ temperature=0
54
+ ),
55
+ vectorstore=myscale_vector_store,
56
+ document_contents=MYSCALE_TABLES[table_name].table_contents,
57
+ metadata_field_info=MYSCALE_TABLES[table_name].metadata_col_attributes,
58
+ use_original_query=False,
59
+ structured_query_translator=MyScaleTranslator()
60
+ )
61
+ return retriever
62
+
63
+
64
+ @st.cache_resource
65
+ def build_vector_sql_db_chain_retriever(table_name: str) -> VectorSQLDatabaseChainRetriever:
66
+ """Get a group of relative docs from MyScaleDB"""
67
+ with st.spinner(f'Building Vector SQL Database Retriever for MyScaleDB/{table_name}...'):
68
+ if GLOBAL_CONFIG.myscale_enable_https == False:
69
+ engine = create_engine(
70
+ f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@'
71
+ f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}'
72
+ f'/{MYSCALE_TABLES[table_name].database}?protocol=http'
73
+ )
74
+ else:
75
+ engine = create_engine(
76
+ f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@'
77
+ f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}'
78
+ f'/{MYSCALE_TABLES[table_name].database}?protocol=https'
79
+ )
80
+ metadata = MetaData(bind=engine)
81
+ logger.info(f"{table_name} metadata is : {metadata}")
82
+ prompt = PromptTemplate(
83
+ input_variables=["input", "table_info", "top_k"],
84
+ template=MYSCALE_PROMPT,
85
+ )
86
+ # Custom `out_put_parser` rewrite search SQL, make it's possible to query custom column.
87
+ output_parser = VectorSQLRetrieveOutputParser.from_embeddings(
88
+ model=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name],
89
+ # rewrite columns needs be searched.
90
+ must_have_columns=MYSCALE_TABLES[table_name].must_have_col_names
91
+ )
92
+
93
+ # `db_chain` will generate a SQL
94
+ vector_sql_db_chain: VectorSQLDatabaseChain = VectorSQLDatabaseChain.from_llm(
95
+ llm=ChatOpenAI(
96
+ model_name=GLOBAL_CONFIG.query_model,
97
+ base_url=GLOBAL_CONFIG.openai_api_base,
98
+ api_key=GLOBAL_CONFIG.openai_api_key,
99
+ temperature=0
100
+ ),
101
+ prompt=prompt,
102
+ top_k=10,
103
+ return_direct=True,
104
+ db=SQLDatabase(
105
+ engine,
106
+ None,
107
+ metadata,
108
+ include_tables=[MYSCALE_TABLES[table_name].table],
109
+ max_string_length=1024
110
+ ),
111
+ sql_cmd_parser=output_parser, # TODO needs update `langchain`, fix return type.
112
+ native_format=True
113
+ )
114
+
115
+ # `retriever` can search a group of documents with `db_chain`
116
+ vector_sql_db_chain_retriever = VectorSQLDatabaseChainRetriever(
117
+ sql_db_chain=vector_sql_db_chain,
118
+ page_content_key=MYSCALE_TABLES[table_name].text_col_name
119
+ )
120
+ return vector_sql_db_chain_retriever
backend/retrievers/__init__.py ADDED
File without changes
backend/retrievers/self_query.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import pandas as pd
4
+ import streamlit as st
5
+ from langchain.retrievers import SelfQueryRetriever
6
+ from langchain_core.documents import Document
7
+ from langchain_core.runnables import RunnableConfig
8
+
9
+ from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
10
+ from backend.constants.myscale_tables import MYSCALE_TABLES
11
+ from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons
12
+ from backend.callbacks.self_query_callbacks import ChatDataSelfAskCallBackHandler, CustomSelfQueryRetrieverCallBackHandler
13
+ from ui.utils import display
14
+ from logger import logger
15
+
16
+
17
+ def process_self_query(selected_table, query_type):
18
+ place_holder = st.empty()
19
+ logger.info(
20
+ f"button-1: {RetrieverButtons.self_query_from_db}, "
21
+ f"button-2: {RetrieverButtons.self_query_with_llm}, "
22
+ f"content: {st.session_state.query_self}"
23
+ )
24
+ with place_holder.expander('🪵 Chat Log', expanded=True):
25
+ try:
26
+ if query_type == RetrieverButtons.self_query_from_db:
27
+ callback = CustomSelfQueryRetrieverCallBackHandler()
28
+ retriever: SelfQueryRetriever = \
29
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["retriever"]
30
+ config: RunnableConfig = {"callbacks": [callback]}
31
+
32
+ relevant_docs = retriever.invoke(
33
+ input=st.session_state.query_self,
34
+ config=config
35
+ )
36
+
37
+ callback.progress_bar.progress(
38
+ value=1.0, text="[Question -> LLM -> Query filter -> MyScaleDB -> Results] Done!✅")
39
+
40
+ st.markdown(f"### Self Query Results from `{selected_table}` \n"
41
+ f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n")
42
+ display(
43
+ dataframe=pd.DataFrame(
44
+ [{**d.metadata, 'abstract': d.page_content} for d in relevant_docs]
45
+ ),
46
+ columns_=MYSCALE_TABLES[selected_table].must_have_col_names
47
+ )
48
+ elif query_type == RetrieverButtons.self_query_with_llm:
49
+ # callback = CustomSelfQueryRetrieverCallBackHandler()
50
+ callback = ChatDataSelfAskCallBackHandler()
51
+ chain: CustomRetrievalQAWithSourcesChain = \
52
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["chain"]
53
+ chain_results = chain(st.session_state.query_self, callbacks=[callback])
54
+ callback.progress_bar.progress(
55
+ value=1.0,
56
+ text="[Question -> LLM -> Query filter -> MyScaleDB -> Related Results -> LLM -> LLM Answer] Done!✅"
57
+ )
58
+
59
+ documents_reference: List[Document] = chain_results["source_documents"]
60
+ st.markdown(f"### SelfQueryRetriever Results from `{selected_table}` \n"
61
+ f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n")
62
+ display(
63
+ pd.DataFrame(
64
+ [{**d.metadata, 'abstract': d.page_content} for d in documents_reference]
65
+ )
66
+ )
67
+ st.markdown(
68
+ f"### Answer from LLM \n"
69
+ f"> The response of the LLM when given the `SelfQueryRetriever` results. \n\n"
70
+ )
71
+ st.write(chain_results['answer'])
72
+ st.markdown(
73
+ f"### References from `{selected_table}`\n"
74
+ f"> Here shows that which documents used by LLM \n\n"
75
+ )
76
+ if len(chain_results['sources']) == 0:
77
+ st.write("No documents is used by LLM.")
78
+ else:
79
+ display(
80
+ dataframe=pd.DataFrame(
81
+ [{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']]
82
+ ),
83
+ columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names,
84
+ index='ref_id'
85
+ )
86
+ st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
87
+ except Exception as e:
88
+ st.write('Oops 😵 Something bad happened...')
89
+ raise e
backend/retrievers/vector_sql_output_parser.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+
3
+ from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
4
+
5
+
6
+ class VectorSQLRetrieveOutputParser(VectorSQLOutputParser):
7
+ """Based on VectorSQLOutputParser
8
+ It also modify the SQL to get all columns
9
+ """
10
+ must_have_columns: List[str]
11
+
12
+ @property
13
+ def _type(self) -> str:
14
+ return "vector_sql_retrieve_custom"
15
+
16
+ def parse(self, text: str) -> Dict[str, Any]:
17
+ text = text.strip()
18
+ start = text.upper().find("SELECT")
19
+ if start >= 0:
20
+ end = text.upper().find("FROM")
21
+ text = text.replace(
22
+ text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns))
23
+ return super().parse(text)
backend/retrievers/vector_sql_query.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import pandas as pd
4
+ import streamlit as st
5
+ from langchain.schema import Document
6
+ from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
7
+
8
+ from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
9
+ from backend.constants.myscale_tables import MYSCALE_TABLES
10
+ from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons
11
+ from backend.callbacks.vector_sql_callbacks import VectorSQLSearchDBCallBackHandler, VectorSQLSearchLLMCallBackHandler
12
+ from ui.utils import display
13
+ from logger import logger
14
+
15
+
16
+ def process_sql_query(selected_table: str, query_type: str):
17
+ place_holder = st.empty()
18
+ logger.info(
19
+ f"button-1: {st.session_state[RetrieverButtons.vector_sql_query_from_db]}, "
20
+ f"button-2: {st.session_state[RetrieverButtons.vector_sql_query_with_llm]}, "
21
+ f"table: {selected_table}, "
22
+ f"content: {st.session_state.query_sql}"
23
+ )
24
+ with place_holder.expander('🪵 Query Log', expanded=True):
25
+ try:
26
+ if query_type == RetrieverButtons.vector_sql_query_from_db:
27
+ callback = VectorSQLSearchDBCallBackHandler()
28
+ vector_sql_retriever: VectorSQLDatabaseChainRetriever = \
29
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_retriever"]
30
+ relevant_docs: List[Document] = vector_sql_retriever.get_relevant_documents(
31
+ query=st.session_state.query_sql,
32
+ callbacks=[callback]
33
+ )
34
+
35
+ callback.progress_bar.progress(
36
+ value=1.0,
37
+ text="[Question -> LLM -> SQL Statement -> MyScaleDB -> Results] Done! ✅"
38
+ )
39
+
40
+ st.markdown(f"### Vector Search Results from `{selected_table}` \n"
41
+ f"> Here we get documents from MyScaleDB with given sql statement \n\n")
42
+ display(
43
+ pd.DataFrame(
44
+ [{**d.metadata, 'abstract': d.page_content} for d in relevant_docs]
45
+ )
46
+ )
47
+ elif query_type == RetrieverButtons.vector_sql_query_with_llm:
48
+ callback = VectorSQLSearchLLMCallBackHandler(table=selected_table)
49
+ vector_sql_chain: CustomRetrievalQAWithSourcesChain = \
50
+ st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_chain"]
51
+ chain_results = vector_sql_chain(
52
+ inputs=st.session_state.query_sql,
53
+ callbacks=[callback]
54
+ )
55
+
56
+ callback.progress_bar.progress(
57
+ value=1.0,
58
+ text="[Question -> LLM -> SQL Statement -> MyScaleDB -> "
59
+ "(Question,Results) -> LLM -> Results] Done! ✅"
60
+ )
61
+
62
+ documents_reference: List[Document] = chain_results["source_documents"]
63
+ st.markdown(f"### Vector Search Results from `{selected_table}` \n"
64
+ f"> Here we get documents from MyScaleDB with given sql statement \n\n")
65
+ display(
66
+ pd.DataFrame(
67
+ [{**d.metadata, 'abstract': d.page_content} for d in documents_reference]
68
+ )
69
+ )
70
+ st.markdown(
71
+ f"### Answer from LLM \n"
72
+ f"> The response of the LLM when given the vector search results. \n\n"
73
+ )
74
+ st.write(chain_results['answer'])
75
+ st.markdown(
76
+ f"### References from `{selected_table}`\n"
77
+ f"> Here shows that which documents used by LLM \n\n"
78
+ )
79
+ if len(chain_results['sources']) == 0:
80
+ st.write("No documents is used by LLM.")
81
+ else:
82
+ display(
83
+ dataframe=pd.DataFrame(
84
+ [{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']]
85
+ ),
86
+ columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names,
87
+ index='ref_id'
88
+ )
89
+ else:
90
+ raise NotImplementedError(f"Unsupported query type: {query_type}")
91
+ st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
92
+ except Exception as e:
93
+ st.write('Oops 😵 Something bad happened...')
94
+ raise e
95
+
backend/types/__init__.py ADDED
File without changes
backend/types/chains_and_retrievers.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from dataclasses import dataclass
3
+ from typing import List, Any
4
+ from langchain.retrievers import SelfQueryRetriever
5
+ from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
6
+
7
+ from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
8
+
9
+
10
+ @dataclass
11
+ class MetadataColumn:
12
+ name: str
13
+ desc: str
14
+ type: str
15
+
16
+
17
+ @dataclass
18
+ class ChainsAndRetrievers:
19
+ metadata_columns: List[MetadataColumn]
20
+ retriever: SelfQueryRetriever
21
+ chain: CustomRetrievalQAWithSourcesChain
22
+ sql_retriever: VectorSQLDatabaseChainRetriever
23
+ sql_chain: CustomRetrievalQAWithSourcesChain
24
+
25
+ def to_dict(self) -> Dict[str, Any]:
26
+ return {
27
+ "metadata_columns": self.metadata_columns,
28
+ "retriever": self.retriever,
29
+ "chain": self.chain,
30
+ "sql_retriever": self.sql_retriever,
31
+ "sql_chain": self.sql_chain
32
+ }
33
+
34
+
backend/types/global_config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class GlobalConfig:
7
+ openai_api_base: Optional[str] = ""
8
+ openai_api_key: Optional[str] = ""
9
+
10
+ auth0_client_id: Optional[str] = ""
11
+ auth0_domain: Optional[str] = ""
12
+
13
+ myscale_user: Optional[str] = ""
14
+ myscale_password: Optional[str] = ""
15
+ myscale_host: Optional[str] = ""
16
+ myscale_port: Optional[int] = 443
17
+
18
+ query_model: Optional[str] = ""
19
+ chat_model: Optional[str] = ""
20
+
21
+ untrusted_api: Optional[str] = ""
22
+ myscale_enable_https: Optional[bool] = True
backend/types/table_config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ from langchain.chains.query_constructor.schema import AttributeInfo
3
+ from langchain.prompts import PromptTemplate
4
+ from dataclasses import dataclass
5
+ from typing import List
6
+
7
+
8
+ @dataclass
9
+ class TableConfig:
10
+ database: str
11
+ table: str
12
+ table_contents: str
13
+ # column names
14
+ must_have_col_names: List[str]
15
+ vector_col_name: str
16
+ text_col_name: str
17
+ metadata_col_name: str
18
+ # hint for UI
19
+ hint: Callable
20
+ hint_sql: Callable
21
+ # for langchain
22
+ doc_prompt: PromptTemplate
23
+ metadata_col_attributes: List[AttributeInfo]
24
+ emb_model: Callable
25
+ tool_desc: tuple
backend/vector_store/__init__.py ADDED
File without changes
backend/vector_store/myscale_without_metadata.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, List
2
+
3
+ from langchain.docstore.document import Document
4
+ from langchain.embeddings.base import Embeddings
5
+ from langchain.vectorstores.myscale import MyScale, MyScaleSettings
6
+
7
+ from logger import logger
8
+
9
+
10
+ class MyScaleWithoutMetadataJson(MyScale):
11
+ def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [],
12
+ **kwargs: Any) -> None:
13
+ try:
14
+ super().__init__(embedding, config, **kwargs)
15
+ except Exception as e:
16
+ logger.error(e)
17
+ self.must_have_cols: List[str] = must_have_cols
18
+
19
+ def _build_qstr(
20
+ self, q_emb: List[float], topk: int, where_str: Optional[str] = None
21
+ ) -> str:
22
+ q_emb_str = ",".join(map(str, q_emb))
23
+ if where_str:
24
+ where_str = f"PREWHERE {where_str}"
25
+ else:
26
+ where_str = ""
27
+
28
+ q_str = f"""
29
+ SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)}
30
+ FROM {self.config.database}.{self.config.table}
31
+ {where_str}
32
+ ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
33
+ AS dist {self.dist_order}
34
+ LIMIT {topk}
35
+ """
36
+ return q_str
37
+
38
+ def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None,
39
+ **kwargs: Any) -> List[Document]:
40
+ q_str = self._build_qstr(embedding, k, where_str)
41
+ try:
42
+ return [
43
+ Document(
44
+ page_content=r[self.config.column_map["text"]],
45
+ metadata={k: r[k] for k in self.must_have_cols},
46
+ )
47
+ for r in self.client.query(q_str).named_results()
48
+ ]
49
+ except Exception as e:
50
+ logger.error(
51
+ f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
52
+ return []
logger.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ def setup_logger():
5
+ logger_ = logging.getLogger('chat-data')
6
+ logger_.setLevel(logging.INFO)
7
+ if not logger_.handlers:
8
+ console_handler = logging.StreamHandler()
9
+ console_handler.setLevel(logging.INFO)
10
+ formatter = logging.Formatter(
11
+ '%(asctime)s - %(filename)s - %(funcName)s - %(levelname)s - %(message)s - [Thread ID: %(thread)d]'
12
+ )
13
+ console_handler.setFormatter(formatter)
14
+ logger_.addHandler(console_handler)
15
+ return logger_
16
+
17
+
18
+ logger = setup_logger()
requirements.txt CHANGED
@@ -1,15 +1,17 @@
1
- langchain @ git+https://github.com/myscale/langchain.git@preview#egg=langchain&subdirectory=libs/langchain
2
- langchain-experimental @ git+https://github.com/myscale/langchain.git@preview#egg=langchain-experimental&subdirectory=libs/experimental
3
- # https://github.com/PromtEngineer/localGPT/issues/722
4
- sentence_transformers==2.2.2
 
 
5
  InstructorEmbedding
6
  pandas
7
- sentence_transformers
8
- streamlit==1.25
9
  streamlit-auth0-component
10
  altair==4.2.2
11
  clickhouse-connect
12
- openai==0.28
13
  lark
14
  tiktoken
15
  sql-formatter
 
1
+ langchain==0.2.1
2
+ langchain-community==0.2.1
3
+ langchain-core==0.2.1
4
+ langchain-experimental==0.0.59
5
+ langchain-openai==0.1.7
6
+ sentence-transformers==2.2.2
7
  InstructorEmbedding
8
  pandas
9
+ streamlit
10
+ streamlit-extras
11
  streamlit-auth0-component
12
  altair==4.2.2
13
  clickhouse-connect
14
+ openai==1.35.3
15
  lark
16
  tiktoken
17
  sql-formatter
ui/__init__.py ADDED
File without changes
ui/chat_page.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+
4
+ import pandas as pd
5
+ import streamlit as st
6
+ from langchain_core.messages import HumanMessage, FunctionMessage
7
+ from streamlit.delta_generator import DeltaGenerator
8
+
9
+ from backend.chat_bot.json_decoder import CustomJSONDecoder
10
+ from backend.constants.streamlit_keys import CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, \
11
+ EL_UPLOAD_FILES_STATUS, USER_PRIVATE_FILES, EL_BUILD_KB_WITH_FILES, \
12
+ EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \
13
+ USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \
14
+ CHAT_KNOWLEDGE_TABLE, EL_UPLOAD_FILES, EL_SELECTED_KBS
15
+ from backend.constants.variables import DIVIDER_HTML, USER_NAME, RETRIEVER_TOOLS
16
+ from backend.construct.build_chat_bot import build_chat_knowledge_table, initialize_session_manager
17
+ from backend.chat_bot.chat import refresh_sessions, on_session_change_submit, refresh_agent, \
18
+ create_private_knowledge_base_as_tool, \
19
+ remove_private_knowledge_bases, add_file, clear_files, clear_history, back_to_main, on_chat_submit
20
+
21
+
22
+ def render_session_manager():
23
+ with st.expander("🤖 Session Management"):
24
+ if CHAT_CURRENT_USER_SESSIONS not in st.session_state:
25
+ refresh_sessions()
26
+ st.markdown("Here you can update `session_id` and `system_prompt`")
27
+ st.markdown("- Click empty row to add a new item")
28
+ st.markdown("- If needs to delete an item, just click it and press `DEL` key")
29
+ st.markdown("- Don't forget to submit your change.")
30
+
31
+ st.data_editor(
32
+ data=st.session_state[CHAT_CURRENT_USER_SESSIONS],
33
+ num_rows="dynamic",
34
+ key="session_editor",
35
+ use_container_width=True,
36
+ )
37
+ st.button("⏫ Submit", on_click=on_session_change_submit, type="primary")
38
+
39
+
40
+ def render_session_selection():
41
+ with st.expander("✅ Session Selection", expanded=True):
42
+ st.selectbox(
43
+ "Choose a `session` to chat",
44
+ options=st.session_state[CHAT_CURRENT_USER_SESSIONS],
45
+ index=None,
46
+ key=EL_SESSION_SELECTOR,
47
+ format_func=lambda x: x["session_id"],
48
+ on_change=refresh_agent,
49
+ )
50
+
51
+
52
+ def render_files_manager():
53
+ with st.expander("📃 **Upload your personal files**", expanded=False):
54
+ st.markdown("- Files will be parsed by [Unstructured API](https://unstructured.io/api-key).")
55
+ st.markdown("- All files will be converted into vectors and stored in [MyScaleDB](https://myscale.com/).")
56
+ st.file_uploader(label="⏫ **Upload files**", key=EL_UPLOAD_FILES, accept_multiple_files=True)
57
+ # st.markdown("### Uploaded Files")
58
+ st.dataframe(
59
+ data=st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(st.session_state[USER_NAME]),
60
+ use_container_width=True,
61
+ )
62
+ st.session_state[EL_UPLOAD_FILES_STATUS] = st.empty()
63
+ col_1, col_2 = st.columns(2)
64
+ with col_1:
65
+ st.button(label="Upload files", on_click=add_file)
66
+ with col_2:
67
+ st.button(label="Clear all files and tools", on_click=clear_files)
68
+
69
+
70
+ def _render_create_personal_knowledge_bases(div: DeltaGenerator):
71
+ with div:
72
+ st.markdown("- If you haven't upload your personal files, please upload them first.")
73
+ st.markdown("- Select some **files** to build your `personal knowledge base`.")
74
+ st.markdown("- Once the your `personal knowledge base` is built, "
75
+ "it will answer your questions using information from your personal **files**.")
76
+ st.multiselect(
77
+ label="⚡️Select some files to build a **personal knowledge base**",
78
+ options=st.session_state[USER_PRIVATE_FILES],
79
+ placeholder="You should upload some files first",
80
+ key=EL_BUILD_KB_WITH_FILES,
81
+ format_func=lambda x: x["file_name"],
82
+ )
83
+ st.text_input(
84
+ label="⚡️Personal knowledge base name",
85
+ value="get_relevant_documents",
86
+ key=EL_PERSONAL_KB_NAME
87
+ )
88
+ st.text_input(
89
+ label="⚡️Personal knowledge base description",
90
+ value="Searches from some personal files.",
91
+ key=EL_PERSONAL_KB_DESCRIPTION,
92
+ )
93
+ st.button(
94
+ label="Build 🔧",
95
+ on_click=create_private_knowledge_base_as_tool
96
+ )
97
+
98
+
99
+ def _render_remove_personal_knowledge_bases(div: DeltaGenerator):
100
+ with div:
101
+ st.markdown("> Here is all your personal knowledge bases.")
102
+ if USER_PERSONAL_KNOWLEDGE_BASES in st.session_state and len(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]) > 0:
103
+ st.dataframe(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES])
104
+ else:
105
+ st.warning("You don't have any personal knowledge bases, please create a new one.")
106
+ st.multiselect(
107
+ label="Choose a personal knowledge base to delete",
108
+ placeholder="Choose a personal knowledge base to delete",
109
+ options=st.session_state[USER_PERSONAL_KNOWLEDGE_BASES],
110
+ format_func=lambda x: x["tool_name"],
111
+ key=EL_PERSONAL_KB_NEEDS_REMOVE,
112
+ )
113
+ st.button("Delete", on_click=remove_private_knowledge_bases, type="primary")
114
+
115
+
116
+ def render_personal_tools_build():
117
+ with st.expander("🔨 **Build your personal knowledge base**", expanded=True):
118
+ create_new_kb, kb_manager = st.tabs(["Create personal knowledge base", "Personal knowledge base management"])
119
+ _render_create_personal_knowledge_bases(create_new_kb)
120
+ _render_remove_personal_knowledge_bases(kb_manager)
121
+
122
+
123
+ def render_knowledge_base_selector():
124
+ with st.expander("🙋 **Select some knowledge bases to query**", expanded=True):
125
+ st.markdown("- Knowledge bases come in two types: `public` and `private`.")
126
+ st.markdown("- All users can access our `public` knowledge bases.")
127
+ st.markdown("- Only you can access your `personal` knowledge bases.")
128
+ options = st.session_state[RETRIEVER_TOOLS].keys()
129
+ if AVAILABLE_RETRIEVAL_TOOLS in st.session_state:
130
+ options = st.session_state[AVAILABLE_RETRIEVAL_TOOLS]
131
+ st.multiselect(
132
+ label="Select some knowledge base tool",
133
+ placeholder="Please select some knowledge bases to query",
134
+ options=options,
135
+ default=["Wikipedia + Self Querying"],
136
+ key=EL_SELECTED_KBS,
137
+ on_change=refresh_agent,
138
+ )
139
+
140
+
141
+ def chat_page():
142
+ # initialize resources
143
+ build_chat_knowledge_table()
144
+ initialize_session_manager()
145
+
146
+ # render sidebar
147
+ with st.sidebar:
148
+ left, middle, right = st.columns([1, 1, 2])
149
+ with left:
150
+ st.button(label="↩️ Log Out", help="log out and back to main page", on_click=back_to_main)
151
+ with right:
152
+ st.markdown(f"👤 `{st.session_state[USER_NAME]}`")
153
+ st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
154
+ render_session_manager()
155
+ render_session_selection()
156
+ render_files_manager()
157
+ render_personal_tools_build()
158
+ render_knowledge_base_selector()
159
+
160
+ # render chat history
161
+ if "agent" not in st.session_state:
162
+ refresh_agent()
163
+ for msg in st.session_state.agent.memory.chat_memory.messages:
164
+ speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
165
+ if isinstance(msg, FunctionMessage):
166
+ with st.chat_message(name="from knowledge base", avatar="📚"):
167
+ st.write(
168
+ f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
169
+ )
170
+ st.write("Retrieved from knowledge base:")
171
+ try:
172
+ st.dataframe(
173
+ pd.DataFrame.from_records(
174
+ json.loads(msg.content, cls=CustomJSONDecoder)
175
+ ),
176
+ use_container_width=True,
177
+ )
178
+ except Exception as e:
179
+ st.warning(e)
180
+ st.write(msg.content)
181
+ else:
182
+ if len(msg.content) > 0:
183
+ with st.chat_message(speaker):
184
+ # print(type(msg), msg.dict())
185
+ st.write(
186
+ f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
187
+ )
188
+ st.write(f"{msg.content}")
189
+ st.session_state["next_round"] = st.empty()
190
+ from streamlit import _bottom
191
+ with _bottom:
192
+ col1, col2 = st.columns([1, 16])
193
+ with col1:
194
+ st.button("🗑️", help="Clean chat history", on_click=clear_history, type="secondary")
195
+ with col2:
196
+ st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
ui/home.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+
3
+ from streamlit_extras.add_vertical_space import add_vertical_space
4
+ from streamlit_extras.card import card
5
+ from streamlit_extras.colored_header import colored_header
6
+ from streamlit_extras.mention import mention
7
+ from streamlit_extras.tags import tagger_component
8
+
9
+ from logger import logger
10
+ import os
11
+
12
+ import streamlit as st
13
+ from auth0_component import login_button
14
+
15
+ from backend.constants.variables import JUMP_QUERY_ASK, USER_INFO, USER_NAME, DIVIDER_HTML, DIVIDER_THIN_HTML
16
+ from streamlit_extras.let_it_rain import rain
17
+
18
+
19
+ def render_home():
20
+ render_home_header()
21
+ # st.divider()
22
+ # st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
23
+ add_vertical_space(5)
24
+ render_home_content()
25
+ # st.divider()
26
+ st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
27
+ render_home_footer()
28
+
29
+
30
+ def render_home_header():
31
+ logger.info("render home header")
32
+ st.header("ChatData - Your Intelligent Assistant")
33
+ st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
34
+ st.markdown("> [ChatData](https://github.com/myscale/ChatData) \
35
+ is developed by [MyScale](https://myscale.com/), \
36
+ it's an integration of [LangChain](https://www.langchain.com/) \
37
+ and [MyScaleDB](https://github.com/myscale/myscaledb)")
38
+
39
+ tagger_component(
40
+ "Keywords:",
41
+ ["MyScaleDB", "LangChain", "VectorSearch", "ChatBot", "GPT", "arxiv", "wikipedia", "Personal Knowledge Base 📚"],
42
+ color_name=["darkslateblue", "green", "orange", "darkslategrey", "red", "crimson", "darkcyan", "darkgrey"],
43
+ )
44
+ text, col1, col2, col3, _ = st.columns([1, 1, 1, 1, 4])
45
+ with text:
46
+ st.markdown("Related:")
47
+ with col1.container():
48
+ mention(
49
+ label="streamlit",
50
+ icon="streamlit",
51
+ url="https://streamlit.io/",
52
+ write=True
53
+ )
54
+ with col2.container():
55
+ mention(
56
+ label="langchain",
57
+ icon="🦜🔗",
58
+ url="https://www.langchain.com/",
59
+ write=True
60
+ )
61
+ with col3.container():
62
+ mention(
63
+ label="streamlit-extras",
64
+ icon="🪢",
65
+ url="https://github.com/arnaudmiribel/streamlit-extras",
66
+ write=True
67
+ )
68
+
69
+
70
+ def _render_self_query_chain_content():
71
+ col1, col2 = st.columns([1, 1], gap='large')
72
+ with col1.container():
73
+ st.image(image='../assets/home_page_background_1.png',
74
+ caption=None,
75
+ width=None,
76
+ use_column_width=True,
77
+ clamp=False,
78
+ channels="RGB",
79
+ output_format="PNG")
80
+ with col2.container():
81
+ st.header("VectorSearch & SelfQuery with Sources")
82
+ st.info("In this sample, you will learn how **LangChain** integrates with **MyScaleDB**.")
83
+ st.markdown("""This example demonstrates two methods for integrating MyScale into LangChain: [Vector SQL](https://api.python.langchain.com/en/latest/sql/langchain_experimental.sql.vector_sql.VectorSQLDatabaseChain.html) and [Self-querying retriever](https://python.langchain.com/v0.2/docs/integrations/retrievers/self_query/myscale_self_query/). For each method, you can choose one of the following options:
84
+
85
+ 1. `Retrieve from MyScaleDB ➡️` - The LLM (GPT) converts user queries into SQL statements with vector search, executes these searches in MyScaleDB, and retrieves relevant content.
86
+
87
+ 2. `Retrieve and answer with LLM ➡️` - After retrieving relevant content from MyScaleDB, the user query along with the retrieved content is sent to the LLM (GPT), which then provides a comprehensive answer.""")
88
+ add_vertical_space(3)
89
+ _, middle, _ = st.columns([2, 1, 2], gap='small')
90
+ with middle.container():
91
+ st.session_state[JUMP_QUERY_ASK] = st.button("Try sample", use_container_width=False, type="secondary")
92
+
93
+
94
+ def _render_chat_bot_content():
95
+ col1, col2 = st.columns(2, gap='large')
96
+ with col1.container():
97
+ st.image(image='../assets/home_page_background_2.png',
98
+ caption=None,
99
+ width=None,
100
+ use_column_width=True,
101
+ clamp=False,
102
+ channels="RGB",
103
+ output_format="PNG")
104
+ with col2.container():
105
+ st.header("Chat Bot")
106
+ st.info("Now you can try our chatbot, this chatbot is built with MyScale and LangChain.")
107
+ st.markdown("- You need to log in. We use `user_name` to identify each customer.")
108
+ st.markdown("- You can upload your own PDF files and build your own knowledge base. \
109
+ (This is just a sample application. Please do not upload important or confidential files.)")
110
+ st.markdown("- A default session will be assigned as your initial chat session. \
111
+ You can create and switch to other sessions to jump between different chat conversations.")
112
+ add_vertical_space(1)
113
+ _, middle, _ = st.columns([1, 2, 1], gap='small')
114
+ with middle.container():
115
+ if USER_NAME not in st.session_state:
116
+ login_button(clientId=os.environ["AUTH0_CLIENT_ID"],
117
+ domain=os.environ["AUTH0_DOMAIN"],
118
+ key="auth0")
119
+ # if user_info:
120
+ # user_name = user_info.get("nickname", "default") + "_" + user_info.get("email", "null")
121
+ # st.session_state[USER_NAME] = user_name
122
+ # print(user_info)
123
+
124
+
125
+ def render_home_content():
126
+ logger.info("render home content")
127
+ _render_self_query_chain_content()
128
+ add_vertical_space(3)
129
+ _render_chat_bot_content()
130
+
131
+
132
+ def render_home_footer():
133
+ logger.info("render home footer")
134
+ st.write(
135
+ "Please follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!"
136
+ )
137
+ st.write(
138
+ "For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
139
+ st.write("Our [privacy policy](https://myscale.com/privacy/), [terms of service](https://myscale.com/terms/)")
140
+
141
+ # st.write(
142
+ # "Recommended to use the standalone version of Chat-Data, "
143
+ # "available [here](https://myscale-chatdata.hf.space/)."
144
+ # )
145
+
146
+ if st.session_state.auth0 is not None:
147
+ st.session_state[USER_INFO] = dict(st.session_state.auth0)
148
+ if 'email' in st.session_state[USER_INFO]:
149
+ email = st.session_state[USER_INFO]["email"]
150
+ else:
151
+ email = f"{st.session_state[USER_INFO]['nickname']}@{st.session_state[USER_INFO]['sub']}"
152
+ st.session_state["user_name"] = email
153
+ del st.session_state.auth0
154
+ st.rerun()
155
+ if st.session_state.jump_query_ask:
156
+ st.rerun()
ui/retrievers.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_extras.add_vertical_space import add_vertical_space
3
+
4
+ from backend.constants.myscale_tables import MYSCALE_TABLES
5
+ from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, RetrieverButtons
6
+ from backend.retrievers.self_query import process_self_query
7
+ from backend.retrievers.vector_sql_query import process_sql_query
8
+ from backend.constants.variables import JUMP_QUERY_ASK, USER_NAME, USER_INFO
9
+
10
+
11
+ def back_to_main():
12
+ if USER_INFO in st.session_state:
13
+ del st.session_state[USER_INFO]
14
+ if USER_NAME in st.session_state:
15
+ del st.session_state[USER_NAME]
16
+ if JUMP_QUERY_ASK in st.session_state:
17
+ del st.session_state[JUMP_QUERY_ASK]
18
+
19
+
20
+ def _render_table_selector() -> str:
21
+ col1, col2 = st.columns(2)
22
+ with col1:
23
+ selected_table = st.selectbox(
24
+ label='Each public knowledge base is stored in a MyScaleDB table, which is read-only.',
25
+ options=MYSCALE_TABLES.keys(),
26
+ )
27
+ MYSCALE_TABLES[selected_table].hint()
28
+ with col2:
29
+ add_vertical_space(1)
30
+ st.info(f"Here is your selected public knowledge base schema in MyScaleDB",
31
+ icon='📚')
32
+ MYSCALE_TABLES[selected_table].hint_sql()
33
+
34
+ return selected_table
35
+
36
+
37
+ def render_retrievers():
38
+ st.button("⬅️ Back", key="back_sql", on_click=back_to_main)
39
+ st.subheader('Please choose a public knowledge base to search.')
40
+ selected_table = _render_table_selector()
41
+
42
+ tab_sql, tab_self_query = st.tabs(
43
+ tabs=['Vector SQL', 'Self-querying Retriever']
44
+ )
45
+
46
+ with tab_sql:
47
+ render_tab_sql(selected_table)
48
+
49
+ with tab_self_query:
50
+ render_tab_self_query(selected_table)
51
+
52
+
53
+ def render_tab_sql(selected_table: str):
54
+ st.warning(
55
+ "When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
56
+ "the metadata we provide. This table allows filters to be established on the following metadata fields:",
57
+ icon="⚠️")
58
+ st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])
59
+
60
+ cols = st.columns([8, 3, 3, 2])
61
+ cols[0].text_input("Input your question:", key='query_sql')
62
+ with cols[1].container():
63
+ add_vertical_space(2)
64
+ st.button("Retrieve from MyScaleDB ➡️", key=RetrieverButtons.vector_sql_query_from_db)
65
+ with cols[2].container():
66
+ add_vertical_space(2)
67
+ st.button("Retrieve and answer with LLM ➡️", key=RetrieverButtons.vector_sql_query_with_llm)
68
+
69
+ if st.session_state[RetrieverButtons.vector_sql_query_from_db]:
70
+ process_sql_query(selected_table, RetrieverButtons.vector_sql_query_from_db)
71
+
72
+ if st.session_state[RetrieverButtons.vector_sql_query_with_llm]:
73
+ process_sql_query(selected_table, RetrieverButtons.vector_sql_query_with_llm)
74
+
75
+
76
+ def render_tab_self_query(selected_table):
77
+ st.warning(
78
+ "When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
79
+ "the metadata we provide. This table allows filters to be established on the following metadata fields:",
80
+ icon="⚠️")
81
+ st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])
82
+
83
+ cols = st.columns([8, 3, 3, 2])
84
+ cols[0].text_input("Input your question:", key='query_self')
85
+
86
+ with cols[1].container():
87
+ add_vertical_space(2)
88
+ st.button("Retrieve from MyScaleDB ➡️", key='search_self')
89
+ with cols[2].container():
90
+ add_vertical_space(2)
91
+ st.button("Retrieve and answer with LLM ➡️", key='ask_self')
92
+
93
+ if st.session_state.search_self:
94
+ process_self_query(selected_table, RetrieverButtons.self_query_from_db)
95
+
96
+ if st.session_state.ask_self:
97
+ process_self_query(selected_table, RetrieverButtons.self_query_with_llm)
ui/utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def display(dataframe, columns_=None, index=None):
5
+ if len(dataframe) > 0:
6
+ if index:
7
+ dataframe.set_index(index)
8
+ if columns_:
9
+ st.dataframe(dataframe[columns_])
10
+ else:
11
+ st.dataframe(dataframe)
12
+ else:
13
+ st.write(
14
+ "Sorry 😵 we didn't find any articles related to your query.\n\n"
15
+ "Maybe the LLM is too naughty that does not follow our instruction... \n\n"
16
+ "Please try again and use verbs that may match the datatype.",
17
+ unsafe_allow_html=True
18
+ )