chendl's picture
Add application file
0b7b08a
raw
history blame
1.42 kB
import os
import torch
ORI_IMAGE_SIZE = 1024
IMAGE_SIZE = 256
REL_POS = 31
checkpoint = torch.load("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195.pth")
image_encoder_pos_embed = checkpoint["image_encoder.pos_embed"]
image_encoder_pos_embed = torch.nn.functional.interpolate(image_encoder_pos_embed.permute(0, 3, 1, 2), scale_factor=IMAGE_SIZE / ORI_IMAGE_SIZE, mode="bilinear", align_corners=True).permute(0, 2, 3, 1)
checkpoint["image_encoder.pos_embed"] = image_encoder_pos_embed
print(image_encoder_pos_embed.shape)
for idx in [5, 11, 17, 23]:
rel_pos_h = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"]
rel_pos_w = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"]
rel_pos_h = torch.nn.functional.interpolate(
rel_pos_h.permute(1, 0).unsqueeze(0),
size=REL_POS, mode="linear",
align_corners=True,
).squeeze(0).permute(1, 0)
rel_pos_w = torch.nn.functional.interpolate(
rel_pos_w.permute(1, 0).unsqueeze(0),
size=REL_POS, mode="linear",
align_corners=True,
).squeeze(0).permute(1, 0)
checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"] = rel_pos_h
checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"] = rel_pos_w
print(rel_pos_h.shape, rel_pos_w.shape)
torch.save(checkpoint, f"/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_{IMAGE_SIZE}x{IMAGE_SIZE}.pth")