ArturG9 commited on
Commit
9f3b8b8
1 Parent(s): c547536

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -101
app.py CHANGED
@@ -16,59 +16,61 @@ from utills import load_txt_documents, split_docs, load_uploaded_documents, retr
16
  from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
17
  from langchain_community.document_loaders.directory import DirectoryLoader
18
 
19
- script_dir = os.path.dirname(os.path.abspath(__file__))
20
- data_path = os.path.join(script_dir, "data/")
21
- model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
22
- store = {}
23
-
24
- model_name = "sentence-transformers/all-mpnet-base-v2"
25
- model_kwargs = {'device': 'cpu'}
26
- encode_kwargs = {'normalize_embeddings': True}
27
-
28
- hf = HuggingFaceEmbeddings(
29
- model_name=model_name,
30
- model_kwargs=model_kwargs,
31
- encode_kwargs=encode_kwargs
32
- )
33
-
34
- def get_vectorstore(text_chunks):
35
- model_name = "sentence-transformers/all-mpnet-base-v2"
36
- model_kwargs = {'device': 'cpu'}
37
- encode_kwargs = {'normalize_embeddings': True}
38
- hf = HuggingFaceEmbeddings(
39
- model_name=model_name,
40
- model_kwargs=model_kwargs,
41
- encode_kwargs=encode_kwargs
42
- )
43
 
44
- vectorstore = Chroma.from_documents(documents=text_chunks, embedding=hf, persist_directory="docs/chroma/")
45
- return vectorstore
46
 
47
- def get_pdf_text(pdf_docs):
48
- document_loader = DirectoryLoader(pdf_docs)
49
- return document_loader.load()
50
 
51
- def get_text_chunks(text):
52
- text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
53
- separator="\n",
54
- chunk_size=1000,
55
- chunk_overlap=200,
56
- length_function=len
57
- )
58
- chunks = text_splitter.split_text(text)
59
- return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def create_conversational_rag_chain(vectorstore):
62
-
63
  script_dir = os.path.dirname(os.path.abspath(__file__))
64
  model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
65
-
66
- retriever = vectorstore.as_retriever(search_type='mmr', search_kwargs={"k": 7})
67
 
68
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
69
 
