Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Commit
·
fc81a43
1
Parent(s):
6cd385f
update README
Browse files- condition/midas/midas/vit.py +33 -13
condition/midas/midas/vit.py
CHANGED
@@ -128,12 +128,32 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
|
128 |
return posemb
|
129 |
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
def flat_forward_flex(model, x):
|
132 |
b, c, h, w = x.shape
|
133 |
|
134 |
-
pos_embed =
|
135 |
-
|
136 |
-
|
137 |
|
138 |
B = x.shape[0]
|
139 |
|
@@ -352,10 +372,10 @@ def _make_vit_b16_backbone(
|
|
352 |
|
353 |
# We inject this function into the VisionTransformer instances so that
|
354 |
# we can use it with interpolated position embeddings without modifying the library source.
|
355 |
-
pretrained.model.forward_flex = types.MethodType(forward_flex,
|
356 |
-
|
357 |
-
pretrained.model._resize_pos_embed = types.MethodType(
|
358 |
-
|
359 |
|
360 |
return pretrained
|
361 |
|
@@ -550,13 +570,13 @@ def _make_vit_b_rn50_backbone(
|
|
550 |
|
551 |
# We inject this function into the VisionTransformer instances so that
|
552 |
# we can use it with interpolated position embeddings without modifying the library source.
|
553 |
-
pretrained.model.forward_flex = types.MethodType(forward_flex,
|
554 |
-
|
555 |
|
556 |
-
# We inject this function into the VisionTransformer instances so that
|
557 |
-
# we can use it with interpolated position embeddings without modifying the library source.
|
558 |
-
pretrained.model._resize_pos_embed = types.MethodType(
|
559 |
-
|
560 |
|
561 |
return pretrained
|
562 |
|
|
|
128 |
return posemb
|
129 |
|
130 |
|
131 |
+
def _flat_resize_pos_embed(model, posemb, gs_h, gs_w):
|
132 |
+
posemb_tok, posemb_grid = (
|
133 |
+
posemb[:, :model.start_index],
|
134 |
+
posemb[0, model.start_index:],
|
135 |
+
)
|
136 |
+
|
137 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
138 |
+
|
139 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
|
140 |
+
-1).permute(0, 3, 1, 2)
|
141 |
+
posemb_grid = F.interpolate(posemb_grid,
|
142 |
+
size=(gs_h, gs_w),
|
143 |
+
mode="bilinear")
|
144 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
145 |
+
|
146 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
147 |
+
|
148 |
+
return posemb
|
149 |
+
|
150 |
+
|
151 |
def flat_forward_flex(model, x):
|
152 |
b, c, h, w = x.shape
|
153 |
|
154 |
+
pos_embed = _flat_resize_pos_embed(model, model.pos_embed,
|
155 |
+
h // model.patch_size[1],
|
156 |
+
w // model.patch_size[0])
|
157 |
|
158 |
B = x.shape[0]
|
159 |
|
|
|
372 |
|
373 |
# We inject this function into the VisionTransformer instances so that
|
374 |
# we can use it with interpolated position embeddings without modifying the library source.
|
375 |
+
# pretrained.model.forward_flex = types.MethodType(forward_flex,
|
376 |
+
# pretrained.model)
|
377 |
+
# pretrained.model._resize_pos_embed = types.MethodType(
|
378 |
+
# _resize_pos_embed, pretrained.model)
|
379 |
|
380 |
return pretrained
|
381 |
|
|
|
570 |
|
571 |
# We inject this function into the VisionTransformer instances so that
|
572 |
# we can use it with interpolated position embeddings without modifying the library source.
|
573 |
+
# pretrained.model.forward_flex = types.MethodType(forward_flex,
|
574 |
+
# pretrained.model)
|
575 |
|
576 |
+
# # We inject this function into the VisionTransformer instances so that
|
577 |
+
# # we can use it with interpolated position embeddings without modifying the library source.
|
578 |
+
# pretrained.model._resize_pos_embed = types.MethodType(
|
579 |
+
# _resize_pos_embed, pretrained.model)
|
580 |
|
581 |
return pretrained
|
582 |
|