Ubai commited on
Commit
af00b58
1 Parent(s): 7da7eb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -116
app.py CHANGED
@@ -1,136 +1,93 @@
1
  import gradio as gr
2
  import os
3
-
4
  from langchain.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
- from langchain.llms import HuggingFaceHub
10
-
 
11
  from pathlib import Path
12
  import chromadb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # List of available LLM models
15
- list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
16
- "google/gemma-7b-it", "google/gemma-2b-it",
17
- "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2",
18
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct",
19
- "google/flan-t5-xxl"
20
- ]
21
- list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
-
23
- # Load PDF document and create doc splits
24
- def load_doc(list_file_path, chunk_size, chunk_overlap):
25
- loaders = [PyPDFLoader(x) for x in list_file_path]
26
- pages = []
27
- for loader in loaders:
28
- pages.extend(loader.load())
29
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
30
- doc_splits = text_splitter.split_documents(pages)
31
- return doc_splits
32
-
33
- # Create vector database
34
- def create_db(splits, collection_name):
35
- embedding = HuggingFaceEmbeddings()
36
- new_client = chromadb.EphemeralClient()
37
- vectordb = Chroma.from_documents(
38
- documents=splits,
39
- embedding=embedding,
40
- client=new_client,
41
- collection_name=collection_name
42
- )
43
- return vectordb
44
-
45
- # Initialize langchain LLM chain
46
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
47
- if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
48
- model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
49
- elif llm_model == "microsoft/phi-2":
50
- raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
51
- elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
52
- model_kwargs = {"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
53
- else:
54
- model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
55
-
56
- llm = HuggingFaceHub(
57
- repo_id=llm_model,
58
- model_kwargs=model_kwargs
59
- )
60
-
61
- memory = ConversationBufferMemory(
62
- memory_key="chat_history",
63
- output_key='answer',
64
- return_messages=True
65
- )
66
-
67
- retriever = vector_db.as_retriever()
68
-
69
- qa_chain = ConversationalRetrievalChain.from_llm(
70
- llm,
71
- retriever=retriever,
72
- chain_type="stuff",
73
- memory=memory,
74
- return_source_documents=True,
75
- verbose=False
76
- )
77
-
78
- progress(0.9, desc="Done!")
79
- return qa_chain
80
-
81
- def initialize_demo(list_file_obj, chunk_size, chunk_overlap, db_progress):
82
- list_file_path = [file.name for file in list_file_obj if file is not None]
83
- collection_name = Path(list_file_path[0]).stem.replace(" ", "-")[:50]
84
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
85
- vector_db = create_db(doc_splits, collection_name)
86
- qa_chain = initialize_llmchain(
87
- list_llm[0], # Using Mistral-7B-Instruct-v0.2 as the LLM model
88
- 0.7, # Temperature
89
- 1024, # Max Tokens
90
- 3, # Top K
91
- vector_db,
92
- db_progress
93
- )
94
- return vector_db, collection_name, qa_chain, "Complete!"
95
-
96
- def upload_file(file_obj):
97
- list_file_path = []
98
- for file in file_obj:
99
- if file is not None:
100
- file_path = file.name
101
- list_file_path.append(file_path)
102
- return list_file_path
103
 
104
  def demo():
105
  with gr.Blocks(theme="base") as demo:
106
- vector_db = gr.State()
107
  collection_name = gr.State()
108
- qa_chain = gr.State()
109
-
110
- with gr.Tab("Step 1 - Document pre-processing"):
111
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
112
- slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
113
- slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
114
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
115
- db_btn = gr.Button("Generate vector database...")
116
-
117
- with gr.Tab("Step 2 - QA chain initialization"):
118
- llm_progress = gr.Textbox(value="None", label="QA chain initialization")
119
- qachain_btn = gr.Button("Initialize question-answering chain...")
120
-
121
- with gr.Tab("Step 3 - Conversation with chatbot"):
 
 
 
 
 
 
 
122
  chatbot = gr.Chatbot(height=300)
123
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
124
- source1_page = gr.Number(label="Page", scale=1)
125
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
126
- source2_page = gr.Number(label="Page", scale=1)
127
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
128
- source3_page = gr.Number(label="Page", scale=1)
 
 
 
 
 
 
 
129
  msg = gr.Textbox(placeholder="Type message", container=True)
 
 
130
  submit_btn = gr.Button("Submit")
131
  clear_btn = gr.ClearButton([msg, chatbot])
132
 
133
- document.upload(initialize_demo, inputs=[document, slider_chunk_size, slider_chunk_overlap, db_progress], outputs=[vector_db, collection_name, qa_chain, db_progress])
134
- qachain_btn.click(initialize_llmchain, inputs=[qa_chain, llm_progress], outputs=[qa_chain, llm_progress])
135
- submit_btn.click(lambda: None, inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
 
 
 
 
 
 
136
 
 
 
 
1
  import gradio as gr
2
  import os
 
3
  from langchain.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.vectorstores import Chroma
6
  from langchain.chains import ConversationalRetrievalChain
7
  from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.llms import HuggingFacePipeline, HuggingFaceHub
9
+ from langchain.chains import ConversationChain
10
+ from langchain.memory import ConversationBufferMemory
11
  from pathlib import Path
12
  import chromadb
13
+ from transformers import AutoTokenizer
14
+ import transformers
15
+ import torch
16
+ import tqdm
17
+ import accelerate
18
+
19
+ # Default LLM model
20
+ chosen_llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
21
+
22
+ # Default chunk size and overlap
23
+ chunk_size = 600
24
+ chunk_overlap = 40
25
+
26
+ # Default model configuration
27
+ llm_temperature = 0.7
28
+ max_tokens = 1024
29
+ top_k = 3
30
+
31
+ # Initialize vector database in background
32
+ accelerated(initialize_database)() # Run in background with Accelerate
33
 
34
+ # Define functions (no changes needed here)
35
+ # ... (your existing functions here)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def demo():
38
  with gr.Blocks(theme="base") as demo:
39
+ qa_chain = gr.State() # Store the initialized QA chain
40
  collection_name = gr.State()
41
+
42
+ gr.Markdown(
43
+ """
44
+ <center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
45
+ <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
46
+ <b>Note:</b> This AI assistant performs retrieval-augmented generation from your PDF documents. \
47
+ When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
48
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br>
49
+ """
50
+ )
51
+
52
+ with gr.Row():
53
+ document = gr.Files(
54
+ height=100,
55
+ file_count="multiple",
56
+ file_types=["pdf"],
57
+ interactive=True,
58
+ label="Upload your PDF documents (single or multiple)",
59
+ )
60
+
61
+ with gr.Row():
62
  chatbot = gr.Chatbot(height=300)
63
+
64
+ with gr.Accordion("Advanced - Document references", open=False):
65
+ with gr.Row():
66
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
67
+ source1_page = gr.Number(label="Page", scale=1)
68
+ with gr.Row():
69
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
70
+ source2_page = gr.Number(label="Page", scale=1)
71
+ with gr.Row():
72
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
73
+ source3_page = gr.Number(label="Page", scale=1)
74
+
75
+ with gr.Row():
76
  msg = gr.Textbox(placeholder="Type message", container=True)
77
+
78
+ with gr.Row():
79
  submit_btn = gr.Button("Submit")
80
  clear_btn = gr.ClearButton([msg, chatbot])
81
 
82
+ # Initialize default QA chain when documents are uploaded
83
+ document.uploaded(initialize_LLM, inputs=[chosen_llm_model])
84
+
85
+ # Chatbot events
86
+ msg.submit(conversation, inputs=[qa_chain, msg, chatbot])
87
+ submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot])
88
+ clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
89
+
90
+ demo.launch(debug=True)
91
 
92
+ if __name__ == "__main__":
93
+ demo()