Upload model
Browse files- cls_token.py +47 -0
- config.json +235 -0
- enable_cpe_support.py +59 -0
- hf_model.py +84 -0
- input_conditioner.py +41 -0
- model.py +40 -0
- pytorch_model.bin +3 -0
- vit_patch_generator.py +291 -0
cls_token.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class ClsToken(nn.Module):
|
6 |
+
def __init__(self, ndim: int,
|
7 |
+
num_tokens: int = 1,
|
8 |
+
enabled: bool = True,
|
9 |
+
register_multiple: int = 0,
|
10 |
+
):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.ndim = ndim
|
14 |
+
self.enabled = enabled
|
15 |
+
self.num_registers = 0
|
16 |
+
self.num_tokens = num_tokens
|
17 |
+
if enabled:
|
18 |
+
if register_multiple > 0:
|
19 |
+
self.num_registers = register_multiple - (num_tokens % register_multiple)
|
20 |
+
|
21 |
+
scale = ndim ** -0.5
|
22 |
+
self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale)
|
23 |
+
else:
|
24 |
+
self.token = None
|
25 |
+
|
26 |
+
self.num_patches = self.num_tokens + self.num_registers
|
27 |
+
|
28 |
+
def disable(self):
|
29 |
+
self.token = None
|
30 |
+
self.enabled = False
|
31 |
+
|
32 |
+
def forward(self, x: torch.Tensor):
|
33 |
+
if self.token is None:
|
34 |
+
return x
|
35 |
+
|
36 |
+
token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
|
37 |
+
x = torch.cat([
|
38 |
+
token,
|
39 |
+
x,
|
40 |
+
], dim=1)
|
41 |
+
|
42 |
+
return x
|
43 |
+
|
44 |
+
def no_weight_decay(self):
|
45 |
+
return [
|
46 |
+
'token',
|
47 |
+
]
|
config.json
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"RADIOModel"
|
4 |
+
],
|
5 |
+
"args": {
|
6 |
+
"aa": null,
|
7 |
+
"amp": true,
|
8 |
+
"amp_dtype": "bfloat16",
|
9 |
+
"amp_impl": "native",
|
10 |
+
"aug_repeats": 0,
|
11 |
+
"aug_splits": 0,
|
12 |
+
"auto_loss_balance_mode": "adaloss",
|
13 |
+
"batch_size": 32,
|
14 |
+
"bn_eps": null,
|
15 |
+
"bn_momentum": null,
|
16 |
+
"cache_dir": null,
|
17 |
+
"channels_last": false,
|
18 |
+
"checkpoint_hist": 10,
|
19 |
+
"class_map": "",
|
20 |
+
"clip_grad": null,
|
21 |
+
"clip_mode": "norm",
|
22 |
+
"cls_token_per_teacher": true,
|
23 |
+
"coco_annotations_file": "/datasets/coco2017-adlsa/annotations/captions_val2017.json",
|
24 |
+
"coco_image_dir": "/datasets/coco2017-adlsa/val2017",
|
25 |
+
"color_jitter": 0.4,
|
26 |
+
"cooldown_epochs": 0,
|
27 |
+
"cpe_max_size": 1050,
|
28 |
+
"crd_loss": false,
|
29 |
+
"crd_loss_weight": 0.8,
|
30 |
+
"crop_pct": null,
|
31 |
+
"cutmix": 0.0,
|
32 |
+
"cutmix_minmax": null,
|
33 |
+
"data_dir": "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/captioning/datacomp/dc1b/stage2",
|
34 |
+
"dataset": "nvgpt4",
|
35 |
+
"dataset_download": false,
|
36 |
+
"debug_full_knn": false,
|
37 |
+
"decay_epochs": 90,
|
38 |
+
"decay_milestones": [
|
39 |
+
90,
|
40 |
+
180,
|
41 |
+
270
|
42 |
+
],
|
43 |
+
"decay_rate": 0.1,
|
44 |
+
"device": "cuda:0",
|
45 |
+
"dist_bn": "reduce",
|
46 |
+
"distributed": true,
|
47 |
+
"drop": 0.0,
|
48 |
+
"drop_block": null,
|
49 |
+
"drop_connect": null,
|
50 |
+
"drop_path": null,
|
51 |
+
"epoch_repeats": 0.0,
|
52 |
+
"epochs": 300,
|
53 |
+
"eval": false,
|
54 |
+
"eval_metric": "knn_top1",
|
55 |
+
"eval_teacher": false,
|
56 |
+
"eval_teacher_only": false,
|
57 |
+
"eval_throughput": false,
|
58 |
+
"experiment": "checkpoints",
|
59 |
+
"fast_norm": false,
|
60 |
+
"feature_summarizer": "cls_token",
|
61 |
+
"feature_upscale_factor": null,
|
62 |
+
"fuser": "",
|
63 |
+
"gp": "avg",
|
64 |
+
"grad_accum_steps": 1,
|
65 |
+
"grad_checkpointing": false,
|
66 |
+
"head_init_bias": null,
|
67 |
+
"head_init_scale": null,
|
68 |
+
"hflip": 0.5,
|
69 |
+
"img_size": null,
|
70 |
+
"in_chans": null,
|
71 |
+
"initial_checkpoint": "",
|
72 |
+
"input_size": null,
|
73 |
+
"interpolation": "",
|
74 |
+
"layer_decay": null,
|
75 |
+
"local_rank": 0,
|
76 |
+
"log_interval": 50,
|
77 |
+
"log_mlflow": false,
|
78 |
+
"log_wandb": true,
|
79 |
+
"loss": "cosine",
|
80 |
+
"loss_auto_balance": false,
|
81 |
+
"lr": 0.001,
|
82 |
+
"lr_base": 0.1,
|
83 |
+
"lr_base_scale": "",
|
84 |
+
"lr_base_size": 256,
|
85 |
+
"lr_cycle_decay": 0.5,
|
86 |
+
"lr_cycle_limit": 1,
|
87 |
+
"lr_cycle_mul": 1.0,
|
88 |
+
"lr_k_decay": 1.0,
|
89 |
+
"lr_noise": null,
|
90 |
+
"lr_noise_pct": 0.67,
|
91 |
+
"lr_noise_std": 1.0,
|
92 |
+
"mean": null,
|
93 |
+
"min_lr": 0,
|
94 |
+
"mixup": 0.0,
|
95 |
+
"mixup_mode": "batch",
|
96 |
+
"mixup_off_epoch": 0,
|
97 |
+
"mixup_prob": 1.0,
|
98 |
+
"mixup_switch_prob": 0.5,
|
99 |
+
"mlp_hidden_size": 1520,
|
100 |
+
"mlp_num_inner": 3,
|
101 |
+
"mlp_version": "v2",
|
102 |
+
"model": "vit_huge_patch14_224",
|
103 |
+
"model_ema": false,
|
104 |
+
"model_ema_decay": 0.9998,
|
105 |
+
"model_ema_force_cpu": false,
|
106 |
+
"model_kwargs": {},
|
107 |
+
"momentum": 0.9,
|
108 |
+
"no_aug": false,
|
109 |
+
"no_ddp_bb": false,
|
110 |
+
"no_prefetcher": false,
|
111 |
+
"no_resume_opt": false,
|
112 |
+
"num_classes": null,
|
113 |
+
"opt": "fusedlamb",
|
114 |
+
"opt_betas": null,
|
115 |
+
"opt_eps": null,
|
116 |
+
"opt_kwargs": {},
|
117 |
+
"output": "/lustre/fs6/portfolios/llmservice/users/mranzinger/output/evfm/dfn_oai/11-29-23_vit-h-14-cpe_dfn-oai-dino_maxres",
|
118 |
+
"patience_epochs": 10,
|
119 |
+
"pin_mem": false,
|
120 |
+
"prefetcher": true,
|
121 |
+
"pretrained": false,
|
122 |
+
"rank": 0,
|
123 |
+
"ratio": [
|
124 |
+
0.75,
|
125 |
+
1.3333333333333333
|
126 |
+
],
|
127 |
+
"recount": 1,
|
128 |
+
"recovery_interval": 0,
|
129 |
+
"register_multiple": 8,
|
130 |
+
"remode": "pixel",
|
131 |
+
"reprob": 0.0,
|
132 |
+
"resplit": false,
|
133 |
+
"resume": "/lustre/fs6/portfolios/llmservice/users/mranzinger/output/evfm/dfn_oai/11-29-23_vit-h-14-cpe_dfn-oai-dino_maxres/checkpoints/last.pth.tar",
|
134 |
+
"save_images": false,
|
135 |
+
"scale": [
|
136 |
+
0.5,
|
137 |
+
1.0
|
138 |
+
],
|
139 |
+
"sched": "cosine",
|
140 |
+
"sched_on_updates": true,
|
141 |
+
"seed": 42,
|
142 |
+
"smoothing": 0.1,
|
143 |
+
"split_bn": false,
|
144 |
+
"start_epoch": null,
|
145 |
+
"std": null,
|
146 |
+
"steps_per_epoch": 2000,
|
147 |
+
"sync_bn": false,
|
148 |
+
"synchronize_step": false,
|
149 |
+
"teachers": [
|
150 |
+
{
|
151 |
+
"amp": true,
|
152 |
+
"amp_dtype": "bfloat16",
|
153 |
+
"batch_size": 16,
|
154 |
+
"fd_loss_weight": 1.0,
|
155 |
+
"fd_normalize": false,
|
156 |
+
"feature_distillation": true,
|
157 |
+
"input_size": 378,
|
158 |
+
"model": "ViT-H-14-378-quickgelu",
|
159 |
+
"name": "clip",
|
160 |
+
"pretrained": "dfn5b",
|
161 |
+
"sample_rate": 16,
|
162 |
+
"summary_loss_weight": 1.0,
|
163 |
+
"type": "open_clip",
|
164 |
+
"vitdet_prob": 0.05,
|
165 |
+
"vitdet_window_sizes": [
|
166 |
+
3,
|
167 |
+
9,
|
168 |
+
9,
|
169 |
+
9
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"amp": false,
|
174 |
+
"amp_dtype": "bfloat16",
|
175 |
+
"batch_size": 16,
|
176 |
+
"fd_loss_weight": 0.8,
|
177 |
+
"fd_normalize": false,
|
178 |
+
"feature_distillation": true,
|
179 |
+
"input_size": 336,
|
180 |
+
"model": "ViT-L/14@336px",
|
181 |
+
"name": "openai_clip",
|
182 |
+
"pretrained": "openai",
|
183 |
+
"sample_rate": 16,
|
184 |
+
"summary_loss_weight": 0.8,
|
185 |
+
"type": "openai_clip",
|
186 |
+
"use_summary": false
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"amp": true,
|
190 |
+
"amp_dtype": "bfloat16",
|
191 |
+
"batch_size": 16,
|
192 |
+
"fd_loss_weight": 1.0,
|
193 |
+
"fd_normalize": false,
|
194 |
+
"feature_distillation": true,
|
195 |
+
"input_size": 224,
|
196 |
+
"model": "dinov2_vitg14",
|
197 |
+
"name": "dino_v2",
|
198 |
+
"sample_rate": 16,
|
199 |
+
"summary_loss_weight": 1.0,
|
200 |
+
"type": "dino_v2"
|
201 |
+
}
|
202 |
+
],
|
203 |
+
"torchcompile": null,
|
204 |
+
"torchscript": false,
|
205 |
+
"train_interpolation": "random",
|
206 |
+
"train_split": "train",
|
207 |
+
"tta": 0,
|
208 |
+
"use_coco": false,
|
209 |
+
"use_multi_epochs_loader": false,
|
210 |
+
"val_data_dir": "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/classification/imagenet-1k/webdataset",
|
211 |
+
"val_img_size": 378,
|
212 |
+
"val_split": "val",
|
213 |
+
"validation_batch_size": 128,
|
214 |
+
"vflip": 0.0,
|
215 |
+
"wandb_entity": "",
|
216 |
+
"wandb_group": "dfn_oai",
|
217 |
+
"wandb_job_type": "",
|
218 |
+
"wandb_name": "",
|
219 |
+
"wandb_project": "",
|
220 |
+
"warmup_epochs": 2.5,
|
221 |
+
"warmup_lr": 1e-05,
|
222 |
+
"warmup_prefix": false,
|
223 |
+
"weight_decay": 2e-05,
|
224 |
+
"worker_seeding": "all",
|
225 |
+
"workers": 4,
|
226 |
+
"world_size": 64
|
227 |
+
},
|
228 |
+
"auto_map": {
|
229 |
+
"AutoConfig": "hf_model.RADIOConfig",
|
230 |
+
"AutoModel": "hf_model.RADIOModel"
|
231 |
+
},
|
232 |
+
"torch_dtype": "float32",
|
233 |
+
"transformers_version": "4.29.0",
|
234 |
+
"version": "v1"
|
235 |
+
}
|
enable_cpe_support.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Tuple
|
2 |
+
from types import MethodType
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from timm.models import VisionTransformer, checkpoint_seq
|
8 |
+
|
9 |
+
from .vit_patch_generator import ViTPatchGenerator
|
10 |
+
|
11 |
+
|
12 |
+
def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
|
13 |
+
x = self.patch_generator(x)
|
14 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
15 |
+
x = checkpoint_seq(self.blocks, x)
|
16 |
+
else:
|
17 |
+
x = self.blocks(x)
|
18 |
+
x = self.norm(x)
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
def enable_cpe(model: nn.Module,
|
23 |
+
max_img_size: Union[int, Tuple[int, int]] = 1024,
|
24 |
+
num_cls_tokens: int = 1,
|
25 |
+
pos_dropout: float = 0.1,
|
26 |
+
register_multiple: int = 0,
|
27 |
+
):
|
28 |
+
if not isinstance(model, VisionTransformer):
|
29 |
+
raise ValueError("CPE only support for VisionTransformer models!")
|
30 |
+
|
31 |
+
patch_size = model.patch_embed.patch_size[0]
|
32 |
+
embed_dim = model.embed_dim
|
33 |
+
input_dims = model.patch_embed.img_size
|
34 |
+
normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity)
|
35 |
+
cls_token = model.cls_token is not None
|
36 |
+
|
37 |
+
max_img_size = int(round(max_img_size / patch_size) * patch_size)
|
38 |
+
|
39 |
+
patch_generator = ViTPatchGenerator(
|
40 |
+
patch_size=patch_size,
|
41 |
+
embed_dim=embed_dim,
|
42 |
+
input_dims=input_dims,
|
43 |
+
normalize_patches=normalize_patches,
|
44 |
+
cls_token=cls_token,
|
45 |
+
max_input_dims=max_img_size,
|
46 |
+
pos_dropout=pos_dropout,
|
47 |
+
num_cls_tokens=num_cls_tokens,
|
48 |
+
register_multiple=register_multiple,
|
49 |
+
)
|
50 |
+
|
51 |
+
model.patch_generator = patch_generator
|
52 |
+
model.patch_embed = None
|
53 |
+
model.cls_token = None
|
54 |
+
model.pos_embed = None
|
55 |
+
model.pos_drop = None
|
56 |
+
model.num_cls_tokens = num_cls_tokens
|
57 |
+
model.num_registers = patch_generator.num_registers
|
58 |
+
|
59 |
+
model.forward_features = MethodType(_forward_cpe, model)
|
hf_model.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from collections import namedtuple
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
from timm.models import VisionTransformer
|
18 |
+
import torch
|
19 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
20 |
+
|
21 |
+
|
22 |
+
from .model import create_model_from_args
|
23 |
+
from .input_conditioner import get_default_conditioner, InputConditioner
|
24 |
+
|
25 |
+
|
26 |
+
resource_map = {
|
27 |
+
'radio_v1': 'https://huggingface.co/nvidia/RADIO/raw/main/radio_v1.pth.tar'
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
class RADIOConfig(PretrainedConfig):
|
32 |
+
"""Pretrained Hugging Face configuration for RADIO models."""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
args: Optional[dict] = None,
|
37 |
+
version: Optional[str]="v1",
|
38 |
+
**kwargs,
|
39 |
+
):
|
40 |
+
self.args = args
|
41 |
+
self.version = version
|
42 |
+
super().__init__(**kwargs)
|
43 |
+
|
44 |
+
|
45 |
+
class RADIOModel(PreTrainedModel):
|
46 |
+
"""Pretrained Hugging Face model for RADIO."""
|
47 |
+
|
48 |
+
def __init__(self, config):
|
49 |
+
super().__init__(config)
|
50 |
+
|
51 |
+
RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
|
52 |
+
args = RADIOArgs(**config.args)
|
53 |
+
self.model = create_model_from_args(args)
|
54 |
+
|
55 |
+
self.input_conditioner: InputConditioner = get_default_conditioner()
|
56 |
+
|
57 |
+
#return RADIOModel(mod, conditioner, return_summary=return_summary, return_spatial_features=return_spatial_features)
|
58 |
+
|
59 |
+
def forward(self, x: torch.Tensor):
|
60 |
+
x = self.input_conditioner(x)
|
61 |
+
|
62 |
+
y = self.model.forward_features(x)
|
63 |
+
|
64 |
+
if isinstance(y, (list, tuple)):
|
65 |
+
summary, all_feat = y
|
66 |
+
elif isinstance(self.model, VisionTransformer):
|
67 |
+
patch_gen = getattr(self.model, 'patch_generator', None)
|
68 |
+
if patch_gen is not None:
|
69 |
+
summary = y[:, :patch_gen.num_cls_tokens].flatten(1)
|
70 |
+
all_feat = y[:, patch_gen.num_skip:]
|
71 |
+
elif self.model.global_pool == 'avg':
|
72 |
+
summary = y[:, self.model.num_prefix_tokens:].mean(dim=1)
|
73 |
+
all_feat = y
|
74 |
+
else:
|
75 |
+
summary = y[:, 0]
|
76 |
+
all_feat = y[:, 1:]
|
77 |
+
else:
|
78 |
+
raise ValueError("Unsupported model type")
|
79 |
+
|
80 |
+
if self.return_summary and self.return_spatial_features:
|
81 |
+
return summary, all_feat
|
82 |
+
elif self.return_summary:
|
83 |
+
return summary
|
84 |
+
return all_feat
|
input_conditioner.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
norm_t = Union[Tuple[float, float, float], torch.Tensor]
|
8 |
+
|
9 |
+
class InputConditioner(nn.Module):
|
10 |
+
def __init__(self,
|
11 |
+
input_scale: float,
|
12 |
+
norm_mean: norm_t,
|
13 |
+
norm_std: norm_t,
|
14 |
+
dtype: torch.dtype = torch.float32,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.dtype = dtype
|
19 |
+
|
20 |
+
# self.input_scale = input_scale
|
21 |
+
self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
|
22 |
+
self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
|
23 |
+
|
24 |
+
def forward(self, x: torch.Tensor):
|
25 |
+
# x = x * self.input_scale
|
26 |
+
y = (x - self.norm_mean) / self.norm_std
|
27 |
+
return y.to(self.dtype)
|
28 |
+
|
29 |
+
|
30 |
+
def get_default_conditioner():
|
31 |
+
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
32 |
+
|
33 |
+
return InputConditioner(
|
34 |
+
input_scale=1.0,
|
35 |
+
norm_mean=OPENAI_CLIP_MEAN,
|
36 |
+
norm_std=OPENAI_CLIP_STD,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
def _to_tensor(v: norm_t):
|
41 |
+
return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
|
model.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from timm.models import create_model
|
4 |
+
|
5 |
+
from .enable_cpe_support import enable_cpe
|
6 |
+
|
7 |
+
|
8 |
+
def create_model_from_args(args) -> nn.Module:
|
9 |
+
in_chans = 3
|
10 |
+
if args.in_chans is not None:
|
11 |
+
in_chans = args.in_chans
|
12 |
+
elif args.input_size is not None:
|
13 |
+
in_chans = args.input_size[0]
|
14 |
+
|
15 |
+
model = create_model(
|
16 |
+
args.model,
|
17 |
+
pretrained=args.pretrained,
|
18 |
+
in_chans=in_chans,
|
19 |
+
num_classes=args.num_classes,
|
20 |
+
drop_rate=args.drop,
|
21 |
+
drop_path_rate=args.drop_path,
|
22 |
+
drop_block_rate=args.drop_block,
|
23 |
+
global_pool=args.gp,
|
24 |
+
bn_momentum=args.bn_momentum,
|
25 |
+
bn_eps=args.bn_eps,
|
26 |
+
scriptable=args.torchscript,
|
27 |
+
checkpoint_path=args.initial_checkpoint,
|
28 |
+
**args.model_kwargs,
|
29 |
+
)
|
30 |
+
|
31 |
+
assert not args.cls_token_per_teacher or args.cpe_max_size is not None, "CPE must be enabled for multiple CLS tokens!"
|
32 |
+
|
33 |
+
if args.cpe_max_size is not None:
|
34 |
+
enable_cpe(model,
|
35 |
+
args.cpe_max_size,
|
36 |
+
num_cls_tokens=len(args.teachers) if args.cls_token_per_teacher else 1,
|
37 |
+
register_multiple=args.register_multiple,
|
38 |
+
)
|
39 |
+
|
40 |
+
return model
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:242360b04b7f78204b535ce8a96e28ef3316520d55be43e6873fd45696fb9d61
|
3 |
+
size 2662619441
|
vit_patch_generator.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Union, Tuple, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
from .cls_token import ClsToken
|
10 |
+
|
11 |
+
input_dim_t = Union[int, Tuple[int, int]]
|
12 |
+
|
13 |
+
try:
|
14 |
+
# raise ImportError()
|
15 |
+
from indirect_grid_sample import indirect_grid_sample
|
16 |
+
except ImportError:
|
17 |
+
indirect_grid_sample = None
|
18 |
+
|
19 |
+
class ViTPatchGenerator(nn.Module):
|
20 |
+
def __init__(self,
|
21 |
+
patch_size: int,
|
22 |
+
embed_dim: int,
|
23 |
+
input_dims: input_dim_t,
|
24 |
+
abs_pos: bool = True,
|
25 |
+
normalize_patches: bool = False,
|
26 |
+
cls_token: bool = False,
|
27 |
+
max_input_dims: Optional[input_dim_t] = None,
|
28 |
+
pos_dropout: float = 0.0,
|
29 |
+
return_pos_enc: bool = False,
|
30 |
+
num_cls_tokens: int = 1,
|
31 |
+
register_multiple: int = 0,
|
32 |
+
device=None, dtype=None,
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
if isinstance(input_dims, int):
|
37 |
+
input_dims = (input_dims, input_dims)
|
38 |
+
|
39 |
+
if max_input_dims is None:
|
40 |
+
max_input_dims = input_dims
|
41 |
+
if isinstance(max_input_dims, int):
|
42 |
+
max_input_dims = (max_input_dims, max_input_dims)
|
43 |
+
|
44 |
+
max_input_dims = tuple(
|
45 |
+
int(math.ceil(d / patch_size) * patch_size)
|
46 |
+
for d in max_input_dims
|
47 |
+
)
|
48 |
+
|
49 |
+
self.cpe_mode = max_input_dims != input_dims
|
50 |
+
self.pos_dropout = pos_dropout
|
51 |
+
self.return_pos_enc = return_pos_enc
|
52 |
+
|
53 |
+
factory = dict(device=device, dtype=dtype)
|
54 |
+
|
55 |
+
self.patch_size = patch_size
|
56 |
+
self.abs_pos = abs_pos
|
57 |
+
self.embed_dim = embed_dim
|
58 |
+
|
59 |
+
self.num_rows = max_input_dims[0] // patch_size
|
60 |
+
self.num_cols = max_input_dims[1] // patch_size
|
61 |
+
self.input_dims = tuple(d // patch_size for d in input_dims)
|
62 |
+
self.num_patches = self.num_rows * self.num_cols
|
63 |
+
self.max_input_dims = max_input_dims
|
64 |
+
|
65 |
+
self.im_to_patches = Im2Patches(patch_size)
|
66 |
+
self.embedder = ViTPatchLinear(patch_size, embed_dim, **factory)
|
67 |
+
|
68 |
+
if abs_pos:
|
69 |
+
scale = embed_dim ** -0.5
|
70 |
+
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim, **factory) * scale)
|
71 |
+
|
72 |
+
self.cls_token = ClsToken(
|
73 |
+
embed_dim,
|
74 |
+
num_tokens=num_cls_tokens,
|
75 |
+
enabled=cls_token,
|
76 |
+
register_multiple=register_multiple,
|
77 |
+
)
|
78 |
+
|
79 |
+
self.patch_normalizer = nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
82 |
+
patches = self.embed_patches(x)
|
83 |
+
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
|
84 |
+
patches = self.cls_token(patches)
|
85 |
+
patches = self.patch_normalizer(patches)
|
86 |
+
if self.return_pos_enc:
|
87 |
+
return patches, pos_enc
|
88 |
+
return patches
|
89 |
+
|
90 |
+
@property
|
91 |
+
def apply_cls_token(self):
|
92 |
+
return self.cls_token.enabled
|
93 |
+
|
94 |
+
@property
|
95 |
+
def num_cls_tokens(self):
|
96 |
+
return self.cls_token.num_tokens
|
97 |
+
|
98 |
+
@property
|
99 |
+
def num_registers(self):
|
100 |
+
return self.cls_token.num_registers
|
101 |
+
|
102 |
+
@property
|
103 |
+
def num_skip(self):
|
104 |
+
return self.num_cls_tokens + self.num_registers
|
105 |
+
|
106 |
+
def no_weight_decay(self):
|
107 |
+
return [
|
108 |
+
'pos_embed',
|
109 |
+
]
|
110 |
+
|
111 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
112 |
+
if self.abs_pos:
|
113 |
+
self._load_embed(state_dict[f'{prefix}pos_embed'], self.pos_embed)
|
114 |
+
|
115 |
+
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
|
116 |
+
if src_embed.shape != targ_embed.shape:
|
117 |
+
src_size = int(math.sqrt(src_embed.shape[1]))
|
118 |
+
|
119 |
+
assert src_size ** 2 == src_embed.shape[1], 'Unable to interpolate non-square embedding'
|
120 |
+
|
121 |
+
src_embed = rearrange(src_embed, 'b (h w) c -> b c h w', h=src_size, w=src_size)
|
122 |
+
src_embed = F.interpolate(src_embed, size=(self.num_rows, self.num_cols), mode='bicubic', align_corners=True, antialias=False)
|
123 |
+
src_embed = rearrange(src_embed, 'b c h w -> b (h w) c')
|
124 |
+
targ_embed.data.copy_(src_embed)
|
125 |
+
|
126 |
+
def _load_projection(self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor):
|
127 |
+
if src_proj_weight.shape != targ_proj_weight.shape:
|
128 |
+
src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))
|
129 |
+
|
130 |
+
assert (src_patch_size ** 2) * 3 == src_proj_weight.shape[1], 'Unable to interpolate non-square patch size'
|
131 |
+
|
132 |
+
src_proj_weight = rearrange(src_proj_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
|
133 |
+
src_proj_weight = F.interpolate(src_proj_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
|
134 |
+
src_proj_weight = rearrange(src_proj_weight, 'b c h w -> b (c h w)')
|
135 |
+
targ_proj_weight.data.copy_(src_proj_weight)
|
136 |
+
|
137 |
+
def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
|
138 |
+
patches = self.im_to_patches(x)
|
139 |
+
patches = self.embedder(patches)
|
140 |
+
return patches
|
141 |
+
|
142 |
+
def apply_pos_enc(self,
|
143 |
+
patches: torch.Tensor,
|
144 |
+
patch_idxs: Optional[torch.Tensor] = None,
|
145 |
+
input_size: Optional[Tuple[int, int]] = None,
|
146 |
+
) -> torch.Tensor:
|
147 |
+
if not self.abs_pos:
|
148 |
+
return patches
|
149 |
+
|
150 |
+
pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
|
151 |
+
|
152 |
+
if self.training and self.pos_dropout > 0:
|
153 |
+
keeps = torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device) > self.pos_dropout
|
154 |
+
pos_enc_drop = torch.where(keeps, pos_enc, 0)
|
155 |
+
else:
|
156 |
+
pos_enc_drop = pos_enc
|
157 |
+
|
158 |
+
return patches + pos_enc_drop, pos_enc
|
159 |
+
|
160 |
+
def get_pos_enc(self,
|
161 |
+
batch_size: int,
|
162 |
+
patch_idxs: Optional[torch.Tensor] = None,
|
163 |
+
input_size: Optional[Tuple[int, int]] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
if input_size is None:
|
166 |
+
input_dims = self.input_dims
|
167 |
+
else:
|
168 |
+
input_dims = tuple(d // self.patch_size for d in input_size)
|
169 |
+
|
170 |
+
pos_embed = self._get_pos_embeddings(batch_size, input_dims)
|
171 |
+
|
172 |
+
if patch_idxs is None:
|
173 |
+
return pos_embed
|
174 |
+
|
175 |
+
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
|
176 |
+
|
177 |
+
pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs)
|
178 |
+
return pos_embed
|
179 |
+
|
180 |
+
|
181 |
+
def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]):
|
182 |
+
if (self.num_rows, self.num_cols) == input_dims:
|
183 |
+
return self.pos_embed
|
184 |
+
|
185 |
+
pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(0, 3, 1, 2)
|
186 |
+
|
187 |
+
def window_select(pos_embed):
|
188 |
+
if input_dims[0] < pos_embed.shape[-2]:
|
189 |
+
pos_embed = pos_embed[..., :input_dims[0], :]
|
190 |
+
if input_dims[1] < pos_embed.shape[-1]:
|
191 |
+
pos_embed = pos_embed[..., :, :input_dims[1]]
|
192 |
+
return pos_embed
|
193 |
+
|
194 |
+
if self.cpe_mode:
|
195 |
+
if self.training:
|
196 |
+
min_scale = math.sqrt(0.1)
|
197 |
+
scale = torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale) + min_scale
|
198 |
+
aspect_min = math.log(3 / 4)
|
199 |
+
aspect_max = -aspect_min
|
200 |
+
aspect = torch.exp(torch.rand(batch_size, 1, 1, device=pos_embed.device) * (aspect_max - aspect_min) + aspect_min)
|
201 |
+
|
202 |
+
scale_x = scale * aspect
|
203 |
+
scale_y = scale * (1 / aspect)
|
204 |
+
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
|
205 |
+
|
206 |
+
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy)
|
207 |
+
|
208 |
+
lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[None, None].expand(batch_size, input_dims[0], -1)
|
209 |
+
lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[None, :, None].expand(batch_size, -1, input_dims[1])
|
210 |
+
|
211 |
+
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
|
212 |
+
|
213 |
+
grid_xy = lin_xy * scale_xy + pos_xy
|
214 |
+
|
215 |
+
# Convert to [-1, 1] range
|
216 |
+
grid_xy.mul_(2).sub_(1)
|
217 |
+
|
218 |
+
pos_embed = F.grid_sample(
|
219 |
+
pos_embed.expand(batch_size, -1, -1, -1),
|
220 |
+
grid=grid_xy,
|
221 |
+
mode='bilinear',
|
222 |
+
padding_mode='zeros',
|
223 |
+
align_corners=True,
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
# i_rows, i_cols = input_dims
|
227 |
+
# p_rows, p_cols = pos_embed.shape[2:]
|
228 |
+
# if i_rows <= p_rows and i_cols <= p_cols:
|
229 |
+
# left = (p_cols - i_cols) // 2
|
230 |
+
# top = (p_rows - i_rows) // 2
|
231 |
+
# pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]
|
232 |
+
# else:
|
233 |
+
max_dim = max(input_dims)
|
234 |
+
pos_embed = F.interpolate(pos_embed, size=(max_dim, max_dim), align_corners=True, mode='bilinear')
|
235 |
+
|
236 |
+
pos_embed = window_select(pos_embed)
|
237 |
+
else:
|
238 |
+
pos_embed = window_select(pos_embed)
|
239 |
+
|
240 |
+
if pos_embed.shape[-2:] != input_dims:
|
241 |
+
pos_embed = F.interpolate(pos_embed, size=input_dims, align_corners=True, mode='bilinear')
|
242 |
+
|
243 |
+
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
|
244 |
+
|
245 |
+
return pos_embed
|
246 |
+
|
247 |
+
|
248 |
+
class Im2Patches(nn.Module):
|
249 |
+
def __init__(self, patch_size: int):
|
250 |
+
super().__init__()
|
251 |
+
self.patch_size = patch_size
|
252 |
+
|
253 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
254 |
+
if self.patch_size == 1:
|
255 |
+
patches = x.flatten(2)
|
256 |
+
patches = patches.permute(0, 2, 1)
|
257 |
+
return patches
|
258 |
+
|
259 |
+
py = x.shape[-2] // self.patch_size
|
260 |
+
px = x.shape[-1] // self.patch_size
|
261 |
+
patches = rearrange(x, 'b c (py yy) (px xx) -> b (py px) (c yy xx)',
|
262 |
+
py=py, yy=self.patch_size,
|
263 |
+
px=px, xx=self.patch_size,
|
264 |
+
)
|
265 |
+
return patches
|
266 |
+
|
267 |
+
|
268 |
+
class ViTPatchLinear(nn.Linear):
|
269 |
+
def __init__(self, patch_size: int, embed_dim: int, **factory):
|
270 |
+
super().__init__(
|
271 |
+
3 * (patch_size ** 2),
|
272 |
+
embed_dim,
|
273 |
+
bias=False,
|
274 |
+
**factory
|
275 |
+
)
|
276 |
+
self.patch_size = patch_size
|
277 |
+
|
278 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
279 |
+
if self.bias is not None:
|
280 |
+
self.bias.data.copy_(state_dict[f'{prefix}bias'])
|
281 |
+
|
282 |
+
chk_weight = state_dict[f'{prefix}weight']
|
283 |
+
if chk_weight.shape != self.weight.shape:
|
284 |
+
src_patch_size = int(math.sqrt(chk_weight.shape[1] // 3))
|
285 |
+
|
286 |
+
assert (src_patch_size ** 2) * 3 == chk_weight.shape[1], 'Unable to interpolate non-square patch size'
|
287 |
+
|
288 |
+
chk_weight = rearrange(chk_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
|
289 |
+
chk_weight = F.interpolate(chk_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
|
290 |
+
chk_weight = rearrange(chk_weight, 'b c h w -> b (c h w)')
|
291 |
+
self.weight.data.copy_(chk_weight)
|