import os import torch ORI_IMAGE_SIZE = 1024 IMAGE_SIZE = 256 REL_POS = 31 checkpoint = torch.load("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195.pth") image_encoder_pos_embed = checkpoint["image_encoder.pos_embed"] image_encoder_pos_embed = torch.nn.functional.interpolate(image_encoder_pos_embed.permute(0, 3, 1, 2), scale_factor=IMAGE_SIZE / ORI_IMAGE_SIZE, mode="bilinear", align_corners=True).permute(0, 2, 3, 1) checkpoint["image_encoder.pos_embed"] = image_encoder_pos_embed print(image_encoder_pos_embed.shape) for idx in [5, 11, 17, 23]: rel_pos_h = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"] rel_pos_w = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"] rel_pos_h = torch.nn.functional.interpolate( rel_pos_h.permute(1, 0).unsqueeze(0), size=REL_POS, mode="linear", align_corners=True, ).squeeze(0).permute(1, 0) rel_pos_w = torch.nn.functional.interpolate( rel_pos_w.permute(1, 0).unsqueeze(0), size=REL_POS, mode="linear", align_corners=True, ).squeeze(0).permute(1, 0) checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"] = rel_pos_h checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"] = rel_pos_w print(rel_pos_h.shape, rel_pos_w.shape) torch.save(checkpoint, f"/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_{IMAGE_SIZE}x{IMAGE_SIZE}.pth")