File size: 5,472 Bytes
4390904
 
0af5344
814b6ba
 
9b4d509
23a02a2
b61b33c
 
 
 
33de980
c5d3863
33de980
 
37a9f9c
46c7fc6
 
b61b33c
33de980
2902a60
 
d005da4
 
 
 
 
 
33de980
 
4390904
d005da4
 
 
 
 
4390904
d005da4
 
 
 
 
 
 
33de980
 
 
d005da4
b909e86
 
 
 
b1bf444
0523b65
b1bf444
 
 
 
c5d3863
74e7ff4
0523b65
45c1bf0
 
b1bf444
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae29c1
30a1a24
46c7fc6
 
5306b43
46c7fc6
5306b43
 
 
 
68d2467
5306b43
 
 
 
68d2467
 
5306b43
 
 
46c7fc6
 
 
4ae29c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285d8ef
5306b43
aaf0cd4
4ae29c1
 
 
 
d2b07a8
 
4ae29c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe7b387
4ae29c1
 
 
 
 
 
 
 
 
 
 
 
b1bf444
aaf0cd4
f4c3d99
5276af2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import sys
import tempfile
import gradio as gr
import requests
import uvicorn
import torch
import base64

from fastapi import FastAPI, File, UploadFile
from fastapi.responses import RedirectResponse, StreamingResponse
from typing import List
from pdf2image import convert_from_bytes
from PIL import Image
from torch.utils.data import DataLoader
from transformers import AutoProcessor
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter
from io import BytesIO

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))

from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
    process_images,
    process_queries,
)

app = FastAPI()

# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
    "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval()

model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token = token)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
    model.to(device)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))

# In-memory storage
ds = []
images = []

@app.get("/")
def read_root():
    return RedirectResponse(url="/docs")

@app.post("/index")
async def index(files: List[UploadFile] = File(...)):
    global ds, images
    images = []
    ds = []
    for file in files:
        content = await file.read()
        pdf_image_list = convert_from_bytes(content)
        images.extend(pdf_image_list)

    # Create embeddings for each file and load in memory storage
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor, x),
    )
    for batch_doc in dataloader:
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    
    return {"message": f"Uploaded and converted {len(images)} pages"}

def generate_pdf(results):
    pdf_buffer = BytesIO()
    c = canvas.Canvas(pdf_buffer, pagesize=letter)
    width, height = letter

    for result in results:
        img_base64 = result["image"]
        img_data = base64.b64decode(img_base64)
        
        # Create a temporary file to hold the image
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
            temp_file.write(img_data)
            temp_file.flush()
            
            # Draw the image from the temporary file
            c.drawImage(temp_file.name, 0, 0, width, height)
            c.showPage()
        
        # Clean up the temporary file
        os.remove(temp_file.name)

    c.save()
    pdf_buffer.seek(0)
    return pdf_buffer

@app.get("/search")
async def search(query: str, k: int = 1):
    qs = []
    with torch.no_grad():
        batch_query = process_queries(processor, [query], mock_image)
        batch_query = {k: v.to(device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)

    top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]

    results = []
    for idx in top_k_indices:
        img_byte_arr = BytesIO()
        images[idx].save(img_byte_arr, format='PNG')
        img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
        results.append({"image": img_base64, "page": f"Page {idx}"})

    pdf_buffer = generate_pdf(results)

    # Use StreamingResponse to handle in-memory file
    response = StreamingResponse(pdf_buffer, media_type='application/pdf')
    response.headers['Content-Disposition'] = 'attachment; filename="results.pdf"'

    return response

@app.post("/recommendation")
async def recommendation(file: UploadFile = File(...), k: int = 10):
    content = await file.read()
    pdf_image_list = convert_from_bytes(content)

    qs = []
    dataloader = DataLoader(
        pdf_image_list,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor, x),
    )
    for batch_query in dataloader:
        with torch.no_grad():
            batch_query = {k: v.to(device) for k, v in batch_query.items()}
            embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)

    top_k_indices = scores.argsort(axis=1)[0][-k-1:-1][::-1]

    results = []
    for idx in top_k_indices:
        img_byte_arr = BytesIO()
        images[idx].save(img_byte_arr, format='PNG')
        img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
        results.append({"image": img_base64, "page": f"Page {idx}"})

    pdf_buffer = generate_pdf(results)

    response = StreamingResponse(pdf_buffer, media_type='application/pdf')
    response.headers['Content-Disposition'] = 'attachment; filename="results.pdf"'

    return response

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)