Update app.py
Browse files
app.py
CHANGED
@@ -16,68 +16,59 @@ from utills import load_txt_documents, split_docs, load_uploaded_documents, retr
|
|
16 |
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
|
17 |
from langchain_community.document_loaders.directory import DirectoryLoader
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
if "pdf_docs" not in st.session_state:
|
24 |
-
st.session_state.pdf_docs = []
|
25 |
|
26 |
-
|
27 |
-
|
|
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
handle_userinput(st.session_state.conversation_chain, user_question)
|
46 |
|
47 |
def get_pdf_text(pdf_docs):
|
48 |
-
|
49 |
-
|
50 |
-
pdf_reader = PdfReader(pdf)
|
51 |
-
for page in pdf_reader.pages:
|
52 |
-
text += page.extract_text()
|
53 |
-
return text
|
54 |
|
55 |
def get_text_chunks(text):
|
56 |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
)
|
60 |
chunks = text_splitter.split_text(text)
|
61 |
return chunks
|
62 |
|
63 |
-
def
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
model_kwargs=model_kwargs,
|
70 |
-
encode_kwargs=encode_kwargs
|
71 |
-
)
|
72 |
-
vectorstore = Chroma.from_texts(
|
73 |
-
texts=text_chunks, embedding=embeddings, persist_directory="docs/chroma/")
|
74 |
-
return vectorstore
|
75 |
|
76 |
-
def get_conversation_chain(vectorstore):
|
77 |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
78 |
-
|
79 |
llm = llamacpp.LlamaCpp(
|
80 |
-
model_path
|
81 |
n_gpu_layers=1,
|
82 |
temperature=0.1,
|
83 |
top_p=0.9,
|
@@ -88,26 +79,14 @@ def get_conversation_chain(vectorstore):
|
|
88 |
verbose=False,
|
89 |
)
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
contextualize_q_system_prompt = """Given a context, chat history and the latest user question, formulate a standalone question
|
94 |
which can be understood without the chat history. Do NOT answer the question,
|
95 |
just reformulate it if needed and otherwise return it as is."""
|
96 |
|
97 |
-
|
98 |
-
[
|
99 |
-
("system", contextualize_q_system_prompt),
|
100 |
-
MessagesPlaceholder("chat_history"),
|
101 |
-
("human", "{input}"),
|
102 |
-
]
|
103 |
-
)
|
104 |
-
|
105 |
-
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
|
106 |
|
107 |
-
qa_system_prompt = """Use the following pieces of retrieved context to answer the question{
|
108 |
-
Be informative but don't make too long answers, be polite and formal. \
|
109 |
-
If you don't know the answer, say "I don't know the answer." \
|
110 |
-
{context}"""
|
111 |
|
112 |
qa_prompt = ChatPromptTemplate.from_messages(
|
113 |
[
|
@@ -116,12 +95,12 @@ def get_conversation_chain(vectorstore):
|
|
116 |
("human", "{input}"),
|
117 |
]
|
118 |
)
|
119 |
-
|
120 |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
121 |
|
122 |
-
rag_chain = create_retrieval_chain(
|
123 |
msgs = StreamlitChatMessageHistory(key="special_app_key")
|
124 |
-
|
125 |
conversation_chain = RunnableWithMessageHistory(
|
126 |
rag_chain,
|
127 |
lambda session_id: msgs,
|
@@ -131,17 +110,63 @@ def get_conversation_chain(vectorstore):
|
|
131 |
)
|
132 |
return conversation_chain
|
133 |
|
134 |
-
def
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
try:
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
except Exception as e:
|
144 |
-
st.error(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
-
if __name__ ==
|
147 |
main()
|
|
|
16 |
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
|
17 |
from langchain_community.document_loaders.directory import DirectoryLoader
|
18 |
|
19 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
20 |
+
data_path = os.path.join(script_dir, "data/")
|
21 |
+
model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
|
22 |
+
store = {}
|
|
|
|
|
23 |
|
24 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
25 |
+
model_kwargs = {'device': 'cpu'}
|
26 |
+
encode_kwargs = {'normalize_embeddings': True}
|
27 |
|
28 |
+
hf = HuggingFaceEmbeddings(
|
29 |
+
model_name=model_name,
|
30 |
+
model_kwargs=model_kwargs,
|
31 |
+
encode_kwargs=encode_kwargs
|
32 |
+
)
|
33 |
|
34 |
+
def get_vectorstore(text_chunks):
|
35 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
36 |
+
model_kwargs = {'device': 'cpu'}
|
37 |
+
encode_kwargs = {'normalize_embeddings': True}
|
38 |
+
hf = HuggingFaceEmbeddings(
|
39 |
+
model_name=model_name,
|
40 |
+
model_kwargs=model_kwargs,
|
41 |
+
encode_kwargs=encode_kwargs
|
42 |
+
)
|
43 |
|
44 |
+
vectorstore = Chroma.from_documents(documents=text_chunks, embedding=hf, persist_directory="docs/chroma/")
|
45 |
+
return vectorstore
|
|
|
46 |
|
47 |
def get_pdf_text(pdf_docs):
|
48 |
+
document_loader = DirectoryLoader(pdf_docs)
|
49 |
+
return document_loader.load()
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def get_text_chunks(text):
|
52 |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
53 |
+
separator="\n",
|
54 |
+
chunk_size=1000,
|
55 |
+
chunk_overlap=200,
|
56 |
+
length_function=len
|
57 |
)
|
58 |
chunks = text_splitter.split_text(text)
|
59 |
return chunks
|
60 |
|
61 |
+
def create_conversational_rag_chain(vectorstore):
|
62 |
+
|
63 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
64 |
+
model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
|
65 |
+
|
66 |
+
retriever = vectorstore.as_retriever(search_type='mmr', search_kwargs={"k": 7})
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
|
|
68 |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
69 |
+
|
70 |
llm = llamacpp.LlamaCpp(
|
71 |
+
model_path=os.path.join(model_path),
|
72 |
n_gpu_layers=1,
|
73 |
temperature=0.1,
|
74 |
top_p=0.9,
|
|
|
79 |
verbose=False,
|
80 |
)
|
81 |
|
82 |
+
contextualize_q_system_prompt = """Given a context, chat history and the latest user question
|
83 |
+
which maybe reference context in the chat history, formulate a standalone question
|
|
|
84 |
which can be understood without the chat history. Do NOT answer the question,
|
85 |
just reformulate it if needed and otherwise return it as is."""
|
86 |
|
87 |
+
ha_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_system_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
+
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
|
|
|
|
|
|
|
90 |
|
91 |
qa_prompt = ChatPromptTemplate.from_messages(
|
92 |
[
|
|
|
95 |
("human", "{input}"),
|
96 |
]
|
97 |
)
|
98 |
+
|
99 |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
100 |
|
101 |
+
rag_chain = create_retrieval_chain(ha_retriever, question_answer_chain)
|
102 |
msgs = StreamlitChatMessageHistory(key="special_app_key")
|
103 |
+
|
104 |
conversation_chain = RunnableWithMessageHistory(
|
105 |
rag_chain,
|
106 |
lambda session_id: msgs,
|
|
|
110 |
)
|
111 |
return conversation_chain
|
112 |
|
113 |
+
def main():
|
114 |
+
"""Main function for the Streamlit app."""
|
115 |
+
# Initialize chat history if not already present in session state
|
116 |
+
|
117 |
+
documents = []
|
118 |
+
|
119 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
120 |
+
data_path = os.path.join(script_dir, "data/")
|
121 |
+
if not os.path.exists(data_path):
|
122 |
+
st.error(f"Data path does not exist: {data_path}")
|
123 |
+
return
|
124 |
|
125 |
try:
|
126 |
+
# Load documents from the data path
|
127 |
+
documents = load_txt_documents(data_path)
|
128 |
+
if not documents:
|
129 |
+
st.warning("No documents found in the data path.")
|
130 |
+
else:
|
131 |
+
# Split the documents into chunks
|
132 |
+
docs = split_docs(documents, 350, 40)
|
133 |
+
# Add your logic here to use `docs`
|
134 |
+
st.success("Documents loaded and processed successfully.")
|
135 |
except Exception as e:
|
136 |
+
st.error(f"An error occurred while loading documents: {e}")
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
documents = load_txt_documents(data_path)
|
142 |
+
docs = split_docs(documents, 350, 40)
|
143 |
+
|
144 |
+
vectorstore = get_vectorstore(docs)
|
145 |
+
|
146 |
+
msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
|
147 |
+
chain_with_history = create_conversational_rag_chain(vectorstore)
|
148 |
+
|
149 |
+
st.title("Conversational RAG Chatbot")
|
150 |
+
|
151 |
+
if prompt := st.chat_input():
|
152 |
+
st.chat_message("human").write(prompt)
|
153 |
+
|
154 |
+
# Prepare the input dictionary with the correct keys
|
155 |
+
input_dict = {"input": prompt, "chat_history": msgs.messages}
|
156 |
+
config = {"configurable": {"session_id": "any"}}
|
157 |
+
|
158 |
+
# Process user input and handle response
|
159 |
+
response = chain_with_history.invoke(input_dict, config)
|
160 |
+
st.chat_message("ai").write(response["answer"])
|
161 |
+
|
162 |
+
# Display retrieved documents (if any and present in response)
|
163 |
+
if "docs" in response and response["documents"]:
|
164 |
+
for index, doc in enumerate(response["documents"]):
|
165 |
+
with st.expander(f"Document {index + 1}"):
|
166 |
+
st.write(doc)
|
167 |
+
|
168 |
+
# Update chat history in session state
|
169 |
+
st.session_state["chat_history"] = msgs
|
170 |
|
171 |
+
if __name__ == "__main__":
|
172 |
main()
|