co2 / app.py
ngmitam's picture
Update
994f18f
raw
history blame contribute delete
No virus
4.84 kB
import base64
from datetime import datetime
import os
import streamlit as st
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.llms import GPT4All
from streamlit_chat import message
from huggingface_hub import hf_hub_download
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
def get_pdf_text(pdfs):
text = ""
for pdf in pdfs:
pdf_reader = PdfReader(pdf)
for page in pdf_reader.pages:
text += page.extract_text()
return text
def get_text_chunks(text):
text_splitter = CharacterTextSplitter(separator="\n",
chunk_size=1000, chunk_overlap=200, length_function=len)
chunks = text_splitter.split_text(text)
return chunks
def get_vectorstore(text_chunks):
# embeddings = OpenAIEmbeddings()
embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2")
vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
return vectorstore
def get_conversation_chain(vectorstore):
callbacks = [StreamingStdOutCallbackHandler()]
llm = GPT4All(model="/tmp/ggml-gpt4all-j-v1.3-groovy.bin",
max_tokens=1000, backend='gptj', callbacks=callbacks, n_batch=8, verbose=False)
# llm = ChatOpenAI()
memory = ConversationBufferMemory(
memory_key='chat_history', return_messages=True)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=memory
)
return conversation_chain
def user_input(user_question):
# log user question with timestamp
print(f"[{datetime.now()}]ask:{user_question}")
with st.spinner("Thinking ..."):
response = st.session_state.conversation({'question': user_question})
# log bot answer with timestamp
print(f"\n[{datetime.now()}]ans:{response['answer']}")
st.session_state.chat_history = response['chat_history']
for i, messages in enumerate(st.session_state.chat_history):
if i % 2 == 0:
message(messages.content, is_user=True)
else:
message(messages.content)
def display_pdf(pdf):
with open("./docs/"+pdf, "rb") as f:
base64_pdf = base64.b64encode(f.read()).decode('utf-8')
pdf_display = F'<embed src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" type="application/pdf">'
st.markdown(pdf_display, unsafe_allow_html=True)
examples = [
"CO2 Emissions trend in 2022?",
"What is the average CO2 emissions in 2022?",
"CO2 Emissions of the United States in 2022?",
"Compare CO2 Emissions of the United States and China in 2022?",
]
def main():
load_dotenv()
if "ggml-gpt4all-j-v1.3-groovy.bin" not in os.listdir("/tmp"):
hf_hub_download(repo_id="dnato/ggml-gpt4all-j-v1.3-groovy.bin",
filename="ggml-gpt4all-j-v1.3-groovy.bin", local_dir="/tmp")
st.set_page_config(page_title="CO2 Emission Document Chatbot")
if "pdf" not in st.session_state:
st.session_state.pdf = None
st.header("Choose document")
for pdf in os.listdir("./docs"):
if st.button(pdf):
with st.spinner("Loading document..."):
st.session_state.pdf = pdf
pdf_text = get_pdf_text(["./docs/"+pdf])
text_chunks = get_text_chunks(pdf_text)
vectorstore = get_vectorstore(text_chunks)
conversation_chain = get_conversation_chain(vectorstore)
st.session_state.conversation = conversation_chain
st.session_state.chat_history = None
if st.session_state.pdf:
if st.success("<-- Chat about the document on the left sidebar"):
display_pdf(st.session_state.pdf)
with st.sidebar:
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
st.header("Query your document")
user_question = st.text_input(
"Ask a question about the document...", disabled=not st.session_state.conversation)
if st.session_state.conversation:
for example in examples:
if st.button(example):
user_question = example
if user_question and st.session_state.conversation:
user_input(user_question)
if __name__ == '__main__':
main()