|
|
|
|
|
import math |
|
from itertools import chain |
|
from typing import Any, Optional |
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.functional import interpolate |
|
from einops.layers.torch import Rearrange |
|
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
from transformers import AutoConfig, AutoModel, AutoProcessor, AutoImageProcessor |
|
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTModel |
|
|
|
def handle_feature_output( |
|
x: torch.Tensor, feature_reduce_method: Optional[str] = None, num_discard_tokens: int = 0 |
|
) -> torch.Tensor: |
|
"""Handle feature output from transformer. |
|
|
|
Args: |
|
x (torch.Tensor): input feature to be handled. shape is |
|
[B, 1+H*W+N, C] if including both CLS and register tokens. |
|
[B, 1+H*W, C] for standard model (N=0). |
|
[B, H*W, C] for model without CLS. |
|
feature_reduce_method (Optional[str]): method to select token. Options: |
|
- `mean_pooling`: average over spatial tokens (non CLS tokens), output shape = [B, C]. |
|
- `max_pooling`: max over spatial tokens, output shape = [B, C]. |
|
- `cls`: return CLS token only, output shape = [B, C]. |
|
- `identity`: return the feature without touching it, output shape = input shape. |
|
- `None`: return spatial tokens, output shape = [B, H*W, C] (assuming input is [B, 1+H*W, C]). |
|
suppose raw feature is in shape [B, 1+H*W, C], `1` corresponds to CLS token. |
|
num_discard_tokens (int): |
|
number of tokens to be discarded. Assuming they are at the end of the sequence. |
|
Returns: |
|
torch.Tensor: selected feature tokens. |
|
""" |
|
|
|
match feature_reduce_method: |
|
case "mean_pooling": |
|
return torch.mean(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) |
|
case "max_pooling": |
|
return torch.amax(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) |
|
case "cls": |
|
return x[:, 0] |
|
case "identity": |
|
return x |
|
case None: |
|
return x[:, 1 : x.size(1) - num_discard_tokens] |
|
case _: |
|
raise NotImplementedError(f"feature_reduce_method {feature_reduce_method} it not implemented.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ViTEmbeddingsNoCLS(ViTEmbeddings): |
|
"""ViT Embedding Module without CLS token.""" |
|
|
|
def __init__(self, config: AutoConfig, use_mask_token: bool = False): |
|
"""Initialization. |
|
|
|
Args: |
|
config (AutoConfig): config for ViT. |
|
use_mask_token (bool, optional): whether to use mask token. Defaults to False. |
|
""" |
|
super(ViTEmbeddingsNoCLS, self).__init__(config, use_mask_token=use_mask_token) |
|
self.cls_token = None |
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
|
""" |
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher |
|
resolution images. |
|
|
|
Source: |
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 |
|
""" |
|
|
|
num_patches = embeddings.shape[1] |
|
num_positions = self.position_embeddings.shape[1] - 1 |
|
if num_patches == num_positions and height == width: |
|
return self.position_embeddings |
|
patch_pos_embed = self.position_embeddings[:, 1:] |
|
dim = embeddings.shape[-1] |
|
h0 = height // self.config.patch_size |
|
w0 = width // self.config.patch_size |
|
|
|
|
|
h0, w0 = h0 + 0.1, w0 + 0.1 |
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) |
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) |
|
patch_pos_embed = nn.functional.interpolate( |
|
patch_pos_embed, |
|
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] |
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
return patch_pos_embed |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.Tensor, |
|
bool_masked_pos: Optional[torch.BoolTensor] = None, |
|
interpolate_pos_encoding: bool = False, |
|
) -> torch.Tensor: |
|
batch_size, num_channels, height, width = pixel_values.shape |
|
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) |
|
|
|
if bool_masked_pos is not None: |
|
seq_length = embeddings.shape[1] |
|
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) |
|
|
|
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) |
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask |
|
|
|
|
|
if interpolate_pos_encoding: |
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) |
|
else: |
|
embeddings = embeddings + self.position_embeddings[:, 1:] |
|
|
|
embeddings = self.dropout(embeddings) |
|
|
|
return embeddings |
|
|
|
|
|
|
|
class ViTModelNoCLS(ViTModel): |
|
"""ViT Model without CLS token.""" |
|
|
|
def __init__(self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None: |
|
super(ViTModelNoCLS, self).__init__(config, add_pooling_layer, use_mask_token) |
|
self.embeddings = ViTEmbeddingsNoCLS(config, use_mask_token=use_mask_token) |
|
self.no_cls = True |
|
|
|
def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None: |
|
"""Initialize the weights""" |
|
if isinstance(module, (nn.Linear, nn.Conv2d)): |
|
|
|
|
|
module.weight.data = nn.init.trunc_normal_( |
|
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range |
|
).to(module.weight.dtype) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, ViTEmbeddings): |
|
module.position_embeddings.data = nn.init.trunc_normal_( |
|
module.position_embeddings.data.to(torch.float32), |
|
mean=0.0, |
|
std=self.config.initializer_range, |
|
).to(module.position_embeddings.dtype) |
|
|
|
|
|
|
|
class ViTEmbeddingsReg(ViTEmbeddings): |
|
""" |
|
ViT Embedding Module with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1 |
|
""" |
|
|
|
def __init__(self, config: AutoConfig, use_mask_token: bool = False, num_reg_tokens: int = 7): |
|
super(ViTEmbeddingsReg, self).__init__(config, use_mask_token=use_mask_token) |
|
self.reg_token = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size)) |
|
self.num_reg_tokens = num_reg_tokens |
|
self.reg_pos_embed = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size)) |
|
|
|
self.reg_pos_embed.data = nn.init.trunc_normal_( |
|
self.reg_pos_embed.data.to(torch.float32), |
|
mean=0.0, |
|
std=self.config.initializer_range, |
|
).to(self.reg_pos_embed.dtype) |
|
|
|
self.reg_token.data = nn.init.trunc_normal_( |
|
self.reg_token.data.to(torch.float32), |
|
mean=0.0, |
|
std=self.config.initializer_range, |
|
).to(self.reg_token.dtype) |
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
|
""" |
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher |
|
resolution images. |
|
|
|
Source: |
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 |
|
""" |
|
|
|
num_patches = embeddings.shape[1] - 1 - self.num_reg_tokens |
|
num_positions = self.position_embeddings.shape[1] - 1 |
|
if num_patches == num_positions and height == width: |
|
return self.position_embeddings |
|
class_pos_embed = self.position_embeddings[:, 0] |
|
patch_pos_embed = self.position_embeddings[:, 1:] |
|
reg_pos_embed = self.reg_pos_embed |
|
dim = embeddings.shape[-1] |
|
h0 = height // self.config.patch_size |
|
w0 = width // self.config.patch_size |
|
|
|
|
|
h0, w0 = h0 + 0.1, w0 + 0.1 |
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) |
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) |
|
patch_pos_embed = nn.functional.interpolate( |
|
patch_pos_embed, |
|
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] |
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, reg_pos_embed), dim=1) |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.Tensor, |
|
bool_masked_pos: Optional[torch.BoolTensor] = None, |
|
interpolate_pos_encoding: bool = False, |
|
) -> torch.Tensor: |
|
batch_size, num_channels, height, width = pixel_values.shape |
|
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) |
|
|
|
if bool_masked_pos is not None: |
|
seq_length = embeddings.shape[1] |
|
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) |
|
|
|
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) |
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask |
|
|
|
|
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
|
reg_tokens = self.reg_token.expand(batch_size, -1, -1) |
|
embeddings = torch.cat((cls_tokens, embeddings, reg_tokens), dim=1) |
|
|
|
|
|
if interpolate_pos_encoding: |
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) |
|
else: |
|
embeddings = embeddings + torch.cat([self.position_embeddings, self.reg_pos_embed], dim=1) |
|
|
|
embeddings = self.dropout(embeddings) |
|
|
|
return embeddings |
|
|
|
|
|
|
|
class ViTModelReg(ViTModel): |
|
"""ViT Model with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1""" |
|
|
|
def __init__( |
|
self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, num_reg_tokens: int = 7 |
|
): |
|
super(ViTModelReg, self).__init__(config, add_pooling_layer, use_mask_token) |
|
self.embeddings = ViTEmbeddingsReg(config, use_mask_token=use_mask_token, num_reg_tokens=num_reg_tokens) |
|
self.num_reg_tokens = num_reg_tokens |
|
|
|
def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None: |
|
"""Initialize the weights""" |
|
if isinstance(module, (nn.Linear, nn.Conv2d)): |
|
|
|
|
|
module.weight.data = nn.init.trunc_normal_( |
|
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range |
|
).to(module.weight.dtype) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, ViTEmbeddings): |
|
module.position_embeddings.data = nn.init.trunc_normal_( |
|
module.position_embeddings.data.to(torch.float32), |
|
mean=0.0, |
|
std=self.config.initializer_range, |
|
).to(module.position_embeddings.dtype) |
|
module.cls_token.data = nn.init.trunc_normal_( |
|
module.cls_token.data.to(torch.float32), |
|
mean=0.0, |
|
std=self.config.initializer_range, |
|
).to(module.cls_token.dtype) |
|
|
|
|
|
class DeiT(nn.Module): |
|
"""DeiT model. |
|
|
|
Paper: Training data-efficient image transformers & distillation through attention |
|
https://arxiv.org/abs/2012.12877 |
|
Huggingface Reference: https://huggingface.co/docs/transformers/en/model_doc/deit |
|
|
|
Attributes: |
|
model_name (str): name of the model. |
|
pretrained (bool): whether to use pretrained weights. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "facebook/deit-small-patch16-224", |
|
pretrained: bool = False, |
|
image_size: int = 224, |
|
): |
|
super().__init__() |
|
self.image_size = image_size |
|
model = AutoModel.from_pretrained(model_name) |
|
if pretrained: |
|
self.model = model |
|
else: |
|
deit_config = model.config |
|
self.model = AutoModel.from_config(deit_config) |
|
del model |
|
|
|
self.model.pooler = nn.Identity() |
|
|
|
self.processor = AutoProcessor.from_pretrained(model_name) |
|
|
|
def get_feature_size( |
|
self, |
|
keep_spatial: bool = False, |
|
return_torch_size: bool = False, |
|
) -> torch.Size | tuple[int, ...]: |
|
"""Get the size of the feature. |
|
|
|
Args: |
|
keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. |
|
return_torch_size (bool): if true, return torch.Size type. Defaults to False. |
|
|
|
Returns: |
|
torch.Size | tuple[int, ...]: returned feature shape. |
|
""" |
|
with torch.inference_mode(): |
|
image_size = (224, 224) |
|
x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) |
|
y = self.forward(x)[:, 1:] |
|
size = y.size()[1:][::-1] |
|
if keep_spatial: |
|
assert math.isqrt(size[-1]) |
|
h = w = int(math.sqrt(size[-1])) |
|
size = (size[0], h, w) |
|
if return_torch_size: |
|
size = torch.Size(size) |
|
return size |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
do_resize: bool = True, |
|
interpolate_pos_encoding: Optional[bool] = None, |
|
do_rescale: bool = True, |
|
do_normalize: bool = True, |
|
) -> torch.Tensor: |
|
"""Forward pass of the model |
|
|
|
Args: |
|
x (torch.Tensor): model input. |
|
|
|
- arguments for self.processor. Details can be find at |
|
https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor |
|
do_resize (bool): if do resizing in processor. Defaults to True. |
|
interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. |
|
do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. |
|
do_normalize (bool): if do normalize in processor. Defaults to True. |
|
|
|
Returns: |
|
torch.Tensor: model output. |
|
""" |
|
input = self.processor( |
|
x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize |
|
).to(self.model.device) |
|
y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) |
|
return y.last_hidden_state |
|
|
|
|
|
class DeiTNoCLS(nn.Module): |
|
"""Modified DeiT model without CLS token.""" |
|
|
|
def __init__( |
|
self, model_name: str = "nocls-facebook/deit-small-patch16-224", pretrained: bool = False, image_size: int = 224 |
|
): |
|
super().__init__() |
|
self.image_size = image_size |
|
pretrained_model_name = model_name.replace("nocls-", "") |
|
deit_config = AutoConfig.from_pretrained(pretrained_model_name) |
|
self.model = ViTModelNoCLS(deit_config) |
|
if pretrained: |
|
pretrained_model = AutoModel.from_pretrained(pretrained_model_name) |
|
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()} |
|
self.load_state_dict(pretrained_dict, strict=False) |
|
del pretrained_model, pretrained_dict |
|
|
|
self.model.pooler = nn.Identity() |
|
self.processor = AutoProcessor.from_pretrained(pretrained_model_name) |
|
self.no_cls = True |
|
|
|
def get_feature_size( |
|
self, |
|
keep_spatial: bool = False, |
|
return_torch_size: bool = False, |
|
) -> torch.Size | tuple[int, ...]: |
|
"""Get the size of the feature. |
|
|
|
Args: |
|
keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. |
|
return_torch_size (bool): if true, return torch.Size type. Defaults to False. |
|
|
|
Returns: |
|
torch.Size | tuple[int, ...]: returned feature shape. |
|
""" |
|
with torch.inference_mode(): |
|
image_size = (self.image_size, self.image_size) |
|
x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) |
|
y = self.forward(x) |
|
size = y.size()[1:][::-1] |
|
if keep_spatial: |
|
assert math.isqrt(size[-1]) |
|
h = w = int(math.sqrt(size[-1])) |
|
size = (size[0], h, w) |
|
if return_torch_size: |
|
size = torch.Size(size) |
|
return size |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
do_resize: bool = True, |
|
interpolate_pos_encoding: Optional[bool] = None, |
|
do_rescale: bool = True, |
|
do_normalize: bool = True, |
|
) -> torch.Tensor: |
|
"""Forward pass of the model |
|
|
|
Args: |
|
x (torch.Tensor): model input. |
|
|
|
- arguments for self.processor. Details can be find at |
|
https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor |
|
do_resize (bool): if do resizing in processor. Defaults to True. |
|
do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. |
|
do_normalize (bool): if do normalize in processor. Defaults to True. |
|
|
|
- argument for forward |
|
interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: model output. |
|
""" |
|
input = self.processor( |
|
x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize |
|
).to(self.model.device) |
|
y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) |
|
return y.last_hidden_state |
|
|
|
|
|
class DeiTReg(nn.Module): |
|
"""Modified DeiT model with register tokens.""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "reg-facebook/deit-small-patch16-224", |
|
pretrained: bool = False, |
|
image_size: int = 224, |
|
num_reg_tokens: int = 7, |
|
): |
|
super().__init__() |
|
self.image_size = image_size |
|
pretrained_model_name = model_name.replace("reg-", "") |
|
deit_config = AutoConfig.from_pretrained(pretrained_model_name) |
|
self.model = ViTModelReg(deit_config, num_reg_tokens=num_reg_tokens) |
|
if pretrained: |
|
pretrained_model = AutoModel.from_pretrained(pretrained_model_name) |
|
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()} |
|
self.load_state_dict(pretrained_dict, strict=False) |
|
del pretrained_model, pretrained_dict |
|
|
|
self.model.pooler = nn.Identity() |
|
self.processor = AutoProcessor.from_pretrained(pretrained_model_name) |
|
self.num_reg_tokens = num_reg_tokens |
|
|
|
def get_feature_size( |
|
self, |
|
keep_spatial: bool = False, |
|
return_torch_size: bool = False, |
|
) -> torch.Size | tuple[int, ...]: |
|
"""Get the size of the feature. |
|
|
|
Args: |
|
keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. |
|
return_torch_size (bool): if true, return torch.Size type. Defaults to False. |
|
|
|
Returns: |
|
torch.Size | tuple[int, ...]: returned feature shape. |
|
""" |
|
with torch.inference_mode(): |
|
image_size = (self.image_size, self.image_size) |
|
x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) |
|
y = self.forward(x)[:, 1 : -self.num_reg_tokens] |
|
size = y.size()[1:][::-1] |
|
if keep_spatial: |
|
assert math.isqrt(size[-1]) |
|
h = w = int(math.sqrt(size[-1])) |
|
size = (size[0], h, w) |
|
if return_torch_size: |
|
size = torch.Size(size) |
|
return size |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
do_resize: bool = True, |
|
interpolate_pos_encoding: Optional[bool] = None, |
|
do_rescale: bool = True, |
|
do_normalize: bool = True, |
|
) -> torch.Tensor: |
|
"""Forward pass of the model |
|
|
|
Args: |
|
x (torch.Tensor): model input. |
|
|
|
- arguments for self.processor. Details can be find at |
|
https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor |
|
do_resize (bool): if do resizing in processor. Defaults to True. |
|
interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. |
|
do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. |
|
do_normalize (bool): if do normalize in processor. Defaults to True. |
|
|
|
Returns: |
|
torch.Tensor: model output. |
|
""" |
|
input = self.processor( |
|
x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize |
|
).to(self.model.device) |
|
y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) |
|
return y.last_hidden_state |
|
|
|
|
|
def build_backbone(model_name: str, pretrained: bool = False, image_size: int = 224, **kwargs: Any) -> nn.Module: |
|
"""Build the backbone visual encoder of robot vision foundation model. |
|
|
|
Args: |
|
model_name (str): name of the model. |
|
pretrained (bool): whether to use pretrained weights. Defaults to False. |
|
image_size (int): size of the image. Assume a square image. Defaults to 224 |
|
kwargs (Any): any kwargs specific to some models. For example, |
|
`num_reg_tokens` for `DeiTReg` when `"reg"` in `model_name` |
|
|
|
Returns: |
|
nn.Module: backbone network. |
|
""" |
|
if "reg" in model_name: |
|
return DeiTReg(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs) |
|
elif "nocls" in model_name: |
|
return DeiTNoCLS(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs) |
|
elif "deit" in model_name: |
|
return DeiT(model_name=model_name, pretrained=pretrained, image_size=image_size) |
|
else: |
|
raise NotImplementedError(f"Requested {model_name} is not implemented.") |
|
|
|
class Interpolation(nn.Module): |
|
"""Interpolation nn.Module wrap for nn.functional.interpolate. |
|
|
|
Attributes: |
|
target_size (tuple[int, int] | torch.Size): target spatial size of this interpolation. |
|
""" |
|
|
|
def __init__(self, target_size: tuple[int, int] | torch.Size) -> None: |
|
super().__init__() |
|
self.target_size = target_size |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Very simple forward pass to call interpolate().""" |
|
return interpolate(x, self.target_size) |
|
|
|
|
|
class LinearAdapterHead(nn.Module): |
|
"""Adapter head contains a single linear layer.""" |
|
def __init__( |
|
self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size |
|
): |
|
"""Initialization function for LinearAdapterHead. |
|
Args: |
|
source_size (tuple[int, ...] | torch.Size): the size of the source feature. |
|
target_size (tuple[int, ...] | torch.Size): the size of the target feature. |
|
num_layer (int): number of MLP layers (One linear layer if num_layer = 1). |
|
""" |
|
super().__init__() |
|
|
|
self.source_size = source_size |
|
self.target_size = target_size |
|
|
|
source_channel_size = self.source_size[0] |
|
target_channel_size = self.target_size[0] |
|
|
|
self.adapter = nn.Sequential( |
|
nn.Linear(source_channel_size, target_channel_size), |
|
) |
|
|
|
def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: |
|
"""Forward pass for the adapter. """ |
|
assert backbone_no_cls == False |
|
|
|
|
|
x = x[:, 0] |
|
x = self.adapter(x) |
|
return x |
|
|
|
|
|
class MLPAdapterHead(nn.Module): |
|
"""MLP Adapter module. |
|
|
|
Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. |
|
Will first do interpolation to match the spatial size [H_t, W_t], |
|
followed by MLP to project to the target channel dimension [C_t]. |
|
|
|
Attributes: |
|
source_size (tuple[int, ...] | torch.Size): the size of the source feature. [C, H, W] |
|
target_size (tuple[int, ...] | torch.Size): the size of the target feature. [C, H, W] |
|
adapter (nn.Module): the adapter module. |
|
interpolation (nn.Module): interpolation to adjust sizes before MLP. |
|
""" |
|
|
|
def __init__( |
|
self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size, num_layer: int |
|
): |
|
"""Initialization function for MLPAdapter. |
|
|
|
Args: |
|
source_size (tuple[int, ...] | torch.Size): the size of the source feature. |
|
target_size (tuple[int, ...] | torch.Size): the size of the target feature. |
|
num_layer (int): number of MLP layers (One linear layer if num_layer = 1). |
|
""" |
|
super().__init__() |
|
assert num_layer >= 1, f"`num_layer` in {self._get_name()} should >= 1. Got {num_layer}" |
|
|
|
self.source_size = source_size |
|
self.target_size = target_size |
|
|
|
source_channel_size = self.source_size[0] |
|
target_channel_size = self.target_size[0] |
|
|
|
self.interpolation = nn.Sequential( |
|
nn.Identity(), |
|
) |
|
if self.source_size[1] != self.target_size[1]: |
|
self.interpolation = nn.Sequential( |
|
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
|
Interpolation(self.target_size[1:]), |
|
Rearrange("b c h w-> b (h w) c"), |
|
) |
|
|
|
if num_layer == 1: |
|
self.adapter = nn.Sequential( |
|
nn.Linear(source_channel_size, target_channel_size), |
|
) |
|
elif num_layer >= 2: |
|
hidden_dim = source_channel_size * 2 |
|
self.adapter = nn.Sequential( |
|
nn.Linear(source_channel_size, hidden_dim), |
|
*list( |
|
chain.from_iterable([[nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)] for _ in range(num_layer - 2)]) |
|
), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, target_channel_size), |
|
) |
|
|
|
def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: |
|
"""Forward pass for the adapter. First interpolation then MLP.""" |
|
|
|
if not backbone_no_cls: |
|
x = x[:, 1:] |
|
|
|
x = self.interpolation(x) |
|
x = self.adapter(x) |
|
return x |
|
|
|
|
|
class ConvAdapterHead(nn.Module): |
|
"""Convolutional Adapter module. |
|
|
|
Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. |
|
Uses CNN to map channel and spatial sizes jointly. |
|
Note: only work for (16, 16), (any, any), any <= 14, and (64, 64) spatial sizes for now. |
|
|
|
Attributes: |
|
source_size (tuple[int, ...] | torch.Size): the size of the source feature. |
|
target_size (tuple[int, ...] | torch.Size): the size of the target feature. |
|
adapter (nn.Module): the adapter module. |
|
interpolation (nn.Module): interpolation to adjust sizes before MLP. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
source_size: tuple[int, ...] | torch.Size, |
|
target_size: tuple[int, ...] | torch.Size, |
|
): |
|
"""Initialization function for ConvAdapter. |
|
|
|
Args: |
|
source_size (tuple[int, ...] | torch.Size): the size of the source feature. |
|
target_size (tuple[int, ...] | torch.Size): the size of the target feature. |
|
""" |
|
super().__init__() |
|
self.source_size = source_size |
|
self.target_size = target_size |
|
|
|
hidden_dim = self.source_size[0] * 2 |
|
source_channel_size = self.source_size[0] |
|
target_channel_size = self.target_size[0] |
|
|
|
if self.source_size[1] < 12: |
|
raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.") |
|
elif self.source_size[1] < 16: |
|
self.pad = nn.Sequential( |
|
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
|
nn.ConvTranspose2d( |
|
source_channel_size, |
|
source_channel_size, |
|
kernel_size=3, |
|
stride=1, |
|
output_padding=14 - self.source_size[1], |
|
), |
|
) |
|
self.source_size = (self.source_size[0], 16, 16) |
|
elif self.source_size[1] == 16 or self.source_size[1] == 64: |
|
self.pad = nn.Sequential( |
|
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
|
) |
|
else: |
|
raise NotImplementedError("feature spatial size (>=16x16) other than 16x16 and 64x64 is not supported.") |
|
|
|
if self.source_size[1] < self.target_size[1]: |
|
self.adapter = nn.Sequential( |
|
nn.LayerNorm(self.source_size), |
|
nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, 31, 31]), |
|
nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, 64, 64]), |
|
nn.ConvTranspose2d(hidden_dim, target_channel_size, kernel_size=3, stride=1, padding=1), |
|
Rearrange("b c h w-> b (h w) c"), |
|
) |
|
elif self.source_size[1] == self.target_size[1]: |
|
self.adapter = nn.Sequential( |
|
nn.LayerNorm(self.source_size), |
|
nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, *self.source_size[1:]]), |
|
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, *self.source_size[1:]]), |
|
nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), |
|
Rearrange("b c h w-> b (h w) c"), |
|
) |
|
else: |
|
self.adapter = nn.Sequential( |
|
nn.LayerNorm(self.source_size), |
|
nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, 32, 32]), |
|
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, 16, 16]), |
|
nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), |
|
Rearrange("b c h w-> b (h w) c"), |
|
) |
|
|
|
def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: |
|
"""Forward pass for ConvAdapter""" |
|
|
|
if not backbone_no_cls: |
|
x = x[:, 1:] |
|
|
|
x = self.pad(x) |
|
x = self.adapter(x) |
|
return x |
|
|
|
|
|
class LightConvAdapterHead(nn.Module): |
|
"""Light Convolutional Adapter module. |
|
|
|
Transforms features from source size in [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. |
|
Uses CNN to map channel and spatial sizes jointly. |
|
Note: only work for source sizes (H_s, W_s): (16, 16), (any, any), 12 <= any <= 14, |
|
and target sizes (H_t, W_t): (16, 16) and (64, 64) for now. |
|
|
|
Attributes: |
|
source_size (tuple[int, ...] | torch.Size): the size of the source feature, |
|
channel first (C, H, W). |
|
target_size (tuple[int, ...] | torch.Size): the size of the target feature, |
|
channel first (C, H, W). |
|
adapter (nn.Module): the adapter module. |
|
interpolation (nn.Module): interpolation to adjust sizes before MLP. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
source_size: tuple[int, ...] | torch.Size, |
|
target_size: tuple[int, ...] | torch.Size, |
|
hidden_size_factor: int | float = 1.0, |
|
): |
|
"""Initialization function for ConvAdapter. |
|
|
|
Args: |
|
source_size (tuple[int, ...] | torch.Size): the size of the source feature. |
|
target_size (tuple[int, ...] | torch.Size): the size of the target feature. |
|
hidden_size_factor (int | float): the size of hidden dim of feature translator |
|
as a factor of input feature hidden dim. |
|
""" |
|
super().__init__() |
|
if source_size[1] != source_size[2] or target_size[1] != target_size[2]: |
|
raise NotImplementedError( |
|
"Currently does not support non-square feature maps like source size" |
|
"{source_size} and target size {target_size}." |
|
) |
|
self.source_size = source_size |
|
self.target_size = target_size |
|
self.hidden_size_factor = hidden_size_factor |
|
|
|
hidden_dim = int(self.source_size[0] * hidden_size_factor) |
|
source_channel_size = self.source_size[0] |
|
target_channel_size = self.target_size[0] |
|
|
|
if self.source_size[1] < 12: |
|
raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.") |
|
elif self.source_size[1] < 16 and self.target_size[1] >= 16: |
|
self.pad = nn.Sequential( |
|
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
|
nn.ConvTranspose2d( |
|
source_channel_size, |
|
source_channel_size, |
|
kernel_size=3, |
|
stride=1, |
|
output_padding=14 - self.source_size[1], |
|
), |
|
) |
|
self.source_size = (self.source_size[0], 16, 16) |
|
elif (self.source_size[1] == 16 or self.source_size[1] == 64) or \ |
|
(self.source_size[1] == 14 and self.target_size[1] == 14): |
|
|
|
self.pad = nn.Sequential( |
|
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
|
) |
|
elif self.target_size[1] < 14: |
|
self.pad = nn.Sequential( |
|
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), |
|
) |
|
else: |
|
raise NotImplementedError("feature spatial size larger than 16x16 (other than 64x64) is not supported.") |
|
|
|
if self.source_size[1] == 16 and self.target_size[1] == 64: |
|
self.adapter = nn.Sequential( |
|
nn.LayerNorm(self.source_size), |
|
nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, 31, 31]), |
|
nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, 64, 64]), |
|
Rearrange("b c h w-> b (h w) c"), |
|
nn.Linear(hidden_dim, target_channel_size), |
|
) |
|
elif self.source_size[1] == self.target_size[1]: |
|
self.adapter = nn.Sequential( |
|
nn.LayerNorm(self.source_size), |
|
nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, *self.source_size[1:]]), |
|
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, *self.source_size[1:]]), |
|
Rearrange("b c h w-> b (h w) c"), |
|
nn.Linear(hidden_dim, target_channel_size), |
|
) |
|
elif self.source_size[1] == 64 and self.target_size[1] == 16: |
|
self.adapter = nn.Sequential( |
|
nn.LayerNorm(self.source_size), |
|
nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, 32, 32]), |
|
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, 16, 16]), |
|
Rearrange("b c h w-> b (h w) c"), |
|
nn.Linear(hidden_dim, target_channel_size), |
|
) |
|
elif self.target_size[1] == 7: |
|
self.adapter = nn.Sequential( |
|
nn.LayerNorm(self.source_size), |
|
nn.Conv2d(source_channel_size, hidden_dim, kernel_size=4, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.LayerNorm([hidden_dim, 7, 7]), |
|
Rearrange("b c h w-> b (h w) c"), |
|
nn.Linear(hidden_dim, target_channel_size) |
|
) |
|
else: |
|
NotImplementedError(f"{self.source_size} to {self.target_size} is not supported.") |
|
|
|
def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: |
|
"""Forward pass for ConvAdapter""" |
|
|
|
if not backbone_no_cls: |
|
x = x[:, 1:] |
|
x = self.pad(x) |
|
x = self.adapter(x) |
|
return x |
|
|
|
|
|
class FeatureTranslator(nn.Module): |
|
"""Base class for the feature translator. |
|
|
|
The flow is backbone_adapter -> translator_stem -> translator_heads. |
|
|
|
Attributes: |
|
backbone_feature_size (torch.Size): the size of features of the backbone. |
|
target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. |
|
translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048. |
|
target_model_names (list[str]): convenient attribute to hold all the names of the target models. |
|
|
|
backbone_adapter (nn.Module): the adapter to map channel dim of backbone to the translator hidden dim. |
|
translator_stem (nn.Module): the shared stem for all target models. |
|
translator_heads (nn.ModuleDict): specific heads for different target models. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
backbone_feature_size: torch.Size, |
|
target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], |
|
translator_hidden_size: int = 1024, |
|
) -> None: |
|
"""Initalization function for FeatureTranslator. |
|
|
|
Args: |
|
backbone_feature_size (torch.Size): the size of features of the backbone. |
|
target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. |
|
translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048. |
|
""" |
|
super().__init__() |
|
self.backbone_feature_size = backbone_feature_size |
|
self.target_feature_sizes = target_feature_sizes |
|
self.translator_hidden_size = translator_hidden_size |
|
self.target_model_names = list(target_feature_sizes.keys()) |
|
self.legit_target_model_name_map: dict[str, str] = {t: t.replace(".", "_") for t in self.target_model_names} |
|
self.translator_heads: nn.ModuleDict = None |
|
|
|
self.backbone_adapter = nn.Sequential( |
|
nn.LayerNorm(self.backbone_feature_size[0]), |
|
nn.Linear( |
|
self.backbone_feature_size[0], |
|
self.translator_hidden_size, |
|
), |
|
) |
|
self.translator_stem: nn.Module = nn.Identity() |
|
self.build_translator_heads() |
|
|
|
def build_translator_heads(self) -> None: |
|
"""Build translator heads to match the dimension of each target feature set. |
|
|
|
Example: |
|
translator_heads: dict[str, nn.Module] = ... |
|
self.translator_heads = nn.ModuleDict(translator_heads) |
|
""" |
|
raise NotImplementedError("build_translator_heads() should be overridden") |
|
|
|
def forward( |
|
self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False |
|
) -> torch.Tensor: |
|
"""Forward pass for a base feature translator. |
|
|
|
Args: |
|
x (torch.Tensor): input features from the backbone. [B, (1)+H*W, C]. |
|
(1) means optional CLS token. If `backbone_no_cls==True`, then [B, H*W, C]. |
|
target_model_names (Optional[list[str]]): names of the target models. |
|
backbone_no_cls (bool): indicate backbone has cls token or not. |
|
Can use it to customize whether to drop cls. |
|
|
|
Returns: |
|
dict[str, torch.Tensor]: predicted features for target models. |
|
""" |
|
|
|
x = self.backbone_adapter(x) |
|
x = self.translator_stem(x) |
|
target_model_names = target_model_names if target_model_names is not None else self.target_model_names |
|
features = {t: self.translator_heads[self.legit_target_model_name_map[t]](x, backbone_no_cls=backbone_no_cls) for t in target_model_names} |
|
return features |
|
|
|
|
|
class MLPFeatureTranslator(FeatureTranslator): |
|
def __init__( |
|
self, |
|
backbone_feature_size: torch.Size, |
|
target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], |
|
translator_hidden_size: int = 1024, |
|
translator_n_layer: int = 3, |
|
) -> None: |
|
"""Initalization function for MLPFeatureTranslator. |
|
|
|
Args: |
|
backbone_feature_size (torch.Size): the size of features of the backbone. |
|
target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. |
|
translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048. |
|
translator_n_layer (int): number of MLP layers. Defaults to 3. |
|
""" |
|
self.translator_n_layer = translator_n_layer |
|
|
|
super().__init__( |
|
backbone_feature_size=backbone_feature_size, |
|
target_feature_sizes=target_feature_sizes, |
|
translator_hidden_size=translator_hidden_size, |
|
) |
|
|
|
def build_translator_heads(self) -> nn.ModuleDict: |
|
"""Build MLP translator heads to match the dimension of each target feature set.""" |
|
translator_heads = {} |
|
source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:]) |
|
for target_model, target_size in self.target_feature_sizes.items(): |
|
head = MLPAdapterHead(source_size=source_size, target_size=target_size, num_layer=self.translator_n_layer) |
|
translator_heads[self.legit_target_model_name_map[target_model]] = head |
|
self.translator_heads = nn.ModuleDict(translator_heads) |
|
|
|
|
|
class ConvFeatureTranslator(FeatureTranslator): |
|
def __init__( |
|
self, |
|
backbone_feature_size: torch.Size, |
|
target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], |
|
translator_hidden_size: int = 1024, |
|
) -> None: |
|
"""Initalization function for ConvFeatureTranslator. |
|
|
|
Args: |
|
backbone_feature_size (torch.Size): the size of features of the backbone. |
|
target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. |
|
translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048. |
|
""" |
|
super().__init__( |
|
backbone_feature_size=backbone_feature_size, |
|
target_feature_sizes=target_feature_sizes, |
|
translator_hidden_size=translator_hidden_size, |
|
) |
|
|
|
def build_translator_heads(self) -> nn.ModuleDict: |
|
"""Build translator heads to match the dimension of each target feature set. |
|
|
|
Returns: |
|
nn.ModuleDict: the translator heads. |
|
""" |
|
translator_heads = {} |
|
source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:]) |
|
for target_model, target_size in self.target_feature_sizes.items(): |
|
head = ConvAdapterHead(source_size=source_size, target_size=target_size) |
|
translator_heads[self.legit_target_model_name_map[target_model]] = head |
|
self.translator_heads = nn.ModuleDict(translator_heads) |
|
|
|
|
|
class LightConvFeatureTranslator(FeatureTranslator): |
|
def __init__( |
|
self, |
|
backbone_feature_size: torch.Size, |
|
target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], |
|
translator_hidden_size: int = 1024, |
|
hidden_size_factor: int | float = 1.0, |
|
) -> None: |
|
"""Initalization function for LightConvFeatureTranslator. |
|
It's for a smaller translator compared to ConvFeatureTranslator. |
|
|
|
Args: |
|
backbone_feature_size (torch.Size): the size of features of the backbone. |
|
target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. |
|
translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 1024. |
|
hidden_size_factor: the size of hidden dim of feature translator |
|
as a factor of input feature hidden dim. Defaults to 1.0 |
|
""" |
|
self.hidden_size_factor = hidden_size_factor |
|
super().__init__( |
|
backbone_feature_size=backbone_feature_size, |
|
target_feature_sizes=target_feature_sizes, |
|
translator_hidden_size=translator_hidden_size, |
|
) |
|
self.backbone_adapter = nn.Identity() |
|
|
|
def build_translator_heads(self) -> nn.ModuleDict: |
|
"""Build translator heads to match the dimension of each target feature set. |
|
|
|
Returns: |
|
nn.ModuleDict: the translator heads. |
|
""" |
|
translator_heads = {} |
|
for target_model, target_size in self.target_feature_sizes.items(): |
|
if "_cls" in target_model: |
|
head = LinearAdapterHead( |
|
source_size=self.backbone_feature_size, |
|
target_size=target_size |
|
) |
|
else: |
|
head = LightConvAdapterHead( |
|
source_size=self.backbone_feature_size, |
|
target_size=target_size, |
|
hidden_size_factor=self.hidden_size_factor |
|
) |
|
translator_heads[self.legit_target_model_name_map[target_model]] = head |
|
self.translator_heads = nn.ModuleDict(translator_heads) |
|
|
|
|
|
class TransformerFreatureTranslator(FeatureTranslator): |
|
def __init__( |
|
self, |
|
backbone_feature_size: torch.Size, |
|
target_feature_sizes: dict[str, torch.Size | tuple[int, int]], |
|
translator_hidden_size: int = 1024, |
|
translator_n_layers: int = 2, |
|
translator_n_heads: int = 8, |
|
translator_activation: str = "gelu", |
|
) -> None: |
|
super().__init__( |
|
backbone_feature_size=backbone_feature_size, |
|
target_feature_sizes=target_feature_sizes, |
|
translator_hidden_size=translator_hidden_size, |
|
) |
|
|
|
self.translator_stem = nn.TransformerDecoder( |
|
nn.TransformerDecoderLayer( |
|
d_model=translator_hidden_size, |
|
nhead=translator_n_heads, |
|
dim_feedforward=translator_hidden_size * 2, |
|
activation=translator_activation, |
|
batch_first=True, |
|
norm_first=True, |
|
), |
|
num_layers=translator_n_layers, |
|
) |
|
|
|
self.decode_tokens = nn.Parameter( |
|
torch.randn((1, math.prod(self.backbone_feature_size[1:]), translator_hidden_size)) |
|
) |
|
|
|
self.target_model_emb = nn.ParameterDict( |
|
{ |
|
self.legit_target_model_name_map[t]: torch.randn(1, 1, translator_hidden_size) |
|
for t in self.target_model_names |
|
} |
|
) |
|
|
|
def build_translator_heads(self) -> None: |
|
"""Build Transformer translator heads to match the dimension of each target feature set.""" |
|
translator_heads = {} |
|
for target_model, target_size in self.target_feature_sizes.items(): |
|
head = MLPAdapterHead( |
|
source_size=(self.translator_hidden_size, *self.backbone_feature_size[1:]), |
|
target_size=target_size, |
|
num_layer=2, |
|
) |
|
translator_heads[self.legit_target_model_name_map[target_model]] = head |
|
self.translator_heads = nn.ModuleDict(translator_heads) |
|
|
|
def forward( |
|
self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False |
|
) -> torch.Tensor: |
|
"""Forward pass for a simple linear translator. |
|
|
|
Args: |
|
x (torch.Tensor): input features from the backbone. |
|
target_model_names (Optional[str]): names of the target models. |
|
backbone_no_cls (bool): indicate backbone has cls token or not. |
|
Can use it to customize whether to drop cls. |
|
|
|
Returns: |
|
dict[str, torch.Tensor]: predicted features for target models. |
|
""" |
|
if not backbone_no_cls: |
|
x = x[:, 1:] |
|
x = self.backbone_adapter(x) |
|
features = {} |
|
target_model_names = target_model_names if target_model_names is not None else self.target_model_names |
|
for t in target_model_names: |
|
feature = self.translator_stem( |
|
torch.cat( |
|
[ |
|
self.decode_tokens.repeat(x.size(0), 1, 1), |
|
self.target_model_emb[self.legit_target_model_name_map[t]].repeat(x.size(0), 1, 1), |
|
], |
|
dim=1, |
|
), |
|
memory=x, |
|
)[:, 1:, ...] |
|
features[t] = self.translator_heads[self.legit_target_model_name_map[t]](feature) |
|
return features |
|
|
|
|
|
def build_feature_translator(translator_type: str, **kwargs: Any) -> FeatureTranslator: |
|
"""Handy function to build feature translators given the type |
|
|
|
Args: |
|
translator_type (str): the type of the translator, |
|
one in `"mlp"`, `"conv"`, `"lconv"`, `"transformer"` (or `"trans"`). |
|
At the moment we are actively using `"lconv"`. |
|
|
|
Returns: |
|
FeatureTranslator: the corresponding FeatureTranslator |
|
""" |
|
if translator_type == "mlp": |
|
return MLPFeatureTranslator(**kwargs) |
|
elif translator_type == "conv": |
|
return ConvFeatureTranslator(**kwargs) |
|
elif translator_type == "lconv": |
|
return LightConvFeatureTranslator(**kwargs) |
|
elif translator_type == "transformer" or translator_type == "trans": |
|
return TransformerFreatureTranslator(**kwargs) |
|
else: |
|
raise NotImplementedError(f"Requested {translator_type} is not implemented yet.") |
|
|
|
|
|
class TheiaConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
backbone: str | nn.Module = "facebook/deit-tiny-patch16-224", |
|
pretrained: bool = False, |
|
target_feature_sizes: Optional[dict[str, torch.Size | tuple[int, ...]]] = None, |
|
translator_type: str = "lconv", |
|
translator_hidden_size_factor: float | int = 1.0, |
|
target_loss_weights: Optional[dict[str, float]] = None, |
|
feature_reduce_method: Optional[str] = None, |
|
feature_neck: bool = False, |
|
feature_neck_hidden_dim: int = 256, |
|
forward_neck: bool = False, |
|
feature_neck_nonlinearity: str = "relu", |
|
iamge_size: int = 224, |
|
num_reg_tokens: int = 0, |
|
**kwargs: Any |
|
): |
|
self.backbone = backbone |
|
self.pretrained = pretrained |
|
self.target_feature_sizes = target_feature_sizes |
|
self.translator_type = translator_type |
|
self.translator_hidden_size_factor = translator_hidden_size_factor |
|
self.target_loss_weights = target_loss_weights |
|
self.feature_reduce_method = feature_reduce_method |
|
self.feature_neck = feature_neck |
|
self.feature_neck_hidden_dim = feature_neck_hidden_dim |
|
self.forward_neck = forward_neck |
|
self.feature_neck_nonlinearity = feature_neck_nonlinearity |
|
self.image_size = 224 |
|
self.num_reg_tokens = num_reg_tokens |
|
super().__init__(**kwargs) |
|
|
|
class TheiaModel(PreTrainedModel): |
|
config_class = TheiaConfig |
|
|
|
def __init__(self, config: TheiaConfig): |
|
super().__init__(config) |
|
|
|
self.target_feature_sizes = config.target_feature_sizes |
|
self.preprocessor = None |
|
self.pretrained = config.pretrained |
|
|
|
|
|
self.image_size = config.image_size |
|
if "reg" in config.backbone: |
|
self.backbone: nn.Module = build_backbone(config.backbone, config.pretrained, image_size=config.image_size, num_reg_tokens = config.num_reg_tokens) |
|
else: |
|
self.backbone: nn.Module = build_backbone(config.backbone, config.pretrained, image_size=config.image_size) |
|
|
|
|
|
self.feature_reduce_method = config.feature_reduce_method |
|
self.no_cls = hasattr(self.backbone, "no_cls") |
|
self.num_reg_tokens = self.backbone.num_reg_tokens if hasattr(self.backbone, "num_reg_tokens") else 0 |
|
|
|
|
|
backbone_feature_size = self.backbone.get_feature_size(keep_spatial=True) |
|
if self.target_feature_sizes: |
|
translator_kwargs = { |
|
"hidden_size_factor": config.translator_hidden_size_factor |
|
} |
|
translator_kwargs["backbone_feature_size"] = backbone_feature_size |
|
translator_kwargs["target_feature_sizes"] = config.target_feature_sizes |
|
self.translator = build_feature_translator( |
|
config.translator_type, **translator_kwargs |
|
) |
|
else: |
|
self.translator = None |
|
|
|
self.feature_neck = config.feature_neck |
|
self.feature_neck_hidden_dim = config.feature_neck_hidden_dim |
|
self.forward_neck = config.forward_neck |
|
if self.feature_neck: |
|
num_tokens_edge = self.backbone.model.config.image_size // self.backbone.model.config.patch_size |
|
self.neck = nn.Sequential( |
|
Rearrange("b (h w) c -> b c h w", h=num_tokens_edge, w=num_tokens_edge), |
|
nn.Conv2d(self.backbone.model.config.hidden_size, self.feature_neck_hidden_dim, kernel_size=4, stride=2, padding=1), |
|
nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), |
|
nn.Conv2d(self.feature_neck_hidden_dim, self.feature_neck_hidden_dim, kernel_size=3, stride=2), |
|
nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), |
|
nn.Conv2d(self.feature_neck_hidden_dim, self.feature_neck_hidden_dim, kernel_size=3, stride=1), |
|
nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), |
|
nn.Flatten() |
|
) |
|
else: |
|
self.neck = None |
|
|
|
|
|
self.mse_loss = nn.MSELoss() |
|
self.l1_loss = nn.SmoothL1Loss() |
|
self.cos_loss = nn.CosineEmbeddingLoss() |
|
self.cos_target = torch.ones((1), dtype=torch.int, requires_grad=False) |
|
self.target_loss_weights = config.target_loss_weights |
|
|
|
def load_pretrained_weights(self, checkpoint_path: str) -> None: |
|
""" |
|
Load weights from `checkpoint_path` manually. |
|
|
|
Args: |
|
checkpoint_path (str): path to the weights. |
|
""" |
|
|
|
if checkpoint_path: |
|
weights_dict = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
pretrained_dict = {k: v for k, v in weights_dict.items() if k in self.state_dict()} |
|
self.load_state_dict(pretrained_dict, strict=False) |
|
|
|
def freeze_translator(self) -> None: |
|
"""Freeze feature translators `self.translator`.""" |
|
if self.translator is not None: |
|
for param in self.translator.parameters(): |
|
param.requires_grad = False |
|
|
|
def freeze_backbone(self) -> None: |
|
"""Freeze backbone (encoder) `self.backbone`. """ |
|
self.freeze_encoder() |
|
|
|
def freeze_encoder(self) -> None: |
|
"""Freeze backbone (encoder) `self.backbone`. """ |
|
for param in self.backbone.parameters(): |
|
param.requires_grad = False |
|
|
|
def freeze_neck(self) -> None: |
|
"""Freeze feature neck `self.neck`.""" |
|
if self.neck is not None: |
|
for param in self.neck.parameters(): |
|
param.requires_grad = False |
|
|
|
def freeze_everything(self) -> None: |
|
"""Freeze all parameters in the model.""" |
|
self.freeze_translator() |
|
self.freeze_neck() |
|
self.freeze_encoder() |
|
|
|
def unfreeze_translator(self) -> None: |
|
if self.translator is not None: |
|
for param in self.translator.parameters(): |
|
param.requires_grad = True |
|
|
|
def unfreeze_backbone(self) -> None: |
|
"Set parameters in backbone (encoder) `self.backbone` trainable." |
|
self.unfreeze_encoder() |
|
|
|
def unfreeze_encoder(self) -> None: |
|
"Set parameters in backbone (encoder) `self.backbone` trainable." |
|
for param in self.backbone.parameters(): |
|
param.requires_grad = True |
|
|
|
def unfreeze_neck(self) -> None: |
|
"Set parameters in feature neck `self.neck` trainable." |
|
if self.neck is not None: |
|
for param in self.neck.parameters(): |
|
param.requires_grad = True |
|
|
|
def unfreeze_everything(self) -> None: |
|
"""Set all parameters trainable.""" |
|
self.unfreeze_translator() |
|
self.unfreeze_neck() |
|
self.unfreeze_encoder() |
|
|
|
def set_forward_neck(self, forward_neck: bool = True) -> None: |
|
""" |
|
Set `self.forward_neck` to `forward_neck` value. |
|
|
|
Args: |
|
forward_neck (bool): whether forward the feature through the random initialized neck. |
|
If set to True, the output from `self.forward()` will be in shape [batch_size, self.config.feature_neck_hidden_dim] |
|
""" |
|
self.forward_neck = forward_neck |
|
|
|
def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
|
"""Forward RVFM feature only (before translators). |
|
|
|
Args: |
|
x (torch.Tensor): input image. By default it accepts images |
|
in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8. |
|
kwargs (Any): kwargs including mainly those for huggingface preprocessor: |
|
`do_resize` (bool) defaults to True. |
|
`interpolate_pos_encoding` (Optional[bool]) defaults to None. |
|
`do_rescale` (bool) defaults to True. |
|
`do_normalize` (bool) defaults to True. |
|
|
|
Returns: |
|
torch.Tensor: RVFM feature. |
|
""" |
|
feature = self.backbone(x, **kwargs) |
|
|
|
|
|
|
|
return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens) |
|
|
|
def forward(self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, **kwargs: Any) -> dict[str, torch.Tensor] | torch.Tensor: |
|
"""Forward pass of Robot Vision Foundation Model. |
|
|
|
Args: |
|
x (torch.Tensor): input image. By default it accepts images |
|
in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8. |
|
target_model_names (Optional[list[str]]): names of the target foundation models. |
|
kwargs (Any): kwargs including mainly those for huggingface preprocessor: |
|
`do_resize` (bool) defaults to True. |
|
`interpolate_pos_encoding` (Optional[bool]) defaults to None. |
|
`do_rescale` (bool) defaults to True. |
|
`do_normalize` (bool) defaults to True. |
|
|
|
Returns: |
|
if `self.forward_neck`: |
|
torch.Tensor: compact vector feature passed through the neck. [B, C_neck] |
|
else: |
|
dict[str, torch.Tensor]: features that match to each foundation model. |
|
Each feature is in [B, (H*W), C] or [B, C]. |
|
""" |
|
if self.forward_neck: |
|
x = self.forward_feature(x) |
|
return self.neck(x) |
|
else: |
|
x = self.backbone(x, **kwargs) |
|
if self.num_reg_tokens > 0: |
|
x = x[:, :-self.num_reg_tokens] |
|
features = self.translator(x, target_model_names, backbone_no_cls=self.no_cls) |
|
return features |
|
|
|
def get_loss(self, pred_features: dict[str, torch.Tensor], y: dict[str, torch.Tensor]) -> dict[str, Any]: |
|
"""Get loss terms given predictions and targets. |
|
|
|
Args: |
|
pred_features (dict[str, torch.Tensor]): predictions. |
|
y (dict[str, torch.Tensor]): targets. |
|
|
|
Returns: |
|
tuple[Any, ...]: loss terms |
|
""" |
|
mse_loss_avg, cos_loss_avg, l1_loss_avg = 0, 0, 0 |
|
mse_losses_per_model = {} |
|
cos_losses_per_model = {} |
|
l1_losses_per_model = {} |
|
|
|
for t in pred_features: |
|
pred = pred_features[t] |
|
target = y[t] |
|
|
|
|
|
mse_loss = self.mse_loss(pred, target) |
|
weight = self.target_loss_weights if self.target_loss_weights else 1.0 / len(pred_features) |
|
|
|
|
|
l1_loss = self.l1_loss(pred, target) |
|
|
|
|
|
pred_norm = F.normalize(pred.flatten(start_dim=1), dim=1, p=2) |
|
target_norm = F.normalize(target.flatten(start_dim=1), dim=1, p=2) |
|
target = self.cos_target.repeat(pred.size(0)).to(pred.device) |
|
cos_loss = self.cos_loss(pred_norm, target_norm, target) |
|
|
|
mse_loss_avg += mse_loss * weight |
|
cos_loss_avg += cos_loss / len(pred_features) |
|
l1_loss_avg += l1_loss * weight |
|
|
|
mse_losses_per_model[t] = mse_loss.item() |
|
cos_losses_per_model[t] = cos_loss.item() |
|
l1_losses_per_model[t] = l1_loss.item() |
|
|
|
return { |
|
"mse_loss": mse_loss_avg, |
|
"cos_loss": cos_loss_avg, |
|
"l1_loss": l1_loss_avg, |
|
"mse_losses_per_model": mse_losses_per_model, |
|
"cos_losses_per_model": cos_losses_per_model, |
|
"l1_losses_per_model": l1_losses_per_model, |
|
} |