Spaces:
Sleeping
Sleeping
File size: 2,585 Bytes
33de980 4390904 33de980 814b6ba 33de980 2902a60 d005da4 33de980 4390904 d005da4 4390904 d005da4 33de980 d005da4 33de980 4390904 33de980 4390904 33de980 4390904 33de980 4390904 33de980 4390904 33de980 4390904 33de980 d005da4 33de980 d005da4 14c5169 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import io
import os
import sys
from fastapi import FastAPI, File, UploadFile
import gradio as gr
import requests
from typing import List
import torch
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from transformers import AutoProcessor
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
process_images,
process_queries,
)
app = FastAPI()
# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
"google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token = token)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
# In-memory storage
ds = []
images = []
@app.post("/index")
async def index(files: List[UploadFile] = File(...)):
global ds, images
images = []
ds = []
for file in files:
content = await file.read()
pdf_image_list = convert_from_path(io.BytesIO(content))
images.extend(pdf_image_list)
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return {"message": f"Uploaded and converted {len(images)} pages"}
@app.post("/search")
async def search(query: str, k: int):
qs = []
with torch.no_grad():
batch_query = process_queries(processor, [query], mock_image)
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
results = [{"page": idx, "image": "image_placeholder"} for idx in top_k_indices]
return {"results": results} |