RAGGO / app.py
mohAhmad's picture
Update app.py
2eefa77 verified
import streamlit as st
from sentence_transformers import SentenceTransformer
import pdfplumber
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import torch
from transformers import BartForConditionalGeneration, BartTokenizer
import gc
# Load Sentence Embedding Model for Vector Store
sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
# Load the Generator Model
generator = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
generator_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
# Initialize documents list with some sample documents
documents = [
"Streamlit is an open-source Python library that makes it easy to build beautiful custom web-apps for machine learning and data science.",
"Hugging Face is a company that provides tools and models for natural language processing (NLP).",
"Retrieval-Augmented Generation (RAG) is a method that combines document retrieval with a generative model for question answering.",
"Letter of Recommendation August 29, 2024 Muthukumaran Azhagesan Senior Software Engineer +8326559594 I am Muthukumaran Azhagesan, a Senior Software Engineer and Lead at Cisco...",
]
# Encode the initial documents for similarity comparison
doc_embeddings = sentence_model.encode(documents)
# Streamlit Frontend
st.set_page_config(page_title="RAG-based PDF Query Application", layout="wide")
st.title("πŸ“„ Retrieval-Augmented Generation (RAG) Application")
# File Upload for PDF Documents
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
if uploaded_file:
# Extract text from PDF
pdf_text = ""
with pdfplumber.open(uploaded_file) as pdf:
for page_num, page in enumerate(pdf.pages):
if page_num > 20: # Limit to first 20 pages for efficiency
break
page_text = page.extract_text()
if page_text: # Check if text was extracted
pdf_text += page_text + " "
if pdf_text:
# Add the PDF text to the documents list and update document embeddings
documents.append(pdf_text)
pdf_embedding = sentence_model.encode([pdf_text])
doc_embeddings = np.vstack([doc_embeddings, pdf_embedding])
st.success("PDF text added to knowledge base for querying!")
else:
st.error("No text could be extracted from the PDF.")
# User Input
st.markdown("Enter your query below:")
query = st.text_input("πŸ” Enter your query")
if st.button("πŸ’¬ Get Answer"):
if query:
try:
# Step 1: Encode the query using SentenceTransformer
query_embedding = sentence_model.encode(query).reshape(1, -1) # Reshape to (1, 384)
# Step 2: Calculate Cosine Similarity
similarity_scores = cosine_similarity(query_embedding, doc_embeddings)
# Step 3: Get the index of the most similar document
top_index = similarity_scores[0].argmax()
# Check if top_index is valid
if top_index < len(documents):
retrieved_doc = documents[top_index]
# Log the retrieved document for debugging
st.write(f"Retrieved Document: {retrieved_doc}")
# Step 4: Use the Generator to Answer the Question
input_ids = generator_tokenizer.encode(f"question: {query} context: {retrieved_doc}", return_tensors="pt")
outputs = generator.generate(input_ids, max_length=200, num_return_sequences=1)
# Decode and display the response
answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
st.write("**Answer:**")
st.write(answer)
else:
st.error("No relevant document found.")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
finally:
gc.collect()
else:
st.error("Please enter a query.")