# Copyright (C) 2024 Apple Inc. All Rights Reserved. # Factory functions to build and load ViT models. from __future__ import annotations import logging import types from dataclasses import dataclass from typing import Dict, List, Literal, Optional import timm import torch import torch.nn as nn from .vit import ( forward_features_eva_fixed, make_vit_b16_backbone, resize_patch_embed, resize_vit, ) LOGGER = logging.getLogger(__name__) ViTPreset = Literal[ "dinov2l16_384", ] @dataclass class ViTConfig: """Configuration for ViT.""" in_chans: int embed_dim: int img_size: int = 384 patch_size: int = 16 # In case we need to rescale the backbone when loading from timm. timm_preset: Optional[str] = None timm_img_size: int = 384 timm_patch_size: int = 16 # The following 2 parameters are only used by DPT. See dpt_factory.py. encoder_feature_layer_ids: List[int] = None """The layers in the Beit/ViT used to constructs encoder features for DPT.""" encoder_feature_dims: List[int] = None """The dimension of features of encoder layers from Beit/ViT features for DPT.""" VIT_CONFIG_DICT: Dict[ViTPreset, ViTConfig] = { "dinov2l16_384": ViTConfig( in_chans=3, embed_dim=1024, encoder_feature_layer_ids=[5, 11, 17, 23], encoder_feature_dims=[256, 512, 1024, 1024], img_size=384, patch_size=16, timm_preset="vit_large_patch14_dinov2", timm_img_size=518, timm_patch_size=14, ), } def create_vit( preset: ViTPreset, use_pretrained: bool = False, checkpoint_uri: str | None = None, use_grad_checkpointing: bool = False, ) -> nn.Module: """Create and load a VIT backbone module. Args: ---- preset: The VIT preset to load the pre-defined config. use_pretrained: Load pretrained weights if True, default is False. checkpoint_uri: Checkpoint to load the wights from. use_grad_checkpointing: Use grandient checkpointing. Returns: ------- A Torch ViT backbone module. """ config = VIT_CONFIG_DICT[preset] img_size = (config.img_size, config.img_size) patch_size = (config.patch_size, config.patch_size) if "eva02" in preset: model = timm.create_model(config.timm_preset, pretrained=use_pretrained) model.forward_features = types.MethodType(forward_features_eva_fixed, model) else: model = timm.create_model( config.timm_preset, pretrained=use_pretrained, dynamic_img_size=True ) model = make_vit_b16_backbone( model, encoder_feature_dims=config.encoder_feature_dims, encoder_feature_layer_ids=config.encoder_feature_layer_ids, vit_features=config.embed_dim, use_grad_checkpointing=use_grad_checkpointing, ) if config.patch_size != config.timm_patch_size: model.model = resize_patch_embed(model.model, new_patch_size=patch_size) if config.img_size != config.timm_img_size: model.model = resize_vit(model.model, img_size=img_size) if checkpoint_uri is not None: state_dict = torch.load(checkpoint_uri, map_location="cpu") missing_keys, unexpected_keys = model.load_state_dict( state_dict=state_dict, strict=False ) if len(unexpected_keys) != 0: raise KeyError(f"Found unexpected keys when loading vit: {unexpected_keys}") if len(missing_keys) != 0: raise KeyError(f"Keys are missing when loading vit: {missing_keys}") LOGGER.info(model) return model.model