|
from collections import defaultdict |
|
from contextlib import contextmanager |
|
from logging import getLogger |
|
import math |
|
import sys |
|
from typing import List, Union, Iterable |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
|
|
from timm.models import VisionTransformer |
|
from einops import rearrange |
|
|
|
from .extra_models import DinoWrapper |
|
|
|
DEFAULT_NUM_WINDOWED = 5 |
|
DEFAULT_NUM_GLOBAL = 4 |
|
|
|
|
|
class VitDetArgs: |
|
def __init__(self, |
|
window_size: int, |
|
num_summary_tokens: int, |
|
num_windowed: int = None, |
|
num_global: int = None, |
|
): |
|
self.window_size = window_size |
|
self.num_summary_tokens = num_summary_tokens |
|
self.num_windowed = num_windowed |
|
self.num_global = num_global |
|
|
|
|
|
def apply_vitdet_arch(model: Union[VisionTransformer, DinoWrapper], args: VitDetArgs): |
|
if isinstance(model, VisionTransformer): |
|
patch_embed = getattr(model, 'patch_generator', model.patch_embed) |
|
|
|
return ViTDetHook(patch_embed, model.blocks, args) |
|
elif isinstance(model, DinoWrapper): |
|
inner = model.inner |
|
|
|
patch_embed = getattr(inner, 'patch_generator', inner.patch_embed) |
|
return ViTDetHook(patch_embed, inner.blocks, args) |
|
else: |
|
print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr) |
|
|
|
|
|
class ViTDetHook: |
|
def __init__(self, |
|
embedder: nn.Module, |
|
blocks: nn.Sequential, |
|
args: VitDetArgs, |
|
): |
|
self.blocks = blocks |
|
self.num_summary_tokens = args.num_summary_tokens |
|
self.window_size = args.window_size |
|
|
|
self._input_resolution = None |
|
self._num_windows = None |
|
self._cls_patch = None |
|
self._order_cache = dict() |
|
|
|
embedder.register_forward_pre_hook(self._enter_model) |
|
|
|
|
|
|
|
|
|
blocks.register_forward_pre_hook(self._enter_blocks) |
|
|
|
is_global = True |
|
if args.num_windowed is not None: |
|
period = args.num_windowed + 1 |
|
else: |
|
num_global = args.num_global or DEFAULT_NUM_GLOBAL |
|
period = max(len(blocks) // num_global, 1) |
|
|
|
for i, layer in enumerate(blocks[:-1]): |
|
ctr = i % period |
|
if ctr == 0: |
|
layer.register_forward_pre_hook(self._to_windows) |
|
is_global = False |
|
elif ctr == period - 1: |
|
layer.register_forward_pre_hook(self._to_global) |
|
is_global = True |
|
|
|
|
|
if not is_global: |
|
blocks[-1].register_forward_pre_hook(self._to_global) |
|
|
|
blocks.register_forward_hook(self._exit_model) |
|
|
|
def _enter_model(self, _, input: List[torch.Tensor]): |
|
self._input_resolution = input[0].shape[-2:] |
|
|
|
def _enter_blocks(self, _, input: List[torch.Tensor]): |
|
|
|
|
|
patches = input[0] |
|
patches = self._rearrange_patches(patches) |
|
|
|
return (patches,) + input[1:] |
|
|
|
def _to_windows(self, _, input: List[torch.Tensor]): |
|
patches = input[0] |
|
|
|
if self.num_summary_tokens: |
|
self._cls_patch = patches[:, :self.num_summary_tokens] |
|
patches = patches[:, self.num_summary_tokens:] |
|
|
|
patches = rearrange( |
|
patches, 'b (p t) c -> (b p) t c', |
|
p=self._num_windows, t=self.window_size ** 2, |
|
) |
|
|
|
return (patches,) + input[1:] |
|
|
|
def _to_global(self, _, input: List[torch.Tensor]): |
|
patches = input[0] |
|
|
|
patches = rearrange( |
|
patches, '(b p) t c -> b (p t) c', |
|
p=self._num_windows, t=self.window_size ** 2, |
|
b=patches.shape[0] // self._num_windows, |
|
) |
|
|
|
if self.num_summary_tokens: |
|
patches = torch.cat([ |
|
self._cls_patch, |
|
patches, |
|
], dim=1) |
|
|
|
return (patches,) + input[1:] |
|
|
|
def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor): |
|
|
|
patch_order = self._order_cache[self._input_resolution][0] |
|
patch_order = patch_order.reshape(1, -1, 1).expand_as(patches) |
|
|
|
ret_patches = torch.empty_like(patches) |
|
ret_patches = torch.scatter( |
|
ret_patches, |
|
dim=1, |
|
index=patch_order, |
|
src=patches, |
|
) |
|
|
|
return ret_patches |
|
|
|
def _rearrange_patches(self, patches: torch.Tensor): |
|
|
|
|
|
|
|
|
|
patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None)) |
|
if patch_order is None: |
|
num_feat_patches = patches.shape[1] - self.num_summary_tokens |
|
num_pixels = self._input_resolution[0] * self._input_resolution[1] |
|
|
|
patch_size = int(round(math.sqrt(num_pixels / num_feat_patches))) |
|
rows = self._input_resolution[-2] // patch_size |
|
cols = self._input_resolution[-1] // patch_size |
|
|
|
w_rows = rows // self.window_size |
|
w_cols = cols // self.window_size |
|
|
|
patch_order = torch.arange(0, num_feat_patches, device=patches.device) |
|
|
|
patch_order = rearrange( |
|
patch_order, '(wy py wx px) -> (wy wx py px)', |
|
wy=w_rows, wx=w_cols, |
|
py=self.window_size, px=self.window_size, |
|
) |
|
|
|
if self.num_summary_tokens: |
|
patch_order = torch.cat([ |
|
torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device), |
|
patch_order + self.num_summary_tokens, |
|
]) |
|
|
|
self._num_windows = w_rows * w_cols |
|
self._order_cache[self._input_resolution] = ( |
|
patch_order, |
|
self._num_windows, |
|
) |
|
|
|
patch_order = patch_order.reshape(1, -1, 1).expand_as(patches) |
|
patches = torch.gather(patches, dim=1, index=patch_order) |
|
return patches |
|
|