Spaces:
Sleeping
Sleeping
HUANG-Stephanie
commited on
Commit
•
3a0c450
1
Parent(s):
5d3c3b6
Update colpali-main/demo/app.py
Browse files- colpali-main/demo/app.py +6 -17
colpali-main/demo/app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import os
|
2 |
-
import sys
|
3 |
|
4 |
import gradio as gr
|
5 |
import torch
|
@@ -9,13 +8,12 @@ from torch.utils.data import DataLoader
|
|
9 |
from tqdm import tqdm
|
10 |
from transformers import AutoProcessor
|
11 |
|
12 |
-
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
13 |
-
|
14 |
from colpali_engine.models.paligemma_colbert_architecture import ColPali
|
15 |
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
|
16 |
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
|
17 |
|
18 |
-
|
|
|
19 |
qs = []
|
20 |
with torch.no_grad():
|
21 |
batch_query = process_queries(processor, [query], mock_image)
|
@@ -26,17 +24,8 @@ def search(query: str, ds, images, k):
|
|
26 |
# run evaluation
|
27 |
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
|
28 |
scores = retriever_evaluator.evaluate(qs, ds)
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
results = []
|
33 |
-
for idx in top_k_indices:
|
34 |
-
results.append((images[idx], f"Page {idx}"))
|
35 |
-
|
36 |
-
return results
|
37 |
-
|
38 |
-
#best_page = int(scores.argmax(axis=1).item())
|
39 |
-
#return f"The most relevant page is {best_page}", images[best_page]
|
40 |
|
41 |
|
42 |
def index(file, ds):
|
@@ -59,6 +48,7 @@ def index(file, ds):
|
|
59 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
60 |
return f"Uploaded and converted {len(images)} pages", ds, images
|
61 |
|
|
|
62 |
COLORS = ["#4285f4", "#db4437", "#f4b400", "#0f9d58", "#e48ef1"]
|
63 |
# Load model
|
64 |
model_name = "vidore/colpali"
|
@@ -90,9 +80,8 @@ with gr.Blocks() as demo:
|
|
90 |
search_button = gr.Button("🔍 Search")
|
91 |
message2 = gr.Textbox("Query not yet set")
|
92 |
output_img = gr.Image()
|
93 |
-
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
|
94 |
|
95 |
-
search_button.click(search, inputs=[query, embeds, imgs
|
96 |
|
97 |
|
98 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
|
|
8 |
from tqdm import tqdm
|
9 |
from transformers import AutoProcessor
|
10 |
|
|
|
|
|
11 |
from colpali_engine.models.paligemma_colbert_architecture import ColPali
|
12 |
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
|
13 |
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
|
14 |
|
15 |
+
|
16 |
+
def search(query: str, ds, images):
|
17 |
qs = []
|
18 |
with torch.no_grad():
|
19 |
batch_query = process_queries(processor, [query], mock_image)
|
|
|
24 |
# run evaluation
|
25 |
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
|
26 |
scores = retriever_evaluator.evaluate(qs, ds)
|
27 |
+
best_page = int(scores.argmax(axis=1).item())
|
28 |
+
return f"The most relevant page is {best_page}", images[best_page]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
def index(file, ds):
|
|
|
48 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
49 |
return f"Uploaded and converted {len(images)} pages", ds, images
|
50 |
|
51 |
+
|
52 |
COLORS = ["#4285f4", "#db4437", "#f4b400", "#0f9d58", "#e48ef1"]
|
53 |
# Load model
|
54 |
model_name = "vidore/colpali"
|
|
|
80 |
search_button = gr.Button("🔍 Search")
|
81 |
message2 = gr.Textbox("Query not yet set")
|
82 |
output_img = gr.Image()
|
|
|
83 |
|
84 |
+
search_button.click(search, inputs=[query, embeds, imgs], outputs=[message2, output_img])
|
85 |
|
86 |
|
87 |
if __name__ == "__main__":
|