|
import streamlit as st |
|
from sentence_transformers import SentenceTransformer |
|
import pdfplumber |
|
import numpy as np |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import torch |
|
from transformers import BartForConditionalGeneration, BartTokenizer |
|
import gc |
|
|
|
|
|
sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2') |
|
|
|
|
|
generator = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") |
|
generator_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") |
|
|
|
|
|
documents = [ |
|
"Streamlit is an open-source Python library that makes it easy to build beautiful custom web-apps for machine learning and data science.", |
|
"Hugging Face is a company that provides tools and models for natural language processing (NLP).", |
|
"Retrieval-Augmented Generation (RAG) is a method that combines document retrieval with a generative model for question answering.", |
|
"Letter of Recommendation August 29, 2024 Muthukumaran Azhagesan Senior Software Engineer +8326559594 I am Muthukumaran Azhagesan, a Senior Software Engineer and Lead at Cisco...", |
|
] |
|
|
|
|
|
doc_embeddings = sentence_model.encode(documents) |
|
|
|
|
|
st.set_page_config(page_title="RAG-based PDF Query Application", layout="wide") |
|
st.title("π Retrieval-Augmented Generation (RAG) Application") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf") |
|
if uploaded_file: |
|
|
|
pdf_text = "" |
|
with pdfplumber.open(uploaded_file) as pdf: |
|
for page_num, page in enumerate(pdf.pages): |
|
if page_num > 20: |
|
break |
|
page_text = page.extract_text() |
|
if page_text: |
|
pdf_text += page_text + " " |
|
|
|
if pdf_text: |
|
|
|
documents.append(pdf_text) |
|
pdf_embedding = sentence_model.encode([pdf_text]) |
|
doc_embeddings = np.vstack([doc_embeddings, pdf_embedding]) |
|
st.success("PDF text added to knowledge base for querying!") |
|
else: |
|
st.error("No text could be extracted from the PDF.") |
|
|
|
|
|
st.markdown("Enter your query below:") |
|
query = st.text_input("π Enter your query") |
|
|
|
if st.button("π¬ Get Answer"): |
|
if query: |
|
try: |
|
|
|
query_embedding = sentence_model.encode(query).reshape(1, -1) |
|
|
|
|
|
similarity_scores = cosine_similarity(query_embedding, doc_embeddings) |
|
|
|
|
|
top_index = similarity_scores[0].argmax() |
|
|
|
|
|
if top_index < len(documents): |
|
retrieved_doc = documents[top_index] |
|
|
|
|
|
st.write(f"Retrieved Document: {retrieved_doc}") |
|
|
|
|
|
input_ids = generator_tokenizer.encode(f"question: {query} context: {retrieved_doc}", return_tensors="pt") |
|
outputs = generator.generate(input_ids, max_length=200, num_return_sequences=1) |
|
|
|
|
|
answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
st.write("**Answer:**") |
|
st.write(answer) |
|
else: |
|
st.error("No relevant document found.") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
finally: |
|
gc.collect() |
|
else: |
|
st.error("Please enter a query.") |
|
|