import base64 from datetime import datetime import os import streamlit as st from dotenv import load_dotenv from PyPDF2 import PdfReader from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import FAISS from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings from langchain.memory import ConversationBufferMemory from langchain.chains import ConversationalRetrievalChain from langchain.llms import GPT4All from streamlit_chat import message from huggingface_hub import hf_hub_download from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler def get_pdf_text(pdfs): text = "" for pdf in pdfs: pdf_reader = PdfReader(pdf) for page in pdf_reader.pages: text += page.extract_text() return text def get_text_chunks(text): text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len) chunks = text_splitter.split_text(text) return chunks def get_vectorstore(text_chunks): # embeddings = OpenAIEmbeddings() embeddings = HuggingFaceEmbeddings( model_name="all-MiniLM-L6-v2") vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings) return vectorstore def get_conversation_chain(vectorstore): callbacks = [StreamingStdOutCallbackHandler()] llm = GPT4All(model="/tmp/ggml-gpt4all-j-v1.3-groovy.bin", max_tokens=1000, backend='gptj', callbacks=callbacks, n_batch=8, verbose=False) # llm = ChatOpenAI() memory = ConversationBufferMemory( memory_key='chat_history', return_messages=True) conversation_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=vectorstore.as_retriever(), memory=memory ) return conversation_chain def user_input(user_question): # log user question with timestamp print(f"[{datetime.now()}]ask:{user_question}") with st.spinner("Thinking ..."): response = st.session_state.conversation({'question': user_question}) # log bot answer with timestamp print(f"\n[{datetime.now()}]ans:{response['answer']}") st.session_state.chat_history = response['chat_history'] for i, messages in enumerate(st.session_state.chat_history): if i % 2 == 0: message(messages.content, is_user=True) else: message(messages.content) def display_pdf(pdf): with open("./docs/"+pdf, "rb") as f: base64_pdf = base64.b64encode(f.read()).decode('utf-8') pdf_display = F'' st.markdown(pdf_display, unsafe_allow_html=True) examples = [ "CO2 Emissions trend in 2022?", "What is the average CO2 emissions in 2022?", "CO2 Emissions of the United States in 2022?", "Compare CO2 Emissions of the United States and China in 2022?", ] def main(): load_dotenv() if "ggml-gpt4all-j-v1.3-groovy.bin" not in os.listdir("/tmp"): hf_hub_download(repo_id="dnato/ggml-gpt4all-j-v1.3-groovy.bin", filename="ggml-gpt4all-j-v1.3-groovy.bin", local_dir="/tmp") st.set_page_config(page_title="CO2 Emission Document Chatbot") if "pdf" not in st.session_state: st.session_state.pdf = None st.header("Choose document") for pdf in os.listdir("./docs"): if st.button(pdf): with st.spinner("Loading document..."): st.session_state.pdf = pdf pdf_text = get_pdf_text(["./docs/"+pdf]) text_chunks = get_text_chunks(pdf_text) vectorstore = get_vectorstore(text_chunks) conversation_chain = get_conversation_chain(vectorstore) st.session_state.conversation = conversation_chain st.session_state.chat_history = None if st.session_state.pdf: if st.success("<-- Chat about the document on the left sidebar"): display_pdf(st.session_state.pdf) with st.sidebar: if "conversation" not in st.session_state: st.session_state.conversation = None if "chat_history" not in st.session_state: st.session_state.chat_history = None st.header("Query your document") user_question = st.text_input( "Ask a question about the document...", disabled=not st.session_state.conversation) if st.session_state.conversation: for example in examples: if st.button(example): user_question = example if user_question and st.session_state.conversation: user_input(user_question) if __name__ == '__main__': main()