import torch import torch.nn as nn from timm.models.layers import to_2tuple class PatchEmbed_org(nn.Module): """Image to Patch Embedding""" def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) y = x.flatten(2).transpose(1, 2) return y class PatchEmbed_new(nn.Module): """Flexible Image to Patch Embedding""" def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) self.img_size = img_size self.patch_size = patch_size self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=stride ) # with overlapped patches # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w self.patch_hw = (h, w) self.num_patches = h * w def get_output_shape(self, img_size): # todo: don't be lazy.. return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." # x = self.proj(x).flatten(2).transpose(1, 2) x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 return x class PatchEmbed3D_new(nn.Module): """Flexible Image to Patch Embedding""" def __init__( self, video_size=(16, 224, 224), patch_size=(2, 16, 16), in_chans=3, embed_dim=768, stride=(2, 16, 16), ): super().__init__() self.video_size = video_size self.patch_size = patch_size self.in_chans = in_chans self.proj = nn.Conv3d( in_chans, embed_dim, kernel_size=patch_size, stride=stride ) _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w self.patch_thw = (t, h, w) self.num_patches = t * h * w def get_output_shape(self, video_size): # todo: don't be lazy.. return self.proj( torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2]) ).shape def forward(self, x): B, C, T, H, W = x.shape x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14 x = x.flatten(2) # 32, 768, 1568 x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768 return x if __name__ == "__main__": # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16)) # input = torch.rand(8,1,1024,128) # output = patch_emb(input) # print(output.shape) # (8,512,64) patch_emb = PatchEmbed3D_new( video_size=(6, 224, 224), patch_size=(2, 16, 16), in_chans=3, embed_dim=768, stride=(2, 16, 16), ) input = torch.rand(8, 3, 6, 224, 224) output = patch_emb(input) print(output.shape) # (8,64)