File size: 1,415 Bytes
0b7b08a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import torch
ORI_IMAGE_SIZE = 1024
IMAGE_SIZE = 256
REL_POS = 31

checkpoint = torch.load("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195.pth")
image_encoder_pos_embed = checkpoint["image_encoder.pos_embed"]
image_encoder_pos_embed = torch.nn.functional.interpolate(image_encoder_pos_embed.permute(0, 3, 1, 2), scale_factor=IMAGE_SIZE / ORI_IMAGE_SIZE, mode="bilinear", align_corners=True).permute(0, 2, 3, 1)
checkpoint["image_encoder.pos_embed"] = image_encoder_pos_embed
print(image_encoder_pos_embed.shape)

for idx in [5, 11, 17, 23]:
    rel_pos_h = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"]
    rel_pos_w = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"]
    rel_pos_h = torch.nn.functional.interpolate(
        rel_pos_h.permute(1, 0).unsqueeze(0),
        size=REL_POS, mode="linear",
        align_corners=True,
    ).squeeze(0).permute(1, 0)
    rel_pos_w = torch.nn.functional.interpolate(
        rel_pos_w.permute(1, 0).unsqueeze(0),
        size=REL_POS, mode="linear",
        align_corners=True,
    ).squeeze(0).permute(1, 0)
    checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"] = rel_pos_h
    checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"] = rel_pos_w
    print(rel_pos_h.shape, rel_pos_w.shape)

torch.save(checkpoint, f"/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_{IMAGE_SIZE}x{IMAGE_SIZE}.pth")