|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import namedtuple |
|
from typing import Callable, Dict, Optional, List, Union |
|
|
|
from timm.models import VisionTransformer |
|
import torch |
|
from torch import nn |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
from .common import RESOURCE_MAP, DEFAULT_VERSION |
|
|
|
|
|
from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput |
|
from .adaptor_generic import GenericAdaptor, AdaptorBase |
|
from .adaptor_mlp import create_mlp_from_config |
|
from .adaptor_registry import adaptor_registry |
|
from .cls_token import ClsToken |
|
from .enable_cpe_support import enable_cpe |
|
from .enable_spectral_reparam import configure_spectral_reparam_from_args |
|
from .eradio_model import eradio |
|
from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer |
|
from .radio_model import create_model_from_args |
|
from .radio_model import RADIOModel as RADIOModelBase, Resolution |
|
from .input_conditioner import get_default_conditioner, InputConditioner |
|
from .open_clip_adaptor import OpenCLIP_RADIO |
|
from .vit_patch_generator import ViTPatchGenerator |
|
from .vitdet import apply_vitdet_arch, VitDetArgs |
|
|
|
|
|
from .extra_timm_models import * |
|
|
|
|
|
|
|
def rename_all_gamma_to_weight_with_proxy(module): |
|
""" |
|
Renames all parameters named 'gamma' in a module (including submodules) |
|
to 'weight' and sets up a property so that accesses to 'gamma' still work. |
|
""" |
|
|
|
for submodule_name, submodule in module.named_modules(): |
|
|
|
for param_name, param in list(submodule.named_parameters(recurse=False)): |
|
if 'gamma' in param_name: |
|
|
|
new_name = param_name.replace('gamma', 'weight') |
|
|
|
|
|
delattr(submodule, param_name) |
|
setattr(submodule, new_name, nn.Parameter(param.data)) |
|
|
|
|
|
def make_property(old_name, new_name): |
|
return property(lambda self: getattr(self, new_name), |
|
lambda self, value: setattr(self, new_name, value)) |
|
|
|
|
|
setattr(submodule.__class__, param_name, make_property(param_name, new_name)) |
|
|
|
|
|
class RADIOConfig(PretrainedConfig): |
|
"""Pretrained Hugging Face configuration for RADIO models.""" |
|
|
|
def __init__( |
|
self, |
|
args: Optional[dict] = None, |
|
version: Optional[str] = DEFAULT_VERSION, |
|
patch_size: Optional[int] = None, |
|
max_resolution: Optional[int] = None, |
|
preferred_resolution: Optional[Resolution] = None, |
|
adaptor_names: Union[str, List[str]] = None, |
|
adaptor_configs: Dict[str, Dict[str, int]] = None, |
|
vitdet_window_size: Optional[int] = None, |
|
feature_normalizer_config: Optional[dict] = None, |
|
inter_feature_normalizer_config: Optional[dict] = None, |
|
rename_gamma_to_weight: bool = False, |
|
**kwargs, |
|
): |
|
self.args = args |
|
for field in ["dtype", "amp_dtype"]: |
|
if self.args is not None and field in self.args: |
|
|
|
|
|
|
|
self.args[field] = str(args[field]).split(".")[-1] |
|
self.version = version |
|
resource = RESOURCE_MAP[version] |
|
self.patch_size = patch_size or resource.patch_size |
|
self.max_resolution = max_resolution or resource.max_resolution |
|
self.preferred_resolution = ( |
|
preferred_resolution or resource.preferred_resolution |
|
) |
|
self.adaptor_names = adaptor_names |
|
self.adaptor_configs = adaptor_configs |
|
self.vitdet_window_size = vitdet_window_size |
|
self.feature_normalizer_config = feature_normalizer_config |
|
self.inter_feature_normalizer_config = inter_feature_normalizer_config |
|
self.rename_gamma_to_weight = rename_gamma_to_weight |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
class RADIOModel(PreTrainedModel): |
|
"""Pretrained Hugging Face model for RADIO. |
|
|
|
This class inherits from PreTrainedModel, which provides |
|
HuggingFace's functionality for loading and saving models. |
|
""" |
|
|
|
config_class = RADIOConfig |
|
|
|
def __init__(self, config: RADIOConfig): |
|
super().__init__(config) |
|
|
|
RADIOArgs = namedtuple("RADIOArgs", config.args.keys()) |
|
args = RADIOArgs(**config.args) |
|
self.config = config |
|
|
|
model = create_model_from_args(args) |
|
input_conditioner: InputConditioner = get_default_conditioner() |
|
|
|
dtype = getattr(args, "dtype", torch.float32) |
|
if isinstance(dtype, str): |
|
|
|
dtype = getattr(torch, dtype) |
|
model.to(dtype=dtype) |
|
input_conditioner.dtype = dtype |
|
|
|
summary_idxs = torch.tensor( |
|
[i for i, t in enumerate(args.teachers) if t.get("use_summary", True)], |
|
dtype=torch.int64, |
|
) |
|
|
|
adaptor_configs = config.adaptor_configs |
|
adaptor_names = config.adaptor_names or [] |
|
|
|
adaptors = dict() |
|
for adaptor_name in adaptor_names: |
|
mlp_config = adaptor_configs[adaptor_name] |
|
adaptor = GenericAdaptor(args, None, None, mlp_config) |
|
adaptor.head_idx = mlp_config["head_idx"] |
|
adaptors[adaptor_name] = adaptor |
|
|
|
feature_normalizer = None |
|
if config.feature_normalizer_config is not None: |
|
|
|
feature_normalizer = FeatureNormalizer(config.feature_normalizer_config["embed_dim"]) |
|
|
|
inter_feature_normalizer = None |
|
if config.inter_feature_normalizer_config is not None: |
|
inter_feature_normalizer = IntermediateFeatureNormalizer( |
|
config.inter_feature_normalizer_config["num_intermediates"], |
|
config.inter_feature_normalizer_config["embed_dim"], |
|
rot_per_layer=config.inter_feature_normalizer_config["rot_per_layer"], |
|
dtype=dtype) |
|
|
|
self.radio_model = RADIOModelBase( |
|
model, |
|
input_conditioner, |
|
summary_idxs=summary_idxs, |
|
patch_size=config.patch_size, |
|
max_resolution=config.max_resolution, |
|
window_size=config.vitdet_window_size, |
|
preferred_resolution=config.preferred_resolution, |
|
adaptors=adaptors, |
|
feature_normalizer=feature_normalizer, |
|
inter_feature_normalizer=inter_feature_normalizer, |
|
) |
|
|
|
if config.rename_gamma_to_weight: |
|
rename_all_gamma_to_weight_with_proxy(self.radio_model) |
|
|
|
@property |
|
def adaptors(self) -> nn.ModuleDict: |
|
return self.radio_model.adaptors |
|
|
|
@property |
|
def model(self) -> VisionTransformer: |
|
return self.radio_model.model |
|
|
|
@property |
|
def input_conditioner(self) -> InputConditioner: |
|
return self.radio_model.input_conditioner |
|
|
|
@property |
|
def num_summary_tokens(self) -> int: |
|
return self.radio_model.num_summary_tokens |
|
|
|
@property |
|
def patch_size(self) -> int: |
|
return self.radio_model.patch_size |
|
|
|
@property |
|
def max_resolution(self) -> int: |
|
return self.radio_model.max_resolution |
|
|
|
@property |
|
def preferred_resolution(self) -> Resolution: |
|
return self.radio_model.preferred_resolution |
|
|
|
@property |
|
def window_size(self) -> int: |
|
return self.radio_model.window_size |
|
|
|
@property |
|
def min_resolution_step(self) -> int: |
|
return self.radio_model.min_resolution_step |
|
|
|
def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]: |
|
return self.radio_model.make_preprocessor_external() |
|
|
|
def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution: |
|
return self.radio_model.get_nearest_supported_resolution(height, width) |
|
|
|
def switch_to_deploy(self): |
|
return self.radio_model.switch_to_deploy() |
|
|
|
def forward(self, x: torch.Tensor): |
|
return self.radio_model.forward(x) |
|
|