from io import BytesIO import os import sys import tempfile from fastapi import FastAPI, File, UploadFile from fastapi.responses import RedirectResponse, StreamingResponse import gradio as gr import requests import uvicorn from typing import List import torch from pdf2image import convert_from_bytes from PIL import Image from torch.utils.data import DataLoader from transformers import AutoProcessor import base64 from reportlab.pdfgen import canvas from reportlab.lib.pagesizes import letter sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main'))) from colpali_engine.models.paligemma_colbert_architecture import ColPali from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator from colpali_engine.utils.colpali_processing_utils import ( process_images, process_queries, ) app = FastAPI() # Load model model_name = "vidore/colpali" token = os.environ.get("HF_TOKEN") model = ColPali.from_pretrained( "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval() model.load_adapter(model_name) processor = AutoProcessor.from_pretrained(model_name, token = token) device = "cuda:0" if torch.cuda.is_available() else "cpu" if device != model.device: model.to(device) mock_image = Image.new("RGB", (448, 448), (255, 255, 255)) # In-memory storage ds = [] images = [] @app.get("/") def read_root(): return RedirectResponse(url="/docs") @app.post("/index") async def index(files: List[UploadFile] = File(...)): global ds, images images = [] ds = [] for file in files: content = await file.read() pdf_image_list = convert_from_bytes(content) images.extend(pdf_image_list) dataloader = DataLoader( images, batch_size=4, shuffle=False, collate_fn=lambda x: process_images(processor, x), ) for batch_doc in dataloader: with torch.no_grad(): batch_doc = {k: v.to(device) for k, v in batch_doc.items()} embeddings_doc = model(**batch_doc) ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) return {"message": f"Uploaded and converted {len(images)} pages"} @app.post("/search") async def search(query: str, k: int): qs = [] with torch.no_grad(): batch_query = process_queries(processor, [query], mock_image) batch_query = {k: v.to(device) for k, v in batch_query.items()} embeddings_query = model(**batch_query) qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) retriever_evaluator = CustomEvaluator(is_multi_vector=True) scores = retriever_evaluator.evaluate(qs, ds) top_k_indices = scores.argsort(axis=1)[0][-k:][::-1] results = [] for idx in top_k_indices: img_byte_arr = BytesIO() images[idx].save(img_byte_arr, format='PNG') img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') results.append({"image": img_base64, "page": f"Page {idx}"}) # Generate PDF pdf_buffer = BytesIO() c = canvas.Canvas(pdf_buffer, pagesize=letter) width, height = letter for result in results: img_base64 = result["image"] img_data = base64.b64decode(img_base64) # Create a temporary file to hold the image with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: temp_file.write(img_data) temp_file.flush() # Draw the image from the temporary file c.drawImage(temp_file.name, 0, 0, width, height) c.showPage() # Clean up the temporary file os.remove(temp_file.name) c.save() pdf_buffer.seek(0) # Use StreamingResponse to handle in-memory file response = StreamingResponse(pdf_buffer, media_type='application/pdf') response.headers['Content-Disposition'] = 'attachment; filename="search_results.pdf"' return response if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)