Spaces:
Runtime error
Runtime error
import os | |
import torch | |
ORI_IMAGE_SIZE = 1024 | |
IMAGE_SIZE = 256 | |
REL_POS = 31 | |
checkpoint = torch.load("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195.pth") | |
image_encoder_pos_embed = checkpoint["image_encoder.pos_embed"] | |
image_encoder_pos_embed = torch.nn.functional.interpolate(image_encoder_pos_embed.permute(0, 3, 1, 2), scale_factor=IMAGE_SIZE / ORI_IMAGE_SIZE, mode="bilinear", align_corners=True).permute(0, 2, 3, 1) | |
checkpoint["image_encoder.pos_embed"] = image_encoder_pos_embed | |
print(image_encoder_pos_embed.shape) | |
for idx in [5, 11, 17, 23]: | |
rel_pos_h = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"] | |
rel_pos_w = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"] | |
rel_pos_h = torch.nn.functional.interpolate( | |
rel_pos_h.permute(1, 0).unsqueeze(0), | |
size=REL_POS, mode="linear", | |
align_corners=True, | |
).squeeze(0).permute(1, 0) | |
rel_pos_w = torch.nn.functional.interpolate( | |
rel_pos_w.permute(1, 0).unsqueeze(0), | |
size=REL_POS, mode="linear", | |
align_corners=True, | |
).squeeze(0).permute(1, 0) | |
checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"] = rel_pos_h | |
checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"] = rel_pos_w | |
print(rel_pos_h.shape, rel_pos_w.shape) | |
torch.save(checkpoint, f"/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_{IMAGE_SIZE}x{IMAGE_SIZE}.pth") | |