|
import abc |
|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from sentence_transformers import SentenceTransformer |
|
from timm.models.vision_transformer import ( |
|
VisionTransformer, |
|
build_model_with_cfg, |
|
checkpoint_filter_fn, |
|
checkpoint_seq, |
|
resolve_pretrained_cfg, |
|
) |
|
from torch import Tensor, nn |
|
|
|
|
|
class BlankLayer(nn.Module): |
|
pass |
|
|
|
|
|
class CustomViT(VisionTransformer): |
|
def __init__( |
|
self, |
|
*args, |
|
image_pooling="gmp", |
|
**kwargs, |
|
): |
|
super(CustomViT, self).__init__( |
|
*args, **kwargs |
|
) |
|
self.image_pooling = image_pooling |
|
|
|
def forward_head(self, x, pre_logits: bool = False): |
|
if self.image_pooling: |
|
if self.image_pooling == "gap": |
|
x = x[:, self.num_prefix_tokens:].mean(dim=1) |
|
elif self.image_pooling == "gmp": |
|
x = x[:, self.num_prefix_tokens:].max(dim=-2)[0] |
|
elif self.image_pooling == "all": |
|
x = x[:, self.num_prefix_tokens:] |
|
else: |
|
x = x[:, 0] |
|
x = self.fc_norm(x) |
|
return x if pre_logits else self.head(x) |
|
|
|
def forward(self, x, get_pos_tokens=False): |
|
x = self.forward_features(x, get_pos_tokens=get_pos_tokens) |
|
if get_pos_tokens: |
|
return self.fc_norm(x[:, self.num_prefix_tokens:]) |
|
x = self.forward_head(x) |
|
return x |
|
|
|
def forward_features(self, x, get_pos_tokens=False): |
|
_, nc, h, w = x.shape |
|
x = self.patch_embed(x) |
|
x = self._pos_embed(x, w, h) |
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
x = checkpoint_seq(self.blocks, x) |
|
else: |
|
x = self.blocks(x) |
|
x = self.norm(x) |
|
return x |
|
|
|
def _pos_embed(self, x, w, h): |
|
if self.no_embed_class: |
|
|
|
|
|
x = x + self.pos_embed |
|
if self.cls_token is not None: |
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) |
|
else: |
|
|
|
|
|
if self.cls_token is not None: |
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) |
|
x = x + self._interpolate_pos_encoding(x, w, h) |
|
return self.pos_drop(x) |
|
|
|
def _interpolate_pos_encoding(self, x, w, h): |
|
npatch = x.shape[1] - 1 |
|
N = self.pos_embed.shape[1] - 1 |
|
if npatch == N and w == h: |
|
return self.pos_embed |
|
class_pos_embed = self.pos_embed[:, 0] |
|
patch_pos_embed = self.pos_embed[:, 1:] |
|
dim = x.shape[-1] |
|
w0 = w // self.patch_embed.patch_size[0] |
|
h0 = h // self.patch_embed.patch_size[1] |
|
|
|
|
|
w0, h0 = w0 + 0.1, h0 + 0.1 |
|
patch_pos_embed = nn.functional.interpolate( |
|
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( |
|
0, 3, 1, 2 |
|
), |
|
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), |
|
mode="bicubic", |
|
) |
|
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] |
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) |
|
|
|
|
|
def _create_vision_transformer(variant, pretrained=False, **kwargs): |
|
if kwargs.get("features_only", None): |
|
raise RuntimeError("features_only not implemented for Vision Transformer models.") |
|
|
|
pretrained_cfg = resolve_pretrained_cfg( |
|
variant, pretrained_cfg=kwargs.pop("pretrained_cfg", None) |
|
) |
|
model = build_model_with_cfg( |
|
CustomViT, |
|
variant, |
|
pretrained, |
|
pretrained_cfg=pretrained_cfg, |
|
pretrained_filter_fn=checkpoint_filter_fn, |
|
pretrained_custom_load="npz" in pretrained_cfg["url"], |
|
**kwargs, |
|
) |
|
return model |
|
|
|
|
|
def vit_base_patch16_224(pretrained=False, variant="vit_base_patch16_224_dino", **kwargs): |
|
"""ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294""" |
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) |
|
model = _create_vision_transformer(variant, pretrained=pretrained, **model_kwargs) |
|
return model |
|
|
|
|
|
class CLIPpyModel(abc.ABC, torch.nn.Module): |
|
""" Implements code for running inference with pre-trained CLIPpy model. |
|
|
|
NOTE: weights used are for a model trained with lower batch-size leading to results below those in paper. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
image_pooling: str = "cls", |
|
text_pooling: str = "gap", |
|
): |
|
super().__init__() |
|
|
|
self.visual = BlankLayer() |
|
|
|
self.visual.trunk = vit_base_patch16_224(True, image_pooling=image_pooling) |
|
|
|
self.text = SentenceTransformer("sentence-transformers/sentence-t5-base") |
|
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07)) |
|
self.set_text_pooling(text_pooling) |
|
|
|
self._divisor_eps = 1e-4 |
|
self._image_pooling = image_pooling |
|
self._text_pooling = text_pooling |
|
|
|
def forward( |
|
self, |
|
images: Tensor, |
|
input_ids: Tensor, |
|
input_id_masks: Tensor, |
|
get_pos_tokens: bool = False, |
|
**kwargs, |
|
): |
|
|
|
image_encodings = self.encode_image(images, get_pos_tokens=get_pos_tokens) |
|
|
|
if get_pos_tokens: |
|
return { |
|
image_encodings: image_encodings, |
|
} |
|
|
|
text_encodings = self.encode_text(input_ids, input_id_masks) |
|
|
|
return { |
|
image_encodings: image_encodings, |
|
text_encodings: text_encodings, |
|
} |
|
|
|
def encode_text(self, input_ids: Tensor, input_id_masks: Tensor = None, **kwargs): |
|
output = self.text({"input_ids": input_ids, "attention_mask": input_id_masks})[ |
|
"sentence_embedding" |
|
] |
|
return self.text_head(output) |
|
|
|
def text_head(self, hidden_states: Tensor, input_id_masks: Tensor = None, **kwargs): |
|
return F.normalize(hidden_states, dim=-1, eps=self._divisor_eps).float() |
|
|
|
def encode_image(self, images: Tensor, get_pos_tokens: bool = False, **kwargs): |
|
output = self.visual.trunk(images, get_pos_tokens) |
|
return self.image_head(output, get_pos_tokens=get_pos_tokens) |
|
|
|
def image_head(self, hidden_states: Tensor, get_pos_tokens: bool = False, **kwargs): |
|
return F.normalize(hidden_states, dim=-1, eps=self._divisor_eps).float() |
|
|
|
def set_text_pooling(self, pooling): |
|
""" Converts pooling in the Hugging Face model to be max or average pooling""" |
|
if pooling == "gmp": |
|
self.text[1].pooling_mode_mean_tokens = False |
|
self.text[1].pooling_mode_max_tokens = True |
|
elif pooling == "gap": |
|
pass |
|
else: |
|
raise NotImplementedError(f"{pooling} not implemented") |
|
|