Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image, ImageDraw
|
3 |
+
import torch
|
4 |
+
from transformers import OwlViTProcessor, OwlViTForObjectDetection, OwlViTModel, OwlViTImageProcessor
|
5 |
+
from transformers.image_transforms import center_to_corners_format
|
6 |
+
from transformers.models.owlvit.modeling_owlvit import box_iou
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
# from utils import iou
|
10 |
+
|
11 |
+
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
|
12 |
+
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
|
13 |
+
|
14 |
+
from transformers.models.owlvit.modeling_owlvit import OwlViTImageGuidedObjectDetectionOutput, OwlViTClassPredictionHead
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def classpredictionhead_box_forward(
|
21 |
+
self,
|
22 |
+
image_embeds,
|
23 |
+
query_indice,
|
24 |
+
query_mask,
|
25 |
+
):
|
26 |
+
image_class_embeds = self.dense0(image_embeds)
|
27 |
+
|
28 |
+
# Normalize image and text features
|
29 |
+
image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6)
|
30 |
+
print(image_class_embeds.shape)
|
31 |
+
query_embeds = image_class_embeds[0, query_indice].unsqueeze(0).unsqueeze(0)
|
32 |
+
print(query_embeds.shape)
|
33 |
+
# query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6)
|
34 |
+
|
35 |
+
# Get class predictions
|
36 |
+
pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
|
37 |
+
|
38 |
+
# Apply a learnable shift and scale to logits
|
39 |
+
logit_shift = self.logit_shift(image_embeds)
|
40 |
+
logit_scale = self.logit_scale(image_embeds)
|
41 |
+
logit_scale = self.elu(logit_scale) + 1
|
42 |
+
pred_logits = (pred_logits + logit_shift) * logit_scale
|
43 |
+
|
44 |
+
if query_mask is not None:
|
45 |
+
if query_mask.ndim > 1:
|
46 |
+
query_mask = torch.unsqueeze(query_mask, dim=-2)
|
47 |
+
|
48 |
+
pred_logits = pred_logits.to(torch.float64)
|
49 |
+
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
|
50 |
+
pred_logits = pred_logits.to(torch.float32)
|
51 |
+
|
52 |
+
return (pred_logits, image_class_embeds)
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
def class_predictor(
|
57 |
+
self,
|
58 |
+
image_feats,
|
59 |
+
query_indice=None,
|
60 |
+
query_mask=None,
|
61 |
+
):
|
62 |
+
|
63 |
+
(pred_logits, image_class_embeds) = self.class_head.classpredictionhead_box_forward(image_feats, query_indice, query_mask)
|
64 |
+
|
65 |
+
return (pred_logits, image_class_embeds)
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
def get_max_iou_indice(target_pred_boxes, query_box, target_sizes):
|
75 |
+
boxes = center_to_corners_format(target_pred_boxes)
|
76 |
+
img_h, img_w = target_sizes.unbind(1)
|
77 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
78 |
+
boxes = boxes * scale_fct[:, None, :]
|
79 |
+
|
80 |
+
iou, _ = box_iou(boxes.squeeze(0), query_box)
|
81 |
+
|
82 |
+
return iou.argmax()
|
83 |
+
|
84 |
+
|
85 |
+
def box_guided_detection(
|
86 |
+
self: OwlViTForObjectDetection,
|
87 |
+
pixel_values,
|
88 |
+
query_box=None,
|
89 |
+
target_sizes=None,
|
90 |
+
output_attentions=None,
|
91 |
+
output_hidden_states=None,
|
92 |
+
return_dict=None,
|
93 |
+
):
|
94 |
+
|
95 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
96 |
+
output_hidden_states = (
|
97 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
98 |
+
)
|
99 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
100 |
+
|
101 |
+
# Compute feature maps for the input and query images
|
102 |
+
# query_feature_map = self.image_embedder(pixel_values=query_pixel_values)[0]
|
103 |
+
feature_map, vision_outputs = self.image_embedder(
|
104 |
+
pixel_values=pixel_values,
|
105 |
+
output_attentions=output_attentions,
|
106 |
+
output_hidden_states=output_hidden_states,
|
107 |
+
)
|
108 |
+
|
109 |
+
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
110 |
+
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
111 |
+
|
112 |
+
# batch_size, num_patches, num_patches, hidden_dim = query_feature_map.shape
|
113 |
+
# query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
114 |
+
# # Get top class embedding and best box index for each query image in batch
|
115 |
+
# query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(query_image_feats, query_feature_map)
|
116 |
+
|
117 |
+
# Predict object boxes
|
118 |
+
target_pred_boxes = self.box_predictor(image_feats, feature_map)
|
119 |
+
|
120 |
+
# Get MAX IOU box corresponding embedding
|
121 |
+
query_indice = get_max_iou_indice(target_pred_boxes, query_box, target_sizes)
|
122 |
+
|
123 |
+
# Predict object classes [batch_size, num_patches, num_queries+1]
|
124 |
+
(pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_indice=query_indice)
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
if not return_dict:
|
131 |
+
output = (
|
132 |
+
feature_map,
|
133 |
+
# query_feature_map,
|
134 |
+
target_pred_boxes,
|
135 |
+
# query_pred_boxes,
|
136 |
+
pred_logits,
|
137 |
+
class_embeds,
|
138 |
+
vision_outputs.to_tuple(),
|
139 |
+
)
|
140 |
+
output = tuple(x for x in output if x is not None)
|
141 |
+
return output
|
142 |
+
|
143 |
+
return OwlViTImageGuidedObjectDetectionOutput(
|
144 |
+
image_embeds=feature_map,
|
145 |
+
# query_image_embeds=query_feature_map,
|
146 |
+
target_pred_boxes=target_pred_boxes,
|
147 |
+
# query_pred_boxes=query_pred_boxes,
|
148 |
+
logits=pred_logits,
|
149 |
+
class_embeds=class_embeds,
|
150 |
+
text_model_output=None,
|
151 |
+
vision_model_output=vision_outputs,
|
152 |
+
)
|
153 |
+
|
154 |
+
|
155 |
+
model.box_guided_detection = partial(box_guided_detection, model)
|
156 |
+
model.class_predictor = partial(class_predictor, model)
|
157 |
+
model.class_head.classpredictionhead_box_forward = partial(classpredictionhead_box_forward, model.class_head)
|
158 |
+
|
159 |
+
|
160 |
+
outputs = None
|
161 |
+
def prepare_embedds(xmin, ymin, xmax, ymax, image):
|
162 |
+
box = (int(xmin), int(ymin), int(xmax), int(ymax))
|
163 |
+
return (image, [(box, "manul")])
|
164 |
+
|
165 |
+
def manul_box_change(xmin, ymin, xmax, ymax, image):
|
166 |
+
box = (int(xmin), int(ymin), int(xmax), int(ymax))
|
167 |
+
return (image, [(box, "manul")])
|
168 |
+
|
169 |
+
def threshold_change(xmin, ymin, xmax, ymax, image, threshold, nms):
|
170 |
+
manul_box = (int(xmin), int(ymin), int(xmax), int(ymax))
|
171 |
+
|
172 |
+
global outputs
|
173 |
+
target_sizes = torch.Tensor([image.size[::-1]])
|
174 |
+
|
175 |
+
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms, target_sizes=target_sizes)
|
176 |
+
|
177 |
+
boxes = results[0]['boxes'].type(torch.int64).tolist()
|
178 |
+
scores = results[0]['scores'].tolist()
|
179 |
+
labels = list(zip(boxes, scores))
|
180 |
+
labels.append((manul_box, "manual"))
|
181 |
+
|
182 |
+
cnt = len(boxes) - 1
|
183 |
+
|
184 |
+
return (image, labels), cnt
|
185 |
+
|
186 |
+
def one_shot_detect(xmin, ymin, xmax, ymax, image, threshold, nms):
|
187 |
+
manul_box = (int(xmin), int(ymin), int(xmax), int(ymax))
|
188 |
+
|
189 |
+
global outputs
|
190 |
+
target_sizes = torch.Tensor([image.size[::-1]])
|
191 |
+
inputs = processor(images=image.convert("RGB"), return_tensors="pt")
|
192 |
+
outputs = model.box_guided_detection(**inputs, query_box=torch.Tensor([manul_box]), target_sizes=target_sizes)
|
193 |
+
|
194 |
+
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms, target_sizes=target_sizes)
|
195 |
+
|
196 |
+
boxes = results[0]['boxes'].type(torch.int64).tolist()
|
197 |
+
scores = results[0]['scores'].tolist()
|
198 |
+
labels = list(zip(boxes, scores))
|
199 |
+
labels.append((manul_box, "manual"))
|
200 |
+
|
201 |
+
cnt = len(boxes) - 1
|
202 |
+
|
203 |
+
return (image, labels), cnt
|
204 |
+
|
205 |
+
|
206 |
+
with gr.Blocks() as demo:
|
207 |
+
with gr.Row():
|
208 |
+
with gr.Column():
|
209 |
+
image = gr.Image(type="pil")
|
210 |
+
threshold = gr.Number(0.95, label="threshold", step=0.01)
|
211 |
+
nms = gr.Number(0.3, label="nms", step=0.01)
|
212 |
+
cnt = gr.Number(0, label="count", interactive=False)
|
213 |
+
with gr.Column():
|
214 |
+
annotatedimage = gr.AnnotatedImage()
|
215 |
+
with gr.Row():
|
216 |
+
xmin = gr.Number(8, label="xmin")
|
217 |
+
ymin = gr.Number(198, label="ymin")
|
218 |
+
xmax = gr.Number(100, label="xmax")
|
219 |
+
ymax = gr.Number(428, label="ymax")
|
220 |
+
button = gr.Button(variant="primary")
|
221 |
+
|
222 |
+
xmin.change(manul_box_change, [xmin, ymin, xmax, ymax, image], [annotatedimage])
|
223 |
+
ymin.change(manul_box_change, [xmin, ymin, xmax, ymax, image], [annotatedimage])
|
224 |
+
xmax.change(manul_box_change, [xmin, ymin, xmax, ymax, image], [annotatedimage])
|
225 |
+
ymax.change(manul_box_change, [xmin, ymin, xmax, ymax, image], [annotatedimage])
|
226 |
+
threshold.change(threshold_change, [xmin, ymin, xmax, ymax, image, threshold, nms], [annotatedimage, cnt])
|
227 |
+
nms.change(threshold_change, [xmin, ymin, xmax, ymax, image, threshold, nms], [annotatedimage, cnt])
|
228 |
+
image.upload(prepare_embedds, [xmin, ymin, xmax, ymax, image], [annotatedimage])
|
229 |
+
button.click(one_shot_detect, [xmin, ymin, xmax, ymax, image, threshold, nms], [annotatedimage, cnt])
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
demo.launch(server_port=7861)
|