Ubai commited on
Commit
339ce69
1 Parent(s): 6b0097a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -82
app.py CHANGED
@@ -1,95 +1,82 @@
1
  import gradio as gr
2
  import os
3
- from langchain_community.document_loaders import PyPDFLoader # Corrected import
 
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain_community.vectorstores import Chroma # Corrected import
6
- from langchain.chains import ConversationalRetrievalChain # Note: Not from "langchain_community"
7
- from langchain_community.embeddings import HuggingFaceEmbeddings # Corrected import
8
- from langchain_community.llms import HuggingFacePipeline, HuggingFaceHub # Corrected import
9
- from langchain.chains import ConversationChain # Note: Not from "langchain_community"
10
  from langchain.memory import ConversationBufferMemory
 
11
  from pathlib import Path
12
  import chromadb
 
13
  from transformers import AutoTokenizer
14
  import transformers
15
  import torch
16
- import tqdm
17
  import accelerate
18
 
19
- # LLM model and parameters (adjusted for clarity)
20
- chosen_llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
21
- llm_temperature = 0.7
22
- max_tokens = 1024
23
- top_k = 3
24
-
25
- # Chunk size and overlap (adjusted for clarity)
26
- chunk_size = 600
27
- chunk_overlap = 40
28
-
29
- # Initialize vector database in background
30
- accelerate(initialize_database)() # Function definition moved here
31
-
32
-
33
- def initialize_database():
34
- """
35
- This function initializes the vector database (assumed to be ChromaDB).
36
- Modify this function based on your specific database needs.
37
- """
38
- # Replace with your ChromaDB connection and schema creation logic
39
- # ...
40
- pass
41
-
42
-
43
- def demo():
44
- with gr.Blocks(theme="base") as demo:
45
- qa_chain = gr.State() # Store the initialized QA chain
46
- collection_name = gr.State()
47
 
48
- gr.Markdown(
49
- """
50
- <center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
51
- <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
52
- <b>Note:</b> This AI assistant performs retrieval-augmented generation from your PDF documents. \
53
- When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
54
- <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br>
55
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
57
-
58
- with gr.Row():
59
- document = gr.Files(
60
- height=100,
61
- file_count="multiple",
62
- file_types=["pdf"],
63
- interactive=True,
64
- label="Upload your PDF documents (single or multiple)",
65
- )
66
-
67
- with gr.Row():
68
- chatbot = gr.Chatbot(height=300)
69
-
70
- with gr.Accordion("Advanced - Document references", open=False):
71
- with gr.Row():
72
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
73
- source1_page = gr.Number(label="Page", scale=1)
74
- with gr.Row():
75
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
76
- source2_page = gr.Number(label="Page", scale=1)
77
- with gr.Row():
78
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
79
- source3_page = gr.Number(label="Page", scale=1)
80
-
81
- with gr.Row():
82
- msg = gr.Textbox(placeholder="Type message", container=True)
83
-
84
- with gr.Row():
85
- submit_btn = gr.Button("Submit")
86
- clear_btn = gr.ClearButton([msg, chatbot])
87
-
88
- # Initialize default QA chain when documents are uploaded
89
- document.uploaded(initialize_LLM, inputs=[chosen_llm_model])
90
-
91
- # Chatbot events
92
- msg.submit(conversation, inputs=[qa_chain, msg, chatbot])
93
- submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot])
94
- clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
95
-
 
1
  import gradio as gr
2
  import os
3
+
4
+ from langchain.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.vectorstores import Chroma
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.llms import HuggingFaceHub
9
+ from langchain.chains import ConversationalRetrievalChain
 
10
  from langchain.memory import ConversationBufferMemory
11
+
12
  from pathlib import Path
13
  import chromadb
14
+
15
  from transformers import AutoTokenizer
16
  import transformers
17
  import torch
18
+ import tqdm
19
  import accelerate
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Default LLM model
23
+ llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
24
+
25
+ # Other settings
26
+ default_persist_directory = './chroma_HF/'
27
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
28
+ "google/gemma-7b-it","google/gemma-2b-it", \
29
+ "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
30
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
31
+ "google/flan-t5-xxl"
32
+ ]
33
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
34
+
35
+ # Load vector database
36
+ def load_db():
37
+ embedding = HuggingFaceEmbeddings()
38
+ vectordb = Chroma(
39
+ persist_directory=default_persist_directory,
40
+ embedding_function=embedding)
41
+ return vectordb
42
+
43
+
44
+ # Initialize langchain LLM chain
45
+ def initialize_llmchain(vector_db, progress=gr.Progress()):
46
+ progress(0.5, desc="Initializing HF Hub...")
47
+ # Use of trust_remote_code as model_kwargs
48
+ # Warning: langchain issue
49
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
50
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
51
+ llm = HuggingFaceHub(
52
+ repo_id=llm_model,
53
+ model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3, "load_in_8bit": True}
54
  )
55
+ # ... (other model configurations for different model options)
56
+ else:
57
+ llm = HuggingFaceHub(
58
+ repo_id=llm_model,
59
+ model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3}
60
+ )
61
+
62
+ progress(0.75, desc="Defining buffer memory...")
63
+ memory = ConversationBufferMemory(
64
+ memory_key="chat_history",
65
+ output_key='answer',
66
+ return_messages=True
67
+ )
68
+ retriever=vector_db.as_retriever()
69
+ progress(0.8, desc="Defining retrieval chain...")
70
+ qa_chain = ConversationalRetrievalChain.from_llm(
71
+ llm,
72
+ retriever=retriever,
73
+ chain_type="stuff",
74
+ memory=memory,
75
+ return_source_documents=True,
76
+ verbose=False,
77
+ )
78
+ progress(0.9, desc="Done!")
79
+ return qa_chain
80
+
81
+
82
+ # ... (other functions remain the same)