70
  llm = llamacpp.LlamaCpp(
71
- model_path=os.path.join(model_path),
72
  n_gpu_layers=1,
73
  temperature=0.1,
74
  top_p=0.9,
@@ -84,7 +86,7 @@ def create_conversational_rag_chain(vectorstore):
84
  which can be understood without the chat history. Do NOT answer the question,
85
  just reformulate it if needed and otherwise return it as is."""
86
 
87
- ha_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_system_prompt)
88
 
89
  qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
90
 
@@ -110,63 +112,5 @@ def create_conversational_rag_chain(vectorstore):
110
  )
111
  return conversation_chain
112
 
113
- def main():
114
- """Main function for the Streamlit app."""
115
- # Initialize chat history if not already present in session state
116
-
117
- documents = []
118
-
119
- script_dir = os.path.dirname(os.path.abspath(__file__))
120
- data_path = os.path.join(script_dir, "data/")
121
- if not os.path.exists(data_path):
122
- st.error(f"Data path does not exist: {data_path}")
123
- return
124
-
125
- try:
126
- # Load documents from the data path
127
- documents = load_txt_documents(data_path)
128
- if not documents:
129
- st.warning("No documents found in the data path.")
130
- else:
131
- # Split the documents into chunks
132
- docs = split_docs(documents, 350, 40)
133
- # Add your logic here to use `docs`
134
- st.success("Documents loaded and processed successfully.")
135
- except Exception as e:
136
- st.error(f"An error occurred while loading documents: {e}")
137
-
138
-
139
-
140
-
141
- documents = load_txt_documents(data_path)
142
- docs = split_docs(documents, 350, 40)
143
-
144
- vectorstore = get_vectorstore(docs)
145
-
146
- msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
147
- chain_with_history = create_conversational_rag_chain(vectorstore)
148
-
149
- st.title("Conversational RAG Chatbot")
150
-
151
- if prompt := st.chat_input():
152
- st.chat_message("human").write(prompt)
153
-
154
- # Prepare the input dictionary with the correct keys
155
- input_dict = {"input": prompt, "chat_history": msgs.messages}
156
- config = {"configurable": {"session_id": "any"}}
157
-
158
- # Process user input and handle response
159
- response = chain_with_history.invoke(input_dict, config)
160
- st.chat_message("ai").write(response["answer"])
161
-
162
- # Display retrieved documents (if any and present in response)
163
- if "docs" in response and response["documents"]:
164
- for index, doc in enumerate(response["documents"]):
165
- with st.expander(f"Document {index + 1}"):
166
- st.write(doc)
167
-
168
- # Update chat history in session state
169
- st.session_state["chat_history"] = msgs
170
-
171
  if __name__ == "__main__":
172
  main()
 
16
  from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
17
  from langchain_community.document_loaders.directory import DirectoryLoader
18
 
19
+ def main():
20
+ st.set_page_config(page_title="Conversational RAG Chatbot", page_icon=":robot:")
21
+ st.title("Conversational RAG Chatbot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ if "documents" not in st.session_state:
24
+ st.session_state.documents = []
25
 
26
+ if "conversation_chain" not in st.session_state:
27
+ st.session_state.conversation_chain = None
 
28
 
29
+ script_dir = os.path.dirname(os.path.abspath(__file__))
30
+ data_path = os.path.join(script_dir, "data/")
31
+
32
+ if not os.path.exists(data_path):
33
+ st.error(f"Data path does not exist: {data_path}")
34
+ return
35
+
36
+ try:
37
+ documents = load_txt_documents(data_path)
38
+ if not documents:
39
+ st.warning("No documents found in the data path.")
40
+ else:
41
+ st.session_state.documents = documents
42
+ docs = split_docs(documents, 350, 40)
43
+ vectorstore = retriever_from_chroma(docs, HuggingFaceEmbeddings(), "mmr", 7)
44
+ st.session_state.conversation_chain = create_conversational_rag_chain(vectorstore)
45
+ st.success("Documents loaded and processed successfully.")
46
+ except Exception as e:
47
+ st.error(f"An error occurred while loading documents: {e}")
48
+
49
+ if prompt := st.text_input("Enter your question:"):
50
+ msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
51
+ st.chat_message("human").write(prompt)
52
+
53
+ input_dict = {"input": prompt, "chat_history": msgs.messages}
54
+ config = {"configurable": {"session_id": "any"}}
55
+
56
+ response = st.session_state.conversation_chain.invoke(input_dict, config)
57
+ st.chat_message("ai").write(response["answer"])
58
+
59
+ if "docs" in response and response["documents"]:
60
+ for index, doc in enumerate(response["documents"]):
61
+ with st.expander(f"Document {index + 1}"):
62
+ st.write(doc)
63
+
64
+ st.session_state["chat_history"] = msgs
65
 
66
  def create_conversational_rag_chain(vectorstore):
 
67
  script_dir = os.path.dirname(os.path.abspath(__file__))
68
  model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
 
 
69
 
70
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
71
 
72
  llm = llamacpp.LlamaCpp(
73
+ model_path=model_path,
74
  n_gpu_layers=1,
75
  temperature=0.1,
76
  top_p=0.9,
 
86
  which can be understood without the chat history. Do NOT answer the question,
87
  just reformulate it if needed and otherwise return it as is."""
88
 
89
+ ha_retriever = history_aware_retriever(llm, vectorstore.as_retriever(), contextualize_q_system_prompt)
90
 
91
  qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
92
 
 
112
  )
113
  return conversation_chain
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  if __name__ == "__main__":
116
  main()