ArturG9 commited on
Commit
c547536
1 Parent(s): 97134b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -74
app.py CHANGED
@@ -16,68 +16,59 @@ from utills import load_txt_documents, split_docs, load_uploaded_documents, retr
16
  from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
17
  from langchain_community.document_loaders.directory import DirectoryLoader
18
 
19
- def main():
20
- st.set_page_config(page_title="Chat with multiple PDFs", page_icon=":books:")
21
- st.header("Chat with multiple PDFs :books:")
22
-
23
- if "pdf_docs" not in st.session_state:
24
- st.session_state.pdf_docs = []
25
 
26
- if "conversation_chain" not in st.session_state:
27
- st.session_state.conversation_chain = None
 
28
 
29
- with st.sidebar:
30
- st.subheader("Your documents")
31
- pdf_docs = st.file_uploader("Upload your PDFs here and click on 'Process'", accept_multiple_files=True)
32
- if pdf_docs:
33
- st.session_state.pdf_docs.extend(pdf_docs)
34
 
35
- if st.button("Process"):
36
- with st.spinner("Processing"):
37
- raw_text = get_pdf_text(st.session_state.pdf_docs)
38
- text_chunks = get_text_chunks(raw_text)
39
- vectorstore = get_vectorstore(text_chunks)
40
- st.session_state.conversation_chain = get_conversation_chain(vectorstore)
41
- st.success("Documents processed and conversation chain created successfully.")
 
 
42
 
43
- user_question = st.text_input("Ask a question about your documents:")
44
- if user_question:
45
- handle_userinput(st.session_state.conversation_chain, user_question)
46
 
47
  def get_pdf_text(pdf_docs):
48
- text = ""
49
- for pdf in pdf_docs:
50
- pdf_reader = PdfReader(pdf)
51
- for page in pdf_reader.pages:
52
- text += page.extract_text()
53
- return text
54
 
55
  def get_text_chunks(text):
56
  text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
57
- chunk_size=600, chunk_overlap=50,
58
- separators=["\n \n \n", "\n \n", "\n1" , "(?<=\. )", " ", ""],
 
 
59
  )
60
  chunks = text_splitter.split_text(text)
61
  return chunks
62
 
63
- def get_vectorstore(text_chunks):
64
- model_name = "sentence-transformers/all-mpnet-base-v2"
65
- model_kwargs = {'device': 'cpu'}
66
- encode_kwargs = {'normalize_embeddings': True}
67
- embeddings = HuggingFaceEmbeddings(
68
- model_name=model_name,
69
- model_kwargs=model_kwargs,
70
- encode_kwargs=encode_kwargs
71
- )
72
- vectorstore = Chroma.from_texts(
73
- texts=text_chunks, embedding=embeddings, persist_directory="docs/chroma/")
74
- return vectorstore
75
 
76
- def get_conversation_chain(vectorstore):
77
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
78
- script_dir = os.path.dirname(os.path.abspath(__file__))
79
  llm = llamacpp.LlamaCpp(
80
- model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf'),
81
  n_gpu_layers=1,
82
  temperature=0.1,
83
  top_p=0.9,
@@ -88,26 +79,14 @@ def get_conversation_chain(vectorstore):
88
  verbose=False,
89
  )
90
 
91
- retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 7})
92
-
93
- contextualize_q_system_prompt = """Given a context, chat history and the latest user question, formulate a standalone question
94
  which can be understood without the chat history. Do NOT answer the question,
95
  just reformulate it if needed and otherwise return it as is."""
96
 
97
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
98
- [
99
- ("system", contextualize_q_system_prompt),
100
- MessagesPlaceholder("chat_history"),
101
- ("human", "{input}"),
102
- ]
103
- )
104
-
105
- history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
106
 
107
- qa_system_prompt = """Use the following pieces of retrieved context to answer the question{input}. \
108
- Be informative but don't make too long answers, be polite and formal. \
109
- If you don't know the answer, say "I don't know the answer." \
110
- {context}"""
111
 
112
  qa_prompt = ChatPromptTemplate.from_messages(
113
  [
@@ -116,12 +95,12 @@ def get_conversation_chain(vectorstore):
116
  ("human", "{input}"),
117
  ]
118
  )
119
-
120
  question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
121
 
122
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
123
  msgs = StreamlitChatMessageHistory(key="special_app_key")
124
-
125
  conversation_chain = RunnableWithMessageHistory(
126
  rag_chain,
127
  lambda session_id: msgs,
@@ -131,17 +110,63 @@ def get_conversation_chain(vectorstore):
131
  )
132
  return conversation_chain
133
 
134
- def handle_userinput(conversation_chain, prompt):
135
- msgs = StreamlitChatMessageHistory(key="special_app_key")
136
- st.chat_message("human").write(prompt)
137
- input_dict = {"input": prompt, "chat_history": msgs.messages}
138
- config = {"configurable": {"session_id": 1}}
 
 
 
 
 
 
139
 
140
  try:
141
- response = conversation_chain.invoke(input_dict, config)
142
- st.chat_message("ai").write(response["answer"])
 
 
 
 
 
 
 
143
  except Exception as e:
