ArturG9's picture
Update app.py
237d231 verified
raw
history blame
6.57 kB
import streamlit as st
import os
import sys
import shutil
from langchain.text_splitter import TokenTextSplitter,RecursiveCharacterTextSplitter,CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import pipeline
import torch
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.document_loaders import TextLoader
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.llms import Aphrodite
from typing import Callable, Dict, List, Optional, Union
from langchain.vectorstores import Chroma
import streamlit as st
from langchain_community.llms import llamacpp
from utills import split_docs, retriever_from_chroma, history_aware_retriever,chroma_db
from langchain_community.chat_message_histories.streamlit import StreamlitChatMessageHistory
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
script_dir = os.path.dirname(os.path.abspath(__file__))
data_path = "./data/"
model_path = os.path.join(script_dir, 'mistral-7b-v0.1-layla-v4-Q4_K_M.gguf.2')
store = {}
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
hf = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs)
documents = []
for filename in os.listdir(data_path):
if filename.endswith('.txt'):
file_path = os.path.join(data_path, filename)
documents = TextLoader(file_path).load()
documents.extend(documents)
docs = split_docs(documents, 450, 20)
chroma_db = chroma_db(docs,hf)
retriever = retriever_from_chroma(chroma_db, "mmr", 6)
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
hf = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
llm = llamacpp.LlamaCpp(
model_path= 'mistral-7b-v0.1-layla-v4-Q4_K_M.gguf',
n_gpu_layers=0,
temperature=0.1,
top_p=0.5,
n_ctx=7000,
max_tokens=350,
repeat_penalty=1.7,
stop=["", "Instruction:", "### Instruction:", "###<user>", "</user>"],
callback_manager=callback_manager,
verbose=False,
)
contextualize_q_system_prompt = """Given a context, chat history and the latest user question
which maybe reference context in the chat history, formulate a standalone question
which can be understood without the chat history. Do NOT answer the question,
just reformulate it if needed and otherwise return it as is."""
ha_retriever = history_aware_retriever(llm, retriever, contextualize_q_system_prompt)
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(ha_retriever, question_answer_chain)
msgs = StreamlitChatMessageHistory(key="special_app_key")
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
lambda session_id: msgs,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
def display_chat_history(chat_history):
"""Displays the chat history in Streamlit."""
for msg in chat_history.messages:
st.chat_message(msg.type).write(msg.content)
def display_documents(docs, on_click=None):
"""Displays retrieved documents with optional click action."""
if docs: # Check if documents exist before displaying
for i, document in enumerate(docs): # Iterate over docs, not documents
st.write(f"**Docs {i+1}**")
st.markdown(document, unsafe_allow_html=True) # Allow HTML formatting
if on_click:
if st.button(f"Expand Article {i+1}"):
on_click(i) # Call the user-defined click function
def main(conversational_rag_chain):
"""Main function for the Streamlit app."""
# Initialize chat history if not already present in session state
msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
chain_with_history = conversational_rag_chain
st.title("Conversational RAG Chatbot")
# Display chat history
display_chat_history(msgs)
if prompt := st.chat_input():
st.chat_message("human").write(prompt)
# Prepare the input dictionary with the correct keys
input_dict = {"input": prompt, "chat_history": msgs.messages}
config = {"configurable": {"session_id": "any"}}
# Process user input and handle response
response = chain_with_history.invoke(input_dict, config)
st.chat_message("ai").write(response["answer"])
# Display retrieved documents (if any and present in response)
if "docs" in response and response["documents"]:
docs = response["documents"]
def expand_document(index):
# Implement your document expansion logic here (e.g., show extra details)
st.write(f"Expanding document {index+1}...")
display_documents(docs, expand_document) # Pass click function
# Update chat history in session state
st.session_state["chat_history"] = msgs
if __name__ == "__main__":
main(conversational_rag_chain)