HaohuaLv commited on
Commit
830f83c
1 Parent(s): 58a00fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -27
app.py CHANGED
@@ -1,16 +1,20 @@
1
  import gradio as gr
2
  from PIL import Image, ImageDraw
3
  import torch
4
- from transformers import OwlViTProcessor, OwlViTForObjectDetection
5
  from transformers.image_transforms import center_to_corners_format
6
  from transformers.models.owlvit.modeling_owlvit import box_iou
7
  from functools import partial
 
8
 
 
9
 
10
  processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
11
  model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
12
 
13
- from transformers.models.owlvit.modeling_owlvit import OwlViTImageGuidedObjectDetectionOutput
 
 
14
 
15
 
16
 
@@ -24,9 +28,7 @@ def classpredictionhead_box_forward(
24
 
25
  # Normalize image and text features
26
  image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6)
27
- print(image_class_embeds.shape)
28
  query_embeds = image_class_embeds[0, query_indice].unsqueeze(0).unsqueeze(0)
29
- print(query_embeds.shape)
30
  # query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6)
31
 
32
  # Get class predictions
@@ -66,6 +68,8 @@ def class_predictor(
66
 
67
 
68
 
 
 
69
  def get_max_iou_indice(target_pred_boxes, query_box, target_sizes):
70
  boxes = center_to_corners_format(target_pred_boxes)
71
  img_h, img_w = target_sizes.unbind(1)
@@ -104,6 +108,12 @@ def box_guided_detection(
104
  batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
105
  image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
106
 
 
 
 
 
 
 
107
  target_pred_boxes = self.box_predictor(image_feats, feature_map)
108
 
109
  # Get MAX IOU box corresponding embedding
@@ -113,6 +123,9 @@ def box_guided_detection(
113
  (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_indice=query_indice)
114
 
115
 
 
 
 
116
  if not return_dict:
117
  output = (
118
  feature_map,
@@ -150,31 +163,30 @@ def prepare_embedds(xmin, ymin, xmax, ymax, image):
150
 
151
  def manul_box_change(xmin, ymin, xmax, ymax, image):
152
  box = (int(xmin), int(ymin), int(xmax), int(ymax))
153
- return (image, [(box, "manul")])
154
 
155
  def threshold_change(xmin, ymin, xmax, ymax, image, threshold, nms):
156
  manul_box = (int(xmin), int(ymin), int(xmax), int(ymax))
157
 
158
  global outputs
159
- target_sizes = torch.Tensor([image.size[::-1]])
160
 
161
  results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms, target_sizes=target_sizes)
162
 
163
  boxes = results[0]['boxes'].type(torch.int64).tolist()
164
  scores = results[0]['scores'].tolist()
165
  labels = list(zip(boxes, scores))
166
- labels.append((manul_box, "manual"))
167
 
168
  cnt = len(boxes)
169
 
170
- return (image, labels), cnt
171
 
172
  def one_shot_detect(xmin, ymin, xmax, ymax, image, threshold, nms):
173
  manul_box = (int(xmin), int(ymin), int(xmax), int(ymax))
174
 
175
  global outputs
176
- target_sizes = torch.Tensor([image.size[::-1]])
177
- inputs = processor(images=image.convert("RGB"), return_tensors="pt")
178
  outputs = model.box_guided_detection(**inputs, query_box=torch.Tensor([manul_box]), target_sizes=target_sizes)
179
 
180
  results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms, target_sizes=target_sizes)
@@ -182,37 +194,60 @@ def one_shot_detect(xmin, ymin, xmax, ymax, image, threshold, nms):
182
  boxes = results[0]['boxes'].type(torch.int64).tolist()
183
  scores = results[0]['scores'].tolist()
184
  labels = list(zip(boxes, scores))
185
- labels.append((manul_box, "manual"))
186
 
187
  cnt = len(boxes)
188
 
189
- return (image, labels), cnt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
 
192
  with gr.Blocks() as demo:
193
  with gr.Row():
194
  with gr.Column():
195
- image = gr.Image(type="pil")
 
196
  threshold = gr.Number(0.95, label="threshold", step=0.01)
197
  nms = gr.Number(0.3, label="nms", step=0.01)
198
  cnt = gr.Number(0, label="count", interactive=False)
199
  with gr.Column():
200
  annotatedimage = gr.AnnotatedImage()
201
  with gr.Row():
202
- xmin = gr.Number(8, label="xmin")
203
- ymin = gr.Number(198, label="ymin")
204
- xmax = gr.Number(100, label="xmax")
205
- ymax = gr.Number(428, label="ymax")
206
- button = gr.Button(variant="primary")
207
-
208
- xmin.change(manul_box_change, [xmin, ymin, xmax, ymax, image], [annotatedimage])
209
- ymin.change(manul_box_change, [xmin, ymin, xmax, ymax, image], [annotatedimage])
210
- xmax.change(manul_box_change, [xmin, ymin, xmax, ymax, image], [annotatedimage])
211
- ymax.change(manul_box_change, [xmin, ymin, xmax, ymax, image], [annotatedimage])
212
- threshold.change(threshold_change, [xmin, ymin, xmax, ymax, image, threshold, nms], [annotatedimage, cnt])
213
- nms.change(threshold_change, [xmin, ymin, xmax, ymax, image, threshold, nms], [annotatedimage, cnt])
214
- image.upload(prepare_embedds, [xmin, ymin, xmax, ymax, image], [annotatedimage])
215
- button.click(one_shot_detect, [xmin, ymin, xmax, ymax, image, threshold, nms], [annotatedimage, cnt])
 
 
 
 
216
 
217
 
218
 
 
1
  import gradio as gr
2
  from PIL import Image, ImageDraw
3
  import torch
4
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection, OwlViTModel, OwlViTImageProcessor
5
  from transformers.image_transforms import center_to_corners_format
6
  from transformers.models.owlvit.modeling_owlvit import box_iou
7
  from functools import partial
8
+ import numpy as np
9
 
10
+ # from utils import iou
11
 
12
  processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
13
  model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
14
 
15
+ from transformers.models.owlvit.modeling_owlvit import OwlViTImageGuidedObjectDetectionOutput, OwlViTClassPredictionHead
16
+
17
+
18
 
19
 
20
 
 
28
 
29
  # Normalize image and text features
30
  image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6)
 
31
  query_embeds = image_class_embeds[0, query_indice].unsqueeze(0).unsqueeze(0)
 
32
  # query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6)
33
 
34
  # Get class predictions
 
68
 
69
 
70
 
71
+
72
+
73
  def get_max_iou_indice(target_pred_boxes, query_box, target_sizes):
74
  boxes = center_to_corners_format(target_pred_boxes)
75
  img_h, img_w = target_sizes.unbind(1)
 
108
  batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
109
  image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
110
 
111
+ # batch_size, num_patches, num_patches, hidden_dim = query_feature_map.shape
112
+ # query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches * num_patches, hidden_dim))
113
+ # # Get top class embedding and best box index for each query image in batch
114
+ # query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(query_image_feats, query_feature_map)
115
+
116
+ # Predict object boxes
117
  target_pred_boxes = self.box_predictor(image_feats, feature_map)
118
 
119
  # Get MAX IOU box corresponding embedding
 
123
  (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_indice=query_indice)
124
 
125
 
126
+
127
+
128
+
129
  if not return_dict:
130
  output = (
131
  feature_map,
 
163
 
164
  def manul_box_change(xmin, ymin, xmax, ymax, image):
165
  box = (int(xmin), int(ymin), int(xmax), int(ymax))
166
+ return (image["image"], [(box, "manul")])
167
 
168
  def threshold_change(xmin, ymin, xmax, ymax, image, threshold, nms):
169
  manul_box = (int(xmin), int(ymin), int(xmax), int(ymax))
170
 
171
  global outputs
172
+ target_sizes = torch.Tensor([image["image"].size[::-1]])
173
 
174
  results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms, target_sizes=target_sizes)
175
 
176
  boxes = results[0]['boxes'].type(torch.int64).tolist()
177
  scores = results[0]['scores'].tolist()
178
  labels = list(zip(boxes, scores))
 
179
 
180
  cnt = len(boxes)
181
 
182
+ return (image["image"], labels), cnt
183
 
184
  def one_shot_detect(xmin, ymin, xmax, ymax, image, threshold, nms):
185
  manul_box = (int(xmin), int(ymin), int(xmax), int(ymax))
186
 
187
  global outputs
188
+ target_sizes = torch.Tensor([image["image"].size[::-1]])
189
+ inputs = processor(images=image["image"].convert("RGB"), return_tensors="pt")
190
  outputs = model.box_guided_detection(**inputs, query_box=torch.Tensor([manul_box]), target_sizes=target_sizes)
191
 
192
  results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms, target_sizes=target_sizes)
 
