Update app.py
Browse files
app.py
CHANGED
@@ -16,59 +16,61 @@ from utills import load_txt_documents, split_docs, load_uploaded_documents, retr
|
|
16 |
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
|
17 |
from langchain_community.document_loaders.directory import DirectoryLoader
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
store = {}
|
23 |
-
|
24 |
-
model_name = "sentence-transformers/all-mpnet-base-v2"
|
25 |
-
model_kwargs = {'device': 'cpu'}
|
26 |
-
encode_kwargs = {'normalize_embeddings': True}
|
27 |
-
|
28 |
-
hf = HuggingFaceEmbeddings(
|
29 |
-
model_name=model_name,
|
30 |
-
model_kwargs=model_kwargs,
|
31 |
-
encode_kwargs=encode_kwargs
|
32 |
-
)
|
33 |
-
|
34 |
-
def get_vectorstore(text_chunks):
|
35 |
-
model_name = "sentence-transformers/all-mpnet-base-v2"
|
36 |
-
model_kwargs = {'device': 'cpu'}
|
37 |
-
encode_kwargs = {'normalize_embeddings': True}
|
38 |
-
hf = HuggingFaceEmbeddings(
|
39 |
-
model_name=model_name,
|
40 |
-
model_kwargs=model_kwargs,
|
41 |
-
encode_kwargs=encode_kwargs
|
42 |
-
)
|
43 |
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
return document_loader.load()
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def create_conversational_rag_chain(vectorstore):
|
62 |
-
|
63 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
64 |
model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
|
65 |
-
|
66 |
-
retriever = vectorstore.as_retriever(search_type='mmr', search_kwargs={"k": 7})
|
67 |
|
68 |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
69 |
|
70 |
llm = llamacpp.LlamaCpp(
|
71 |
-
model_path=
|
72 |
n_gpu_layers=1,
|
73 |
temperature=0.1,
|
74 |
top_p=0.9,
|
@@ -84,7 +86,7 @@ def create_conversational_rag_chain(vectorstore):
|
|
84 |
which can be understood without the chat history. Do NOT answer the question,
|
85 |
just reformulate it if needed and otherwise return it as is."""
|
86 |
|
87 |
-
ha_retriever =
|
88 |
|
89 |
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
|
90 |
|
@@ -110,63 +112,5 @@ def create_conversational_rag_chain(vectorstore):
|
|
110 |
)
|
111 |
return conversation_chain
|
112 |
|
113 |
-
def main():
|
114 |
-
"""Main function for the Streamlit app."""
|
115 |
-
# Initialize chat history if not already present in session state
|
116 |
-
|
117 |
-
documents = []
|
118 |
-
|
119 |
-
script_dir = os.path.dirname(os.path.abspath(__file__))
|
120 |
-
data_path = os.path.join(script_dir, "data/")
|
121 |
-
if not os.path.exists(data_path):
|
122 |
-
st.error(f"Data path does not exist: {data_path}")
|
123 |
-
return
|
124 |
-
|
125 |
-
try:
|
126 |
-
# Load documents from the data path
|
127 |
-
documents = load_txt_documents(data_path)
|
128 |
-
if not documents:
|
129 |
-
st.warning("No documents found in the data path.")
|
130 |
-
else:
|
131 |
-
# Split the documents into chunks
|
132 |
-
docs = split_docs(documents, 350, 40)
|
133 |
-
# Add your logic here to use `docs`
|
134 |
-
st.success("Documents loaded and processed successfully.")
|
135 |
-
except Exception as e:
|
136 |
-
st.error(f"An error occurred while loading documents: {e}")
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
documents = load_txt_documents(data_path)
|
142 |
-
docs = split_docs(documents, 350, 40)
|
143 |
-
|
144 |
-
vectorstore = get_vectorstore(docs)
|
145 |
-
|
146 |
-
msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
|
147 |
-
chain_with_history = create_conversational_rag_chain(vectorstore)
|
148 |
-
|
149 |
-
st.title("Conversational RAG Chatbot")
|
150 |
-
|
151 |
-
if prompt := st.chat_input():
|
152 |
-
st.chat_message("human").write(prompt)
|
153 |
-
|
154 |
-
# Prepare the input dictionary with the correct keys
|
155 |
-
input_dict = {"input": prompt, "chat_history": msgs.messages}
|
156 |
-
config = {"configurable": {"session_id": "any"}}
|
157 |
-
|
158 |
-
# Process user input and handle response
|
159 |
-
response = chain_with_history.invoke(input_dict, config)
|
160 |
-
st.chat_message("ai").write(response["answer"])
|
161 |
-
|
162 |
-
# Display retrieved documents (if any and present in response)
|
163 |
-
if "docs" in response and response["documents"]:
|
164 |
-
for index, doc in enumerate(response["documents"]):
|
165 |
-
with st.expander(f"Document {index + 1}"):
|
166 |
-
st.write(doc)
|
167 |
-
|
168 |
-
# Update chat history in session state
|
169 |
-
st.session_state["chat_history"] = msgs
|
170 |
-
|
171 |
if __name__ == "__main__":
|
172 |
main()
|
|
|
16 |
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
|
17 |
from langchain_community.document_loaders.directory import DirectoryLoader
|
18 |
|
19 |
+
def main():
|
20 |
+
st.set_page_config(page_title="Conversational RAG Chatbot", page_icon=":robot:")
|
21 |
+
st.title("Conversational RAG Chatbot")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
if "documents" not in st.session_state:
|
24 |
+
st.session_state.documents = []
|
25 |
|
26 |
+
if "conversation_chain" not in st.session_state:
|
27 |
+
st.session_state.conversation_chain = None
|
|
|
28 |
|
29 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
30 |
+
data_path = os.path.join(script_dir, "data/")
|
31 |
+
|
32 |
+
if not os.path.exists(data_path):
|
33 |
+
st.error(f"Data path does not exist: {data_path}")
|
34 |
+
return
|
35 |
+
|
36 |
+
try:
|
37 |
+
documents = load_txt_documents(data_path)
|
38 |
+
if not documents:
|
39 |
+
st.warning("No documents found in the data path.")
|
40 |
+
else:
|
41 |
+
st.session_state.documents = documents
|
42 |
+
docs = split_docs(documents, 350, 40)
|
43 |
+
vectorstore = retriever_from_chroma(docs, HuggingFaceEmbeddings(), "mmr", 7)
|
44 |
+
st.session_state.conversation_chain = create_conversational_rag_chain(vectorstore)
|
45 |
+
st.success("Documents loaded and processed successfully.")
|
46 |
+
except Exception as e:
|
47 |
+
st.error(f"An error occurred while loading documents: {e}")
|
48 |
+
|
49 |
+
if prompt := st.text_input("Enter your question:"):
|
50 |
+
msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
|
51 |
+
st.chat_message("human").write(prompt)
|
52 |
+
|
53 |
+
input_dict = {"input": prompt, "chat_history": msgs.messages}
|
54 |
+
config = {"configurable": {"session_id": "any"}}
|
55 |
+
|
56 |
+
response = st.session_state.conversation_chain.invoke(input_dict, config)
|
57 |
+
st.chat_message("ai").write(response["answer"])
|
58 |
+
|
59 |
+
if "docs" in response and response["documents"]:
|
60 |
+
for index, doc in enumerate(response["documents"]):
|
61 |
+
with st.expander(f"Document {index + 1}"):
|
62 |
+
st.write(doc)
|
63 |
+
|
64 |
+
st.session_state["chat_history"] = msgs
|
65 |
|
66 |
def create_conversational_rag_chain(vectorstore):
|
|
|
67 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
68 |
model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
|
|
|
|
|
69 |
|
70 |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
71 |
|
72 |
llm = llamacpp.LlamaCpp(
|
73 |
+
model_path=model_path,
|
74 |
n_gpu_layers=1,
|
75 |
temperature=0.1,
|
76 |
top_p=0.9,
|
|
|
86 |
which can be understood without the chat history. Do NOT answer the question,
|
87 |
just reformulate it if needed and otherwise return it as is."""
|
88 |
|
89 |
+
ha_retriever = history_aware_retriever(llm, vectorstore.as_retriever(), contextualize_q_system_prompt)
|
90 |
|
91 |
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
|
92 |
|
|
|
112 |
)
|
113 |
return conversation_chain
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
if __name__ == "__main__":
|
116 |
main()
|