Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import math | |
import warnings | |
from collections import defaultdict | |
from dataclasses import field, dataclass | |
from typing import Any, Dict, List, Optional, Tuple, Union, Callable | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import io | |
from PIL import Image | |
import numpy as np | |
logger = logging.getLogger(__name__) | |
_RESNET_MEAN = [0.485, 0.456, 0.406] | |
_RESNET_STD = [0.229, 0.224, 0.225] | |
class MultiScaleImageFeatureExtractor(nn.Module): | |
def __init__( | |
self, | |
modelname: str = "dino_vits16", | |
freeze: bool = False, | |
scale_factors: list = [1, 1 / 2, 1 / 3], | |
): | |
super().__init__() | |
self.freeze = freeze | |
self.scale_factors = scale_factors | |
if "res" in modelname: | |
self._net = getattr(torchvision.models, modelname)(pretrained=True) | |
self._output_dim = self._net.fc.weight.shape[1] | |
self._net.fc = nn.Identity() | |
elif "dino" in modelname: | |
self._net = torch.hub.load("facebookresearch/dino:main", modelname) | |
self._output_dim = self._net.norm.weight.shape[0] | |
else: | |
raise ValueError(f"Unknown model name {modelname}") | |
for name, value in ( | |
("_resnet_mean", _RESNET_MEAN), | |
("_resnet_std", _RESNET_STD), | |
): | |
self.register_buffer( | |
name, | |
torch.FloatTensor(value).view(1, 3, 1, 1), | |
persistent=False, | |
) | |
if self.freeze: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def get_output_dim(self): | |
return self._output_dim | |
def forward(self, image_rgb: torch.Tensor) -> torch.Tensor: | |
img_normed = self._resnet_normalize_image(image_rgb) | |
features = self._compute_multiscale_features(img_normed) | |
return features | |
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor: | |
return (img - self._resnet_mean) / self._resnet_std | |
def _compute_multiscale_features( | |
self, img_normed: torch.Tensor | |
) -> torch.Tensor: | |
multiscale_features = None | |
if len(self.scale_factors) <= 0: | |
raise ValueError( | |
f"Wrong format of self.scale_factors: {self.scale_factors}" | |
) | |
for scale_factor in self.scale_factors: | |
if scale_factor == 1: | |
inp = img_normed | |
else: | |
inp = self._resize_image(img_normed, scale_factor) | |
if multiscale_features is None: | |
multiscale_features = self._net(inp) | |
else: | |
multiscale_features += self._net(inp) | |
averaged_features = multiscale_features / len(self.scale_factors) | |
return averaged_features | |
def _resize_image(image: torch.Tensor, scale_factor: float) -> torch.Tensor: | |
return nn.functional.interpolate( | |
image, | |
scale_factor=scale_factor, | |
mode="bilinear", | |
align_corners=False, | |
) | |