Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer
|
3 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
4 |
from sentence_transformers import SentenceTransformer
|
|
|
5 |
import faiss
|
6 |
import torch
|
7 |
|
@@ -9,9 +10,6 @@ import torch
|
|
9 |
question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
10 |
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
11 |
|
12 |
-
context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
13 |
-
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
14 |
-
|
15 |
# Load the Generator Model
|
16 |
generator = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
|
17 |
generator_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
@@ -22,14 +20,14 @@ sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
|
|
22 |
# Initialize FAISS index for fast similarity search
|
23 |
index = faiss.IndexFlatIP(384)
|
24 |
|
25 |
-
#
|
26 |
-
# You can replace this list with any collection of documents
|
27 |
documents = [
|
28 |
-
"
|
29 |
-
|
|
|
30 |
]
|
31 |
|
32 |
-
# Encode the documents and add to FAISS index
|
33 |
doc_embeddings = sentence_model.encode(documents, convert_to_tensor=True).cpu().detach().numpy()
|
34 |
index.add(doc_embeddings)
|
35 |
|
@@ -37,29 +35,52 @@ index.add(doc_embeddings)
|
|
37 |
st.set_page_config(page_title="RAG-based PDF Query Application", layout="wide")
|
38 |
st.title("π Retrieval-Augmented Generation (RAG) Application")
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# User Input
|
41 |
-
st.markdown("
|
42 |
query = st.text_input("π Enter your query")
|
43 |
|
44 |
if st.button("π¬ Get Answer"):
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
48 |
|
49 |
-
|
50 |
-
|
51 |
|
52 |
-
|
53 |
-
|
54 |
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer
|
3 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
+
import pdfplumber
|
6 |
import faiss
|
7 |
import torch
|
8 |
|
|
|
10 |
question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
11 |
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
12 |
|
|
|
|
|
|
|
13 |
# Load the Generator Model
|
14 |
generator = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
|
15 |
generator_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
|
|
20 |
# Initialize FAISS index for fast similarity search
|
21 |
index = faiss.IndexFlatIP(384)
|
22 |
|
23 |
+
# Initialize documents list with some sample documents
|
|
|
24 |
documents = [
|
25 |
+
"Streamlit is an open-source Python library that makes it easy to build beautiful custom web-apps for machine learning and data science.",
|
26 |
+
"Hugging Face is a company that provides tools and models for natural language processing (NLP).",
|
27 |
+
"Retrieval-Augmented Generation (RAG) is a method that combines document retrieval with a generative model for question answering.",
|
28 |
]
|
29 |
|
30 |
+
# Encode the initial documents and add to FAISS index
|
31 |
doc_embeddings = sentence_model.encode(documents, convert_to_tensor=True).cpu().detach().numpy()
|
32 |
index.add(doc_embeddings)
|
33 |
|
|
|
35 |
st.set_page_config(page_title="RAG-based PDF Query Application", layout="wide")
|
36 |
st.title("π Retrieval-Augmented Generation (RAG) Application")
|
37 |
|
38 |
+
# File Upload for PDF Documents
|
39 |
+
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
|
40 |
+
if uploaded_file:
|
41 |
+
# Extract text from PDF
|
42 |
+
pdf_text = ""
|
43 |
+
with pdfplumber.open(uploaded_file) as pdf:
|
44 |
+
for page in pdf.pages:
|
45 |
+
page_text = page.extract_text()
|
46 |
+
if page_text: # Check if text was extracted
|
47 |
+
pdf_text += page_text
|
48 |
+
|
49 |
+
if pdf_text:
|
50 |
+
# Add the PDF text to the documents list and FAISS index
|
51 |
+
documents.append(pdf_text)
|
52 |
+
pdf_embedding = sentence_model.encode([pdf_text], convert_to_tensor=True).cpu().detach().numpy()
|
53 |
+
index.add(pdf_embedding)
|
54 |
+
st.success("PDF text added to knowledge base for querying!")
|
55 |
+
else:
|
56 |
+
st.error("No text could be extracted from the PDF.")
|
57 |
+
|
58 |
# User Input
|
59 |
+
st.markdown("Enter your query below:")
|
60 |
query = st.text_input("π Enter your query")
|
61 |
|
62 |
if st.button("π¬ Get Answer"):
|
63 |
+
if query:
|
64 |
+
# Step 1: Encode the query
|
65 |
+
question_inputs = question_tokenizer(query, return_tensors="pt")
|
66 |
+
question_embedding = question_encoder(**question_inputs).pooler_output.detach().cpu().numpy()
|
67 |
|
68 |
+
# Step 2: Perform FAISS search for document retrieval
|
69 |
+
_, retrieved_doc_indices = index.search(question_embedding, k=3) # Retrieve top 3 relevant documents
|
70 |
|
71 |
+
# Step 3: Retrieve the top documents
|
72 |
+
retrieved_docs = [documents[idx] for idx in retrieved_doc_indices[0]]
|
73 |
|
74 |
+
# Step 4: Concatenate retrieved documents
|
75 |
+
context = " ".join(retrieved_docs)
|
76 |
|
77 |
+
# Step 5: Use the Generator to Answer the Question
|
78 |
+
input_ids = generator_tokenizer.encode(f"question: {query} context: {context}", return_tensors="pt")
|
79 |
+
outputs = generator.generate(input_ids, max_length=200, num_return_sequences=1)
|
80 |
|
81 |
+
# Decode and display the response
|
82 |
+
answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
83 |
+
st.write("**Answer:**")
|
84 |
+
st.write(answer)
|
85 |
+
else:
|
86 |
+
st.error("Please enter a query.")
|