File size: 4,361 Bytes
3c63951 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from distutils.version import LooseVersion
from typing import List, Optional, Tuple, Union
import warnings
import torch
from torch import nn
from timm.models.registry import register_model
from .forward_intermediates import forward_intermediates
class PaliGemmaWrapper(nn.Module):
def __init__(self, vis_model: nn.Module, embed_dim: int):
super().__init__()
self.vis_model = vis_model
self.embed_dim = embed_dim
@property
def patch_size(self):
return self.vis_model.embeddings.patch_size
@property
def blocks(self):
return self.vis_model.encoder.layers
@property
def embed_dim(self):
return self.vis_model.embeddings.embed_dim
def forward(self, x: torch.Tensor):
outputs = self.vis_model(
x,
return_dict=False,
interpolate_pos_encoding=True,
)
features = outputs[0].to(torch.float32)
summary = features.mean(dim=1)
return summary, features
def forward_features(self, x: torch.Tensor):
return self(x)
def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16):
from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version
if LooseVersion(tx_version) > LooseVersion('4.44.2'):
warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.')
extra_args = dict()
if dtype is not None:
extra_args['torch_dtype'] = dtype
rev = str(dtype).split('.')[-1]
extra_args['revision'] = rev
model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)
vis_model = model.vision_tower.vision_model
vis_model = PaliGemmaWrapper(vis_model, embed_dim)
return vis_model
@register_model
def paligemma_896_student(**kwargs):
model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None)
return model
def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs):
if cache_dir:
torch.hub.set_dir(cache_dir)
model = torch.hub.load(
'facebookresearch/dinov2',
dino_v2_model,
pretrained=pretrained,
# **kwargs,
)
return model
class DinoWrapper(nn.Module):
def __init__(self, dino_model: nn.Module):
super().__init__()
self.inner = dino_model
dino_model.blocks = nn.Sequential(*dino_model.blocks)
@property
def embed_dim(self):
return self.inner.embed_dim
@property
def patch_size(self):
return self.inner.patch_size
@property
def num_cls_tokens(self):
return getattr(self.inner, 'num_tokens', 1)
@property
def num_registers(self):
return getattr(self.inner, 'num_register_tokens', 0)
@property
def num_summary_tokens(self):
return self.num_cls_tokens + self.num_registers
@property
def blocks(self):
return self.inner.blocks
def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
parts = self.inner.forward_features(*args, **kwargs)
cls_token = parts['x_norm_clstoken']
features = parts['x_norm_patchtokens']
return cls_token, features
def forward_features(self, x: torch.Tensor):
x = self.inner.prepare_tokens_with_masks(x)
x = self.inner.blocks(x)
x_norm = self.inner.norm(x)
return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]
def patchify(self, x: torch.Tensor) -> torch.Tensor:
return self.inner.prepare_tokens_with_masks(x)
def forward_intermediates(self,
x: torch.Tensor,
norm: bool = False,
**kwargs,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
return forward_intermediates(
self,
patch_extractor=self.inner.prepare_tokens_with_masks,
num_summary_tokens=self.num_summary_tokens,
num_cls_tokens=self.num_cls_tokens,
norm=self.inner.norm if norm else lambda y: y,
x=x,
**kwargs,
)
@register_model
def dino_v2_g_student(**kwargs):
model = _load_dino_v2('dinov2_vitg14_reg', pretrained=False)
model = DinoWrapper(model)
return model
|