HUANG-Stephanie commited on
Commit
4390904
1 Parent(s): 44cf120

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +2 -8
  2. app.py +89 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Cvquest Colpali
3
- emoji: 🌍
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.39.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: cvquest-colpali
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.39.0
 
 
6
  ---
 
 
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from pdf2image import convert_from_path
7
+ from PIL import Image
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+ from transformers import AutoProcessor
11
+
12
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
13
+
14
+ from colpali_engine.models.paligemma_colbert_architecture import ColPali
15
+ from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
16
+ from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
17
+
18
+ def search(query: str, ds, images):
19
+ qs = []
20
+ with torch.no_grad():
21
+ batch_query = process_queries(processor, [query], mock_image)
22
+ batch_query = {k: v.to(device) for k, v in batch_query.items()}
23
+ embeddings_query = model(**batch_query)
24
+ qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
25
+
26
+ # run evaluation
27
+ retriever_evaluator = CustomEvaluator(is_multi_vector=True)
28
+ scores = retriever_evaluator.evaluate(qs, ds)
29
+ best_page = int(scores.argmax(axis=1).item())
30
+ return f"The most relevant page is {best_page}", images[best_page]
31
+
32
+
33
+ def index(file, ds):
34
+ """Example script to run inference with ColPali"""
35
+ images = []
36
+ for f in file:
37
+ images.extend(convert_from_path(f))
38
+
39
+ # run inference - docs
40
+ dataloader = DataLoader(
41
+ images,
42
+ batch_size=4,
43
+ shuffle=False,
44
+ collate_fn=lambda x: process_images(processor, x),
45
+ )
46
+ for batch_doc in tqdm(dataloader):
47
+ with torch.no_grad():
48
+ batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
49
+ embeddings_doc = model(**batch_doc)
50
+ ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
51
+ return f"Uploaded and converted {len(images)} pages", ds, images
52
+
53
+ COLORS = ["#4285f4", "#db4437", "#f4b400", "#0f9d58", "#e48ef1"]
54
+ # Load model
55
+ model_name = "vidore/colpali"
56
+ token = os.environ.get("HF_TOKEN")
57
+ model = ColPali.from_pretrained(
58
+ "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token=token
59
+ ).eval()
60
+ model.load_adapter(model_name)
61
+ processor = AutoProcessor.from_pretrained(model_name, token=token)
62
+ device = model.device
63
+ mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
64
+
65
+ with gr.Blocks() as demo:
66
+ gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models 📚🔍")
67
+ gr.Markdown("## 1️⃣ Upload PDFs")
68
+ file = gr.File(file_types=["pdf"], file_count="multiple")
69
+
70
+ gr.Markdown("## 2️⃣ Convert the PDFs and upload")
71
+ convert_button = gr.Button("🔄 Convert and upload")
72
+ message = gr.Textbox("Files not yet uploaded")
73
+ embeds = gr.State(value=[])
74
+ imgs = gr.State(value=[])
75
+
76
+ # Define the actions
77
+ convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
78
+
79
+ gr.Markdown("## 3️⃣ Search")
80
+ query = gr.Textbox(placeholder="Enter your query here")
81
+ search_button = gr.Button("🔍 Search")
82
+ message2 = gr.Textbox("Query not yet set")
83
+ output_img = gr.Image()
84
+
85
+ search_button.click(search, inputs=[query, embeds, imgs], outputs=[message2, output_img])
86
+
87
+
88
+ if __name__ == "__main__":
89
+ demo.queue(max_size=10).launch(debug=True, share=True)