Spaces:
Running
Running
Synced repo using 'sync_with_huggingface' Github Action
Browse files- .streamlit/secrets.example.toml +4 -1
- app.py +7 -9
- callbacks/arxiv_callbacks.py +4 -3
- chains/arxiv_chains.py +11 -8
- chat.py +6 -6
- lib/helper.py +43 -28
- lib/json_conv.py +5 -2
- lib/private_kb.py +2 -1
- lib/schemas.py +1 -1
- lib/sessions.py +14 -11
- login.py +9 -8
- prompts/arxiv_prompt.py +1 -1
.streamlit/secrets.example.toml
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
-
MYSCALE_HOST = "msc-4a9e710a.us-east-1.aws.staging.myscale.cloud"
|
2 |
MYSCALE_PORT = 443
|
3 |
MYSCALE_USER = "chatdata"
|
4 |
MYSCALE_PASSWORD = "myscale_rocks"
|
5 |
OPENAI_API_BASE = "https://api.openai.com/v1"
|
6 |
OPENAI_API_KEY = "<your-openai-key>"
|
|
|
|
|
|
|
|
1 |
+
MYSCALE_HOST = "msc-4a9e710a.us-east-1.aws.staging.myscale.cloud" # read-only database provided by MyScale
|
2 |
MYSCALE_PORT = 443
|
3 |
MYSCALE_USER = "chatdata"
|
4 |
MYSCALE_PASSWORD = "myscale_rocks"
|
5 |
OPENAI_API_BASE = "https://api.openai.com/v1"
|
6 |
OPENAI_API_KEY = "<your-openai-key>"
|
7 |
+
UNSTRUCTURED_API = "<your-unstructured-io-api>" # optional if you don't upload documents
|
8 |
+
AUTH0_DOMAIN = "<your-auth0-domain>" # optional if you don't user management
|
9 |
+
AUTH0_CLIENT_ID = "<your-auth0-client-id>" # optiona
|
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
import json
|
2 |
-
import time
|
3 |
import pandas as pd
|
4 |
from os import environ
|
5 |
import streamlit as st
|
@@ -13,10 +11,10 @@ from login import login, back_to_main
|
|
13 |
from lib.helper import build_tools, build_all, sel_map, display
|
14 |
|
15 |
|
16 |
-
|
17 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
18 |
|
19 |
-
st.set_page_config(page_title="ChatData",
|
|
|
20 |
st.markdown(
|
21 |
f"""
|
22 |
<style>
|
@@ -36,11 +34,12 @@ if login():
|
|
36 |
if "user_name" in st.session_state:
|
37 |
chat_page()
|
38 |
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
|
39 |
-
|
40 |
sel = st.selectbox('Choose the knowledge base you want to ask with:',
|
41 |
-
|
42 |
sel_map[sel]['hint']()
|
43 |
-
tab_sql, tab_self_query = st.tabs(
|
|
|
44 |
with tab_sql:
|
45 |
sel_map[sel]['hint_sql']()
|
46 |
st.text_input("Ask a question:", key='query_sql')
|
@@ -85,7 +84,6 @@ if login():
|
|
85 |
st.write('Oops π΅ Something bad happened...')
|
86 |
raise e
|
87 |
|
88 |
-
|
89 |
with tab_self_query:
|
90 |
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='π‘')
|
91 |
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
|
@@ -132,4 +130,4 @@ if login():
|
|
132 |
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
|
133 |
except Exception as e:
|
134 |
st.write('Oops π΅ Something bad happened...')
|
135 |
-
raise e
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
from os import environ
|
3 |
import streamlit as st
|
|
|
11 |
from lib.helper import build_tools, build_all, sel_map, display
|
12 |
|
13 |
|
|
|
14 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
15 |
|
16 |
+
st.set_page_config(page_title="ChatData",
|
17 |
+
page_icon="https://myscale.com/favicon.ico")
|
18 |
st.markdown(
|
19 |
f"""
|
20 |
<style>
|
|
|
34 |
if "user_name" in st.session_state:
|
35 |
chat_page()
|
36 |
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
|
37 |
+
|
38 |
sel = st.selectbox('Choose the knowledge base you want to ask with:',
|
39 |
+
options=['ArXiv Papers', 'Wikipedia'])
|
40 |
sel_map[sel]['hint']()
|
41 |
+
tab_sql, tab_self_query = st.tabs(
|
42 |
+
['Vector SQL', 'Self-Query Retrievers'])
|
43 |
with tab_sql:
|
44 |
sel_map[sel]['hint_sql']()
|
45 |
st.text_input("Ask a question:", key='query_sql')
|
|
|
84 |
st.write('Oops π΅ Something bad happened...')
|
85 |
raise e
|
86 |
|
|
|
87 |
with tab_self_query:
|
88 |
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='π‘')
|
89 |
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
|
|
|
130 |
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
|
131 |
except Exception as e:
|
132 |
st.write('Oops π΅ Something bad happened...')
|
133 |
+
raise e
|
callbacks/arxiv_callbacks.py
CHANGED
@@ -8,7 +8,6 @@ from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
|
8 |
StreamlitCallbackHandler,
|
9 |
)
|
10 |
from langchain.schema.output import LLMResult
|
11 |
-
from streamlit.delta_generator import DeltaGenerator
|
12 |
|
13 |
|
14 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
@@ -26,7 +25,8 @@ class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
|
26 |
self.progress_bar.progress(value=0.6, text="Searching in DB...")
|
27 |
if "repr" in outputs:
|
28 |
st.markdown("### Generated Filter")
|
29 |
-
st.markdown(
|
|
|
30 |
|
31 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
32 |
pass
|
@@ -88,7 +88,8 @@ class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
|
88 |
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
|
89 |
print(f"Vector SQL: {text}")
|
90 |
self.prog_value += self.prog_interval
|
91 |
-
self.progress_bar.progress(
|
|
|
92 |
|
93 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
94 |
cid = ".".join(serialized["id"])
|
|
|
8 |
StreamlitCallbackHandler,
|
9 |
)
|
10 |
from langchain.schema.output import LLMResult
|
|
|
11 |
|
12 |
|
13 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
|
|
25 |
self.progress_bar.progress(value=0.6, text="Searching in DB...")
|
26 |
if "repr" in outputs:
|
27 |
st.markdown("### Generated Filter")
|
28 |
+
st.markdown(
|
29 |
+
f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
|
30 |
|
31 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
32 |
pass
|
|
|
88 |
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
|
89 |
print(f"Vector SQL: {text}")
|
90 |
self.prog_value += self.prog_interval
|
91 |
+
self.progress_bar.progress(
|
92 |
+
value=self.prog_value, text="Searching in DB...")
|
93 |
|
94 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
95 |
cid = ".".join(serialized["id"])
|
chains/arxiv_chains.py
CHANGED
@@ -8,7 +8,6 @@ from langchain.callbacks.manager import (
|
|
8 |
CallbackManagerForChainRun,
|
9 |
)
|
10 |
from langchain.embeddings.base import Embeddings
|
11 |
-
from langchain.schema import BaseRetriever
|
12 |
from langchain.callbacks.manager import Callbacks
|
13 |
from langchain.schema.prompt_template import format_document
|
14 |
from langchain.docstore.document import Document
|
@@ -20,11 +19,12 @@ from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
|
|
20 |
|
21 |
logger = logging.getLogger()
|
22 |
|
|
|
23 |
class MyScaleWithoutMetadataJson(MyScale):
|
24 |
def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
|
25 |
super().__init__(embedding, config, **kwargs)
|
26 |
self.must_have_cols: List[str] = must_have_cols
|
27 |
-
|
28 |
def _build_qstr(
|
29 |
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
|
30 |
) -> str:
|
@@ -43,7 +43,7 @@ class MyScaleWithoutMetadataJson(MyScale):
|
|
43 |
LIMIT {topk}
|
44 |
"""
|
45 |
return q_str
|
46 |
-
|
47 |
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
|
48 |
q_str = self._build_qstr(embedding, k, where_str)
|
49 |
try:
|
@@ -55,9 +55,11 @@ class MyScaleWithoutMetadataJson(MyScale):
|
|
55 |
for r in self.client.query(q_str).named_results()
|
56 |
]
|
57 |
except Exception as e:
|
58 |
-
logger.error(
|
|
|
59 |
return []
|
60 |
|
|
|
61 |
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
62 |
"""Based on VectorSQLOutputParser
|
63 |
It also modify the SQL to get all columns
|
@@ -73,9 +75,11 @@ class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
|
73 |
start = text.upper().find("SELECT")
|
74 |
if start >= 0:
|
75 |
end = text.upper().find("FROM")
|
76 |
-
text = text.replace(
|
|
|
77 |
return super().parse(text)
|
78 |
|
|
|
79 |
class ArXivStuffDocumentChain(StuffDocumentsChain):
|
80 |
"""Combine arxiv documents with PDF reference number"""
|
81 |
|
@@ -172,8 +176,7 @@ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
|
|
172 |
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
|
173 |
sources.append(d)
|
174 |
ref_cnt += 1
|
175 |
-
|
176 |
-
|
177 |
result: Dict[str, Any] = {
|
178 |
self.answer_key: answer,
|
179 |
self.sources_answer_key: sources,
|
@@ -191,4 +194,4 @@ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
|
|
191 |
|
192 |
@property
|
193 |
def _chain_type(self) -> str:
|
194 |
-
return "arxiv_qa_with_sources_chain"
|
|
|
8 |
CallbackManagerForChainRun,
|
9 |
)
|
10 |
from langchain.embeddings.base import Embeddings
|
|
|
11 |
from langchain.callbacks.manager import Callbacks
|
12 |
from langchain.schema.prompt_template import format_document
|
13 |
from langchain.docstore.document import Document
|
|
|
19 |
|
20 |
logger = logging.getLogger()
|
21 |
|
22 |
+
|
23 |
class MyScaleWithoutMetadataJson(MyScale):
|
24 |
def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
|
25 |
super().__init__(embedding, config, **kwargs)
|
26 |
self.must_have_cols: List[str] = must_have_cols
|
27 |
+
|
28 |
def _build_qstr(
|
29 |
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
|
30 |
) -> str:
|
|
|
43 |
LIMIT {topk}
|
44 |
"""
|
45 |
return q_str
|
46 |
+
|
47 |
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
|
48 |
q_str = self._build_qstr(embedding, k, where_str)
|
49 |
try:
|
|
|
55 |
for r in self.client.query(q_str).named_results()
|
56 |
]
|
57 |
except Exception as e:
|
58 |
+
logger.error(
|
59 |
+
f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
60 |
return []
|
61 |
|
62 |
+
|
63 |
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
64 |
"""Based on VectorSQLOutputParser
|
65 |
It also modify the SQL to get all columns
|
|
|
75 |
start = text.upper().find("SELECT")
|
76 |
if start >= 0:
|
77 |
end = text.upper().find("FROM")
|
78 |
+
text = text.replace(
|
79 |
+
text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns))
|
80 |
return super().parse(text)
|
81 |
|
82 |
+
|
83 |
class ArXivStuffDocumentChain(StuffDocumentsChain):
|
84 |
"""Combine arxiv documents with PDF reference number"""
|
85 |
|
|
|
176 |
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
|
177 |
sources.append(d)
|
178 |
ref_cnt += 1
|
179 |
+
|
|
|
180 |
result: Dict[str, Any] = {
|
181 |
self.answer_key: answer,
|
182 |
self.sources_answer_key: sources,
|
|
|
194 |
|
195 |
@property
|
196 |
def _chain_type(self) -> str:
|
197 |
+
return "arxiv_qa_with_sources_chain"
|
chat.py
CHANGED
@@ -8,9 +8,6 @@ from lib.sessions import SessionManager
|
|
8 |
from lib.private_kb import PrivateKnowledgeBase
|
9 |
from langchain.schema import HumanMessage, FunctionMessage
|
10 |
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
11 |
-
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
12 |
-
StreamlitCallbackHandler,
|
13 |
-
)
|
14 |
from lib.json_conv import CustomJSONDecoder
|
15 |
|
16 |
from lib.helper import (
|
@@ -313,7 +310,8 @@ def chat_page():
|
|
313 |
key="b_tool_files",
|
314 |
format_func=lambda x: x["file_name"],
|
315 |
)
|
316 |
-
st.text_input(
|
|
|
317 |
st.text_input(
|
318 |
"Tool Description",
|
319 |
"Searches among user's private files and returns related documents",
|
@@ -359,14 +357,16 @@ def chat_page():
|
|
359 |
)
|
360 |
st.markdown("### Uploaded Files")
|
361 |
st.dataframe(
|
362 |
-
st.session_state.private_kb.list_files(
|
|
|
363 |
use_container_width=True,
|
364 |
)
|
365 |
col_1, col_2 = st.columns(2)
|
366 |
with col_1:
|
367 |
st.button("Add Files", on_click=add_file)
|
368 |
with col_2:
|
369 |
-
st.button("Clear Files and All Tools",
|
|
|
370 |
|
371 |
st.button("Clear Chat History", on_click=clear_history)
|
372 |
st.button("Logout", on_click=back_to_main)
|
|
|
8 |
from lib.private_kb import PrivateKnowledgeBase
|
9 |
from langchain.schema import HumanMessage, FunctionMessage
|
10 |
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
|
|
|
|
|
|
11 |
from lib.json_conv import CustomJSONDecoder
|
12 |
|
13 |
from lib.helper import (
|
|
|
310 |
key="b_tool_files",
|
311 |
format_func=lambda x: x["file_name"],
|
312 |
)
|
313 |
+
st.text_input(
|
314 |
+
"Tool Name", "get_relevant_documents", key="b_tool_name")
|
315 |
st.text_input(
|
316 |
"Tool Description",
|
317 |
"Searches among user's private files and returns related documents",
|
|
|
357 |
)
|
358 |
st.markdown("### Uploaded Files")
|
359 |
st.dataframe(
|
360 |
+
st.session_state.private_kb.list_files(
|
361 |
+
st.session_state.user_name),
|
362 |
use_container_width=True,
|
363 |
)
|
364 |
col_1, col_2 = st.columns(2)
|
365 |
with col_1:
|
366 |
st.button("Add Files", on_click=add_file)
|
367 |
with col_2:
|
368 |
+
st.button("Clear Files and All Tools",
|
369 |
+
on_click=clear_files)
|
370 |
|
371 |
st.button("Clear Chat History", on_click=clear_history)
|
372 |
st.button("Logout", on_click=back_to_main)
|
lib/helper.py
CHANGED
@@ -4,10 +4,8 @@ import time
|
|
4 |
import hashlib
|
5 |
from typing import Dict, Any, List, Tuple
|
6 |
import re
|
7 |
-
import pandas as pd
|
8 |
from os import environ
|
9 |
import streamlit as st
|
10 |
-
import datetime
|
11 |
from langchain.schema import BaseRetriever
|
12 |
from langchain.tools import Tool
|
13 |
from langchain.pydantic_v1 import BaseModel, Field
|
@@ -20,7 +18,7 @@ except ImportError:
|
|
20 |
from sqlalchemy.ext.declarative import declarative_base
|
21 |
from sqlalchemy.orm import sessionmaker
|
22 |
from clickhouse_sqlalchemy import (
|
23 |
-
|
24 |
)
|
25 |
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
26 |
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
@@ -43,12 +41,12 @@ from langchain.prompts.prompt import PromptTemplate
|
|
43 |
from langchain.prompts.chat import MessagesPlaceholder
|
44 |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
45 |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
46 |
-
from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage
|
47 |
SystemMessage, ChatMessage, ToolMessage
|
48 |
from langchain.memory import SQLChatMessageHistory
|
49 |
from langchain.memory.chat_message_histories.sql import \
|
50 |
-
|
51 |
-
from langchain.schema.messages import BaseMessage
|
52 |
# from langchain.agents.agent_toolkits import create_retriever_tool
|
53 |
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
54 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
@@ -73,7 +71,7 @@ UNSTRUCTURED_API = st.secrets['UNSTRUCTURED_API']
|
|
73 |
|
74 |
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
75 |
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
76 |
-
|
77 |
DEFAULT_SYSTEM_PROMPT = (
|
78 |
"Do your best to answer the questions. "
|
79 |
"Feel free to use any tools available to look up "
|
@@ -81,6 +79,7 @@ DEFAULT_SYSTEM_PROMPT = (
|
|
81 |
"when calling search functions."
|
82 |
)
|
83 |
|
|
|
84 |
def hint_arxiv():
|
85 |
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
86 |
"For example: \n\n"
|
@@ -150,7 +149,8 @@ sel_map = {
|
|
150 |
"hint": hint_wiki,
|
151 |
"hint_sql": hint_sql_wiki,
|
152 |
"doc_prompt": PromptTemplate(
|
153 |
-
input_variables=["page_content",
|
|
|
154 |
template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
|
155 |
"metadata_cols": [
|
156 |
AttributeInfo(
|
@@ -224,6 +224,7 @@ sel_map = {
|
|
224 |
}
|
225 |
}
|
226 |
|
|
|
227 |
def build_embedding_model(_sel):
|
228 |
"""Build embedding model
|
229 |
"""
|
@@ -253,7 +254,8 @@ def build_chains_retrievers(_sel: str) -> Dict[str, Any]:
|
|
253 |
"sql_retriever": sql_retriever,
|
254 |
"sql_chain": sql_chain
|
255 |
}
|
256 |
-
|
|
|
257 |
def build_self_query(_sel: str) -> SelfQueryRetriever:
|
258 |
"""Build self querying retriever
|
259 |
|
@@ -278,18 +280,20 @@ def build_self_query(_sel: str) -> SelfQueryRetriever:
|
|
278 |
"vector": sel_map[_sel]["vector_col"],
|
279 |
"metadata": sel_map[_sel]["metadata_col"]
|
280 |
})
|
281 |
-
doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
|
282 |
must_have_cols=sel_map[_sel]['must_have_cols'])
|
283 |
|
284 |
with st.spinner(f"Building Self Query Retriever for {_sel}..."):
|
285 |
metadata_field_info = sel_map[_sel]["metadata_cols"]
|
286 |
retriever = SelfQueryRetriever.from_llm(
|
287 |
-
OpenAI(model_name=query_model_name,
|
|
|
288 |
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
289 |
use_original_query=False, structured_query_translator=MyScaleTranslator())
|
290 |
return retriever
|
291 |
|
292 |
-
|
|
|
293 |
"""Build Vector SQL Database Retriever
|
294 |
|
295 |
:param _sel: selected knowledge base
|
@@ -308,7 +312,8 @@ def build_vector_sql(_sel: str)->VectorSQLDatabaseChainRetriever:
|
|
308 |
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
309 |
model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
|
310 |
sql_query_chain = VectorSQLDatabaseChain.from_llm(
|
311 |
-
llm=OpenAI(model_name=query_model_name,
|
|
|
312 |
prompt=PROMPT,
|
313 |
top_k=10,
|
314 |
return_direct=True,
|
@@ -319,8 +324,9 @@ def build_vector_sql(_sel: str)->VectorSQLDatabaseChainRetriever:
|
|
319 |
sql_retriever = VectorSQLDatabaseChainRetriever(
|
320 |
sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
|
321 |
return sql_retriever
|
322 |
-
|
323 |
-
|
|
|
324 |
"""_summary_
|
325 |
|
326 |
:param _sel: selected knowledge base
|
@@ -350,6 +356,7 @@ def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str="Self-query")
|
|
350 |
)
|
351 |
return chain
|
352 |
|
|
|
353 |
@st.cache_resource
|
354 |
def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
355 |
"""build all resources
|
@@ -365,6 +372,7 @@ def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
365 |
sel_map_obj[k] = build_chains_retrievers(k)
|
366 |
return sel_map_obj, embeddings
|
367 |
|
|
|
368 |
def create_message_model(table_name, DynamicBase): # type: ignore
|
369 |
"""
|
370 |
Create a message model for a given table name.
|
@@ -397,6 +405,7 @@ def create_message_model(table_name, DynamicBase): # type: ignore
|
|
397 |
|
398 |
return Message
|
399 |
|
|
|
400 |
def _message_from_dict(message: dict) -> BaseMessage:
|
401 |
_type = message["type"]
|
402 |
if _type == "human":
|
@@ -417,6 +426,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
|
417 |
else:
|
418 |
raise ValueError(f"Got unexpected message type: {_type}")
|
419 |
|
|
|
420 |
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
421 |
"""The default message converter for SQLChatMessageHistory."""
|
422 |
|
@@ -425,27 +435,28 @@ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
|
425 |
|
426 |
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
|
427 |
tstamp = time.time()
|
428 |
-
msg_id = hashlib.sha256(
|
|
|
429 |
user_id, _ = session_id.split("?")
|
430 |
return self.model_class(
|
431 |
-
id=tstamp,
|
432 |
msg_id=msg_id,
|
433 |
user_id=user_id,
|
434 |
-
session_id=session_id,
|
435 |
type=message.type,
|
436 |
addtionals=json.dumps(message.additional_kwargs),
|
437 |
message=json.dumps({
|
438 |
-
"type": message.type,
|
439 |
"additional_kwargs": {"timestamp": tstamp},
|
440 |
"data": message.dict()})
|
441 |
)
|
442 |
-
|
443 |
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
444 |
msg_dump = json.loads(sql_message.message)
|
445 |
msg = _message_from_dict(msg_dump)
|
446 |
msg.additional_kwargs = msg_dump["additional_kwargs"]
|
447 |
return msg
|
448 |
-
|
449 |
def get_sql_model_class(self) -> Any:
|
450 |
return self.model_class
|
451 |
|
@@ -458,7 +469,7 @@ def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs)
|
|
458 |
connection_string=f'{conn_str}/chat?protocol=https',
|
459 |
custom_message_converter=DefaultClickhouseMessageConverter(name))
|
460 |
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
|
461 |
-
|
462 |
_system_message = SystemMessage(
|
463 |
content=system_prompt
|
464 |
)
|
@@ -475,10 +486,12 @@ def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs)
|
|
475 |
return_intermediate_steps=True,
|
476 |
**kwargs
|
477 |
)
|
478 |
-
|
|
|
479 |
class RetrieverInput(BaseModel):
|
480 |
query: str = Field(description="query to look up in retriever")
|
481 |
|
|
|
482 |
def create_retriever_tool(
|
483 |
retriever: BaseRetriever, name: str, description: str
|
484 |
) -> Tool:
|
@@ -499,7 +512,7 @@ def create_retriever_tool(
|
|
499 |
docs: List[Document] = func(*args, **kwargs)
|
500 |
return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
|
501 |
return wrapped_retrieve
|
502 |
-
|
503 |
return Tool(
|
504 |
name=name,
|
505 |
description=description,
|
@@ -507,7 +520,8 @@ def create_retriever_tool(
|
|
507 |
coroutine=retriever.aget_relevant_documents,
|
508 |
args_schema=RetrieverInput,
|
509 |
)
|
510 |
-
|
|
|
511 |
@st.cache_resource
|
512 |
def build_tools():
|
513 |
"""build all resources
|
@@ -531,8 +545,9 @@ def build_tools():
|
|
531 |
})
|
532 |
return sel_map_obj
|
533 |
|
|
|
534 |
def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
535 |
-
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
536 |
openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
|
537 |
)
|
538 |
tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users
|
@@ -543,7 +558,7 @@ def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temper
|
|
543 |
chat_llm,
|
544 |
tools=sel_tools,
|
545 |
system_prompt=system_prompt
|
546 |
-
|
547 |
return agent
|
548 |
|
549 |
|
@@ -556,4 +571,4 @@ def display(dataframe, columns_=None, index=None):
|
|
556 |
else:
|
557 |
st.dataframe(dataframe)
|
558 |
else:
|
559 |
-
st.write("Sorry π΅ we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
|
|
|
4 |
import hashlib
|
5 |
from typing import Dict, Any, List, Tuple
|
6 |
import re
|
|
|
7 |
from os import environ
|
8 |
import streamlit as st
|
|
|
9 |
from langchain.schema import BaseRetriever
|
10 |
from langchain.tools import Tool
|
11 |
from langchain.pydantic_v1 import BaseModel, Field
|
|
|
18 |
from sqlalchemy.ext.declarative import declarative_base
|
19 |
from sqlalchemy.orm import sessionmaker
|
20 |
from clickhouse_sqlalchemy import (
|
21 |
+
types, engines
|
22 |
)
|
23 |
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
24 |
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
|
|
41 |
from langchain.prompts.chat import MessagesPlaceholder
|
42 |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
43 |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
44 |
+
from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage, \
|
45 |
SystemMessage, ChatMessage, ToolMessage
|
46 |
from langchain.memory import SQLChatMessageHistory
|
47 |
from langchain.memory.chat_message_histories.sql import \
|
48 |
+
DefaultMessageConverter
|
49 |
+
from langchain.schema.messages import BaseMessage
|
50 |
# from langchain.agents.agent_toolkits import create_retriever_tool
|
51 |
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
52 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
|
|
71 |
|
72 |
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
73 |
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
74 |
+
(HumanMessagePromptTemplate, '{question}')])
|
75 |
DEFAULT_SYSTEM_PROMPT = (
|
76 |
"Do your best to answer the questions. "
|
77 |
"Feel free to use any tools available to look up "
|
|
|
79 |
"when calling search functions."
|
80 |
)
|
81 |
|
82 |
+
|
83 |
def hint_arxiv():
|
84 |
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
85 |
"For example: \n\n"
|
|
|
149 |
"hint": hint_wiki,
|
150 |
"hint_sql": hint_sql_wiki,
|
151 |
"doc_prompt": PromptTemplate(
|
152 |
+
input_variables=["page_content",
|
153 |
+
"url", "title", "ref_id", "views"],
|
154 |
template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
|
155 |
"metadata_cols": [
|
156 |
AttributeInfo(
|
|
|
224 |
}
|
225 |
}
|
226 |
|
227 |
+
|
228 |
def build_embedding_model(_sel):
|
229 |
"""Build embedding model
|
230 |
"""
|
|
|
254 |
"sql_retriever": sql_retriever,
|
255 |
"sql_chain": sql_chain
|
256 |
}
|
257 |
+
|
258 |
+
|
259 |
def build_self_query(_sel: str) -> SelfQueryRetriever:
|
260 |
"""Build self querying retriever
|
261 |
|
|
|
280 |
"vector": sel_map[_sel]["vector_col"],
|
281 |
"metadata": sel_map[_sel]["metadata_col"]
|
282 |
})
|
283 |
+
doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
|
284 |
must_have_cols=sel_map[_sel]['must_have_cols'])
|
285 |
|
286 |
with st.spinner(f"Building Self Query Retriever for {_sel}..."):
|
287 |
metadata_field_info = sel_map[_sel]["metadata_cols"]
|
288 |
retriever = SelfQueryRetriever.from_llm(
|
289 |
+
OpenAI(model_name=query_model_name,
|
290 |
+
openai_api_key=OPENAI_API_KEY, temperature=0),
|
291 |
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
292 |
use_original_query=False, structured_query_translator=MyScaleTranslator())
|
293 |
return retriever
|
294 |
|
295 |
+
|
296 |
+
def build_vector_sql(_sel: str) -> VectorSQLDatabaseChainRetriever:
|
297 |
"""Build Vector SQL Database Retriever
|
298 |
|
299 |
:param _sel: selected knowledge base
|
|
|
312 |
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
313 |
model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
|
314 |
sql_query_chain = VectorSQLDatabaseChain.from_llm(
|
315 |
+
llm=OpenAI(model_name=query_model_name,
|
316 |
+
openai_api_key=OPENAI_API_KEY, temperature=0),
|
317 |
prompt=PROMPT,
|
318 |
top_k=10,
|
319 |
return_direct=True,
|
|
|
324 |
sql_retriever = VectorSQLDatabaseChainRetriever(
|
325 |
sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
|
326 |
return sql_retriever
|
327 |
+
|
328 |
+
|
329 |
+
def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str = "Self-query") -> ArXivQAwithSourcesChain:
|
330 |
"""_summary_
|
331 |
|
332 |
:param _sel: selected knowledge base
|
|
|
356 |
)
|
357 |
return chain
|
358 |
|
359 |
+
|
360 |
@st.cache_resource
|
361 |
def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
362 |
"""build all resources
|
|
|
372 |
sel_map_obj[k] = build_chains_retrievers(k)
|
373 |
return sel_map_obj, embeddings
|
374 |
|
375 |
+
|
376 |
def create_message_model(table_name, DynamicBase): # type: ignore
|
377 |
"""
|
378 |
Create a message model for a given table name.
|
|
|
405 |
|
406 |
return Message
|
407 |
|
408 |
+
|
409 |
def _message_from_dict(message: dict) -> BaseMessage:
|
410 |
_type = message["type"]
|
411 |
if _type == "human":
|
|
|
426 |
else:
|
427 |
raise ValueError(f"Got unexpected message type: {_type}")
|
428 |
|
429 |
+
|
430 |
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
431 |
"""The default message converter for SQLChatMessageHistory."""
|
432 |
|
|
|
435 |
|
436 |
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
|
437 |
tstamp = time.time()
|
438 |
+
msg_id = hashlib.sha256(
|
439 |
+
f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest()
|
440 |
user_id, _ = session_id.split("?")
|
441 |
return self.model_class(
|
442 |
+
id=tstamp,
|
443 |
msg_id=msg_id,
|
444 |
user_id=user_id,
|
445 |
+
session_id=session_id,
|
446 |
type=message.type,
|
447 |
addtionals=json.dumps(message.additional_kwargs),
|
448 |
message=json.dumps({
|
449 |
+
"type": message.type,
|
450 |
"additional_kwargs": {"timestamp": tstamp},
|
451 |
"data": message.dict()})
|
452 |
)
|
453 |
+
|
454 |
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
455 |
msg_dump = json.loads(sql_message.message)
|
456 |
msg = _message_from_dict(msg_dump)
|
457 |
msg.additional_kwargs = msg_dump["additional_kwargs"]
|
458 |
return msg
|
459 |
+
|
460 |
def get_sql_model_class(self) -> Any:
|
461 |
return self.model_class
|
462 |
|
|
|
469 |
connection_string=f'{conn_str}/chat?protocol=https',
|
470 |
custom_message_converter=DefaultClickhouseMessageConverter(name))
|
471 |
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
|
472 |
+
|
473 |
_system_message = SystemMessage(
|
474 |
content=system_prompt
|
475 |
)
|
|
|
486 |
return_intermediate_steps=True,
|
487 |
**kwargs
|
488 |
)
|
489 |
+
|
490 |
+
|
491 |
class RetrieverInput(BaseModel):
|
492 |
query: str = Field(description="query to look up in retriever")
|
493 |
|
494 |
+
|
495 |
def create_retriever_tool(
|
496 |
retriever: BaseRetriever, name: str, description: str
|
497 |
) -> Tool:
|
|
|
512 |
docs: List[Document] = func(*args, **kwargs)
|
513 |
return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
|
514 |
return wrapped_retrieve
|
515 |
+
|
516 |
return Tool(
|
517 |
name=name,
|
518 |
description=description,
|
|
|
520 |
coroutine=retriever.aget_relevant_documents,
|
521 |
args_schema=RetrieverInput,
|
522 |
)
|
523 |
+
|
524 |
+
|
525 |
@st.cache_resource
|
526 |
def build_tools():
|
527 |
"""build all resources
|
|
|
545 |
})
|
546 |
return sel_map_obj
|
547 |
|
548 |
+
|
549 |
def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
550 |
+
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
551 |
openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
|
552 |
)
|
553 |
tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users
|
|
|
558 |
chat_llm,
|
559 |
tools=sel_tools,
|
560 |
system_prompt=system_prompt
|
561 |
+
)
|
562 |
return agent
|
563 |
|
564 |
|
|
|
571 |
else:
|
572 |
st.dataframe(dataframe)
|
573 |
else:
|
574 |
+
st.write("Sorry π΅ we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
|
lib/json_conv.py
CHANGED
@@ -1,15 +1,18 @@
|
|
1 |
import json
|
2 |
import datetime
|
3 |
|
|
|
4 |
class CustomJSONEncoder(json.JSONEncoder):
|
5 |
def default(self, obj):
|
6 |
if isinstance(obj, datetime.datetime):
|
7 |
return datetime.datetime.isoformat(obj)
|
8 |
return json.JSONEncoder.default(self, obj)
|
9 |
|
|
|
10 |
class CustomJSONDecoder(json.JSONDecoder):
|
11 |
def __init__(self, *args, **kwargs):
|
12 |
-
json.JSONDecoder.__init__(
|
|
|
13 |
|
14 |
def object_hook(self, source):
|
15 |
for k, v in source.items():
|
@@ -18,4 +21,4 @@ class CustomJSONDecoder(json.JSONDecoder):
|
|
18 |
source[k] = datetime.datetime.fromisoformat(str(v))
|
19 |
except:
|
20 |
pass
|
21 |
-
return source
|
|
|
1 |
import json
|
2 |
import datetime
|
3 |
|
4 |
+
|
5 |
class CustomJSONEncoder(json.JSONEncoder):
|
6 |
def default(self, obj):
|
7 |
if isinstance(obj, datetime.datetime):
|
8 |
return datetime.datetime.isoformat(obj)
|
9 |
return json.JSONEncoder.default(self, obj)
|
10 |
|
11 |
+
|
12 |
class CustomJSONDecoder(json.JSONDecoder):
|
13 |
def __init__(self, *args, **kwargs):
|
14 |
+
json.JSONDecoder.__init__(
|
15 |
+
self, object_hook=self.object_hook, *args, **kwargs)
|
16 |
|
17 |
def object_hook(self, source):
|
18 |
for k, v in source.items():
|
|
|
21 |
source[k] = datetime.datetime.fromisoformat(str(v))
|
22 |
except:
|
23 |
pass
|
24 |
+
return source
|
lib/private_kb.py
CHANGED
@@ -52,7 +52,8 @@ def parse_files(api_key, user_id, files: List[UploadedFile]):
|
|
52 |
|
53 |
def extract_embedding(embeddings: Embeddings, texts):
|
54 |
if len(texts) > 0:
|
55 |
-
embs = embeddings.embed_documents(
|
|
|
56 |
for i, _ in enumerate(texts):
|
57 |
texts[i]["vector"] = embs[i]
|
58 |
return texts
|
|
|
52 |
|
53 |
def extract_embedding(embeddings: Embeddings, texts):
|
54 |
if len(texts) > 0:
|
55 |
+
embs = embeddings.embed_documents(
|
56 |
+
[t["text"] for _, t in enumerate(texts)])
|
57 |
for i, _ in enumerate(texts):
|
58 |
texts[i]["vector"] = embs[i]
|
59 |
return texts
|
lib/schemas.py
CHANGED
@@ -49,4 +49,4 @@ def create_session_table(table_name, DynamicBase): # type: ignore
|
|
49 |
order_by=('session_id')),
|
50 |
{'comment': 'Store Session and Prompts'}
|
51 |
)
|
52 |
-
return Session
|
|
|
49 |
order_by=('session_id')),
|
50 |
{'comment': 'Store Session and Prompts'}
|
51 |
)
|
52 |
+
return Session
|
lib/sessions.py
CHANGED
@@ -6,9 +6,9 @@ except ImportError:
|
|
6 |
from langchain.schema import BaseChatMessageHistory
|
7 |
from datetime import datetime
|
8 |
from sqlalchemy import Column, Text, orm, create_engine
|
9 |
-
from clickhouse_sqlalchemy import types, engines
|
10 |
from .schemas import create_message_model, create_session_table
|
11 |
|
|
|
12 |
def get_sessions(engine, model_class, user_id):
|
13 |
with orm.sessionmaker(engine)() as session:
|
14 |
result = (
|
@@ -20,14 +20,17 @@ def get_sessions(engine, model_class, user_id):
|
|
20 |
)
|
21 |
return json.loads(result)
|
22 |
|
|
|
23 |
class SessionManager:
|
24 |
def __init__(self, session_state, host, port, username, password,
|
25 |
db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
|
26 |
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
|
27 |
self.engine = create_engine(conn_str, echo=False)
|
28 |
-
self.sess_model_class = create_session_table(
|
|
|
29 |
self.sess_model_class.metadata.create_all(self.engine)
|
30 |
-
self.msg_model_class = create_message_model(
|
|
|
31 |
self.msg_model_class.metadata.create_all(self.engine)
|
32 |
self.Session = orm.sessionmaker(self.engine)
|
33 |
self.session_state = session_state
|
@@ -46,14 +49,15 @@ class SessionManager:
|
|
46 |
sessions.append({
|
47 |
"session_id": r.session_id.split("?")[-1],
|
48 |
"system_prompt": r.system_prompt,
|
49 |
-
|
50 |
return sessions
|
51 |
-
|
52 |
def modify_system_prompt(self, session_id, sys_prompt):
|
53 |
with self.Session() as session:
|
54 |
-
session.update(self.sess_model_class).where(
|
|
|
55 |
session.commit()
|
56 |
-
|
57 |
def add_session(self, user_id, session_id, system_prompt, **kwargs):
|
58 |
with self.Session() as session:
|
59 |
elem = self.sess_model_class(
|
@@ -62,14 +66,13 @@ class SessionManager:
|
|
62 |
)
|
63 |
session.add(elem)
|
64 |
session.commit()
|
65 |
-
|
66 |
def remove_session(self, session_id):
|
67 |
with self.Session() as session:
|
68 |
-
session.query(self.sess_model_class).where(
|
|
|
69 |
# session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
|
70 |
if "agent" in self.session_state:
|
71 |
self.session_state.agent.memory.chat_memory.clear()
|
72 |
if "file_analyzer" in self.session_state:
|
73 |
self.session_state.file_analyzer.clear_files()
|
74 |
-
|
75 |
-
|
|
|
6 |
from langchain.schema import BaseChatMessageHistory
|
7 |
from datetime import datetime
|
8 |
from sqlalchemy import Column, Text, orm, create_engine
|
|
|
9 |
from .schemas import create_message_model, create_session_table
|
10 |
|
11 |
+
|
12 |
def get_sessions(engine, model_class, user_id):
|
13 |
with orm.sessionmaker(engine)() as session:
|
14 |
result = (
|
|
|
20 |
)
|
21 |
return json.loads(result)
|
22 |
|
23 |
+
|
24 |
class SessionManager:
|
25 |
def __init__(self, session_state, host, port, username, password,
|
26 |
db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
|
27 |
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
|
28 |
self.engine = create_engine(conn_str, echo=False)
|
29 |
+
self.sess_model_class = create_session_table(
|
30 |
+
sess_table, declarative_base())
|
31 |
self.sess_model_class.metadata.create_all(self.engine)
|
32 |
+
self.msg_model_class = create_message_model(
|
33 |
+
msg_table, declarative_base())
|
34 |
self.msg_model_class.metadata.create_all(self.engine)
|
35 |
self.Session = orm.sessionmaker(self.engine)
|
36 |
self.session_state = session_state
|
|
|
49 |
sessions.append({
|
50 |
"session_id": r.session_id.split("?")[-1],
|
51 |
"system_prompt": r.system_prompt,
|
52 |
+
})
|
53 |
return sessions
|
54 |
+
|
55 |
def modify_system_prompt(self, session_id, sys_prompt):
|
56 |
with self.Session() as session:
|
57 |
+
session.update(self.sess_model_class).where(
|
58 |
+
self.sess_model_class == session_id).value(system_prompt=sys_prompt)
|
59 |
session.commit()
|
60 |
+
|
61 |
def add_session(self, user_id, session_id, system_prompt, **kwargs):
|
62 |
with self.Session() as session:
|
63 |
elem = self.sess_model_class(
|
|
|
66 |
)
|
67 |
session.add(elem)
|
68 |
session.commit()
|
69 |
+
|
70 |
def remove_session(self, session_id):
|
71 |
with self.Session() as session:
|
72 |
+
session.query(self.sess_model_class).where(
|
73 |
+
self.sess_model_class.session_id == session_id).delete()
|
74 |
# session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
|
75 |
if "agent" in self.session_state:
|
76 |
self.session_state.agent.memory.chat_memory.clear()
|
77 |
if "file_analyzer" in self.session_state:
|
78 |
self.session_state.file_analyzer.clear_files()
|
|
|
|
login.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
-
import json
|
2 |
-
import time
|
3 |
-
import pandas as pd
|
4 |
-
from os import environ
|
5 |
import streamlit as st
|
6 |
from auth0_component import login_button
|
7 |
|
8 |
AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
|
9 |
AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
|
10 |
|
|
|
11 |
def login():
|
12 |
if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
|
13 |
return True
|
14 |
-
st.subheader(
|
|
|
15 |
st.write("You can now chat with ArXiv and Wikipedia! π\n")
|
16 |
st.write("Built purely with streamlit π , LangChain π¦π and love β€οΈ for AI!")
|
17 |
-
st.write(
|
18 |
-
|
|
|
|
|
19 |
st.divider()
|
20 |
col1, col2 = st.columns(2, gap='large')
|
21 |
with col1.container():
|
@@ -33,7 +33,7 @@ def login():
|
|
33 |
st.write("- [Privacy Policy](https://myscale.com/privacy/)\n"
|
34 |
"- [Terms of Sevice](https://myscale.com/terms/)")
|
35 |
if st.session_state.auth0 is not None:
|
36 |
-
st.session_state.user_info = dict(st.session_state.auth0)
|
37 |
if 'email' in st.session_state.user_info:
|
38 |
email = st.session_state.user_info["email"]
|
39 |
else:
|
@@ -44,6 +44,7 @@ def login():
|
|
44 |
if st.session_state.jump_query_ask:
|
45 |
st.experimental_rerun()
|
46 |
|
|
|
47 |
def back_to_main():
|
48 |
if "user_info" in st.session_state:
|
49 |
del st.session_state.user_info
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from auth0_component import login_button
|
3 |
|
4 |
AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
|
5 |
AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
|
6 |
|
7 |
+
|
8 |
def login():
|
9 |
if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
|
10 |
return True
|
11 |
+
st.subheader(
|
12 |
+
"π€ Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! π€ ")
|
13 |
st.write("You can now chat with ArXiv and Wikipedia! π\n")
|
14 |
st.write("Built purely with streamlit π , LangChain π¦π and love β€οΈ for AI!")
|
15 |
+
st.write(
|
16 |
+
"Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
|
17 |
+
st.write(
|
18 |
+
"For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
|
19 |
st.divider()
|
20 |
col1, col2 = st.columns(2, gap='large')
|
21 |
with col1.container():
|
|
|
33 |
st.write("- [Privacy Policy](https://myscale.com/privacy/)\n"
|
34 |
"- [Terms of Sevice](https://myscale.com/terms/)")
|
35 |
if st.session_state.auth0 is not None:
|
36 |
+
st.session_state.user_info = dict(st.session_state.auth0)
|
37 |
if 'email' in st.session_state.user_info:
|
38 |
email = st.session_state.user_info["email"]
|
39 |
else:
|
|
|
44 |
if st.session_state.jump_query_ask:
|
45 |
st.experimental_rerun()
|
46 |
|
47 |
+
|
48 |
def back_to_main():
|
49 |
if "user_info" in st.session_state:
|
50 |
del st.session_state.user_info
|
prompts/arxiv_prompt.py
CHANGED
@@ -6,7 +6,7 @@ combine_prompt_template = (
|
|
6 |
+ "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
|
7 |
+ "corresponding section name and page that you refer to when answering. The following is the related information "
|
8 |
+ "about the document that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
|
9 |
-
+ "Now you should
|
10 |
)
|
11 |
|
12 |
_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|
|
|
6 |
+ "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
|
7 |
+ "corresponding section name and page that you refer to when answering. The following is the related information "
|
8 |
+ "about the document that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
|
9 |
+
+ "Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n"
|
10 |
)
|
11 |
|
12 |
_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|