# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import math import warnings from collections import defaultdict from dataclasses import field, dataclass from typing import Any, Dict, List, Optional, Tuple, Union, Callable import torch import torch.nn as nn import torchvision import io from PIL import Image import numpy as np logger = logging.getLogger(__name__) _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] class MultiScaleImageFeatureExtractor(nn.Module): def __init__( self, modelname: str = "dino_vits16", freeze: bool = False, scale_factors: list = [1, 1 / 2, 1 / 3], ): super().__init__() self.freeze = freeze self.scale_factors = scale_factors if "res" in modelname: self._net = getattr(torchvision.models, modelname)(pretrained=True) self._output_dim = self._net.fc.weight.shape[1] self._net.fc = nn.Identity() elif "dino" in modelname: self._net = torch.hub.load("facebookresearch/dino:main", modelname) self._output_dim = self._net.norm.weight.shape[0] else: raise ValueError(f"Unknown model name {modelname}") for name, value in ( ("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD), ): self.register_buffer( name, torch.FloatTensor(value).view(1, 3, 1, 1), persistent=False, ) if self.freeze: for param in self.parameters(): param.requires_grad = False def get_output_dim(self): return self._output_dim def forward(self, image_rgb: torch.Tensor) -> torch.Tensor: img_normed = self._resnet_normalize_image(image_rgb) features = self._compute_multiscale_features(img_normed) return features def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor: return (img - self._resnet_mean) / self._resnet_std def _compute_multiscale_features( self, img_normed: torch.Tensor ) -> torch.Tensor: multiscale_features = None if len(self.scale_factors) <= 0: raise ValueError( f"Wrong format of self.scale_factors: {self.scale_factors}" ) for scale_factor in self.scale_factors: if scale_factor == 1: inp = img_normed else: inp = self._resize_image(img_normed, scale_factor) if multiscale_features is None: multiscale_features = self._net(inp) else: multiscale_features += self._net(inp) averaged_features = multiscale_features / len(self.scale_factors) return averaged_features @staticmethod def _resize_image(image: torch.Tensor, scale_factor: float) -> torch.Tensor: return nn.functional.interpolate( image, scale_factor=scale_factor, mode="bilinear", align_corners=False, )