surprisedPikachu007 commited on
Commit
f13531e
1 Parent(s): 5c30eba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain_community.document_loaders import PyPDFLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain.chains import ConversationalRetrievalChain
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.llms import HuggingFacePipeline, HuggingFaceEndpoint
8
+ from langchain.memory import ConversationBufferMemory
9
+ from pathlib import Path
10
+ import chromadb
11
+ import re
12
+
13
+ def load_doc(list_file_path, chunk_size=600, chunk_overlap=40):
14
+ loaders = [PyPDFLoader(x) for x in list_file_path]
15
+ pages = [page for loader in loaders for page in loader.load()]
16
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
17
+ doc_splits = text_splitter.split_documents(pages)
18
+ return doc_splits
19
+
20
+ def create_db(splits, collection_name):
21
+ embedding = HuggingFaceEmbeddings()
22
+ client = chromadb.EphemeralClient()
23
+ vectordb = Chroma.from_documents(
24
+ documents=splits,
25
+ embedding=embedding,
26
+ client=client,
27
+ collection_name=collection_name,
28
+ )
29
+ return vectordb
30
+
31
+ def initialize_llmchain(llm_model, vector_db, progress=gr.Progress()):
32
+ llm = HuggingFaceEndpoint(
33
+ repo_id=llm_model,
34
+ temperature=0.7,
35
+ max_new_tokens=1024,
36
+ top_k=3,
37
+ )
38
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
39
+ retriever = vector_db.as_retriever()
40
+ qa_chain = ConversationalRetrievalChain.from_llm(
41
+ llm,
42
+ retriever=retriever,
43
+ chain_type="stuff",
44
+ memory=memory,
45
+ return_source_documents=True,
46
+ verbose=False,
47
+ )
48
+ return qa_chain
49
+
50
+ def create_collection_name(filepath):
51
+ collection_name = Path(filepath).stem
52
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)[:50]
53
+ if len(collection_name) < 3:
54
+ collection_name += 'xyz'
55
+ if not collection_name[0].isalnum():
56
+ collection_name = 'A' + collection_name[1:]
57
+ if not collection_name[-1].isalnum():
58
+ collection_name = collection_name[:-1] + 'Z'
59
+ return collection_name
60
+
61
+ def initialize_database(list_file_obj, progress=gr.Progress()):
62
+ list_file_path = [x.name for x in list_file_obj if x is not None]
63
+ collection_name = create_collection_name(list_file_path[0])
64
+ doc_splits = load_doc(list_file_path)
65
+ vector_db = create_db(doc_splits, collection_name)
66
+ return vector_db, collection_name, "Complete!"
67
+
68
+ def initialize_LLM(llm_model, vector_db, progress=gr.Progress()):
69
+ qa_chain = initialize_llmchain(llm_model, vector_db, progress)
70
+ return qa_chain, "Complete!"
71
+
72
+ def conversation(qa_chain, message, history):
73
+ formatted_chat_history = [(f"User: {user_message}", f"Assistant: {bot_message}") for user_message, bot_message in history]
74
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
75
+ response_answer = response["answer"]
76
+ if "Helpful Answer:" in response_answer:
77
+ response_answer = response_answer.split("Helpful Answer:")[-1]
78
+ response_sources = response["source_documents"]
79
+ response_source1 = response_sources[0].page_content.strip()
80
+ response_source2 = response_sources[1].page_content.strip()
81
+ response_source3 = response_sources[2].page_content.strip()
82
+ response_source1_page = response_sources[0].metadata["page"] + 1
83
+ response_source2_page = response_sources[1].metadata["page"] + 1
84
+ response_source3_page = response_sources[2].metadata["page"] + 1
85
+ new_history = history + [(message, response_answer)]
86
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
87
+
88
+ def demo():
89
+ with gr.Blocks(theme="base") as demo:
90
+ vector_db = gr.State()
91
+ qa_chain = gr.State()
92
+ collection_name = gr.State()
93
+
94
+ gr.Markdown(
95
+ """<center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
96
+ <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
97
+ <b>Note:</b> This AI assistant performs retrieval-augmented generation from your PDF documents.
98
+ When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.
99
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br>
100
+ """)
101
+
102
+ with gr.Tab("Step 1 - Document pre-processing"):
103
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
104
+ with gr.Row():
105
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
106
+ with gr.Row():
107
+ db_btn = gr.Button("Generate vector database...")
108
+
109
+ with gr.Tab("Step 2 - QA chain initialization"):
110
+ llm_btn = gr.Radio(["mistralai/Mistral-7B-Instruct-v0.2"], label="LLM models", value="mistralai/Mistral-7B-Instruct-v0.2", type="index", info="Choose your LLM model")
111
+ with gr.Row():
112
+ llm_progress = gr.Textbox(value="None", label="QA chain initialization")
113
+ with gr.Row():
114
+ qachain_btn = gr.Button("Initialize question-answering chain...")
115
+
116
+ with gr.Tab("Step 3 - Conversation with chatbot"):
117
+ chatbot = gr.Chatbot(height=300)
118
+ with gr.Row():
119
+ msg = gr.Textbox(placeholder="Type message", container=True)
120
+ with gr.Row():
121
+ submit_btn = gr.Button("Submit")
122
+ clear_btn = gr.ClearButton([msg, chatbot])
123
+
124
+ db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, collection_name, db_progress])
125
+ qachain_btn.click(initialize_LLM, inputs=[llm_btn, vector_db], outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot], queue=False)
126
+
127
+ msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot], queue=False)
128
+ submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot], queue=False)
129
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot], queue=False)
130
+
131
+ demo.queue().launch(debug=True)
132
+
133
+ if __name__ == "__main__":
134
+ demo()