194
  boxes = results[0]['boxes'].type(torch.int64).tolist()
195
  scores = results[0]['scores'].tolist()
196
  labels = list(zip(boxes, scores))
 
197
 
198
  cnt = len(boxes)
199
 
200
+ return (image["image"], labels), cnt
201
+
202
+ def save_embedding(exam):
203
+ print(exam)
204
+ global outputs
205
+ embedding = outputs["class_embeds"][0, outputs["logits"].argmax()]
206
+ return embedding.detach().numpy()
207
+
208
+
209
+ def sketch2box(sketch_box):
210
+ mask = sketch_box["mask"].convert("L")
211
+ mask = np.array(mask)
212
+
213
+ masked_index = np.where(mask == 255)
214
+ if len(masked_index[0]) == 0:
215
+ return (sketch_box["image"], []), -1, -1, -1, -1
216
+ xmin, ymin, xmax, ymax = masked_index[1].min(), masked_index[0].min(), masked_index[1].max(), masked_index[0].max()
217
+ box = (xmin, ymin, xmax, ymax)
218
+
219
+ return (sketch_box["image"], [(box, "manual")]), xmin, ymin, xmax, ymax
220
 
221
 
222
  with gr.Blocks() as demo:
223
  with gr.Row():
224
  with gr.Column():
