pflooky commited on
Commit
8324134
1 Parent(s): 0760431

Use gradio for document answering

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. app.py +130 -3
  3. llm_model.py +96 -0
  4. requirements.txt +12 -0
  5. streamlit_app.py +158 -0
  6. vector_db.py +46 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
app.py CHANGED
@@ -1,4 +1,131 @@
1
- import streamlit as st
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain.docstore.document import Document
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
4
 
5
+ import vector_db as vdb
6
+ from llm_model import LLMModel
7
+
8
+ chunk_size = 2000
9
+ chunk_overlap = 200
10
+ uploaded_docs = []
11
+ uploaded_df = gr.Dataframe(headers=["file_name", "content_length"])
12
+ upload_files_section = gr.Files(
13
+ file_types=[".md", ".mdx", ".rst", ".txt"],
14
+ )
15
+ chatbot_stream = gr.Chatbot(bubble_full_width=False, show_copy_button=True)
16
+
17
+
18
+ def load_docs(files):
19
+ all_docs = []
20
+ all_qa = []
21
+ for file in files:
22
+ if file.name is not None:
23
+ with open(file.name, "r") as f:
24
+ file_content = f.read()
25
+ file_name = file.name.split("/")[-1]
26
+ # Create document with metadata
27
+ doc = Document(page_content=file_content, metadata={"source": file_name})
28
+ # Create an instance of the RecursiveCharacterTextSplitter class with specific parameters.
29
+ # It splits text into chunks of 1000 characters each with a 150-character overlap.
30
+ language = get_language(file_name)
31
+ text_splitter = RecursiveCharacterTextSplitter.from_language(
32
+ chunk_size=chunk_size,
33
+ chunk_overlap=chunk_overlap,
34
+ language=language
35
+ )
36
+ # Split the text into chunks using the text splitter.
37
+ doc_chunks = text_splitter.split_documents([doc])
38
+ print(f"Number of chunks: {len(doc_chunks)}")
39
+ # Foreach chunk, send to LLM to get potential questions and answers
40
+ for doc_chunk in doc_chunks:
41
+ gr.Info("Analysing document...")
42
+ potential_qa_from_doc = llm_model.get_potential_question_answer(doc_chunk.page_content)
43
+ all_qa += [Document(page_content=potential_qa_from_doc, metadata=doc_chunk.metadata)]
44
+ all_docs += doc_chunks
45
+ uploaded_docs.append(file.name)
46
+ vector_db.load_docs_into_vector_db(all_qa)
47
+ gr.Info("Loaded document(s) into vector db.")
48
+
49
+ return uploaded_docs
50
+
51
+
52
+ def get_language(file_name: str):
53
+ if file_name.endswith(".md") or file_name.endswith(".mdx"):
54
+ return Language.MARKDOWN
55
+ elif file_name.endswith(".rst"):
56
+ return Language.RST
57
+ else:
58
+ return Language.MARKDOWN
59
+
60
+
61
+ def get_vector_db():
62
+ return vdb.VectorDB()
63
+
64
+
65
+ def get_llm_model(_db: vdb.VectorDB):
66
+ retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2})
67
+ # return LLMModel(retriever=retriever).create_qa_chain()
68
+ return LLMModel(retriever=retriever)
69
+
70
+
71
+ def predict(message, history):
72
+ # resp = llm_model.answer_question_inference(message)
73
+ # return resp.get("answer")
74
+ resp = llm_model.answer_question_inference_text_gen(message)
75
+ final_resp = ""
76
+ for c in resp:
77
+ final_resp += str(c)
78
+ yield final_resp
79
+ # start_time = time.time()
80
+ # res = llm_model({"query": message})
81
+ # sources = []
82
+ # for source_docs in res['source_documents']:
83
+ # if 'source' in source_docs.metadata:
84
+ # sources.append(source_docs.metadata['source'])
85
+ # # Display assistant response in chat message container
86
+ # end_time = time.time()
87
+ # time_taken = "{:.2f}".format(end_time - start_time)
88
+ # format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s"
89
+ # format_source = None
90
+ # for source_docs in res['source_documents']:
91
+ # if 'source' in source_docs.metadata:
92
+ # format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}"
93
+ #
94
+ # return format_answer
95
+
96
+
97
+ def vote(data: gr.LikeData):
98
+ if data.liked:
99
+ gr.Info("You upvoted this response 😊", )
100
+ else:
101
+ gr.Warning("You downvoted this response 👀")
102
+
103
+
104
+ vector_db = get_vector_db()
105
+ llm_model = get_llm_model(vector_db)
106
+
107
+ chat_interface_stream = gr.ChatInterface(
108
+ predict,
109
+ title="👀 Document answering bot",
110
+ description="📚🔦 Upload some documents on the side and ask questions!",
111
+ textbox=gr.Textbox(container=False, scale=7),
112
+ chatbot=chatbot_stream,
113
+ examples=["What is Data Caterer?", "Provide a set of potential questions and answers about the README"]
114
+ )
115
+
116
+ with gr.Blocks() as blocks:
117
+ with gr.Row():
118
+ with gr.Column(scale=1, min_width=100) as upload_col:
119
+ gr.Interface(
120
+ load_docs,
121
+ title="📖 Upload documents",
122
+ inputs=upload_files_section,
123
+ outputs=gr.Files(),
124
+ allow_flagging="never"
125
+ )
126
+ # upload_files_section.upload(load_docs, inputs=upload_files_section)
127
+ with gr.Column(scale=4, min_width=600) as chat_col:
128
+ chatbot_stream.like(vote, None, None)
129
+ chat_interface_stream.render()
130
+
131
+ blocks.queue().launch()
llm_model.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import requests
4
+ from huggingface_hub import InferenceClient
5
+ from langchain.chains import RetrievalQA
6
+ from langchain.prompts import PromptTemplate
7
+ from langchain_community.llms import CTransformers
8
+ from langchain_core.vectorstores import VectorStoreRetriever
9
+
10
+
11
+ class LLMModel:
12
+ base_model = "TheBloke/Llama-2-7B-GGUF"
13
+ specific_model = "llama-2-7b.Q4_K_M.gguf"
14
+ token_model = "meta-llama/Llama-2-7b-hf"
15
+ llm_config = {'context_length': 2048, 'max_new_tokens': 1024, 'temperature': 0.3, 'top_p': 1.0}
16
+
17
+ question_answer_system_prompt = """You are a helpful question answer assistant. Given the following context and a question, provide a set of potential questions and answers.
18
+ Keep answers brief and well-structured. Do not give one word answers."""
19
+ final_assistant_system_prompt = """You are a helpful assistant. Given the following list of relevant questions and answers, generate an answer based on this list only.
20
+ Keep answers brief and well-structured. Do not give one word answers.
21
+ If the answer is not found in the list, kindly state "I don't know.". Don't try to make up an answer."""
22
+ template = """<s>[INST] <<SYS>>
23
+ You are a question answer assistant. Given the following context and a question, generate an answer based on this context only.
24
+ Keep answers brief and well-structured. Do not give one word answers.
25
+ If the answer is not found in the context, kindly state "I don't know.". Don't try to make up an answer.
26
+ <</SYS>>
27
+
28
+ Context: {context}
29
+
30
+ Question: Give me a step by step explanation of {question}[/INST]
31
+ Answer:"""
32
+ qa_chain_prompt = PromptTemplate.from_template(template)
33
+ retriever = None
34
+
35
+ hf_token = os.getenv('HF_TOKEN')
36
+ api_url = os.getenv('API_URL')
37
+ headers = {"Authorization": f"Bearer {hf_token}"}
38
+ client = InferenceClient(api_url)
39
+
40
+ # llm = CTransformers(model=base_model, model_file=specific_model, config=llm_config, hf=True)
41
+ llm = None
42
+
43
+ def __init__(self, retriever: VectorStoreRetriever):
44
+ self.retriever = retriever
45
+
46
+ def create_qa_chain(self):
47
+ return RetrievalQA.from_chain_type(
48
+ llm=self.llm,
49
+ chain_type="stuff",
50
+ retriever=self.retriever,
51
+ return_source_documents=True,
52
+ chain_type_kwargs={"prompt": self.qa_chain_prompt},
53
+ )
54
+
55
+ def format_retrieved_docs(self, docs):
56
+ all_docs = []
57
+ for doc in docs:
58
+ if "source" in doc.metadata:
59
+ all_docs.append(f"""Document: {doc.metadata['source']}\nContent: {doc.page_content}\n\n""")
60
+ return all_docs
61
+
62
+ def format_query(self, question, context, system_prompt):
63
+ prompt = f"""[INST] {system_prompt}
64
+
65
+ Context: {context}
66
+
67
+ Question: Give me a step by step explanation of {question}[/INST]"""
68
+ return prompt
69
+
70
+ def format_question(self, question):
71
+ relevant_docs = self.retriever.get_relevant_documents(question)
72
+ formatted_docs = self.format_retrieved_docs(relevant_docs)
73
+ return self.format_query(question, formatted_docs, self.final_assistant_system_prompt)
74
+
75
+ def get_potential_question_answer(self, document_chunk: str):
76
+ prompt = self.format_query("potential questions and answers.", document_chunk, self.question_answer_system_prompt)
77
+ return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4)
78
+
79
+ def answer_question_inference_text_gen(self, question):
80
+ prompt = self.format_question(question)
81
+ return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4)
82
+
83
+ def answer_question_inference(self, question):
84
+ relevant_docs = self.retriever.get_relevant_documents(question)
85
+ formatted_docs = "".join(self.format_retrieved_docs(relevant_docs))
86
+ if not formatted_docs:
87
+ return "No uploaded documents. Please try upload a document on the left side."
88
+ else:
89
+ print(formatted_docs)
90
+ return self.client.question_answering(question=question, context=formatted_docs)
91
+
92
+ def answer_question_api(self, question):
93
+ formatted_prompt = self.format_question(question)
94
+ resp = requests.post(self.api_url, headers=self.headers, json={"inputs": formatted_prompt}, stream=True)
95
+ for c in resp.iter_content():
96
+ yield c
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tiktoken
2
+ faiss-cpu
3
+ ctransformers
4
+ transformers
5
+ sentence-transformers
6
+ streamlit
7
+ streamlit_lottie
8
+ gradio
9
+ huggingface_hub
10
+ langchain
11
+ langchain_experimental
12
+ llama-cpp-python
streamlit_app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import StringIO
2
+
3
+ import streamlit as st
4
+ from langchain.docstore.document import Document
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
6
+ import time
7
+
8
+ import vector_db as vdb
9
+ from llm_model import LLMModel
10
+
11
+
12
+ def default_state():
13
+ if "startup" not in st.session_state:
14
+ st.session_state.startup = True
15
+
16
+ if "messages" not in st.session_state:
17
+ st.session_state.messages = []
18
+
19
+ if "uploaded_docs" not in st.session_state:
20
+ st.session_state.uploaded_docs = []
21
+
22
+ if "llm_option" not in st.session_state:
23
+ st.session_state.llm_option = "Local"
24
+
25
+ if "answer_loading" not in st.session_state:
26
+ st.session_state.answer_loading = False
27
+
28
+
29
+ def load_doc(file_name: str, file_content: str):
30
+ if file_name is not None:
31
+ # Create document with metadata
32
+ doc = Document(page_content=file_content, metadata={"source": file_name})
33
+ # Create an instance of the RecursiveCharacterTextSplitter class with specific parameters.
34
+ # It splits text into chunks of 1000 characters each with a 150-character overlap.
35
+ language = get_language(file_name)
36
+ text_splitter = RecursiveCharacterTextSplitter.from_language(chunk_size=1000, chunk_overlap=150,
37
+ language=language)
38
+ # Split the text into chunks using the text splitter.
39
+ docs = text_splitter.split_documents([doc])
40
+ return docs
41
+ else:
42
+ return None
43
+
44
+
45
+ def get_language(file_name: str):
46
+ if file_name.endswith(".md") or file_name.endswith(".mdx"):
47
+ return Language.MARKDOWN
48
+ elif file_name.endswith(".rst"):
49
+ return Language.RST
50
+ else:
51
+ return Language.MARKDOWN
52
+
53
+
54
+ @st.cache_resource()
55
+ def get_vector_db():
56
+ return vdb.VectorDB()
57
+
58
+
59
+ @st.cache_resource()
60
+ def get_llm_model(_db: vdb.VectorDB):
61
+ retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2})
62
+ return LLMModel(retriever=retriever).create_qa_chain()
63
+
64
+
65
+ # Initialize an instance of the RetrievalQA class with the specified parameters
66
+ def init_sidebar():
67
+ with st.sidebar:
68
+ st.toggle(
69
+ "Loading from LLM",
70
+ on_change=enable_sidebar(),
71
+ disabled=not st.session_state.answer_loading
72
+ )
73
+ llm_option = st.selectbox(
74
+ 'Select to use local model or inference API',
75
+ options=['Local', 'Inference API']
76
+ )
77
+ st.session_state.llm_option = llm_option
78
+ uploaded_files = st.file_uploader(
79
+ 'Upload file(s)',
80
+ type=['md', 'mdx', 'rst', 'txt'],
81
+ accept_multiple_files=True
82
+ )
83
+ for uploaded_file in uploaded_files:
84
+ if uploaded_file.name not in st.session_state.uploaded_docs:
85
+ # Read the file as a string
86
+ stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
87
+ string_data = stringio.read()
88
+ # Get chunks of text
89
+ doc_chunks = load_doc(uploaded_file.name, string_data)
90
+ st.write(f"Number of chunks={len(doc_chunks)}")
91
+ vector_db.load_docs_into_vector_db(doc_chunks)
92
+ st.session_state.uploaded_docs.append(uploaded_file.name)
93
+
94
+
95
+ def init_chat():
96
+ # Display chat messages from history on app rerun
97
+ for message in st.session_state.messages:
98
+ with st.chat_message(message["role"]):
99
+ st.markdown(message["content"])
100
+
101
+
102
+ def disable_sidebar():
103
+ st.session_state.answer_loading = True
104
+ st.rerun()
105
+
106
+
107
+ def enable_sidebar():
108
+ st.session_state.answer_loading = False
109
+
110
+
111
+ st.set_page_config(page_title="Document Answering Tool", page_icon=":book:")
112
+ vector_db = get_vector_db()
113
+ default_state()
114
+ init_sidebar()
115
+ st.header("Document answering tool")
116
+ st.subheader("Upload your documents on the side and ask questions")
117
+ init_chat()
118
+ llm_model = get_llm_model(vector_db)
119
+ st.session_state.startup = False
120
+
121
+
122
+ # React to user input
123
+ if user_prompt := st.chat_input("What's up?", on_submit=disable_sidebar()):
124
+ # if st.session_state.answer_loading:
125
+ # st.warning("Cannot ask multiple questions at the same time")
126
+ # st.session_state.answer_loading = False
127
+ # else:
128
+ start_time = time.time()
129
+ # Display user message in chat message container
130
+ with st.chat_message("user"):
131
+ st.markdown(user_prompt)
132
+ # Add user message to chat history
133
+ st.session_state.messages.append({"role": "user", "content": user_prompt})
134
+
135
+ if llm_model is not None:
136
+ assistant_chat = st.chat_message("assistant")
137
+ if not st.session_state.uploaded_docs:
138
+ assistant_chat.warning("WARN: Will try answer question without documents")
139
+ with st.spinner('Resolving question...'):
140
+ res = llm_model({"query": user_prompt})
141
+ sources = []
142
+ for source_docs in res['source_documents']:
143
+ if 'source' in source_docs.metadata:
144
+ sources.append(source_docs.metadata['source'])
145
+ # Display assistant response in chat message container
146
+ end_time = time.time()
147
+ time_taken = "{:.2f}".format(end_time - start_time)
148
+ format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s"
149
+ assistant_chat.markdown(format_answer)
150
+ source_expander = assistant_chat.expander("See full sources")
151
+ for source_docs in res['source_documents']:
152
+ if 'source' in source_docs.metadata:
153
+ format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}"
154
+ source_expander.markdown(format_source)
155
+ # Add assistant response to chat history
156
+ st.session_state.messages.append({"role": "assistant", "content": format_answer})
157
+ enable_sidebar()
158
+ st.rerun()
vector_db.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.schema import Document
2
+ from langchain_community.embeddings import HuggingFaceEmbeddings
3
+ from langchain_community.vectorstores.faiss import FAISS
4
+
5
+
6
+ class VectorDB:
7
+ embedding_model = "sentence-transformers/all-MiniLM-l6-v2"
8
+ model_kwargs = {'device': 'cpu'}
9
+ encode_kwargs = {'normalize_embeddings': False}
10
+ local_folder = "db/faiss_db"
11
+ is_load_local = False
12
+ text_embeddings = None
13
+ docs_db = None
14
+
15
+ def __init__(self):
16
+ self.text_embeddings = self.init_text_embeddings(self.embedding_model, self.model_kwargs, self.encode_kwargs)
17
+ self.docs_db = self.init_vector_db(self.local_folder, self.text_embeddings)
18
+
19
+ def init_text_embeddings(self, embedding_model: str, model_kwargs: dict, encode_kwargs: dict):
20
+ return HuggingFaceEmbeddings(
21
+ model_name=embedding_model,
22
+ model_kwargs=model_kwargs,
23
+ encode_kwargs=encode_kwargs
24
+ )
25
+
26
+ def init_vector_db(self, folder_path: str, text_embeddings: HuggingFaceEmbeddings):
27
+ if self.is_load_local:
28
+ try:
29
+ return FAISS.load_local(folder_path=folder_path, embeddings=text_embeddings)
30
+ except Exception as e:
31
+ return FAISS.from_documents([Document(page_content="")], embedding=text_embeddings)
32
+ else:
33
+ return FAISS.from_documents([Document(page_content="")], embedding=text_embeddings)
34
+
35
+ def load_docs_into_vector_db(self, doc_chunks: list):
36
+ if len(doc_chunks) != 0:
37
+ if self.docs_db is None:
38
+ self.docs_db = FAISS.from_documents(doc_chunks, embedding=self.text_embeddings)
39
+ else:
40
+ self.docs_db.add_documents(doc_chunks)
41
+
42
+ def save_vector_db(self):
43
+ if self.docs_db is not None and not self.is_load_local:
44
+ self.docs_db.save_local(self.local_folder)
45
+ else:
46
+ print("No vector db to save.")