File size: 4,839 Bytes
994f18f
7d1995c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
994f18f
7d1995c
 
 
994f18f
7d1995c
 
 
 
 
 
 
 
994f18f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d1995c
 
 
 
 
994f18f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d1995c
994f18f
 
 
 
 
 
 
 
 
 
 
 
 
 
7d1995c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import base64
from datetime import datetime
import os
import streamlit as st
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.llms import GPT4All
from streamlit_chat import message
from huggingface_hub import hf_hub_download

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler


def get_pdf_text(pdfs):
    text = ""
    for pdf in pdfs:
        pdf_reader = PdfReader(pdf)
        for page in pdf_reader.pages:
            text += page.extract_text()
    return text


def get_text_chunks(text):
    text_splitter = CharacterTextSplitter(separator="\n",
                                          chunk_size=1000, chunk_overlap=200, length_function=len)
    chunks = text_splitter.split_text(text)
    return chunks


def get_vectorstore(text_chunks):
    # embeddings = OpenAIEmbeddings()
    embeddings = HuggingFaceEmbeddings(
        model_name="all-MiniLM-L6-v2")
    vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
    return vectorstore


def get_conversation_chain(vectorstore):
    callbacks = [StreamingStdOutCallbackHandler()]
    llm = GPT4All(model="/tmp/ggml-gpt4all-j-v1.3-groovy.bin",
                  max_tokens=1000, backend='gptj', callbacks=callbacks, n_batch=8, verbose=False)
    # llm = ChatOpenAI()
    memory = ConversationBufferMemory(
        memory_key='chat_history', return_messages=True)
    conversation_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=vectorstore.as_retriever(),
        memory=memory

    )
    return conversation_chain


def user_input(user_question):
    # log user question with timestamp
    print(f"[{datetime.now()}]ask:{user_question}")
    with st.spinner("Thinking ..."):
        response = st.session_state.conversation({'question': user_question})
    # log bot answer with timestamp
    print(f"\n[{datetime.now()}]ans:{response['answer']}")
    st.session_state.chat_history = response['chat_history']
    for i, messages in enumerate(st.session_state.chat_history):
        if i % 2 == 0:
            message(messages.content, is_user=True)
        else:
            message(messages.content)


def display_pdf(pdf):
    with open("./docs/"+pdf, "rb") as f:
        base64_pdf = base64.b64encode(f.read()).decode('utf-8')
    pdf_display = F'<embed src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" type="application/pdf">'
    st.markdown(pdf_display, unsafe_allow_html=True)


examples = [
    "CO2 Emissions trend in 2022?",
    "What is the average CO2 emissions in 2022?",
    "CO2 Emissions of the United States in 2022?",
    "Compare CO2 Emissions of the United States and China in 2022?",
]


def main():
    load_dotenv()
    if "ggml-gpt4all-j-v1.3-groovy.bin" not in os.listdir("/tmp"):
        hf_hub_download(repo_id="dnato/ggml-gpt4all-j-v1.3-groovy.bin",
                        filename="ggml-gpt4all-j-v1.3-groovy.bin", local_dir="/tmp")
    st.set_page_config(page_title="CO2 Emission Document Chatbot")

    if "pdf" not in st.session_state:
        st.session_state.pdf = None
        st.header("Choose document")
    for pdf in os.listdir("./docs"):
        if st.button(pdf):
            with st.spinner("Loading document..."):
                st.session_state.pdf = pdf
                pdf_text = get_pdf_text(["./docs/"+pdf])
                text_chunks = get_text_chunks(pdf_text)
                vectorstore = get_vectorstore(text_chunks)
                conversation_chain = get_conversation_chain(vectorstore)
                st.session_state.conversation = conversation_chain
                st.session_state.chat_history = None

    if st.session_state.pdf:
        if st.success("<-- Chat about the document on the left sidebar"):
            display_pdf(st.session_state.pdf)

    with st.sidebar:
        if "conversation" not in st.session_state:
            st.session_state.conversation = None
        if "chat_history" not in st.session_state:
            st.session_state.chat_history = None
        st.header("Query your document")
        user_question = st.text_input(
            "Ask a question about the document...", disabled=not st.session_state.conversation)

        if st.session_state.conversation:
            for example in examples:
                if st.button(example):
                    user_question = example
        if user_question and st.session_state.conversation:
            user_input(user_question)


if __name__ == '__main__':
    main()