mohAhmad commited on
Commit
b3f106f
β€’
1 Parent(s): 6460f90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -6
app.py CHANGED
@@ -1,26 +1,93 @@
1
- if st.button("Get Answer"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  if query:
3
  try:
4
- # Existing embedding process
5
  question_inputs = question_tokenizer(query, return_tensors="pt")
6
  question_embedding = question_encoder(**question_inputs).pooler_output.detach().cpu().numpy()
7
 
8
- # Cosine similarity
9
  similarity_scores = cosine_similarity(question_embedding, doc_embeddings)
 
 
10
  top_indices = similarity_scores[0].argsort()[-3:][::-1]
11
  retrieved_docs = [documents[idx] for idx in top_indices]
 
 
12
  context = " ".join(retrieved_docs)
13
 
14
  # Log the retrieved context for debugging
15
  st.write(f"Context for the query: {context}")
16
 
17
- # Ensure the question and context are correctly formatted
18
  input_ids = generator_tokenizer.encode(f"question: {query} context: {context}", return_tensors="pt")
19
- outputs = generator.generate(input_ids, max_length=200)
 
 
20
  answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
21
-
22
  st.write("**Answer:**")
23
  st.write(answer)
 
24
  except Exception as e:
25
  st.error(f"An error occurred: {str(e)}")
26
  finally:
 
1
+ import streamlit as st
2
+ from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
3
+ from transformers import BartForConditionalGeneration, BartTokenizer
4
+ from sentence_transformers import SentenceTransformer
5
+ import pdfplumber
6
+ import numpy as np
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ import torch
9
+ import gc
10
+
11
+ # Load the Question Encoder, Context Encoder, and Tokenizers
12
+ question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
13
+ question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
14
+
15
+ # Load the Generator Model
16
+ generator = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
17
+ generator_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
18
+
19
+ # Load Sentence Embedding Model for Vector Store
20
+ sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
21
+
22
+ # Initialize documents list with some sample documents
23
+ documents = [
24
+ "Streamlit is an open-source Python library that makes it easy to build beautiful custom web-apps for machine learning and data science.",
25
+ "Hugging Face is a company that provides tools and models for natural language processing (NLP).",
26
+ "Retrieval-Augmented Generation (RAG) is a method that combines document retrieval with a generative model for question answering.",
27
+ ]
28
+
29
+ # Encode the initial documents for similarity comparison
30
+ doc_embeddings = sentence_model.encode(documents)
31
+
32
+ # Streamlit Frontend
33
+ st.set_page_config(page_title="RAG-based PDF Query Application", layout="wide")
34
+ st.title("πŸ“„ Retrieval-Augmented Generation (RAG) Application")
35
+
36
+ # File Upload for PDF Documents
37
+ uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
38
+ if uploaded_file:
39
+ # Extract text from PDF
40
+ pdf_text = ""
41
+ with pdfplumber.open(uploaded_file) as pdf:
42
+ for page_num, page in enumerate(pdf.pages):
43
+ if page_num > 20: # Limit to first 20 pages for efficiency
44
+ break
45
+ page_text = page.extract_text()
46
+ if page_text: # Check if text was extracted
47
+ pdf_text += page_text + " "
48
+
49
+ if pdf_text:
50
+ # Add the PDF text to the documents list and update document embeddings
51
+ documents.append(pdf_text)
52
+ pdf_embedding = sentence_model.encode([pdf_text])
53
+ doc_embeddings = np.vstack([doc_embeddings, pdf_embedding])
54
+ st.success("PDF text added to knowledge base for querying!")
55
+ else:
56
+ st.error("No text could be extracted from the PDF.")
57
+
58
+ # User Input
59
+ st.markdown("Enter your query below:")
60
+ query = st.text_input("πŸ” Enter your query")
61
+
62
+ if st.button("πŸ’¬ Get Answer"):
63
  if query:
64
  try:
65
+ # Step 1: Encode the query
66
  question_inputs = question_tokenizer(query, return_tensors="pt")
67
  question_embedding = question_encoder(**question_inputs).pooler_output.detach().cpu().numpy()
68
 
69
+ # Step 2: Calculate Cosine Similarity
70
  similarity_scores = cosine_similarity(question_embedding, doc_embeddings)
71
+
72
+ # Step 3: Get the indices of the top 3 most similar documents
73
  top_indices = similarity_scores[0].argsort()[-3:][::-1]
74
  retrieved_docs = [documents[idx] for idx in top_indices]
75
+
76
+ # Step 4: Concatenate retrieved documents
77
  context = " ".join(retrieved_docs)
78
 
79
  # Log the retrieved context for debugging
80
  st.write(f"Context for the query: {context}")
81
 
82
+ # Step 5: Use the Generator to Answer the Question
83
  input_ids = generator_tokenizer.encode(f"question: {query} context: {context}", return_tensors="pt")
84
+ outputs = generator.generate(input_ids, max_length=200, num_return_sequences=1)
85
+
86
+ # Decode and display the response
87
  answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
88
  st.write("**Answer:**")
89
  st.write(answer)
90
+
91
  except Exception as e:
92
  st.error(f"An error occurred: {str(e)}")
93
  finally: