Spaces:
Sleeping
Sleeping
import io | |
import os | |
import sys | |
from fastapi import FastAPI, File, UploadFile | |
import gradio as gr | |
import requests | |
from typing import List | |
import torch | |
from pdf2image import convert_from_path | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from transformers import AutoProcessor | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main'))) | |
from colpali_engine.models.paligemma_colbert_architecture import ColPali | |
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator | |
from colpali_engine.utils.colpali_processing_utils import ( | |
process_images, | |
process_queries, | |
) | |
app = FastAPI() | |
# Load model | |
model_name = "vidore/colpali" | |
token = os.environ.get("HF_TOKEN") | |
model = ColPali.from_pretrained( | |
"google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval() | |
model.load_adapter(model_name) | |
processor = AutoProcessor.from_pretrained(model_name, token = token) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
if device != model.device: | |
model.to(device) | |
mock_image = Image.new("RGB", (448, 448), (255, 255, 255)) | |
# In-memory storage | |
ds = [] | |
images = [] | |
async def index(files: List[UploadFile] = File(...)): | |
global ds, images | |
images = [] | |
ds = [] | |
for file in files: | |
content = await file.read() | |
pdf_image_list = convert_from_path(io.BytesIO(content)) | |
images.extend(pdf_image_list) | |
dataloader = DataLoader( | |
images, | |
batch_size=4, | |
shuffle=False, | |
collate_fn=lambda x: process_images(processor, x), | |
) | |
for batch_doc in dataloader: | |
with torch.no_grad(): | |
batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
embeddings_doc = model(**batch_doc) | |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
return {"message": f"Uploaded and converted {len(images)} pages"} | |
async def search(query: str, k: int): | |
qs = [] | |
with torch.no_grad(): | |
batch_query = process_queries(processor, [query], mock_image) | |
batch_query = {k: v.to(device) for k, v in batch_query.items()} | |
embeddings_query = model(**batch_query) | |
qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
retriever_evaluator = CustomEvaluator(is_multi_vector=True) | |
scores = retriever_evaluator.evaluate(qs, ds) | |
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1] | |
results = [{"page": idx, "image": "image_placeholder"} for idx in top_k_indices] | |
return {"results": results} |