File size: 3,990 Bytes
b3f106f
 
 
 
 
 
66560e6
b3f106f
 
 
 
 
66560e6
 
 
 
b3f106f
 
 
 
 
2eefa77
b3f106f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6e9d81
6460f90
2f837a0
 
6460f90
b3f106f
2f837a0
b3f106f
66560e6
 
2eefa77
 
 
 
6460f90
2eefa77
 
6460f90
2eefa77
 
 
b3f106f
2eefa77
 
 
 
 
 
b3f106f
6460f90
 
 
 
b6e9d81
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
from sentence_transformers import SentenceTransformer
import pdfplumber
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import torch
from transformers import BartForConditionalGeneration, BartTokenizer
import gc

# Load Sentence Embedding Model for Vector Store
sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

# Load the Generator Model
generator = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
generator_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

# Initialize documents list with some sample documents
documents = [
    "Streamlit is an open-source Python library that makes it easy to build beautiful custom web-apps for machine learning and data science.",
    "Hugging Face is a company that provides tools and models for natural language processing (NLP).",
    "Retrieval-Augmented Generation (RAG) is a method that combines document retrieval with a generative model for question answering.",
    "Letter of Recommendation August 29, 2024 Muthukumaran Azhagesan Senior Software Engineer +8326559594 I am Muthukumaran Azhagesan, a Senior Software Engineer and Lead at Cisco...",
]

# Encode the initial documents for similarity comparison
doc_embeddings = sentence_model.encode(documents)

# Streamlit Frontend
st.set_page_config(page_title="RAG-based PDF Query Application", layout="wide")
st.title("πŸ“„ Retrieval-Augmented Generation (RAG) Application")

# File Upload for PDF Documents
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
if uploaded_file:
    # Extract text from PDF
    pdf_text = ""
    with pdfplumber.open(uploaded_file) as pdf:
        for page_num, page in enumerate(pdf.pages):
            if page_num > 20:  # Limit to first 20 pages for efficiency
                break
            page_text = page.extract_text()
            if page_text:  # Check if text was extracted
                pdf_text += page_text + " "

    if pdf_text:
        # Add the PDF text to the documents list and update document embeddings
        documents.append(pdf_text)
        pdf_embedding = sentence_model.encode([pdf_text])
        doc_embeddings = np.vstack([doc_embeddings, pdf_embedding])
        st.success("PDF text added to knowledge base for querying!")
    else:
        st.error("No text could be extracted from the PDF.")

# User Input
st.markdown("Enter your query below:")
query = st.text_input("πŸ” Enter your query")

if st.button("πŸ’¬ Get Answer"):
    if query:
        try:
            # Step 1: Encode the query using SentenceTransformer
            query_embedding = sentence_model.encode(query).reshape(1, -1)  # Reshape to (1, 384)

            # Step 2: Calculate Cosine Similarity
            similarity_scores = cosine_similarity(query_embedding, doc_embeddings)

            # Step 3: Get the index of the most similar document
            top_index = similarity_scores[0].argmax()
            
            # Check if top_index is valid
            if top_index < len(documents):
                retrieved_doc = documents[top_index]

                # Log the retrieved document for debugging
                st.write(f"Retrieved Document: {retrieved_doc}")

                # Step 4: Use the Generator to Answer the Question
                input_ids = generator_tokenizer.encode(f"question: {query} context: {retrieved_doc}", return_tensors="pt")
                outputs = generator.generate(input_ids, max_length=200, num_return_sequences=1)

                # Decode and display the response
                answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
                st.write("**Answer:**")
                st.write(answer)
            else:
                st.error("No relevant document found.")

        except Exception as e:
            st.error(f"An error occurred: {str(e)}")
        finally:
            gc.collect()
    else:
        st.error("Please enter a query.")