225
+ sketch_box = gr.Image(type="pil", source="upload", tool="sketch")
226
+ box_preview = gr.AnnotatedImage(type="pil", interactive=False, height=256)
227
  threshold = gr.Number(0.95, label="threshold", step=0.01)
228
  nms = gr.Number(0.3, label="nms", step=0.01)
229
  cnt = gr.Number(0, label="count", interactive=False)
230
  with gr.Column():
231
  annotatedimage = gr.AnnotatedImage()
232
  with gr.Row():
233
+ xmin = gr.Number(-1, label="xmin")
234
+ ymin = gr.Number(-1, label="ymin")
235
+ xmax = gr.Number(-1, label="xmax")
236
+ ymax = gr.Number(-1, label="ymax")
237
+ with gr.Row():
238
+ run_button = gr.Button(variant="primary")
239
+ # save_button = gr.Button("Save", variant="secondary")
240
+
241
+
242
+ sketch_box.edit(sketch2box, [sketch_box], [box_preview, xmin, ymin, xmax, ymax])
243
+ xmin.change(manul_box_change, [xmin, ymin, xmax, ymax, sketch_box], [box_preview])
244
+ ymin.change(manul_box_change, [xmin, ymin, xmax, ymax, sketch_box], [box_preview])
245
+ xmax.change(manul_box_change, [xmin, ymin, xmax, ymax, sketch_box], [box_preview])
246
+ ymax.change(manul_box_change, [xmin, ymin, xmax, ymax, sketch_box], [box_preview])
247
+ threshold.change(threshold_change, [xmin, ymin, xmax, ymax, sketch_box, threshold, nms], [annotatedimage, cnt])
248
+ nms.change(threshold_change, [xmin, ymin, xmax, ymax, sketch_box, threshold, nms], [annotatedimage, cnt])
249
+ run_button.click(one_shot_detect, [xmin, ymin, xmax, ymax, sketch_box, threshold, nms], [annotatedimage, cnt])
250
+
251
 
252
 
253