theia-tiny-patch16-224-cdiv / theia_model.py
Brandon May
Upload model
4da7cab verified
raw
history blame
65.5 kB
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
import math
from itertools import chain
from typing import Any, Optional
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import interpolate
from einops.layers.torch import Rearrange
from transformers import PretrainedConfig, PreTrainedModel
from transformers import AutoConfig, AutoModel, AutoProcessor, AutoImageProcessor
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTModel
def handle_feature_output(
x: torch.Tensor, feature_reduce_method: Optional[str] = None, num_discard_tokens: int = 0
) -> torch.Tensor:
"""Handle feature output from transformer.
Args:
x (torch.Tensor): input feature to be handled. shape is
[B, 1+H*W+N, C] if including both CLS and register tokens.
[B, 1+H*W, C] for standard model (N=0).
[B, H*W, C] for model without CLS.
feature_reduce_method (Optional[str]): method to select token. Options:
- `mean_pooling`: average over spatial tokens (non CLS tokens), output shape = [B, C].
- `max_pooling`: max over spatial tokens, output shape = [B, C].
- `cls`: return CLS token only, output shape = [B, C].
- `identity`: return the feature without touching it, output shape = input shape.
- `None`: return spatial tokens, output shape = [B, H*W, C] (assuming input is [B, 1+H*W, C]).
suppose raw feature is in shape [B, 1+H*W, C], `1` corresponds to CLS token.
num_discard_tokens (int):
number of tokens to be discarded. Assuming they are at the end of the sequence.
Returns:
torch.Tensor: selected feature tokens.
"""
match feature_reduce_method:
case "mean_pooling":
return torch.mean(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) # [B, C]
case "max_pooling":
return torch.amax(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) # [B, C]
case "cls":
return x[:, 0] # [B, C]
case "identity":
return x
case None:
return x[:, 1 : x.size(1) - num_discard_tokens]
case _:
raise NotImplementedError(f"feature_reduce_method {feature_reduce_method} it not implemented.")
# Modified from huggingface transformers ViTEmbeddings
# Original Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class ViTEmbeddingsNoCLS(ViTEmbeddings):
"""ViT Embedding Module without CLS token."""
def __init__(self, config: AutoConfig, use_mask_token: bool = False):
"""Initialization.
Args:
config (AutoConfig): config for ViT.
use_mask_token (bool, optional): whether to use mask token. Defaults to False.
"""
super(ViTEmbeddingsNoCLS, self).__init__(config, use_mask_token=use_mask_token)
self.cls_token = None
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1]
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings[:, 1:]
embeddings = self.dropout(embeddings)
return embeddings
# modified from huggingface transformers ViTModel
class ViTModelNoCLS(ViTModel):
"""ViT Model without CLS token."""
def __init__(self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
super(ViTModelNoCLS, self).__init__(config, add_pooling_layer, use_mask_token)
self.embeddings = ViTEmbeddingsNoCLS(config, use_mask_token=use_mask_token)
self.no_cls = True
def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTEmbeddings):
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)
# modified from huggingface transformers ViTEmbeddings
class ViTEmbeddingsReg(ViTEmbeddings):
"""
ViT Embedding Module with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1
"""
def __init__(self, config: AutoConfig, use_mask_token: bool = False, num_reg_tokens: int = 7):
super(ViTEmbeddingsReg, self).__init__(config, use_mask_token=use_mask_token)
self.reg_token = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size))
self.num_reg_tokens = num_reg_tokens
self.reg_pos_embed = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size))
self.reg_pos_embed.data = nn.init.trunc_normal_(
self.reg_pos_embed.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(self.reg_pos_embed.dtype)
self.reg_token.data = nn.init.trunc_normal_(
self.reg_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(self.reg_token.dtype)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1 - self.num_reg_tokens
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
reg_pos_embed = self.reg_pos_embed
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, reg_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
reg_tokens = self.reg_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings, reg_tokens), dim=1)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + torch.cat([self.position_embeddings, self.reg_pos_embed], dim=1)
embeddings = self.dropout(embeddings)
return embeddings
# modified from huggingface transformers ViTModel
class ViTModelReg(ViTModel):
"""ViT Model with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1"""
def __init__(
self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, num_reg_tokens: int = 7
):
super(ViTModelReg, self).__init__(config, add_pooling_layer, use_mask_token)
self.embeddings = ViTEmbeddingsReg(config, use_mask_token=use_mask_token, num_reg_tokens=num_reg_tokens)
self.num_reg_tokens = num_reg_tokens
def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTEmbeddings):
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)
module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.cls_token.dtype)
class DeiT(nn.Module):
"""DeiT model.
Paper: Training data-efficient image transformers & distillation through attention
https://arxiv.org/abs/2012.12877
Huggingface Reference: https://huggingface.co/docs/transformers/en/model_doc/deit
Attributes:
model_name (str): name of the model.
pretrained (bool): whether to use pretrained weights.
"""
def __init__(
self,
model_name: str = "facebook/deit-small-patch16-224",
pretrained: bool = False,
image_size: int = 224,
):
super().__init__()
self.image_size = image_size
model = AutoModel.from_pretrained(model_name)
if pretrained:
self.model = model
else:
deit_config = model.config
self.model = AutoModel.from_config(deit_config)
del model
self.model.pooler = nn.Identity()
self.processor = AutoProcessor.from_pretrained(model_name)
def get_feature_size(
self,
keep_spatial: bool = False,
return_torch_size: bool = False,
) -> torch.Size | tuple[int, ...]:
"""Get the size of the feature.
Args:
keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
return_torch_size (bool): if true, return torch.Size type. Defaults to False.
Returns:
torch.Size | tuple[int, ...]: returned feature shape.
"""
with torch.inference_mode():
image_size = (224, 224)
x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
y = self.forward(x)[:, 1:] # for getting feature size, discard cls token
size = y.size()[1:][::-1]
if keep_spatial:
assert math.isqrt(size[-1])
h = w = int(math.sqrt(size[-1]))
size = (size[0], h, w)
if return_torch_size:
size = torch.Size(size)
return size
def forward(
self,
x: torch.Tensor,
do_resize: bool = True,
interpolate_pos_encoding: Optional[bool] = None,
do_rescale: bool = True,
do_normalize: bool = True,
) -> torch.Tensor:
"""Forward pass of the model
Args:
x (torch.Tensor): model input.
- arguments for self.processor. Details can be find at
https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
do_resize (bool): if do resizing in processor. Defaults to True.
interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
do_normalize (bool): if do normalize in processor. Defaults to True.
Returns:
torch.Tensor: model output.
"""
input = self.processor(
x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
).to(self.model.device)
y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
return y.last_hidden_state
class DeiTNoCLS(nn.Module):
"""Modified DeiT model without CLS token."""
def __init__(
self, model_name: str = "nocls-facebook/deit-small-patch16-224", pretrained: bool = False, image_size: int = 224
):
super().__init__()
self.image_size = image_size
pretrained_model_name = model_name.replace("nocls-", "")
deit_config = AutoConfig.from_pretrained(pretrained_model_name)
self.model = ViTModelNoCLS(deit_config)
if pretrained:
pretrained_model = AutoModel.from_pretrained(pretrained_model_name)
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()}
self.load_state_dict(pretrained_dict, strict=False)
del pretrained_model, pretrained_dict
self.model.pooler = nn.Identity()
self.processor = AutoProcessor.from_pretrained(pretrained_model_name)
self.no_cls = True
def get_feature_size(
self,
keep_spatial: bool = False,
return_torch_size: bool = False,
) -> torch.Size | tuple[int, ...]:
"""Get the size of the feature.
Args:
keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
return_torch_size (bool): if true, return torch.Size type. Defaults to False.
Returns:
torch.Size | tuple[int, ...]: returned feature shape.
"""
with torch.inference_mode():
image_size = (self.image_size, self.image_size)
x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
y = self.forward(x)
size = y.size()[1:][::-1]
if keep_spatial:
assert math.isqrt(size[-1])
h = w = int(math.sqrt(size[-1]))
size = (size[0], h, w)
if return_torch_size:
size = torch.Size(size)
return size
def forward(
self,
x: torch.Tensor,
do_resize: bool = True,
interpolate_pos_encoding: Optional[bool] = None,
do_rescale: bool = True,
do_normalize: bool = True,
) -> torch.Tensor:
"""Forward pass of the model
Args:
x (torch.Tensor): model input.
- arguments for self.processor. Details can be find at
https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
do_resize (bool): if do resizing in processor. Defaults to True.
do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
do_normalize (bool): if do normalize in processor. Defaults to True.
- argument for forward
interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
Returns:
torch.Tensor: model output.
"""
input = self.processor(
x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
).to(self.model.device)
y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
return y.last_hidden_state
class DeiTReg(nn.Module):
"""Modified DeiT model with register tokens."""
def __init__(
self,
model_name: str = "reg-facebook/deit-small-patch16-224",
pretrained: bool = False,
image_size: int = 224,
num_reg_tokens: int = 7,
):
super().__init__()
self.image_size = image_size
pretrained_model_name = model_name.replace("reg-", "")
deit_config = AutoConfig.from_pretrained(pretrained_model_name)
self.model = ViTModelReg(deit_config, num_reg_tokens=num_reg_tokens)
if pretrained:
pretrained_model = AutoModel.from_pretrained(pretrained_model_name)
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()}
self.load_state_dict(pretrained_dict, strict=False)
del pretrained_model, pretrained_dict
self.model.pooler = nn.Identity()
self.processor = AutoProcessor.from_pretrained(pretrained_model_name)
self.num_reg_tokens = num_reg_tokens
def get_feature_size(
self,
keep_spatial: bool = False,
return_torch_size: bool = False,
) -> torch.Size | tuple[int, ...]:
"""Get the size of the feature.
Args:
keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
return_torch_size (bool): if true, return torch.Size type. Defaults to False.
Returns:
torch.Size | tuple[int, ...]: returned feature shape.
"""
with torch.inference_mode():
image_size = (self.image_size, self.image_size)
x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
y = self.forward(x)[:, 1 : -self.num_reg_tokens]
size = y.size()[1:][::-1]
if keep_spatial:
assert math.isqrt(size[-1])
h = w = int(math.sqrt(size[-1]))
size = (size[0], h, w)
if return_torch_size:
size = torch.Size(size)
return size
def forward(
self,
x: torch.Tensor,
do_resize: bool = True,
interpolate_pos_encoding: Optional[bool] = None,
do_rescale: bool = True,
do_normalize: bool = True,
) -> torch.Tensor:
"""Forward pass of the model
Args:
x (torch.Tensor): model input.
- arguments for self.processor. Details can be find at
https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
do_resize (bool): if do resizing in processor. Defaults to True.
interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
do_normalize (bool): if do normalize in processor. Defaults to True.
Returns:
torch.Tensor: model output.
"""
input = self.processor(
x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
).to(self.model.device)
y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
return y.last_hidden_state
def build_backbone(model_name: str, pretrained: bool = False, image_size: int = 224, **kwargs: Any) -> nn.Module:
"""Build the backbone visual encoder of robot vision foundation model.
Args:
model_name (str): name of the model.
pretrained (bool): whether to use pretrained weights. Defaults to False.
image_size (int): size of the image. Assume a square image. Defaults to 224
kwargs (Any): any kwargs specific to some models. For example,
`num_reg_tokens` for `DeiTReg` when `"reg"` in `model_name`
Returns:
nn.Module: backbone network.
"""
if "reg" in model_name:
return DeiTReg(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs)
elif "nocls" in model_name:
return DeiTNoCLS(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs)
elif "deit" in model_name:
return DeiT(model_name=model_name, pretrained=pretrained, image_size=image_size)
else:
raise NotImplementedError(f"Requested {model_name} is not implemented.")
class Interpolation(nn.Module):
"""Interpolation nn.Module wrap for nn.functional.interpolate.
Attributes:
target_size (tuple[int, int] | torch.Size): target spatial size of this interpolation.
"""
def __init__(self, target_size: tuple[int, int] | torch.Size) -> None:
super().__init__()
self.target_size = target_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Very simple forward pass to call interpolate()."""
return interpolate(x, self.target_size)
class LinearAdapterHead(nn.Module):
"""Adapter head contains a single linear layer."""
def __init__(
self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size
):
"""Initialization function for LinearAdapterHead.
Args:
source_size (tuple[int, ...] | torch.Size): the size of the source feature.
target_size (tuple[int, ...] | torch.Size): the size of the target feature.
num_layer (int): number of MLP layers (One linear layer if num_layer = 1).
"""
super().__init__()
self.source_size = source_size
self.target_size = target_size
source_channel_size = self.source_size[0]
target_channel_size = self.target_size[0]
self.adapter = nn.Sequential(
nn.Linear(source_channel_size, target_channel_size),
)
def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
"""Forward pass for the adapter. """
assert backbone_no_cls == False
# x: [B, (1+H*W), C]
# LinearAdapterHead is used only when there is cls token in the backbone.
x = x[:, 0]
x = self.adapter(x)
return x # [B, (H*W), C]
class MLPAdapterHead(nn.Module):
"""MLP Adapter module.
Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t].
Will first do interpolation to match the spatial size [H_t, W_t],
followed by MLP to project to the target channel dimension [C_t].
Attributes:
source_size (tuple[int, ...] | torch.Size): the size of the source feature. [C, H, W]
target_size (tuple[int, ...] | torch.Size): the size of the target feature. [C, H, W]
adapter (nn.Module): the adapter module.
interpolation (nn.Module): interpolation to adjust sizes before MLP.
"""
def __init__(
self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size, num_layer: int
):
"""Initialization function for MLPAdapter.
Args:
source_size (tuple[int, ...] | torch.Size): the size of the source feature.
target_size (tuple[int, ...] | torch.Size): the size of the target feature.
num_layer (int): number of MLP layers (One linear layer if num_layer = 1).
"""
super().__init__()
assert num_layer >= 1, f"`num_layer` in {self._get_name()} should >= 1. Got {num_layer}"
self.source_size = source_size
self.target_size = target_size
source_channel_size = self.source_size[0]
target_channel_size = self.target_size[0]
self.interpolation = nn.Sequential(
nn.Identity(),
)
if self.source_size[1] != self.target_size[1]:
self.interpolation = nn.Sequential(
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
Interpolation(self.target_size[1:]),
Rearrange("b c h w-> b (h w) c"),
)
if num_layer == 1:
self.adapter = nn.Sequential(
nn.Linear(source_channel_size, target_channel_size),
)
elif num_layer >= 2:
hidden_dim = source_channel_size * 2
self.adapter = nn.Sequential(
nn.Linear(source_channel_size, hidden_dim),
*list(
chain.from_iterable([[nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)] for _ in range(num_layer - 2)])
),
nn.ReLU(),
nn.Linear(hidden_dim, target_channel_size),
)
def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
"""Forward pass for the adapter. First interpolation then MLP."""
# x: [B, (1)+H*W, C]
if not backbone_no_cls:
x = x[:, 1:]
# x: [B, (H*W), C]
x = self.interpolation(x)
x = self.adapter(x)
return x # [B, (H*W), C]
class ConvAdapterHead(nn.Module):
"""Convolutional Adapter module.
Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t].
Uses CNN to map channel and spatial sizes jointly.
Note: only work for (16, 16), (any, any), any <= 14, and (64, 64) spatial sizes for now.
Attributes:
source_size (tuple[int, ...] | torch.Size): the size of the source feature.
target_size (tuple[int, ...] | torch.Size): the size of the target feature.
adapter (nn.Module): the adapter module.
interpolation (nn.Module): interpolation to adjust sizes before MLP.
"""
def __init__(
self,
source_size: tuple[int, ...] | torch.Size,
target_size: tuple[int, ...] | torch.Size,
):
"""Initialization function for ConvAdapter.
Args:
source_size (tuple[int, ...] | torch.Size): the size of the source feature.
target_size (tuple[int, ...] | torch.Size): the size of the target feature.
"""
super().__init__()
self.source_size = source_size
self.target_size = target_size
hidden_dim = self.source_size[0] * 2
source_channel_size = self.source_size[0]
target_channel_size = self.target_size[0]
if self.source_size[1] < 12:
raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.")
elif self.source_size[1] < 16: # pad (any, any), any <= 14 to (16, 16)
self.pad = nn.Sequential(
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
nn.ConvTranspose2d(
source_channel_size,
source_channel_size,
kernel_size=3,
stride=1,
output_padding=14 - self.source_size[1],
),
)
self.source_size = (self.source_size[0], 16, 16)
elif self.source_size[1] == 16 or self.source_size[1] == 64: # do nothing for (16, 16) and (64, 64)
self.pad = nn.Sequential(
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
)
else:
raise NotImplementedError("feature spatial size (>=16x16) other than 16x16 and 64x64 is not supported.")
if self.source_size[1] < self.target_size[1]: # (16, 16) / (14, 14) to (64, 64)
self.adapter = nn.Sequential(
nn.LayerNorm(self.source_size),
nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 31
nn.ReLU(),
nn.LayerNorm([hidden_dim, 31, 31]),
nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), # 64
nn.ReLU(),
nn.LayerNorm([hidden_dim, 64, 64]),
nn.ConvTranspose2d(hidden_dim, target_channel_size, kernel_size=3, stride=1, padding=1), # 64
Rearrange("b c h w-> b (h w) c"),
)
elif self.source_size[1] == self.target_size[1]: # (16, 16) to (16, 16)
self.adapter = nn.Sequential(
nn.LayerNorm(self.source_size),
nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), # 16
nn.ReLU(),
nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), # 16
nn.ReLU(),
nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), # 16
Rearrange("b c h w-> b (h w) c"),
)
else: # (64, 64) to (16, 16)
self.adapter = nn.Sequential(
nn.LayerNorm(self.source_size),
nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 32
nn.ReLU(),
nn.LayerNorm([hidden_dim, 32, 32]),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), # 16
nn.ReLU(),
nn.LayerNorm([hidden_dim, 16, 16]),
nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), # 16
Rearrange("b c h w-> b (h w) c"),
)
def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
"""Forward pass for ConvAdapter"""
# x: [B, (1)+H*W, C]
if not backbone_no_cls:
x = x[:, 1:]
# x: [B, H*W, C]
x = self.pad(x)
x = self.adapter(x)
return x # B, (H*W), C
class LightConvAdapterHead(nn.Module):
"""Light Convolutional Adapter module.
Transforms features from source size in [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t].
Uses CNN to map channel and spatial sizes jointly.
Note: only work for source sizes (H_s, W_s): (16, 16), (any, any), 12 <= any <= 14,
and target sizes (H_t, W_t): (16, 16) and (64, 64) for now.
Attributes:
source_size (tuple[int, ...] | torch.Size): the size of the source feature,
channel first (C, H, W).
target_size (tuple[int, ...] | torch.Size): the size of the target feature,
channel first (C, H, W).
adapter (nn.Module): the adapter module.
interpolation (nn.Module): interpolation to adjust sizes before MLP.
"""
def __init__(
self,
source_size: tuple[int, ...] | torch.Size,
target_size: tuple[int, ...] | torch.Size,
hidden_size_factor: int | float = 1.0,
):
"""Initialization function for ConvAdapter.
Args:
source_size (tuple[int, ...] | torch.Size): the size of the source feature.
target_size (tuple[int, ...] | torch.Size): the size of the target feature.
hidden_size_factor (int | float): the size of hidden dim of feature translator
as a factor of input feature hidden dim.
"""
super().__init__()
if source_size[1] != source_size[2] or target_size[1] != target_size[2]:
raise NotImplementedError(
"Currently does not support non-square feature maps like source size"
"{source_size} and target size {target_size}."
)
self.source_size = source_size
self.target_size = target_size
self.hidden_size_factor = hidden_size_factor
hidden_dim = int(self.source_size[0] * hidden_size_factor)
source_channel_size = self.source_size[0]
target_channel_size = self.target_size[0]
if self.source_size[1] < 12:
raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.")
elif self.source_size[1] < 16 and self.target_size[1] >= 16: # pad (any, any), any <= 14 to (16, 16)
self.pad = nn.Sequential(
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
nn.ConvTranspose2d(
source_channel_size,
source_channel_size,
kernel_size=3,
stride=1,
output_padding=14 - self.source_size[1],
),
)
self.source_size = (self.source_size[0], 16, 16)
elif (self.source_size[1] == 16 or self.source_size[1] == 64) or \
(self.source_size[1] == 14 and self.target_size[1] == 14):
# no padding for (16, 16), (64, 64) and (14, 14) <-> (14, 14)
self.pad = nn.Sequential(
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
)
elif self.target_size[1] < 14:
self.pad = nn.Sequential(
Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
)
else:
raise NotImplementedError("feature spatial size larger than 16x16 (other than 64x64) is not supported.")
if self.source_size[1] == 16 and self.target_size[1] == 64: # (16, 16) to (64, 64)
self.adapter = nn.Sequential(
nn.LayerNorm(self.source_size),
nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 31
nn.ReLU(),
nn.LayerNorm([hidden_dim, 31, 31]),
nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), # 64
nn.ReLU(),
nn.LayerNorm([hidden_dim, 64, 64]),
Rearrange("b c h w-> b (h w) c"),
nn.Linear(hidden_dim, target_channel_size),
)
elif self.source_size[1] == self.target_size[1]: # (16, 16) to (16, 16)
self.adapter = nn.Sequential(
nn.LayerNorm(self.source_size),
nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), # 16
nn.ReLU(),
nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), # 16
nn.ReLU(),
nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
Rearrange("b c h w-> b (h w) c"),
nn.Linear(hidden_dim, target_channel_size),
)
elif self.source_size[1] == 64 and self.target_size[1] == 16: # (64, 64) to (16, 16)
self.adapter = nn.Sequential(
nn.LayerNorm(self.source_size),
nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 32
nn.ReLU(),
nn.LayerNorm([hidden_dim, 32, 32]),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), # 16
nn.ReLU(),
nn.LayerNorm([hidden_dim, 16, 16]),
Rearrange("b c h w-> b (h w) c"),
nn.Linear(hidden_dim, target_channel_size),
)
elif self.target_size[1] == 7:
self.adapter = nn.Sequential(
nn.LayerNorm(self.source_size),
nn.Conv2d(source_channel_size, hidden_dim, kernel_size=4, stride=2, padding=1), #14x14 -> 7x7
nn.ReLU(),
nn.LayerNorm([hidden_dim, 7, 7]),
Rearrange("b c h w-> b (h w) c"),
nn.Linear(hidden_dim, target_channel_size)
)
else:
NotImplementedError(f"{self.source_size} to {self.target_size} is not supported.")
def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
"""Forward pass for ConvAdapter"""
# x: [B, (1)+H*W, C]
if not backbone_no_cls:
x = x[:, 1:]
x = self.pad(x)
x = self.adapter(x)
return x # [B, H*W, C]
class FeatureTranslator(nn.Module):
"""Base class for the feature translator.
The flow is backbone_adapter -> translator_stem -> translator_heads.
Attributes:
backbone_feature_size (torch.Size): the size of features of the backbone.
target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048.
target_model_names (list[str]): convenient attribute to hold all the names of the target models.
backbone_adapter (nn.Module): the adapter to map channel dim of backbone to the translator hidden dim.
translator_stem (nn.Module): the shared stem for all target models.
translator_heads (nn.ModuleDict): specific heads for different target models.
"""
def __init__(
self,
backbone_feature_size: torch.Size,
target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
translator_hidden_size: int = 1024,
) -> None:
"""Initalization function for FeatureTranslator.
Args:
backbone_feature_size (torch.Size): the size of features of the backbone.
target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048.
"""
super().__init__()
self.backbone_feature_size = backbone_feature_size # (C, H, W)
self.target_feature_sizes = target_feature_sizes # [(C, H, W)]
self.translator_hidden_size = translator_hidden_size # C
self.target_model_names = list(target_feature_sizes.keys())
self.legit_target_model_name_map: dict[str, str] = {t: t.replace(".", "_") for t in self.target_model_names}
self.translator_heads: nn.ModuleDict = None
self.backbone_adapter = nn.Sequential(
nn.LayerNorm(self.backbone_feature_size[0]), # do a pre-norm
nn.Linear(
self.backbone_feature_size[0], # C in [C,H,W]
self.translator_hidden_size,
),
)
self.translator_stem: nn.Module = nn.Identity()
self.build_translator_heads()
def build_translator_heads(self) -> None:
"""Build translator heads to match the dimension of each target feature set.
Example:
translator_heads: dict[str, nn.Module] = ...
self.translator_heads = nn.ModuleDict(translator_heads)
"""
raise NotImplementedError("build_translator_heads() should be overridden")
def forward(
self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False
) -> torch.Tensor:
"""Forward pass for a base feature translator.
Args:
x (torch.Tensor): input features from the backbone. [B, (1)+H*W, C].
(1) means optional CLS token. If `backbone_no_cls==True`, then [B, H*W, C].
target_model_names (Optional[list[str]]): names of the target models.
backbone_no_cls (bool): indicate backbone has cls token or not.
Can use it to customize whether to drop cls.
Returns:
dict[str, torch.Tensor]: predicted features for target models.
"""
# x: [B, (1)+H*W, C]
x = self.backbone_adapter(x)
x = self.translator_stem(x)
target_model_names = target_model_names if target_model_names is not None else self.target_model_names
features = {t: self.translator_heads[self.legit_target_model_name_map[t]](x, backbone_no_cls=backbone_no_cls) for t in target_model_names}
return features
class MLPFeatureTranslator(FeatureTranslator):
def __init__(
self,
backbone_feature_size: torch.Size,
target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
translator_hidden_size: int = 1024,
translator_n_layer: int = 3,
) -> None:
"""Initalization function for MLPFeatureTranslator.
Args:
backbone_feature_size (torch.Size): the size of features of the backbone.
target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048.
translator_n_layer (int): number of MLP layers. Defaults to 3.
"""
self.translator_n_layer = translator_n_layer
super().__init__(
backbone_feature_size=backbone_feature_size,
target_feature_sizes=target_feature_sizes,
translator_hidden_size=translator_hidden_size,
)
def build_translator_heads(self) -> nn.ModuleDict:
"""Build MLP translator heads to match the dimension of each target feature set."""
translator_heads = {}
source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:])
for target_model, target_size in self.target_feature_sizes.items():
head = MLPAdapterHead(source_size=source_size, target_size=target_size, num_layer=self.translator_n_layer)
translator_heads[self.legit_target_model_name_map[target_model]] = head
self.translator_heads = nn.ModuleDict(translator_heads)
class ConvFeatureTranslator(FeatureTranslator):
def __init__(
self,
backbone_feature_size: torch.Size,
target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
translator_hidden_size: int = 1024,
) -> None:
"""Initalization function for ConvFeatureTranslator.
Args:
backbone_feature_size (torch.Size): the size of features of the backbone.
target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048.
"""
super().__init__(
backbone_feature_size=backbone_feature_size,
target_feature_sizes=target_feature_sizes,
translator_hidden_size=translator_hidden_size,
)
def build_translator_heads(self) -> nn.ModuleDict:
"""Build translator heads to match the dimension of each target feature set.
Returns:
nn.ModuleDict: the translator heads.
"""
translator_heads = {}
source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:])
for target_model, target_size in self.target_feature_sizes.items():
head = ConvAdapterHead(source_size=source_size, target_size=target_size)
translator_heads[self.legit_target_model_name_map[target_model]] = head
self.translator_heads = nn.ModuleDict(translator_heads)
class LightConvFeatureTranslator(FeatureTranslator):
def __init__(
self,
backbone_feature_size: torch.Size,
target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
translator_hidden_size: int = 1024,
hidden_size_factor: int | float = 1.0,
) -> None:
"""Initalization function for LightConvFeatureTranslator.
It's for a smaller translator compared to ConvFeatureTranslator.
Args:
backbone_feature_size (torch.Size): the size of features of the backbone.
target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 1024.
hidden_size_factor: the size of hidden dim of feature translator
as a factor of input feature hidden dim. Defaults to 1.0
"""
self.hidden_size_factor = hidden_size_factor
super().__init__(
backbone_feature_size=backbone_feature_size,
target_feature_sizes=target_feature_sizes,
translator_hidden_size=translator_hidden_size,
)
self.backbone_adapter = nn.Identity()
def build_translator_heads(self) -> nn.ModuleDict:
"""Build translator heads to match the dimension of each target feature set.
Returns:
nn.ModuleDict: the translator heads.
"""
translator_heads = {}
for target_model, target_size in self.target_feature_sizes.items():
if "_cls" in target_model:
head = LinearAdapterHead(
source_size=self.backbone_feature_size,
target_size=target_size
)
else:
head = LightConvAdapterHead(
source_size=self.backbone_feature_size,
target_size=target_size,
hidden_size_factor=self.hidden_size_factor
)
translator_heads[self.legit_target_model_name_map[target_model]] = head
self.translator_heads = nn.ModuleDict(translator_heads)
class TransformerFreatureTranslator(FeatureTranslator):
def __init__(
self,
backbone_feature_size: torch.Size,
target_feature_sizes: dict[str, torch.Size | tuple[int, int]],
translator_hidden_size: int = 1024,
translator_n_layers: int = 2,
translator_n_heads: int = 8,
translator_activation: str = "gelu",
) -> None:
super().__init__(
backbone_feature_size=backbone_feature_size,
target_feature_sizes=target_feature_sizes,
translator_hidden_size=translator_hidden_size,
)
self.translator_stem = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=translator_hidden_size,
nhead=translator_n_heads,
dim_feedforward=translator_hidden_size * 2,
activation=translator_activation,
batch_first=True,
norm_first=True,
),
num_layers=translator_n_layers,
)
self.decode_tokens = nn.Parameter(
torch.randn((1, math.prod(self.backbone_feature_size[1:]), translator_hidden_size))
)
self.target_model_emb = nn.ParameterDict(
{
self.legit_target_model_name_map[t]: torch.randn(1, 1, translator_hidden_size)
for t in self.target_model_names
}
)
def build_translator_heads(self) -> None:
"""Build Transformer translator heads to match the dimension of each target feature set."""
translator_heads = {}
for target_model, target_size in self.target_feature_sizes.items():
head = MLPAdapterHead(
source_size=(self.translator_hidden_size, *self.backbone_feature_size[1:]),
target_size=target_size,
num_layer=2,
)
translator_heads[self.legit_target_model_name_map[target_model]] = head
self.translator_heads = nn.ModuleDict(translator_heads)
def forward(
self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False
) -> torch.Tensor:
"""Forward pass for a simple linear translator.
Args:
x (torch.Tensor): input features from the backbone.
target_model_names (Optional[str]): names of the target models.
backbone_no_cls (bool): indicate backbone has cls token or not.
Can use it to customize whether to drop cls.
Returns:
dict[str, torch.Tensor]: predicted features for target models.
"""
if not backbone_no_cls:
x = x[:, 1:]
x = self.backbone_adapter(x)
features = {}
target_model_names = target_model_names if target_model_names is not None else self.target_model_names
for t in target_model_names:
feature = self.translator_stem(
torch.cat(
[
self.decode_tokens.repeat(x.size(0), 1, 1),
self.target_model_emb[self.legit_target_model_name_map[t]].repeat(x.size(0), 1, 1),
],
dim=1,
),
memory=x,
)[:, 1:, ...]
features[t] = self.translator_heads[self.legit_target_model_name_map[t]](feature)
return features
def build_feature_translator(translator_type: str, **kwargs: Any) -> FeatureTranslator:
"""Handy function to build feature translators given the type
Args:
translator_type (str): the type of the translator,
one in `"mlp"`, `"conv"`, `"lconv"`, `"transformer"` (or `"trans"`).
At the moment we are actively using `"lconv"`.
Returns:
FeatureTranslator: the corresponding FeatureTranslator
"""
if translator_type == "mlp":
return MLPFeatureTranslator(**kwargs)
elif translator_type == "conv":
return ConvFeatureTranslator(**kwargs)
elif translator_type == "lconv":
return LightConvFeatureTranslator(**kwargs)
elif translator_type == "transformer" or translator_type == "trans":
return TransformerFreatureTranslator(**kwargs)
else:
raise NotImplementedError(f"Requested {translator_type} is not implemented yet.")
class TheiaConfig(PretrainedConfig):
def __init__(
self,
backbone: str | nn.Module = "facebook/deit-tiny-patch16-224",
pretrained: bool = False,
target_feature_sizes: Optional[dict[str, torch.Size | tuple[int, ...]]] = None,
translator_type: str = "lconv",
translator_hidden_size_factor: float | int = 1.0,
target_loss_weights: Optional[dict[str, float]] = None,
feature_reduce_method: Optional[str] = None,
feature_neck: bool = False,
feature_neck_hidden_dim: int = 256,
forward_neck: bool = False,
feature_neck_nonlinearity: str = "relu",
iamge_size: int = 224,
num_reg_tokens: int = 0,
**kwargs: Any
):
self.backbone = backbone
self.pretrained = pretrained
self.target_feature_sizes = target_feature_sizes
self.translator_type = translator_type
self.translator_hidden_size_factor = translator_hidden_size_factor
self.target_loss_weights = target_loss_weights
self.feature_reduce_method = feature_reduce_method
self.feature_neck = feature_neck
self.feature_neck_hidden_dim = feature_neck_hidden_dim
self.forward_neck = forward_neck
self.feature_neck_nonlinearity = feature_neck_nonlinearity
self.image_size = 224
self.num_reg_tokens = num_reg_tokens
super().__init__(**kwargs)
class TheiaModel(PreTrainedModel):
config_class = TheiaConfig
def __init__(self, config: TheiaConfig):
super().__init__(config)
self.target_feature_sizes = config.target_feature_sizes
self.preprocessor = None
self.pretrained = config.pretrained
# backbone
self.image_size = config.image_size
if "reg" in config.backbone:
self.backbone: nn.Module = build_backbone(config.backbone, config.pretrained, image_size=config.image_size, num_reg_tokens = config.num_reg_tokens)
else:
self.backbone: nn.Module = build_backbone(config.backbone, config.pretrained, image_size=config.image_size)
# handle output feature (feature reduce)
self.feature_reduce_method = config.feature_reduce_method
self.no_cls = hasattr(self.backbone, "no_cls")
self.num_reg_tokens = self.backbone.num_reg_tokens if hasattr(self.backbone, "num_reg_tokens") else 0
# translator
backbone_feature_size = self.backbone.get_feature_size(keep_spatial=True)
if self.target_feature_sizes:
translator_kwargs = {
"hidden_size_factor": config.translator_hidden_size_factor
}
translator_kwargs["backbone_feature_size"] = backbone_feature_size
translator_kwargs["target_feature_sizes"] = config.target_feature_sizes
self.translator = build_feature_translator(
config.translator_type, **translator_kwargs
)
else:
self.translator = None
self.feature_neck = config.feature_neck
self.feature_neck_hidden_dim = config.feature_neck_hidden_dim
self.forward_neck = config.forward_neck
if self.feature_neck:
num_tokens_edge = self.backbone.model.config.image_size // self.backbone.model.config.patch_size
self.neck = nn.Sequential(
Rearrange("b (h w) c -> b c h w", h=num_tokens_edge, w=num_tokens_edge),
nn.Conv2d(self.backbone.model.config.hidden_size, self.feature_neck_hidden_dim, kernel_size=4, stride=2, padding=1), #14x14 -> 7x7
nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), # just to keep the same as super class
nn.Conv2d(self.feature_neck_hidden_dim, self.feature_neck_hidden_dim, kernel_size=3, stride=2), #7x7 -> 3x3
nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(),
nn.Conv2d(self.feature_neck_hidden_dim, self.feature_neck_hidden_dim, kernel_size=3, stride=1), #3x3 -> 1x1
nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(),
nn.Flatten()
)
else:
self.neck = None
# loss
self.mse_loss = nn.MSELoss()
self.l1_loss = nn.SmoothL1Loss()
self.cos_loss = nn.CosineEmbeddingLoss()
self.cos_target = torch.ones((1), dtype=torch.int, requires_grad=False)
self.target_loss_weights = config.target_loss_weights
def load_pretrained_weights(self, checkpoint_path: str) -> None:
"""
Load weights from `checkpoint_path` manually.
Args:
checkpoint_path (str): path to the weights.
"""
# load theia weights
if checkpoint_path:
weights_dict = torch.load(checkpoint_path, map_location="cpu")
# Filter out unnecessary keys
pretrained_dict = {k: v for k, v in weights_dict.items() if k in self.state_dict()}
self.load_state_dict(pretrained_dict, strict=False)
def freeze_translator(self) -> None:
"""Freeze feature translators `self.translator`."""
if self.translator is not None:
for param in self.translator.parameters():
param.requires_grad = False
def freeze_backbone(self) -> None:
"""Freeze backbone (encoder) `self.backbone`. """
self.freeze_encoder()
def freeze_encoder(self) -> None:
"""Freeze backbone (encoder) `self.backbone`. """
for param in self.backbone.parameters():
param.requires_grad = False
def freeze_neck(self) -> None:
"""Freeze feature neck `self.neck`."""
if self.neck is not None:
for param in self.neck.parameters():
param.requires_grad = False
def freeze_everything(self) -> None:
"""Freeze all parameters in the model."""
self.freeze_translator()
self.freeze_neck()
self.freeze_encoder()
def unfreeze_translator(self) -> None:
if self.translator is not None:
for param in self.translator.parameters():
param.requires_grad = True
def unfreeze_backbone(self) -> None:
"Set parameters in backbone (encoder) `self.backbone` trainable."
self.unfreeze_encoder()
def unfreeze_encoder(self) -> None:
"Set parameters in backbone (encoder) `self.backbone` trainable."
for param in self.backbone.parameters():
param.requires_grad = True
def unfreeze_neck(self) -> None:
"Set parameters in feature neck `self.neck` trainable."
if self.neck is not None:
for param in self.neck.parameters():
param.requires_grad = True
def unfreeze_everything(self) -> None:
"""Set all parameters trainable."""
self.unfreeze_translator()
self.unfreeze_neck()
self.unfreeze_encoder()
def set_forward_neck(self, forward_neck: bool = True) -> None:
"""
Set `self.forward_neck` to `forward_neck` value.
Args:
forward_neck (bool): whether forward the feature through the random initialized neck.
If set to True, the output from `self.forward()` will be in shape [batch_size, self.config.feature_neck_hidden_dim]
"""
self.forward_neck = forward_neck
def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Forward RVFM feature only (before translators).
Args:
x (torch.Tensor): input image. By default it accepts images
in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
kwargs (Any): kwargs including mainly those for huggingface preprocessor:
`do_resize` (bool) defaults to True.
`interpolate_pos_encoding` (Optional[bool]) defaults to None.
`do_rescale` (bool) defaults to True.
`do_normalize` (bool) defaults to True.
Returns:
torch.Tensor: RVFM feature.
"""
feature = self.backbone(x, **kwargs)
# [B, 1+H*W+N, C] if including both CLS and register tokens.
# [B, 1+H*W, C] for standard model (N=0).
# [B, H*W, C] for model without CLS.
return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens)
def forward(self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, **kwargs: Any) -> dict[str, torch.Tensor] | torch.Tensor:
"""Forward pass of Robot Vision Foundation Model.
Args:
x (torch.Tensor): input image. By default it accepts images
in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
target_model_names (Optional[list[str]]): names of the target foundation models.
kwargs (Any): kwargs including mainly those for huggingface preprocessor:
`do_resize` (bool) defaults to True.
`interpolate_pos_encoding` (Optional[bool]) defaults to None.
`do_rescale` (bool) defaults to True.
`do_normalize` (bool) defaults to True.
Returns:
if `self.forward_neck`:
torch.Tensor: compact vector feature passed through the neck. [B, C_neck]
else:
dict[str, torch.Tensor]: features that match to each foundation model.
Each feature is in [B, (H*W), C] or [B, C].
"""
if self.forward_neck:
x = self.forward_feature(x)
return self.neck(x)
else:
x = self.backbone(x, **kwargs)
if self.num_reg_tokens > 0:
x = x[:, :-self.num_reg_tokens] # [B, (1)+H*W, C]
features = self.translator(x, target_model_names, backbone_no_cls=self.no_cls) # each is [B, H*W, C] or [B, C]
return features
def get_loss(self, pred_features: dict[str, torch.Tensor], y: dict[str, torch.Tensor]) -> dict[str, Any]:
"""Get loss terms given predictions and targets.
Args:
pred_features (dict[str, torch.Tensor]): predictions.
y (dict[str, torch.Tensor]): targets.
Returns:
tuple[Any, ...]: loss terms
"""
mse_loss_avg, cos_loss_avg, l1_loss_avg = 0, 0, 0
mse_losses_per_model = {}
cos_losses_per_model = {}
l1_losses_per_model = {}
for t in pred_features:
pred = pred_features[t]
target = y[t]
# mse loss
mse_loss = self.mse_loss(pred, target)
weight = self.target_loss_weights if self.target_loss_weights else 1.0 / len(pred_features)
# l1 loss
l1_loss = self.l1_loss(pred, target)
# cos loss
pred_norm = F.normalize(pred.flatten(start_dim=1), dim=1, p=2)
target_norm = F.normalize(target.flatten(start_dim=1), dim=1, p=2)
target = self.cos_target.repeat(pred.size(0)).to(pred.device)
cos_loss = self.cos_loss(pred_norm, target_norm, target)
mse_loss_avg += mse_loss * weight
cos_loss_avg += cos_loss / len(pred_features) # balance cos by default for meaningful eval
l1_loss_avg += l1_loss * weight
mse_losses_per_model[t] = mse_loss.item()
cos_losses_per_model[t] = cos_loss.item()
l1_losses_per_model[t] = l1_loss.item()
return {
"mse_loss": mse_loss_avg,
"cos_loss": cos_loss_avg,
"l1_loss": l1_loss_avg,
"mse_losses_per_model": mse_losses_per_model,
"cos_losses_per_model": cos_losses_per_model,
"l1_losses_per_model": l1_losses_per_model,
}