Spaces:
Running
Running
import os | |
import fitz # PyMuPDF for parsing PDF | |
import streamlit as st | |
from sentence_transformers import SentenceTransformer, util | |
# Load a pre-trained SentenceTransformer model | |
model_name = "paraphrase-MiniLM-L6-v2" # Can change this to a different model if needed | |
model = SentenceTransformer(model_name) | |
# Function to extract text from a PDF file | |
def extract_text_from_pdf(pdf_path): | |
text = "" | |
with fitz.open(pdf_path) as pdf_document: | |
for page_num in range(pdf_document.page_count): | |
page = pdf_document.load_page(page_num) | |
text += page.get_text() | |
return text | |
# Function to perform semantic search | |
def semantic_search(query, documents, top_k=5): | |
query_embedding = model.encode(query, convert_to_tensor=True) | |
# Convert the list of documents to embeddings | |
document_embeddings = model.encode(documents, convert_to_tensor=True) | |
# Compute cosine similarity scores of query with documents | |
cosine_scores = util.pytorch_cos_sim(query_embedding, document_embeddings) | |
# Sort the results in decreasing order | |
results = [] | |
for idx in range(len(cosine_scores)): | |
results.append((documents[idx], cosine_scores[idx].item())) | |
results = sorted(results, key=lambda x: x[1], reverse=True) | |
return results[:top_k] | |
def main(): | |
st.title("Semantic Search on PDF Documents") | |
query = st.text_input("Enter your query:") | |
pdf_file = st.file_uploader("Upload a PDF file:", type=["pdf"]) | |
if st.button("Search"): | |
if pdf_file: | |
pdf_path = os.path.join("uploads", pdf_file.name) | |
with open(pdf_path, "wb") as f: | |
f.write(pdf_file.read()) | |
pdf_text = extract_text_from_pdf(pdf_path) | |
search_results = semantic_search(query, [pdf_text]) | |
os.remove(pdf_path) # Delete the uploaded file after processing | |
st.write(f"Search results for query: '{query}'") | |
for i, (result, score) in enumerate(search_results, start=1): | |
st.write(f"{i}. Score: {score:.2f}") | |
st.write(result) | |
if __name__ == "__main__": | |
main() | |