cvquest-colpali / app.py
HUANG-Stephanie's picture
Update app.py
164d272 verified
raw
history blame contribute delete
No virus
5.47 kB
import os
import sys
import tempfile
import gradio as gr
import requests
import uvicorn
import torch
import base64
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import RedirectResponse, StreamingResponse
from typing import List
from pdf2image import convert_from_bytes
from PIL import Image
from torch.utils.data import DataLoader
from transformers import AutoProcessor
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter
from io import BytesIO
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
process_images,
process_queries,
)
app = FastAPI()
# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
"google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token = token)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
# In-memory storage
ds = []
images = []
@app.get("/")
def read_root():
return RedirectResponse(url="/docs")
@app.post("/index")
async def index(files: List[UploadFile] = File(...)):
global ds, images
images = []
ds = []
for file in files:
content = await file.read()
pdf_image_list = convert_from_bytes(content)
images.extend(pdf_image_list)
# Create embeddings for each file and load in memory storage
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return {"message": f"Uploaded and converted {len(images)} pages"}
def generate_pdf(results):
pdf_buffer = BytesIO()
c = canvas.Canvas(pdf_buffer, pagesize=letter)
width, height = letter
for result in results:
img_base64 = result["image"]
img_data = base64.b64decode(img_base64)
# Create a temporary file to hold the image
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
temp_file.write(img_data)
temp_file.flush()
# Draw the image from the temporary file
c.drawImage(temp_file.name, 0, 0, width, height)
c.showPage()
# Clean up the temporary file
os.remove(temp_file.name)
c.save()
pdf_buffer.seek(0)
return pdf_buffer
@app.get("/search")
async def search(query: str, k: int = 1):
qs = []
with torch.no_grad():
batch_query = process_queries(processor, [query], mock_image)
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
results = []
for idx in top_k_indices:
img_byte_arr = BytesIO()
images[idx].save(img_byte_arr, format='PNG')
img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
results.append({"image": img_base64, "page": f"Page {idx}"})
pdf_buffer = generate_pdf(results)
# Use StreamingResponse to handle in-memory file
response = StreamingResponse(pdf_buffer, media_type='application/pdf')
response.headers['Content-Disposition'] = 'attachment; filename="results.pdf"'
return response
@app.post("/recommendation")
async def recommendation(file: UploadFile = File(...), k: int = 10):
content = await file.read()
pdf_image_list = convert_from_bytes(content)
qs = []
dataloader = DataLoader(
pdf_image_list,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
top_k_indices = scores.argsort(axis=1)[0][-k-1:-1][::-1]
results = []
for idx in top_k_indices:
img_byte_arr = BytesIO()
images[idx].save(img_byte_arr, format='PNG')
img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
results.append({"image": img_base64, "page": f"Page {idx}"})
pdf_buffer = generate_pdf(results)
response = StreamingResponse(pdf_buffer, media_type='application/pdf')
response.headers['Content-Disposition'] = 'attachment; filename="results.pdf"'
return response
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)