import os import sys import tempfile import gradio as gr import requests import uvicorn import torch import base64 from fastapi import FastAPI, File, UploadFile from fastapi.responses import RedirectResponse, StreamingResponse from typing import List from pdf2image import convert_from_bytes from PIL import Image from torch.utils.data import DataLoader from transformers import AutoProcessor from reportlab.pdfgen import canvas from reportlab.lib.pagesizes import letter from io import BytesIO sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main'))) from colpali_engine.models.paligemma_colbert_architecture import ColPali from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator from colpali_engine.utils.colpali_processing_utils import ( process_images, process_queries, ) app = FastAPI() # Load model model_name = "vidore/colpali" token = os.environ.get("HF_TOKEN") model = ColPali.from_pretrained( "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval() model.load_adapter(model_name) processor = AutoProcessor.from_pretrained(model_name, token = token) device = "cuda:0" if torch.cuda.is_available() else "cpu" if device != model.device: model.to(device) mock_image = Image.new("RGB", (448, 448), (255, 255, 255)) # In-memory storage ds = [] images = [] @app.get("/") def read_root(): return RedirectResponse(url="/docs") @app.post("/index") async def index(files: List[UploadFile] = File(...)): global ds, images images = [] ds = [] for file in files: content = await file.read() pdf_image_list = convert_from_bytes(content) images.extend(pdf_image_list) # Create embeddings for each file and load in memory storage dataloader = DataLoader( images, batch_size=4, shuffle=False, collate_fn=lambda x: process_images(processor, x), ) for batch_doc in dataloader: with torch.no_grad(): batch_doc = {k: v.to(device) for k, v in batch_doc.items()} embeddings_doc = model(**batch_doc) ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) return {"message": f"Uploaded and converted {len(images)} pages"} def generate_pdf(results): pdf_buffer = BytesIO() c = canvas.Canvas(pdf_buffer, pagesize=letter) width, height = letter for result in results: img_base64 = result["image"] img_data = base64.b64decode(img_base64) # Create a temporary file to hold the image with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: temp_file.write(img_data) temp_file.flush() # Draw the image from the temporary file c.drawImage(temp_file.name, 0, 0, width, height) c.showPage() # Clean up the temporary file os.remove(temp_file.name) c.save() pdf_buffer.seek(0) return pdf_buffer @app.get("/search") async def search(query: str, k: int = 1): qs = [] with torch.no_grad(): batch_query = process_queries(processor, [query], mock_image) batch_query = {k: v.to(device) for k, v in batch_query.items()} embeddings_query = model(**batch_query) qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) retriever_evaluator = CustomEvaluator(is_multi_vector=True) scores = retriever_evaluator.evaluate(qs, ds) top_k_indices = scores.argsort(axis=1)[0][-k:][::-1] results = [] for idx in top_k_indices: img_byte_arr = BytesIO() images[idx].save(img_byte_arr, format='PNG') img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') results.append({"image": img_base64, "page": f"Page {idx}"}) pdf_buffer = generate_pdf(results) # Use StreamingResponse to handle in-memory file response = StreamingResponse(pdf_buffer, media_type='application/pdf') response.headers['Content-Disposition'] = 'attachment; filename="results.pdf"' return response @app.post("/recommendation") async def recommendation(file: UploadFile = File(...), k: int = 10): content = await file.read() pdf_image_list = convert_from_bytes(content) qs = [] dataloader = DataLoader( pdf_image_list, batch_size=4, shuffle=False, collate_fn=lambda x: process_images(processor, x), ) for batch_query in dataloader: with torch.no_grad(): batch_query = {k: v.to(device) for k, v in batch_query.items()} embeddings_query = model(**batch_query) qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) retriever_evaluator = CustomEvaluator(is_multi_vector=True) scores = retriever_evaluator.evaluate(qs, ds) top_k_indices = scores.argsort(axis=1)[0][-k-1:-1][::-1] results = [] for idx in top_k_indices: img_byte_arr = BytesIO() images[idx].save(img_byte_arr, format='PNG') img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') results.append({"image": img_base64, "page": f"Page {idx}"}) pdf_buffer = generate_pdf(results) response = StreamingResponse(pdf_buffer, media_type='application/pdf') response.headers['Content-Disposition'] = 'attachment; filename="results.pdf"' return response if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)