144
- st.error(f"Error invoking conversation chain: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- if __name__ == '__main__':
147
  main()
 
16
  from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
17
  from langchain_community.document_loaders.directory import DirectoryLoader
18
 
19
+ script_dir = os.path.dirname(os.path.abspath(__file__))
20
+ data_path = os.path.join(script_dir, "data/")
21
+ model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
22
+ store = {}
 
 
23
 
24
+ model_name = "sentence-transformers/all-mpnet-base-v2"
25
+ model_kwargs = {'device': 'cpu'}
26
+ encode_kwargs = {'normalize_embeddings': True}
27
 
28
+ hf = HuggingFaceEmbeddings(
29
+ model_name=model_name,
30
+ model_kwargs=model_kwargs,
31
+ encode_kwargs=encode_kwargs
32
+ )
33
 
34
+ def get_vectorstore(text_chunks):
35
+ model_name = "sentence-transformers/all-mpnet-base-v2"
36
+ model_kwargs = {'device': 'cpu'}
37
+ encode_kwargs = {'normalize_embeddings': True}
38
+ hf = HuggingFaceEmbeddings(
39
+ model_name=model_name,
40
+ model_kwargs=model_kwargs,
41
+ encode_kwargs=encode_kwargs
42
+ )
43
 
44
+ vectorstore = Chroma.from_documents(documents=text_chunks, embedding=hf, persist_directory="docs/chroma/")
45
+ return vectorstore
 
46
 
47
  def get_pdf_text(pdf_docs):
48
+ document_loader = DirectoryLoader(pdf_docs)
49
+ return document_loader.load()
 
 
 
 
50
 
51
  def get_text_chunks(text):
52
  text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
53
+ separator="\n",
54
+ chunk_size=1000,
55
+ chunk_overlap=200,
56
+ length_function=len
57
  )
58
  chunks = text_splitter.split_text(text)
59
  return chunks
60
 
61
+ def create_conversational_rag_chain(vectorstore):
62
+
63
+ script_dir = os.path.dirname(os.path.abspath(__file__))
64
+ model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
65
+
66
+ retriever = vectorstore.as_retriever(search_type='mmr', search_kwargs={"k": 7})
 
 
 
 
 
 
67
 
 
68
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
69
+
70
  llm = llamacpp.LlamaCpp(
71
+ model_path=os.path.join(model_path),
72
  n_gpu_layers=1,
73
  temperature=0.1,
74
  top_p=0.9,
 
79
  verbose=False,
80
  )
81
 
82
+ contextualize_q_system_prompt = """Given a context, chat history and the latest user question
83
+ which maybe reference context in the chat history, formulate a standalone question
 
84
  which can be understood without the chat history. Do NOT answer the question,
85
  just reformulate it if needed and otherwise return it as is."""
86
 
87
+ ha_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_system_prompt)
 
 
 
 
 
 
 
 
88
 
89
+ qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
 
 
 
90
 
91
  qa_prompt = ChatPromptTemplate.from_messages(
92
  [
 
95
  ("human", "{input}"),
96
  ]
97
  )
98
+
99
  question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
100
 
101
+ rag_chain = create_retrieval_chain(ha_retriever, question_answer_chain)
102
  msgs = StreamlitChatMessageHistory(key="special_app_key")
103
+
104
  conversation_chain = RunnableWithMessageHistory(
105
  rag_chain,
106
  lambda session_id: msgs,
 
110
  )
111
  return conversation_chain
112
 
113
+ def main():
114
+ """Main function for the Streamlit app."""
115
+ # Initialize chat history if not already present in session state
116
+
117
+ documents = []
118
+
119
+ script_dir = os.path.dirname(os.path.abspath(__file__))
120
+ data_path = os.path.join(script_dir, "data/")
121
+ if not os.path.exists(data_path):
122
+ st.error(f"Data path does not exist: {data_path}")
123
+ return
124
 
125
  try:
126
+ # Load documents from the data path
127
+ documents = load_txt_documents(data_path)
128
+ if not documents:
129
+ st.warning("No documents found in the data path.")
130
+ else:
131
+ # Split the documents into chunks
132
+ docs = split_docs(documents, 350, 40)
133
+ # Add your logic here to use `docs`
134
+ st.success("Documents loaded and processed successfully.")
135
  except Exception as e:
136
+ st.error(f"An error occurred while loading documents: {e}")
137
+
138
+
139
+
140
+
141
+ documents = load_txt_documents(data_path)
142
+ docs = split_docs(documents, 350, 40)
143
+
144
+ vectorstore = get_vectorstore(docs)
145
+
146
+ msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
147
+ chain_with_history = create_conversational_rag_chain(vectorstore)
148
+
149
+ st.title("Conversational RAG Chatbot")
150
+
151
+ if prompt := st.chat_input():
152
+ st.chat_message("human").write(prompt)
153
+
154
+ # Prepare the input dictionary with the correct keys
155
+ input_dict = {"input": prompt, "chat_history": msgs.messages}
156
+ config = {"configurable": {"session_id": "any"}}
157
+
158
+ # Process user input and handle response
159
+ response = chain_with_history.invoke(input_dict, config)
160
+ st.chat_message("ai").write(response["answer"])
161
+
162
+ # Display retrieved documents (if any and present in response)
163
+ if "docs" in response and response["documents"]:
164
+ for index, doc in enumerate(response["documents"]):
165
+ with st.expander(f"Document {index + 1}"):
166
+ st.write(doc)
167
+
168
+ # Update chat history in session state
169
+ st.session_state["chat_history"] = msgs
170
 
171
+ if __name__ == "__main__":
172
  main()