Update app.py
Browse files
app.py
CHANGED
@@ -1,26 +1,93 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
if query:
|
3 |
try:
|
4 |
-
#
|
5 |
question_inputs = question_tokenizer(query, return_tensors="pt")
|
6 |
question_embedding = question_encoder(**question_inputs).pooler_output.detach().cpu().numpy()
|
7 |
|
8 |
-
# Cosine
|
9 |
similarity_scores = cosine_similarity(question_embedding, doc_embeddings)
|
|
|
|
|
10 |
top_indices = similarity_scores[0].argsort()[-3:][::-1]
|
11 |
retrieved_docs = [documents[idx] for idx in top_indices]
|
|
|
|
|
12 |
context = " ".join(retrieved_docs)
|
13 |
|
14 |
# Log the retrieved context for debugging
|
15 |
st.write(f"Context for the query: {context}")
|
16 |
|
17 |
-
#
|
18 |
input_ids = generator_tokenizer.encode(f"question: {query} context: {context}", return_tensors="pt")
|
19 |
-
outputs = generator.generate(input_ids, max_length=200)
|
|
|
|
|
20 |
answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
21 |
-
|
22 |
st.write("**Answer:**")
|
23 |
st.write(answer)
|
|
|
24 |
except Exception as e:
|
25 |
st.error(f"An error occurred: {str(e)}")
|
26 |
finally:
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
|
3 |
+
from transformers import BartForConditionalGeneration, BartTokenizer
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
import pdfplumber
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
+
import torch
|
9 |
+
import gc
|
10 |
+
|
11 |
+
# Load the Question Encoder, Context Encoder, and Tokenizers
|
12 |
+
question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
13 |
+
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
14 |
+
|
15 |
+
# Load the Generator Model
|
16 |
+
generator = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
|
17 |
+
generator_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
18 |
+
|
19 |
+
# Load Sentence Embedding Model for Vector Store
|
20 |
+
sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
|
21 |
+
|
22 |
+
# Initialize documents list with some sample documents
|
23 |
+
documents = [
|
24 |
+
"Streamlit is an open-source Python library that makes it easy to build beautiful custom web-apps for machine learning and data science.",
|
25 |
+
"Hugging Face is a company that provides tools and models for natural language processing (NLP).",
|
26 |
+
"Retrieval-Augmented Generation (RAG) is a method that combines document retrieval with a generative model for question answering.",
|
27 |
+
]
|
28 |
+
|
29 |
+
# Encode the initial documents for similarity comparison
|
30 |
+
doc_embeddings = sentence_model.encode(documents)
|
31 |
+
|
32 |
+
# Streamlit Frontend
|
33 |
+
st.set_page_config(page_title="RAG-based PDF Query Application", layout="wide")
|
34 |
+
st.title("π Retrieval-Augmented Generation (RAG) Application")
|
35 |
+
|
36 |
+
# File Upload for PDF Documents
|
37 |
+
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
|
38 |
+
if uploaded_file:
|
39 |
+
# Extract text from PDF
|
40 |
+
pdf_text = ""
|
41 |
+
with pdfplumber.open(uploaded_file) as pdf:
|
42 |
+
for page_num, page in enumerate(pdf.pages):
|
43 |
+
if page_num > 20: # Limit to first 20 pages for efficiency
|
44 |
+
break
|
45 |
+
page_text = page.extract_text()
|
46 |
+
if page_text: # Check if text was extracted
|
47 |
+
pdf_text += page_text + " "
|
48 |
+
|
49 |
+
if pdf_text:
|
50 |
+
# Add the PDF text to the documents list and update document embeddings
|
51 |
+
documents.append(pdf_text)
|
52 |
+
pdf_embedding = sentence_model.encode([pdf_text])
|
53 |
+
doc_embeddings = np.vstack([doc_embeddings, pdf_embedding])
|
54 |
+
st.success("PDF text added to knowledge base for querying!")
|
55 |
+
else:
|
56 |
+
st.error("No text could be extracted from the PDF.")
|
57 |
+
|
58 |
+
# User Input
|
59 |
+
st.markdown("Enter your query below:")
|
60 |
+
query = st.text_input("π Enter your query")
|
61 |
+
|
62 |
+
if st.button("π¬ Get Answer"):
|
63 |
if query:
|
64 |
try:
|
65 |
+
# Step 1: Encode the query
|
66 |
question_inputs = question_tokenizer(query, return_tensors="pt")
|
67 |
question_embedding = question_encoder(**question_inputs).pooler_output.detach().cpu().numpy()
|
68 |
|
69 |
+
# Step 2: Calculate Cosine Similarity
|
70 |
similarity_scores = cosine_similarity(question_embedding, doc_embeddings)
|
71 |
+
|
72 |
+
# Step 3: Get the indices of the top 3 most similar documents
|
73 |
top_indices = similarity_scores[0].argsort()[-3:][::-1]
|
74 |
retrieved_docs = [documents[idx] for idx in top_indices]
|
75 |
+
|
76 |
+
# Step 4: Concatenate retrieved documents
|
77 |
context = " ".join(retrieved_docs)
|
78 |
|
79 |
# Log the retrieved context for debugging
|
80 |
st.write(f"Context for the query: {context}")
|
81 |
|
82 |
+
# Step 5: Use the Generator to Answer the Question
|
83 |
input_ids = generator_tokenizer.encode(f"question: {query} context: {context}", return_tensors="pt")
|
84 |
+
outputs = generator.generate(input_ids, max_length=200, num_return_sequences=1)
|
85 |
+
|
86 |
+
# Decode and display the response
|
87 |
answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
88 |
st.write("**Answer:**")
|
89 |
st.write(answer)
|
90 |
+
|
91 |
except Exception as e:
|
92 |
st.error(f"An error occurred: {str(e)}")
|
93 |
finally:
|