wondervictor commited on
Commit
fc81a43
·
1 Parent(s): 6cd385f

update README

Browse files
Files changed (1) hide show
  1. condition/midas/midas/vit.py +33 -13
condition/midas/midas/vit.py CHANGED
@@ -128,12 +128,32 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w):
128
  return posemb
129
 
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def flat_forward_flex(model, x):
132
  b, c, h, w = x.shape
133
 
134
- pos_embed = model._resize_pos_embed(model.pos_embed,
135
- h // model.patch_size[1],
136
- w // model.patch_size[0])
137
 
138
  B = x.shape[0]
139
 
@@ -352,10 +372,10 @@ def _make_vit_b16_backbone(
352
 
353
  # We inject this function into the VisionTransformer instances so that
354
  # we can use it with interpolated position embeddings without modifying the library source.
355
- pretrained.model.forward_flex = types.MethodType(forward_flex,
356
- pretrained.model)
357
- pretrained.model._resize_pos_embed = types.MethodType(
358
- _resize_pos_embed, pretrained.model)
359
 
360
  return pretrained
361
 
@@ -550,13 +570,13 @@ def _make_vit_b_rn50_backbone(
550
 
551
  # We inject this function into the VisionTransformer instances so that
552
  # we can use it with interpolated position embeddings without modifying the library source.
553
- pretrained.model.forward_flex = types.MethodType(forward_flex,
554
- pretrained.model)
555
 
556
- # We inject this function into the VisionTransformer instances so that
557
- # we can use it with interpolated position embeddings without modifying the library source.
558
- pretrained.model._resize_pos_embed = types.MethodType(
559
- _resize_pos_embed, pretrained.model)
560
 
561
  return pretrained
562
 
 
128
  return posemb
129
 
130
 
131
+ def _flat_resize_pos_embed(model, posemb, gs_h, gs_w):
132
+ posemb_tok, posemb_grid = (
133
+ posemb[:, :model.start_index],
134
+ posemb[0, model.start_index:],
135
+ )
136
+
137
+ gs_old = int(math.sqrt(len(posemb_grid)))
138
+
139
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
140
+ -1).permute(0, 3, 1, 2)
141
+ posemb_grid = F.interpolate(posemb_grid,
142
+ size=(gs_h, gs_w),
143
+ mode="bilinear")
144
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
145
+
146
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
147
+
148
+ return posemb
149
+
150
+
151
  def flat_forward_flex(model, x):
152
  b, c, h, w = x.shape
153
 
154
+ pos_embed = _flat_resize_pos_embed(model, model.pos_embed,
155
+ h // model.patch_size[1],
156
+ w // model.patch_size[0])
157
 
158
  B = x.shape[0]
159
 
 
372
 
373
  # We inject this function into the VisionTransformer instances so that
374
  # we can use it with interpolated position embeddings without modifying the library source.
375
+ # pretrained.model.forward_flex = types.MethodType(forward_flex,
376
+ # pretrained.model)
377
+ # pretrained.model._resize_pos_embed = types.MethodType(
378
+ # _resize_pos_embed, pretrained.model)
379
 
380
  return pretrained
381
 
 
570
 
571
  # We inject this function into the VisionTransformer instances so that
572
  # we can use it with interpolated position embeddings without modifying the library source.
573
+ # pretrained.model.forward_flex = types.MethodType(forward_flex,
574
+ # pretrained.model)
575
 
576
+ # # We inject this function into the VisionTransformer instances so that
577
+ # # we can use it with interpolated position embeddings without modifying the library source.
578
+ # pretrained.model._resize_pos_embed = types.MethodType(
579
+ # _resize_pos_embed, pretrained.model)
580
 
581
  return pretrained
582