Brandon May commited on
Commit
4da7cab
1 Parent(s): 6bffaa0

Upload model

Browse files
Files changed (4) hide show
  1. README.md +199 -5
  2. config.json +40 -0
  3. model.safetensors +3 -0
  4. theia_model.py +1495 -0
README.md CHANGED
@@ -1,5 +1,199 @@
1
- ---
2
- license: other
3
- license_name: theaiinstitute-license
4
- license_link: https://github.com/bdaiinstitute/theia/blob/main/LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TheiaModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "theia_model.TheiaConfig",
7
+ "AutoModel": "theia_model.TheiaModel"
8
+ },
9
+ "backbone": "facebook/deit-tiny-patch16-224",
10
+ "feature_neck": false,
11
+ "feature_neck_hidden_dim": 256,
12
+ "feature_neck_nonlinearity": "relu",
13
+ "feature_reduce_method": null,
14
+ "forward_neck": false,
15
+ "image_size": 224,
16
+ "num_reg_tokens": 0,
17
+ "pretrained": false,
18
+ "target_feature_sizes": {
19
+ "facebook/dinov2-large": [
20
+ 1024,
21
+ 16,
22
+ 16
23
+ ],
24
+ "google/vit-huge-patch14-224-in21k": [
25
+ 1280,
26
+ 16,
27
+ 16
28
+ ],
29
+ "openai/clip-vit-large-patch14": [
30
+ 1024,
31
+ 16,
32
+ 16
33
+ ]
34
+ },
35
+ "target_loss_weights": null,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.41.2",
38
+ "translator_hidden_size_factor": 1.0,
39
+ "translator_type": "lconv"
40
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dd195b67e7e7536455879b5d9eaea35cf85ce417ccb94482c1fd37ca02afd05
3
+ size 40187456
theia_model.py ADDED
@@ -0,0 +1,1495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ import math
4
+ from itertools import chain
5
+ from typing import Any, Optional
6
+ from omegaconf import OmegaConf
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn.functional import interpolate
12
+ from einops.layers.torch import Rearrange
13
+
14
+ from transformers import PretrainedConfig, PreTrainedModel
15
+ from transformers import AutoConfig, AutoModel, AutoProcessor, AutoImageProcessor
16
+ from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTModel
17
+
18
+ def handle_feature_output(
19
+ x: torch.Tensor, feature_reduce_method: Optional[str] = None, num_discard_tokens: int = 0
20
+ ) -> torch.Tensor:
21
+ """Handle feature output from transformer.
22
+
23
+ Args:
24
+ x (torch.Tensor): input feature to be handled. shape is
25
+ [B, 1+H*W+N, C] if including both CLS and register tokens.
26
+ [B, 1+H*W, C] for standard model (N=0).
27
+ [B, H*W, C] for model without CLS.
28
+ feature_reduce_method (Optional[str]): method to select token. Options:
29
+ - `mean_pooling`: average over spatial tokens (non CLS tokens), output shape = [B, C].
30
+ - `max_pooling`: max over spatial tokens, output shape = [B, C].
31
+ - `cls`: return CLS token only, output shape = [B, C].
32
+ - `identity`: return the feature without touching it, output shape = input shape.
33
+ - `None`: return spatial tokens, output shape = [B, H*W, C] (assuming input is [B, 1+H*W, C]).
34
+ suppose raw feature is in shape [B, 1+H*W, C], `1` corresponds to CLS token.
35
+ num_discard_tokens (int):
36
+ number of tokens to be discarded. Assuming they are at the end of the sequence.
37
+ Returns:
38
+ torch.Tensor: selected feature tokens.
39
+ """
40
+
41
+ match feature_reduce_method:
42
+ case "mean_pooling":
43
+ return torch.mean(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) # [B, C]
44
+ case "max_pooling":
45
+ return torch.amax(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) # [B, C]
46
+ case "cls":
47
+ return x[:, 0] # [B, C]
48
+ case "identity":
49
+ return x
50
+ case None:
51
+ return x[:, 1 : x.size(1) - num_discard_tokens]
52
+ case _:
53
+ raise NotImplementedError(f"feature_reduce_method {feature_reduce_method} it not implemented.")
54
+
55
+
56
+ # Modified from huggingface transformers ViTEmbeddings
57
+ # Original Copyright 2021 The HuggingFace Inc. team. All rights reserved.
58
+ #
59
+ # Licensed under the Apache License, Version 2.0 (the "License");
60
+ # you may not use this file except in compliance with the License.
61
+ # You may obtain a copy of the License at
62
+ #
63
+ # http://www.apache.org/licenses/LICENSE-2.0
64
+ #
65
+ # Unless required by applicable law or agreed to in writing, software
66
+ # distributed under the License is distributed on an "AS IS" BASIS,
67
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
68
+ # See the License for the specific language governing permissions and
69
+ # limitations under the License.
70
+ class ViTEmbeddingsNoCLS(ViTEmbeddings):
71
+ """ViT Embedding Module without CLS token."""
72
+
73
+ def __init__(self, config: AutoConfig, use_mask_token: bool = False):
74
+ """Initialization.
75
+
76
+ Args:
77
+ config (AutoConfig): config for ViT.
78
+ use_mask_token (bool, optional): whether to use mask token. Defaults to False.
79
+ """
80
+ super(ViTEmbeddingsNoCLS, self).__init__(config, use_mask_token=use_mask_token)
81
+ self.cls_token = None
82
+
83
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
84
+ """
85
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
86
+ resolution images.
87
+
88
+ Source:
89
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
90
+ """
91
+
92
+ num_patches = embeddings.shape[1]
93
+ num_positions = self.position_embeddings.shape[1] - 1
94
+ if num_patches == num_positions and height == width:
95
+ return self.position_embeddings
96
+ patch_pos_embed = self.position_embeddings[:, 1:]
97
+ dim = embeddings.shape[-1]
98
+ h0 = height // self.config.patch_size
99
+ w0 = width // self.config.patch_size
100
+ # we add a small number to avoid floating point error in the interpolation
101
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
102
+ h0, w0 = h0 + 0.1, w0 + 0.1
103
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
104
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
105
+ patch_pos_embed = nn.functional.interpolate(
106
+ patch_pos_embed,
107
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
108
+ mode="bicubic",
109
+ align_corners=False,
110
+ )
111
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
112
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
113
+ return patch_pos_embed
114
+
115
+ def forward(
116
+ self,
117
+ pixel_values: torch.Tensor,
118
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
119
+ interpolate_pos_encoding: bool = False,
120
+ ) -> torch.Tensor:
121
+ batch_size, num_channels, height, width = pixel_values.shape
122
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
123
+
124
+ if bool_masked_pos is not None:
125
+ seq_length = embeddings.shape[1]
126
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
127
+ # replace the masked visual tokens by mask_tokens
128
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
129
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
130
+
131
+ # add positional encoding to each token
132
+ if interpolate_pos_encoding:
133
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
134
+ else:
135
+ embeddings = embeddings + self.position_embeddings[:, 1:]
136
+
137
+ embeddings = self.dropout(embeddings)
138
+
139
+ return embeddings
140
+
141
+
142
+ # modified from huggingface transformers ViTModel
143
+ class ViTModelNoCLS(ViTModel):
144
+ """ViT Model without CLS token."""
145
+
146
+ def __init__(self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
147
+ super(ViTModelNoCLS, self).__init__(config, add_pooling_layer, use_mask_token)
148
+ self.embeddings = ViTEmbeddingsNoCLS(config, use_mask_token=use_mask_token)
149
+ self.no_cls = True
150
+
151
+ def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
152
+ """Initialize the weights"""
153
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
154
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
155
+ # `trunc_normal_cpu` not implemented in `half` issues
156
+ module.weight.data = nn.init.trunc_normal_(
157
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
158
+ ).to(module.weight.dtype)
159
+ if module.bias is not None:
160
+ module.bias.data.zero_()
161
+ elif isinstance(module, nn.LayerNorm):
162
+ module.bias.data.zero_()
163
+ module.weight.data.fill_(1.0)
164
+ elif isinstance(module, ViTEmbeddings):
165
+ module.position_embeddings.data = nn.init.trunc_normal_(
166
+ module.position_embeddings.data.to(torch.float32),
167
+ mean=0.0,
168
+ std=self.config.initializer_range,
169
+ ).to(module.position_embeddings.dtype)
170
+
171
+
172
+ # modified from huggingface transformers ViTEmbeddings
173
+ class ViTEmbeddingsReg(ViTEmbeddings):
174
+ """
175
+ ViT Embedding Module with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1
176
+ """
177
+
178
+ def __init__(self, config: AutoConfig, use_mask_token: bool = False, num_reg_tokens: int = 7):
179
+ super(ViTEmbeddingsReg, self).__init__(config, use_mask_token=use_mask_token)
180
+ self.reg_token = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size))
181
+ self.num_reg_tokens = num_reg_tokens
182
+ self.reg_pos_embed = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size))
183
+
184
+ self.reg_pos_embed.data = nn.init.trunc_normal_(
185
+ self.reg_pos_embed.data.to(torch.float32),
186
+ mean=0.0,
187
+ std=self.config.initializer_range,
188
+ ).to(self.reg_pos_embed.dtype)
189
+
190
+ self.reg_token.data = nn.init.trunc_normal_(
191
+ self.reg_token.data.to(torch.float32),
192
+ mean=0.0,
193
+ std=self.config.initializer_range,
194
+ ).to(self.reg_token.dtype)
195
+
196
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
197
+ """
198
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
199
+ resolution images.
200
+
201
+ Source:
202
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
203
+ """
204
+
205
+ num_patches = embeddings.shape[1] - 1 - self.num_reg_tokens
206
+ num_positions = self.position_embeddings.shape[1] - 1
207
+ if num_patches == num_positions and height == width:
208
+ return self.position_embeddings
209
+ class_pos_embed = self.position_embeddings[:, 0]
210
+ patch_pos_embed = self.position_embeddings[:, 1:]
211
+ reg_pos_embed = self.reg_pos_embed
212
+ dim = embeddings.shape[-1]
213
+ h0 = height // self.config.patch_size
214
+ w0 = width // self.config.patch_size
215
+ # we add a small number to avoid floating point error in the interpolation
216
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
217
+ h0, w0 = h0 + 0.1, w0 + 0.1
218
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
219
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
220
+ patch_pos_embed = nn.functional.interpolate(
221
+ patch_pos_embed,
222
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
223
+ mode="bicubic",
224
+ align_corners=False,
225
+ )
226
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
227
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
228
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, reg_pos_embed), dim=1)
229
+
230
+ def forward(
231
+ self,
232
+ pixel_values: torch.Tensor,
233
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
234
+ interpolate_pos_encoding: bool = False,
235
+ ) -> torch.Tensor:
236
+ batch_size, num_channels, height, width = pixel_values.shape
237
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
238
+
239
+ if bool_masked_pos is not None:
240
+ seq_length = embeddings.shape[1]
241
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
242
+ # replace the masked visual tokens by mask_tokens
243
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
244
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
245
+
246
+ # add the [CLS] token to the embedded patch tokens
247
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
248
+ reg_tokens = self.reg_token.expand(batch_size, -1, -1)
249
+ embeddings = torch.cat((cls_tokens, embeddings, reg_tokens), dim=1)
250
+
251
+ # add positional encoding to each token
252
+ if interpolate_pos_encoding:
253
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
254
+ else:
255
+ embeddings = embeddings + torch.cat([self.position_embeddings, self.reg_pos_embed], dim=1)
256
+
257
+ embeddings = self.dropout(embeddings)
258
+
259
+ return embeddings
260
+
261
+
262
+ # modified from huggingface transformers ViTModel
263
+ class ViTModelReg(ViTModel):
264
+ """ViT Model with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1"""
265
+
266
+ def __init__(
267
+ self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, num_reg_tokens: int = 7
268
+ ):
269
+ super(ViTModelReg, self).__init__(config, add_pooling_layer, use_mask_token)
270
+ self.embeddings = ViTEmbeddingsReg(config, use_mask_token=use_mask_token, num_reg_tokens=num_reg_tokens)
271
+ self.num_reg_tokens = num_reg_tokens
272
+
273
+ def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
274
+ """Initialize the weights"""
275
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
276
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
277
+ # `trunc_normal_cpu` not implemented in `half` issues
278
+ module.weight.data = nn.init.trunc_normal_(
279
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
280
+ ).to(module.weight.dtype)
281
+ if module.bias is not None:
282
+ module.bias.data.zero_()
283
+ elif isinstance(module, nn.LayerNorm):
284
+ module.bias.data.zero_()
285
+ module.weight.data.fill_(1.0)
286
+ elif isinstance(module, ViTEmbeddings):
287
+ module.position_embeddings.data = nn.init.trunc_normal_(
288
+ module.position_embeddings.data.to(torch.float32),
289
+ mean=0.0,
290
+ std=self.config.initializer_range,
291
+ ).to(module.position_embeddings.dtype)
292
+ module.cls_token.data = nn.init.trunc_normal_(
293
+ module.cls_token.data.to(torch.float32),
294
+ mean=0.0,
295
+ std=self.config.initializer_range,
296
+ ).to(module.cls_token.dtype)
297
+
298
+
299
+ class DeiT(nn.Module):
300
+ """DeiT model.
301
+
302
+ Paper: Training data-efficient image transformers & distillation through attention
303
+ https://arxiv.org/abs/2012.12877
304
+ Huggingface Reference: https://huggingface.co/docs/transformers/en/model_doc/deit
305
+
306
+ Attributes:
307
+ model_name (str): name of the model.
308
+ pretrained (bool): whether to use pretrained weights.
309
+ """
310
+
311
+ def __init__(
312
+ self,
313
+ model_name: str = "facebook/deit-small-patch16-224",
314
+ pretrained: bool = False,
315
+ image_size: int = 224,
316
+ ):
317
+ super().__init__()
318
+ self.image_size = image_size
319
+ model = AutoModel.from_pretrained(model_name)
320
+ if pretrained:
321
+ self.model = model
322
+ else:
323
+ deit_config = model.config
324
+ self.model = AutoModel.from_config(deit_config)
325
+ del model
326
+
327
+ self.model.pooler = nn.Identity()
328
+
329
+ self.processor = AutoProcessor.from_pretrained(model_name)
330
+
331
+ def get_feature_size(
332
+ self,
333
+ keep_spatial: bool = False,
334
+ return_torch_size: bool = False,
335
+ ) -> torch.Size | tuple[int, ...]:
336
+ """Get the size of the feature.
337
+
338
+ Args:
339
+ keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
340
+ return_torch_size (bool): if true, return torch.Size type. Defaults to False.
341
+
342
+ Returns:
343
+ torch.Size | tuple[int, ...]: returned feature shape.
344
+ """
345
+ with torch.inference_mode():
346
+ image_size = (224, 224)
347
+ x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
348
+ y = self.forward(x)[:, 1:] # for getting feature size, discard cls token
349
+ size = y.size()[1:][::-1]
350
+ if keep_spatial:
351
+ assert math.isqrt(size[-1])
352
+ h = w = int(math.sqrt(size[-1]))
353
+ size = (size[0], h, w)
354
+ if return_torch_size:
355
+ size = torch.Size(size)
356
+ return size
357
+
358
+ def forward(
359
+ self,
360
+ x: torch.Tensor,
361
+ do_resize: bool = True,
362
+ interpolate_pos_encoding: Optional[bool] = None,
363
+ do_rescale: bool = True,
364
+ do_normalize: bool = True,
365
+ ) -> torch.Tensor:
366
+ """Forward pass of the model
367
+
368
+ Args:
369
+ x (torch.Tensor): model input.
370
+
371
+ - arguments for self.processor. Details can be find at
372
+ https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
373
+ do_resize (bool): if do resizing in processor. Defaults to True.
374
+ interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
375
+ do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
376
+ do_normalize (bool): if do normalize in processor. Defaults to True.
377
+
378
+ Returns:
379
+ torch.Tensor: model output.
380
+ """
381
+ input = self.processor(
382
+ x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
383
+ ).to(self.model.device)
384
+ y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
385
+ return y.last_hidden_state
386
+
387
+
388
+ class DeiTNoCLS(nn.Module):
389
+ """Modified DeiT model without CLS token."""
390
+
391
+ def __init__(
392
+ self, model_name: str = "nocls-facebook/deit-small-patch16-224", pretrained: bool = False, image_size: int = 224
393
+ ):
394
+ super().__init__()
395
+ self.image_size = image_size
396
+ pretrained_model_name = model_name.replace("nocls-", "")
397
+ deit_config = AutoConfig.from_pretrained(pretrained_model_name)
398
+ self.model = ViTModelNoCLS(deit_config)
399
+ if pretrained:
400
+ pretrained_model = AutoModel.from_pretrained(pretrained_model_name)
401
+ pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()}
402
+ self.load_state_dict(pretrained_dict, strict=False)
403
+ del pretrained_model, pretrained_dict
404
+
405
+ self.model.pooler = nn.Identity()
406
+ self.processor = AutoProcessor.from_pretrained(pretrained_model_name)
407
+ self.no_cls = True
408
+
409
+ def get_feature_size(
410
+ self,
411
+ keep_spatial: bool = False,
412
+ return_torch_size: bool = False,
413
+ ) -> torch.Size | tuple[int, ...]:
414
+ """Get the size of the feature.
415
+
416
+ Args:
417
+ keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
418
+ return_torch_size (bool): if true, return torch.Size type. Defaults to False.
419
+
420
+ Returns:
421
+ torch.Size | tuple[int, ...]: returned feature shape.
422
+ """
423
+ with torch.inference_mode():
424
+ image_size = (self.image_size, self.image_size)
425
+ x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
426
+ y = self.forward(x)
427
+ size = y.size()[1:][::-1]
428
+ if keep_spatial:
429
+ assert math.isqrt(size[-1])
430
+ h = w = int(math.sqrt(size[-1]))
431
+ size = (size[0], h, w)
432
+ if return_torch_size:
433
+ size = torch.Size(size)
434
+ return size
435
+
436
+ def forward(
437
+ self,
438
+ x: torch.Tensor,
439
+ do_resize: bool = True,
440
+ interpolate_pos_encoding: Optional[bool] = None,
441
+ do_rescale: bool = True,
442
+ do_normalize: bool = True,
443
+ ) -> torch.Tensor:
444
+ """Forward pass of the model
445
+
446
+ Args:
447
+ x (torch.Tensor): model input.
448
+
449
+ - arguments for self.processor. Details can be find at
450
+ https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
451
+ do_resize (bool): if do resizing in processor. Defaults to True.
452
+ do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
453
+ do_normalize (bool): if do normalize in processor. Defaults to True.
454
+
455
+ - argument for forward
456
+ interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
457
+
458
+ Returns:
459
+ torch.Tensor: model output.
460
+ """
461
+ input = self.processor(
462
+ x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
463
+ ).to(self.model.device)
464
+ y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
465
+ return y.last_hidden_state
466
+
467
+
468
+ class DeiTReg(nn.Module):
469
+ """Modified DeiT model with register tokens."""
470
+
471
+ def __init__(
472
+ self,
473
+ model_name: str = "reg-facebook/deit-small-patch16-224",
474
+ pretrained: bool = False,
475
+ image_size: int = 224,
476
+ num_reg_tokens: int = 7,
477
+ ):
478
+ super().__init__()
479
+ self.image_size = image_size
480
+ pretrained_model_name = model_name.replace("reg-", "")
481
+ deit_config = AutoConfig.from_pretrained(pretrained_model_name)
482
+ self.model = ViTModelReg(deit_config, num_reg_tokens=num_reg_tokens)
483
+ if pretrained:
484
+ pretrained_model = AutoModel.from_pretrained(pretrained_model_name)
485
+ pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()}
486
+ self.load_state_dict(pretrained_dict, strict=False)
487
+ del pretrained_model, pretrained_dict
488
+
489
+ self.model.pooler = nn.Identity()
490
+ self.processor = AutoProcessor.from_pretrained(pretrained_model_name)
491
+ self.num_reg_tokens = num_reg_tokens
492
+
493
+ def get_feature_size(
494
+ self,
495
+ keep_spatial: bool = False,
496
+ return_torch_size: bool = False,
497
+ ) -> torch.Size | tuple[int, ...]:
498
+ """Get the size of the feature.
499
+
500
+ Args:
501
+ keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
502
+ return_torch_size (bool): if true, return torch.Size type. Defaults to False.
503
+
504
+ Returns:
505
+ torch.Size | tuple[int, ...]: returned feature shape.
506
+ """
507
+ with torch.inference_mode():
508
+ image_size = (self.image_size, self.image_size)
509
+ x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
510
+ y = self.forward(x)[:, 1 : -self.num_reg_tokens]
511
+ size = y.size()[1:][::-1]
512
+ if keep_spatial:
513
+ assert math.isqrt(size[-1])
514
+ h = w = int(math.sqrt(size[-1]))
515
+ size = (size[0], h, w)
516
+ if return_torch_size:
517
+ size = torch.Size(size)
518
+ return size
519
+
520
+ def forward(
521
+ self,
522
+ x: torch.Tensor,
523
+ do_resize: bool = True,
524
+ interpolate_pos_encoding: Optional[bool] = None,
525
+ do_rescale: bool = True,
526
+ do_normalize: bool = True,
527
+ ) -> torch.Tensor:
528
+ """Forward pass of the model
529
+
530
+ Args:
531
+ x (torch.Tensor): model input.
532
+
533
+ - arguments for self.processor. Details can be find at
534
+ https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
535
+ do_resize (bool): if do resizing in processor. Defaults to True.
536
+ interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
537
+ do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
538
+ do_normalize (bool): if do normalize in processor. Defaults to True.
539
+
540
+ Returns:
541
+ torch.Tensor: model output.
542
+ """
543
+ input = self.processor(
544
+ x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
545
+ ).to(self.model.device)
546
+ y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
547
+ return y.last_hidden_state
548
+
549
+
550
+ def build_backbone(model_name: str, pretrained: bool = False, image_size: int = 224, **kwargs: Any) -> nn.Module:
551
+ """Build the backbone visual encoder of robot vision foundation model.
552
+
553
+ Args:
554
+ model_name (str): name of the model.
555
+ pretrained (bool): whether to use pretrained weights. Defaults to False.
556
+ image_size (int): size of the image. Assume a square image. Defaults to 224
557
+ kwargs (Any): any kwargs specific to some models. For example,
558
+ `num_reg_tokens` for `DeiTReg` when `"reg"` in `model_name`
559
+
560
+ Returns:
561
+ nn.Module: backbone network.
562
+ """
563
+ if "reg" in model_name:
564
+ return DeiTReg(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs)
565
+ elif "nocls" in model_name:
566
+ return DeiTNoCLS(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs)
567
+ elif "deit" in model_name:
568
+ return DeiT(model_name=model_name, pretrained=pretrained, image_size=image_size)
569
+ else:
570
+ raise NotImplementedError(f"Requested {model_name} is not implemented.")
571
+
572
+ class Interpolation(nn.Module):
573
+ """Interpolation nn.Module wrap for nn.functional.interpolate.
574
+
575
+ Attributes:
576
+ target_size (tuple[int, int] | torch.Size): target spatial size of this interpolation.
577
+ """
578
+
579
+ def __init__(self, target_size: tuple[int, int] | torch.Size) -> None:
580
+ super().__init__()
581
+ self.target_size = target_size
582
+
583
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
584
+ """Very simple forward pass to call interpolate()."""
585
+ return interpolate(x, self.target_size)
586
+
587
+
588
+ class LinearAdapterHead(nn.Module):
589
+ """Adapter head contains a single linear layer."""
590
+ def __init__(
591
+ self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size
592
+ ):
593
+ """Initialization function for LinearAdapterHead.
594
+ Args:
595
+ source_size (tuple[int, ...] | torch.Size): the size of the source feature.
596
+ target_size (tuple[int, ...] | torch.Size): the size of the target feature.
597
+ num_layer (int): number of MLP layers (One linear layer if num_layer = 1).
598
+ """
599
+ super().__init__()
600
+
601
+ self.source_size = source_size
602
+ self.target_size = target_size
603
+
604
+ source_channel_size = self.source_size[0]
605
+ target_channel_size = self.target_size[0]
606
+
607
+ self.adapter = nn.Sequential(
608
+ nn.Linear(source_channel_size, target_channel_size),
609
+ )
610
+
611
+ def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
612
+ """Forward pass for the adapter. """
613
+ assert backbone_no_cls == False
614
+ # x: [B, (1+H*W), C]
615
+ # LinearAdapterHead is used only when there is cls token in the backbone.
616
+ x = x[:, 0]
617
+ x = self.adapter(x)
618
+ return x # [B, (H*W), C]
619
+
620
+
621
+ class MLPAdapterHead(nn.Module):
622
+ """MLP Adapter module.
623
+
624
+ Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t].
625
+ Will first do interpolation to match the spatial size [H_t, W_t],
626
+ followed by MLP to project to the target channel dimension [C_t].
627
+
628
+ Attributes:
629
+ source_size (tuple[int, ...] | torch.Size): the size of the source feature. [C, H, W]
630
+ target_size (tuple[int, ...] | torch.Size): the size of the target feature. [C, H, W]
631
+ adapter (nn.Module): the adapter module.
632
+ interpolation (nn.Module): interpolation to adjust sizes before MLP.
633
+ """
634
+
635
+ def __init__(
636
+ self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size, num_layer: int
637
+ ):
638
+ """Initialization function for MLPAdapter.
639
+
640
+ Args:
641
+ source_size (tuple[int, ...] | torch.Size): the size of the source feature.
642
+ target_size (tuple[int, ...] | torch.Size): the size of the target feature.
643
+ num_layer (int): number of MLP layers (One linear layer if num_layer = 1).
644
+ """
645
+ super().__init__()
646
+ assert num_layer >= 1, f"`num_layer` in {self._get_name()} should >= 1. Got {num_layer}"
647
+
648
+ self.source_size = source_size
649
+ self.target_size = target_size
650
+
651
+ source_channel_size = self.source_size[0]
652
+ target_channel_size = self.target_size[0]
653
+
654
+ self.interpolation = nn.Sequential(
655
+ nn.Identity(),
656
+ )
657
+ if self.source_size[1] != self.target_size[1]:
658
+ self.interpolation = nn.Sequential(
659
+ Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
660
+ Interpolation(self.target_size[1:]),
661
+ Rearrange("b c h w-> b (h w) c"),
662
+ )
663
+
664
+ if num_layer == 1:
665
+ self.adapter = nn.Sequential(
666
+ nn.Linear(source_channel_size, target_channel_size),
667
+ )
668
+ elif num_layer >= 2:
669
+ hidden_dim = source_channel_size * 2
670
+ self.adapter = nn.Sequential(
671
+ nn.Linear(source_channel_size, hidden_dim),
672
+ *list(
673
+ chain.from_iterable([[nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)] for _ in range(num_layer - 2)])
674
+ ),
675
+ nn.ReLU(),
676
+ nn.Linear(hidden_dim, target_channel_size),
677
+ )
678
+
679
+ def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
680
+ """Forward pass for the adapter. First interpolation then MLP."""
681
+ # x: [B, (1)+H*W, C]
682
+ if not backbone_no_cls:
683
+ x = x[:, 1:]
684
+ # x: [B, (H*W), C]
685
+ x = self.interpolation(x)
686
+ x = self.adapter(x)
687
+ return x # [B, (H*W), C]
688
+
689
+
690
+ class ConvAdapterHead(nn.Module):
691
+ """Convolutional Adapter module.
692
+
693
+ Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t].
694
+ Uses CNN to map channel and spatial sizes jointly.
695
+ Note: only work for (16, 16), (any, any), any <= 14, and (64, 64) spatial sizes for now.
696
+
697
+ Attributes:
698
+ source_size (tuple[int, ...] | torch.Size): the size of the source feature.
699
+ target_size (tuple[int, ...] | torch.Size): the size of the target feature.
700
+ adapter (nn.Module): the adapter module.
701
+ interpolation (nn.Module): interpolation to adjust sizes before MLP.
702
+ """
703
+
704
+ def __init__(
705
+ self,
706
+ source_size: tuple[int, ...] | torch.Size,
707
+ target_size: tuple[int, ...] | torch.Size,
708
+ ):
709
+ """Initialization function for ConvAdapter.
710
+
711
+ Args:
712
+ source_size (tuple[int, ...] | torch.Size): the size of the source feature.
713
+ target_size (tuple[int, ...] | torch.Size): the size of the target feature.
714
+ """
715
+ super().__init__()
716
+ self.source_size = source_size
717
+ self.target_size = target_size
718
+
719
+ hidden_dim = self.source_size[0] * 2
720
+ source_channel_size = self.source_size[0]
721
+ target_channel_size = self.target_size[0]
722
+
723
+ if self.source_size[1] < 12:
724
+ raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.")
725
+ elif self.source_size[1] < 16: # pad (any, any), any <= 14 to (16, 16)
726
+ self.pad = nn.Sequential(
727
+ Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
728
+ nn.ConvTranspose2d(
729
+ source_channel_size,
730
+ source_channel_size,
731
+ kernel_size=3,
732
+ stride=1,
733
+ output_padding=14 - self.source_size[1],
734
+ ),
735
+ )
736
+ self.source_size = (self.source_size[0], 16, 16)
737
+ elif self.source_size[1] == 16 or self.source_size[1] == 64: # do nothing for (16, 16) and (64, 64)
738
+ self.pad = nn.Sequential(
739
+ Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
740
+ )
741
+ else:
742
+ raise NotImplementedError("feature spatial size (>=16x16) other than 16x16 and 64x64 is not supported.")
743
+
744
+ if self.source_size[1] < self.target_size[1]: # (16, 16) / (14, 14) to (64, 64)
745
+ self.adapter = nn.Sequential(
746
+ nn.LayerNorm(self.source_size),
747
+ nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 31
748
+ nn.ReLU(),
749
+ nn.LayerNorm([hidden_dim, 31, 31]),
750
+ nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), # 64
751
+ nn.ReLU(),
752
+ nn.LayerNorm([hidden_dim, 64, 64]),
753
+ nn.ConvTranspose2d(hidden_dim, target_channel_size, kernel_size=3, stride=1, padding=1), # 64
754
+ Rearrange("b c h w-> b (h w) c"),
755
+ )
756
+ elif self.source_size[1] == self.target_size[1]: # (16, 16) to (16, 16)
757
+ self.adapter = nn.Sequential(
758
+ nn.LayerNorm(self.source_size),
759
+ nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), # 16
760
+ nn.ReLU(),
761
+ nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
762
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), # 16
763
+ nn.ReLU(),
764
+ nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
765
+ nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), # 16
766
+ Rearrange("b c h w-> b (h w) c"),
767
+ )
768
+ else: # (64, 64) to (16, 16)
769
+ self.adapter = nn.Sequential(
770
+ nn.LayerNorm(self.source_size),
771
+ nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 32
772
+ nn.ReLU(),
773
+ nn.LayerNorm([hidden_dim, 32, 32]),
774
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), # 16
775
+ nn.ReLU(),
776
+ nn.LayerNorm([hidden_dim, 16, 16]),
777
+ nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), # 16
778
+ Rearrange("b c h w-> b (h w) c"),
779
+ )
780
+
781
+ def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
782
+ """Forward pass for ConvAdapter"""
783
+ # x: [B, (1)+H*W, C]
784
+ if not backbone_no_cls:
785
+ x = x[:, 1:]
786
+ # x: [B, H*W, C]
787
+ x = self.pad(x)
788
+ x = self.adapter(x)
789
+ return x # B, (H*W), C
790
+
791
+
792
+ class LightConvAdapterHead(nn.Module):
793
+ """Light Convolutional Adapter module.
794
+
795
+ Transforms features from source size in [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t].
796
+ Uses CNN to map channel and spatial sizes jointly.
797
+ Note: only work for source sizes (H_s, W_s): (16, 16), (any, any), 12 <= any <= 14,
798
+ and target sizes (H_t, W_t): (16, 16) and (64, 64) for now.
799
+
800
+ Attributes:
801
+ source_size (tuple[int, ...] | torch.Size): the size of the source feature,
802
+ channel first (C, H, W).
803
+ target_size (tuple[int, ...] | torch.Size): the size of the target feature,
804
+ channel first (C, H, W).
805
+ adapter (nn.Module): the adapter module.
806
+ interpolation (nn.Module): interpolation to adjust sizes before MLP.
807
+ """
808
+
809
+ def __init__(
810
+ self,
811
+ source_size: tuple[int, ...] | torch.Size,
812
+ target_size: tuple[int, ...] | torch.Size,
813
+ hidden_size_factor: int | float = 1.0,
814
+ ):
815
+ """Initialization function for ConvAdapter.
816
+
817
+ Args:
818
+ source_size (tuple[int, ...] | torch.Size): the size of the source feature.
819
+ target_size (tuple[int, ...] | torch.Size): the size of the target feature.
820
+ hidden_size_factor (int | float): the size of hidden dim of feature translator
821
+ as a factor of input feature hidden dim.
822
+ """
823
+ super().__init__()
824
+ if source_size[1] != source_size[2] or target_size[1] != target_size[2]:
825
+ raise NotImplementedError(
826
+ "Currently does not support non-square feature maps like source size"
827
+ "{source_size} and target size {target_size}."
828
+ )
829
+ self.source_size = source_size
830
+ self.target_size = target_size
831
+ self.hidden_size_factor = hidden_size_factor
832
+
833
+ hidden_dim = int(self.source_size[0] * hidden_size_factor)
834
+ source_channel_size = self.source_size[0]
835
+ target_channel_size = self.target_size[0]
836
+
837
+ if self.source_size[1] < 12:
838
+ raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.")
839
+ elif self.source_size[1] < 16 and self.target_size[1] >= 16: # pad (any, any), any <= 14 to (16, 16)
840
+ self.pad = nn.Sequential(
841
+ Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
842
+ nn.ConvTranspose2d(
843
+ source_channel_size,
844
+ source_channel_size,
845
+ kernel_size=3,
846
+ stride=1,
847
+ output_padding=14 - self.source_size[1],
848
+ ),
849
+ )
850
+ self.source_size = (self.source_size[0], 16, 16)
851
+ elif (self.source_size[1] == 16 or self.source_size[1] == 64) or \
852
+ (self.source_size[1] == 14 and self.target_size[1] == 14):
853
+ # no padding for (16, 16), (64, 64) and (14, 14) <-> (14, 14)
854
+ self.pad = nn.Sequential(
855
+ Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
856
+ )
857
+ elif self.target_size[1] < 14:
858
+ self.pad = nn.Sequential(
859
+ Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
860
+ )
861
+ else:
862
+ raise NotImplementedError("feature spatial size larger than 16x16 (other than 64x64) is not supported.")
863
+
864
+ if self.source_size[1] == 16 and self.target_size[1] == 64: # (16, 16) to (64, 64)
865
+ self.adapter = nn.Sequential(
866
+ nn.LayerNorm(self.source_size),
867
+ nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 31
868
+ nn.ReLU(),
869
+ nn.LayerNorm([hidden_dim, 31, 31]),
870
+ nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), # 64
871
+ nn.ReLU(),
872
+ nn.LayerNorm([hidden_dim, 64, 64]),
873
+ Rearrange("b c h w-> b (h w) c"),
874
+ nn.Linear(hidden_dim, target_channel_size),
875
+ )
876
+ elif self.source_size[1] == self.target_size[1]: # (16, 16) to (16, 16)
877
+ self.adapter = nn.Sequential(
878
+ nn.LayerNorm(self.source_size),
879
+ nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), # 16
880
+ nn.ReLU(),
881
+ nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
882
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), # 16
883
+ nn.ReLU(),
884
+ nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
885
+ Rearrange("b c h w-> b (h w) c"),
886
+ nn.Linear(hidden_dim, target_channel_size),
887
+ )
888
+ elif self.source_size[1] == 64 and self.target_size[1] == 16: # (64, 64) to (16, 16)
889
+ self.adapter = nn.Sequential(
890
+ nn.LayerNorm(self.source_size),
891
+ nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 32
892
+ nn.ReLU(),
893
+ nn.LayerNorm([hidden_dim, 32, 32]),
894
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), # 16
895
+ nn.ReLU(),
896
+ nn.LayerNorm([hidden_dim, 16, 16]),
897
+ Rearrange("b c h w-> b (h w) c"),
898
+ nn.Linear(hidden_dim, target_channel_size),
899
+ )
900
+ elif self.target_size[1] == 7:
901
+ self.adapter = nn.Sequential(
902
+ nn.LayerNorm(self.source_size),
903
+ nn.Conv2d(source_channel_size, hidden_dim, kernel_size=4, stride=2, padding=1), #14x14 -> 7x7
904
+ nn.ReLU(),
905
+ nn.LayerNorm([hidden_dim, 7, 7]),
906
+ Rearrange("b c h w-> b (h w) c"),
907
+ nn.Linear(hidden_dim, target_channel_size)
908
+ )
909
+ else:
910
+ NotImplementedError(f"{self.source_size} to {self.target_size} is not supported.")
911
+
912
+ def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
913
+ """Forward pass for ConvAdapter"""
914
+ # x: [B, (1)+H*W, C]
915
+ if not backbone_no_cls:
916
+ x = x[:, 1:]
917
+ x = self.pad(x)
918
+ x = self.adapter(x)
919
+ return x # [B, H*W, C]
920
+
921
+
922
+ class FeatureTranslator(nn.Module):
923
+ """Base class for the feature translator.
924
+
925
+ The flow is backbone_adapter -> translator_stem -> translator_heads.
926
+
927
+ Attributes:
928
+ backbone_feature_size (torch.Size): the size of features of the backbone.
929
+ target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
930
+ translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048.
931
+ target_model_names (list[str]): convenient attribute to hold all the names of the target models.
932
+
933
+ backbone_adapter (nn.Module): the adapter to map channel dim of backbone to the translator hidden dim.
934
+ translator_stem (nn.Module): the shared stem for all target models.
935
+ translator_heads (nn.ModuleDict): specific heads for different target models.
936
+ """
937
+
938
+ def __init__(
939
+ self,
940
+ backbone_feature_size: torch.Size,
941
+ target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
942
+ translator_hidden_size: int = 1024,
943
+ ) -> None:
944
+ """Initalization function for FeatureTranslator.
945
+
946
+ Args:
947
+ backbone_feature_size (torch.Size): the size of features of the backbone.
948
+ target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
949
+ translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048.
950
+ """
951
+ super().__init__()
952
+ self.backbone_feature_size = backbone_feature_size # (C, H, W)
953
+ self.target_feature_sizes = target_feature_sizes # [(C, H, W)]
954
+ self.translator_hidden_size = translator_hidden_size # C
955
+ self.target_model_names = list(target_feature_sizes.keys())
956
+ self.legit_target_model_name_map: dict[str, str] = {t: t.replace(".", "_") for t in self.target_model_names}
957
+ self.translator_heads: nn.ModuleDict = None
958
+
959
+ self.backbone_adapter = nn.Sequential(
960
+ nn.LayerNorm(self.backbone_feature_size[0]), # do a pre-norm
961
+ nn.Linear(
962
+ self.backbone_feature_size[0], # C in [C,H,W]
963
+ self.translator_hidden_size,
964
+ ),
965
+ )
966
+ self.translator_stem: nn.Module = nn.Identity()
967
+ self.build_translator_heads()
968
+
969
+ def build_translator_heads(self) -> None:
970
+ """Build translator heads to match the dimension of each target feature set.
971
+
972
+ Example:
973
+ translator_heads: dict[str, nn.Module] = ...
974
+ self.translator_heads = nn.ModuleDict(translator_heads)
975
+ """
976
+ raise NotImplementedError("build_translator_heads() should be overridden")
977
+
978
+ def forward(
979
+ self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False
980
+ ) -> torch.Tensor:
981
+ """Forward pass for a base feature translator.
982
+
983
+ Args:
984
+ x (torch.Tensor): input features from the backbone. [B, (1)+H*W, C].
985
+ (1) means optional CLS token. If `backbone_no_cls==True`, then [B, H*W, C].
986
+ target_model_names (Optional[list[str]]): names of the target models.
987
+ backbone_no_cls (bool): indicate backbone has cls token or not.
988
+ Can use it to customize whether to drop cls.
989
+
990
+ Returns:
991
+ dict[str, torch.Tensor]: predicted features for target models.
992
+ """
993
+ # x: [B, (1)+H*W, C]
994
+ x = self.backbone_adapter(x)
995
+ x = self.translator_stem(x)
996
+ target_model_names = target_model_names if target_model_names is not None else self.target_model_names
997
+ features = {t: self.translator_heads[self.legit_target_model_name_map[t]](x, backbone_no_cls=backbone_no_cls) for t in target_model_names}
998
+ return features
999
+
1000
+
1001
+ class MLPFeatureTranslator(FeatureTranslator):
1002
+ def __init__(
1003
+ self,
1004
+ backbone_feature_size: torch.Size,
1005
+ target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
1006
+ translator_hidden_size: int = 1024,
1007
+ translator_n_layer: int = 3,
1008
+ ) -> None:
1009
+ """Initalization function for MLPFeatureTranslator.
1010
+
1011
+ Args:
1012
+ backbone_feature_size (torch.Size): the size of features of the backbone.
1013
+ target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
1014
+ translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048.
1015
+ translator_n_layer (int): number of MLP layers. Defaults to 3.
1016
+ """
1017
+ self.translator_n_layer = translator_n_layer
1018
+
1019
+ super().__init__(
1020
+ backbone_feature_size=backbone_feature_size,
1021
+ target_feature_sizes=target_feature_sizes,
1022
+ translator_hidden_size=translator_hidden_size,
1023
+ )
1024
+
1025
+ def build_translator_heads(self) -> nn.ModuleDict:
1026
+ """Build MLP translator heads to match the dimension of each target feature set."""
1027
+ translator_heads = {}
1028
+ source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:])
1029
+ for target_model, target_size in self.target_feature_sizes.items():
1030
+ head = MLPAdapterHead(source_size=source_size, target_size=target_size, num_layer=self.translator_n_layer)
1031
+ translator_heads[self.legit_target_model_name_map[target_model]] = head
1032
+ self.translator_heads = nn.ModuleDict(translator_heads)
1033
+
1034
+
1035
+ class ConvFeatureTranslator(FeatureTranslator):
1036
+ def __init__(
1037
+ self,
1038
+ backbone_feature_size: torch.Size,
1039
+ target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
1040
+ translator_hidden_size: int = 1024,
1041
+ ) -> None:
1042
+ """Initalization function for ConvFeatureTranslator.
1043
+
1044
+ Args:
1045
+ backbone_feature_size (torch.Size): the size of features of the backbone.
1046
+ target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
1047
+ translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048.
1048
+ """
1049
+ super().__init__(
1050
+ backbone_feature_size=backbone_feature_size,
1051
+ target_feature_sizes=target_feature_sizes,
1052
+ translator_hidden_size=translator_hidden_size,
1053
+ )
1054
+
1055
+ def build_translator_heads(self) -> nn.ModuleDict:
1056
+ """Build translator heads to match the dimension of each target feature set.
1057
+
1058
+ Returns:
1059
+ nn.ModuleDict: the translator heads.
1060
+ """
1061
+ translator_heads = {}
1062
+ source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:])
1063
+ for target_model, target_size in self.target_feature_sizes.items():
1064
+ head = ConvAdapterHead(source_size=source_size, target_size=target_size)
1065
+ translator_heads[self.legit_target_model_name_map[target_model]] = head
1066
+ self.translator_heads = nn.ModuleDict(translator_heads)
1067
+
1068
+
1069
+ class LightConvFeatureTranslator(FeatureTranslator):
1070
+ def __init__(
1071
+ self,
1072
+ backbone_feature_size: torch.Size,
1073
+ target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
1074
+ translator_hidden_size: int = 1024,
1075
+ hidden_size_factor: int | float = 1.0,
1076
+ ) -> None:
1077
+ """Initalization function for LightConvFeatureTranslator.
1078
+ It's for a smaller translator compared to ConvFeatureTranslator.
1079
+
1080
+ Args:
1081
+ backbone_feature_size (torch.Size): the size of features of the backbone.
1082
+ target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
1083
+ translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 1024.
1084
+ hidden_size_factor: the size of hidden dim of feature translator
1085
+ as a factor of input feature hidden dim. Defaults to 1.0
1086
+ """
1087
+ self.hidden_size_factor = hidden_size_factor
1088
+ super().__init__(
1089
+ backbone_feature_size=backbone_feature_size,
1090
+ target_feature_sizes=target_feature_sizes,
1091
+ translator_hidden_size=translator_hidden_size,
1092
+ )
1093
+ self.backbone_adapter = nn.Identity()
1094
+
1095
+ def build_translator_heads(self) -> nn.ModuleDict:
1096
+ """Build translator heads to match the dimension of each target feature set.
1097
+
1098
+ Returns:
1099
+ nn.ModuleDict: the translator heads.
1100
+ """
1101
+ translator_heads = {}
1102
+ for target_model, target_size in self.target_feature_sizes.items():
1103
+ if "_cls" in target_model:
1104
+ head = LinearAdapterHead(
1105
+ source_size=self.backbone_feature_size,
1106
+ target_size=target_size
1107
+ )
1108
+ else:
1109
+ head = LightConvAdapterHead(
1110
+ source_size=self.backbone_feature_size,
1111
+ target_size=target_size,
1112
+ hidden_size_factor=self.hidden_size_factor
1113
+ )
1114
+ translator_heads[self.legit_target_model_name_map[target_model]] = head
1115
+ self.translator_heads = nn.ModuleDict(translator_heads)
1116
+
1117
+
1118
+ class TransformerFreatureTranslator(FeatureTranslator):
1119
+ def __init__(
1120
+ self,
1121
+ backbone_feature_size: torch.Size,
1122
+ target_feature_sizes: dict[str, torch.Size | tuple[int, int]],
1123
+ translator_hidden_size: int = 1024,
1124
+ translator_n_layers: int = 2,
1125
+ translator_n_heads: int = 8,
1126
+ translator_activation: str = "gelu",
1127
+ ) -> None:
1128
+ super().__init__(
1129
+ backbone_feature_size=backbone_feature_size,
1130
+ target_feature_sizes=target_feature_sizes,
1131
+ translator_hidden_size=translator_hidden_size,
1132
+ )
1133
+
1134
+ self.translator_stem = nn.TransformerDecoder(
1135
+ nn.TransformerDecoderLayer(
1136
+ d_model=translator_hidden_size,
1137
+ nhead=translator_n_heads,
1138
+ dim_feedforward=translator_hidden_size * 2,
1139
+ activation=translator_activation,
1140
+ batch_first=True,
1141
+ norm_first=True,
1142
+ ),
1143
+ num_layers=translator_n_layers,
1144
+ )
1145
+
1146
+ self.decode_tokens = nn.Parameter(
1147
+ torch.randn((1, math.prod(self.backbone_feature_size[1:]), translator_hidden_size))
1148
+ )
1149
+
1150
+ self.target_model_emb = nn.ParameterDict(
1151
+ {
1152
+ self.legit_target_model_name_map[t]: torch.randn(1, 1, translator_hidden_size)
1153
+ for t in self.target_model_names
1154
+ }
1155
+ )
1156
+
1157
+ def build_translator_heads(self) -> None:
1158
+ """Build Transformer translator heads to match the dimension of each target feature set."""
1159
+ translator_heads = {}
1160
+ for target_model, target_size in self.target_feature_sizes.items():
1161
+ head = MLPAdapterHead(
1162
+ source_size=(self.translator_hidden_size, *self.backbone_feature_size[1:]),
1163
+ target_size=target_size,
1164
+ num_layer=2,
1165
+ )
1166
+ translator_heads[self.legit_target_model_name_map[target_model]] = head
1167
+ self.translator_heads = nn.ModuleDict(translator_heads)
1168
+
1169
+ def forward(
1170
+ self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False
1171
+ ) -> torch.Tensor:
1172
+ """Forward pass for a simple linear translator.
1173
+
1174
+ Args:
1175
+ x (torch.Tensor): input features from the backbone.
1176
+ target_model_names (Optional[str]): names of the target models.
1177
+ backbone_no_cls (bool): indicate backbone has cls token or not.
1178
+ Can use it to customize whether to drop cls.
1179
+
1180
+ Returns:
1181
+ dict[str, torch.Tensor]: predicted features for target models.
1182
+ """
1183
+ if not backbone_no_cls:
1184
+ x = x[:, 1:]
1185
+ x = self.backbone_adapter(x)
1186
+ features = {}
1187
+ target_model_names = target_model_names if target_model_names is not None else self.target_model_names
1188
+ for t in target_model_names:
1189
+ feature = self.translator_stem(
1190
+ torch.cat(
1191
+ [
1192
+ self.decode_tokens.repeat(x.size(0), 1, 1),
1193
+ self.target_model_emb[self.legit_target_model_name_map[t]].repeat(x.size(0), 1, 1),
1194
+ ],
1195
+ dim=1,
1196
+ ),
1197
+ memory=x,
1198
+ )[:, 1:, ...]
1199
+ features[t] = self.translator_heads[self.legit_target_model_name_map[t]](feature)
1200
+ return features
1201
+
1202
+
1203
+ def build_feature_translator(translator_type: str, **kwargs: Any) -> FeatureTranslator:
1204
+ """Handy function to build feature translators given the type
1205
+
1206
+ Args:
1207
+ translator_type (str): the type of the translator,
1208
+ one in `"mlp"`, `"conv"`, `"lconv"`, `"transformer"` (or `"trans"`).
1209
+ At the moment we are actively using `"lconv"`.
1210
+
1211
+ Returns:
1212
+ FeatureTranslator: the corresponding FeatureTranslator
1213
+ """
1214
+ if translator_type == "mlp":
1215
+ return MLPFeatureTranslator(**kwargs)
1216
+ elif translator_type == "conv":
1217
+ return ConvFeatureTranslator(**kwargs)
1218
+ elif translator_type == "lconv":
1219
+ return LightConvFeatureTranslator(**kwargs)
1220
+ elif translator_type == "transformer" or translator_type == "trans":
1221
+ return TransformerFreatureTranslator(**kwargs)
1222
+ else:
1223
+ raise NotImplementedError(f"Requested {translator_type} is not implemented yet.")
1224
+
1225
+
1226
+ class TheiaConfig(PretrainedConfig):
1227
+ def __init__(
1228
+ self,
1229
+ backbone: str | nn.Module = "facebook/deit-tiny-patch16-224",
1230
+ pretrained: bool = False,
1231
+ target_feature_sizes: Optional[dict[str, torch.Size | tuple[int, ...]]] = None,
1232
+ translator_type: str = "lconv",
1233
+ translator_hidden_size_factor: float | int = 1.0,
1234
+ target_loss_weights: Optional[dict[str, float]] = None,
1235
+ feature_reduce_method: Optional[str] = None,
1236
+ feature_neck: bool = False,
1237
+ feature_neck_hidden_dim: int = 256,
1238
+ forward_neck: bool = False,
1239
+ feature_neck_nonlinearity: str = "relu",
1240
+ iamge_size: int = 224,
1241
+ num_reg_tokens: int = 0,
1242
+ **kwargs: Any
1243
+ ):
1244
+ self.backbone = backbone
1245
+ self.pretrained = pretrained
1246
+ self.target_feature_sizes = target_feature_sizes
1247
+ self.translator_type = translator_type
1248
+ self.translator_hidden_size_factor = translator_hidden_size_factor
1249
+ self.target_loss_weights = target_loss_weights
1250
+ self.feature_reduce_method = feature_reduce_method
1251
+ self.feature_neck = feature_neck
1252
+ self.feature_neck_hidden_dim = feature_neck_hidden_dim
1253
+ self.forward_neck = forward_neck
1254
+ self.feature_neck_nonlinearity = feature_neck_nonlinearity
1255
+ self.image_size = 224
1256
+ self.num_reg_tokens = num_reg_tokens
1257
+ super().__init__(**kwargs)
1258
+
1259
+ class TheiaModel(PreTrainedModel):
1260
+ config_class = TheiaConfig
1261
+
1262
+ def __init__(self, config: TheiaConfig):
1263
+ super().__init__(config)
1264
+
1265
+ self.target_feature_sizes = config.target_feature_sizes
1266
+ self.preprocessor = None
1267
+ self.pretrained = config.pretrained
1268
+
1269
+ # backbone
1270
+ self.image_size = config.image_size
1271
+ if "reg" in config.backbone:
1272
+ self.backbone: nn.Module = build_backbone(config.backbone, config.pretrained, image_size=config.image_size, num_reg_tokens = config.num_reg_tokens)
1273
+ else:
1274
+ self.backbone: nn.Module = build_backbone(config.backbone, config.pretrained, image_size=config.image_size)
1275
+
1276
+ # handle output feature (feature reduce)
1277
+ self.feature_reduce_method = config.feature_reduce_method
1278
+ self.no_cls = hasattr(self.backbone, "no_cls")
1279
+ self.num_reg_tokens = self.backbone.num_reg_tokens if hasattr(self.backbone, "num_reg_tokens") else 0
1280
+
1281
+ # translator
1282
+ backbone_feature_size = self.backbone.get_feature_size(keep_spatial=True)
1283
+ if self.target_feature_sizes:
1284
+ translator_kwargs = {
1285
+ "hidden_size_factor": config.translator_hidden_size_factor
1286
+ }
1287
+ translator_kwargs["backbone_feature_size"] = backbone_feature_size
1288
+ translator_kwargs["target_feature_sizes"] = config.target_feature_sizes
1289
+ self.translator = build_feature_translator(
1290
+ config.translator_type, **translator_kwargs
1291
+ )
1292
+ else:
1293
+ self.translator = None
1294
+
1295
+ self.feature_neck = config.feature_neck
1296
+ self.feature_neck_hidden_dim = config.feature_neck_hidden_dim
1297
+ self.forward_neck = config.forward_neck
1298
+ if self.feature_neck:
1299
+ num_tokens_edge = self.backbone.model.config.image_size // self.backbone.model.config.patch_size
1300
+ self.neck = nn.Sequential(
1301
+ Rearrange("b (h w) c -> b c h w", h=num_tokens_edge, w=num_tokens_edge),
1302
+ nn.Conv2d(self.backbone.model.config.hidden_size, self.feature_neck_hidden_dim, kernel_size=4, stride=2, padding=1), #14x14 -> 7x7
1303
+ nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), # just to keep the same as super class
1304
+ nn.Conv2d(self.feature_neck_hidden_dim, self.feature_neck_hidden_dim, kernel_size=3, stride=2), #7x7 -> 3x3
1305
+ nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(),
1306
+ nn.Conv2d(self.feature_neck_hidden_dim, self.feature_neck_hidden_dim, kernel_size=3, stride=1), #3x3 -> 1x1
1307
+ nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(),
1308
+ nn.Flatten()
1309
+ )
1310
+ else:
1311
+ self.neck = None
1312
+
1313
+ # loss
1314
+ self.mse_loss = nn.MSELoss()
1315
+ self.l1_loss = nn.SmoothL1Loss()
1316
+ self.cos_loss = nn.CosineEmbeddingLoss()
1317
+ self.cos_target = torch.ones((1), dtype=torch.int, requires_grad=False)
1318
+ self.target_loss_weights = config.target_loss_weights
1319
+
1320
+ def load_pretrained_weights(self, checkpoint_path: str) -> None:
1321
+ """
1322
+ Load weights from `checkpoint_path` manually.
1323
+
1324
+ Args:
1325
+ checkpoint_path (str): path to the weights.
1326
+ """
1327
+ # load theia weights
1328
+ if checkpoint_path:
1329
+ weights_dict = torch.load(checkpoint_path, map_location="cpu")
1330
+ # Filter out unnecessary keys
1331
+ pretrained_dict = {k: v for k, v in weights_dict.items() if k in self.state_dict()}
1332
+ self.load_state_dict(pretrained_dict, strict=False)
1333
+
1334
+ def freeze_translator(self) -> None:
1335
+ """Freeze feature translators `self.translator`."""
1336
+ if self.translator is not None:
1337
+ for param in self.translator.parameters():
1338
+ param.requires_grad = False
1339
+
1340
+ def freeze_backbone(self) -> None:
1341
+ """Freeze backbone (encoder) `self.backbone`. """
1342
+ self.freeze_encoder()
1343
+
1344
+ def freeze_encoder(self) -> None:
1345
+ """Freeze backbone (encoder) `self.backbone`. """
1346
+ for param in self.backbone.parameters():
1347
+ param.requires_grad = False
1348
+
1349
+ def freeze_neck(self) -> None:
1350
+ """Freeze feature neck `self.neck`."""
1351
+ if self.neck is not None:
1352
+ for param in self.neck.parameters():
1353
+ param.requires_grad = False
1354
+
1355
+ def freeze_everything(self) -> None:
1356
+ """Freeze all parameters in the model."""
1357
+ self.freeze_translator()
1358
+ self.freeze_neck()
1359
+ self.freeze_encoder()
1360
+
1361
+ def unfreeze_translator(self) -> None:
1362
+ if self.translator is not None:
1363
+ for param in self.translator.parameters():
1364
+ param.requires_grad = True
1365
+
1366
+ def unfreeze_backbone(self) -> None:
1367
+ "Set parameters in backbone (encoder) `self.backbone` trainable."
1368
+ self.unfreeze_encoder()
1369
+
1370
+ def unfreeze_encoder(self) -> None:
1371
+ "Set parameters in backbone (encoder) `self.backbone` trainable."
1372
+ for param in self.backbone.parameters():
1373
+ param.requires_grad = True
1374
+
1375
+ def unfreeze_neck(self) -> None:
1376
+ "Set parameters in feature neck `self.neck` trainable."
1377
+ if self.neck is not None:
1378
+ for param in self.neck.parameters():
1379
+ param.requires_grad = True
1380
+
1381
+ def unfreeze_everything(self) -> None:
1382
+ """Set all parameters trainable."""
1383
+ self.unfreeze_translator()
1384
+ self.unfreeze_neck()
1385
+ self.unfreeze_encoder()
1386
+
1387
+ def set_forward_neck(self, forward_neck: bool = True) -> None:
1388
+ """
1389
+ Set `self.forward_neck` to `forward_neck` value.
1390
+
1391
+ Args:
1392
+ forward_neck (bool): whether forward the feature through the random initialized neck.
1393
+ If set to True, the output from `self.forward()` will be in shape [batch_size, self.config.feature_neck_hidden_dim]
1394
+ """
1395
+ self.forward_neck = forward_neck
1396
+
1397
+ def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
1398
+ """Forward RVFM feature only (before translators).
1399
+
1400
+ Args:
1401
+ x (torch.Tensor): input image. By default it accepts images
1402
+ in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
1403
+ kwargs (Any): kwargs including mainly those for huggingface preprocessor:
1404
+ `do_resize` (bool) defaults to True.
1405
+ `interpolate_pos_encoding` (Optional[bool]) defaults to None.
1406
+ `do_rescale` (bool) defaults to True.
1407
+ `do_normalize` (bool) defaults to True.
1408
+
1409
+ Returns:
1410
+ torch.Tensor: RVFM feature.
1411
+ """
1412
+ feature = self.backbone(x, **kwargs)
1413
+ # [B, 1+H*W+N, C] if including both CLS and register tokens.
1414
+ # [B, 1+H*W, C] for standard model (N=0).
1415
+ # [B, H*W, C] for model without CLS.
1416
+ return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens)
1417
+
1418
+ def forward(self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, **kwargs: Any) -> dict[str, torch.Tensor] | torch.Tensor:
1419
+ """Forward pass of Robot Vision Foundation Model.
1420
+
1421
+ Args:
1422
+ x (torch.Tensor): input image. By default it accepts images
1423
+ in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
1424
+ target_model_names (Optional[list[str]]): names of the target foundation models.
1425
+ kwargs (Any): kwargs including mainly those for huggingface preprocessor:
1426
+ `do_resize` (bool) defaults to True.
1427
+ `interpolate_pos_encoding` (Optional[bool]) defaults to None.
1428
+ `do_rescale` (bool) defaults to True.
1429
+ `do_normalize` (bool) defaults to True.
1430
+
1431
+ Returns:
1432
+ if `self.forward_neck`:
1433
+ torch.Tensor: compact vector feature passed through the neck. [B, C_neck]
1434
+ else:
1435
+ dict[str, torch.Tensor]: features that match to each foundation model.
1436
+ Each feature is in [B, (H*W), C] or [B, C].
1437
+ """
1438
+ if self.forward_neck:
1439
+ x = self.forward_feature(x)
1440
+ return self.neck(x)
1441
+ else:
1442
+ x = self.backbone(x, **kwargs)
1443
+ if self.num_reg_tokens > 0:
1444
+ x = x[:, :-self.num_reg_tokens] # [B, (1)+H*W, C]
1445
+ features = self.translator(x, target_model_names, backbone_no_cls=self.no_cls) # each is [B, H*W, C] or [B, C]
1446
+ return features
1447
+
1448
+ def get_loss(self, pred_features: dict[str, torch.Tensor], y: dict[str, torch.Tensor]) -> dict[str, Any]:
1449
+ """Get loss terms given predictions and targets.
1450
+
1451
+ Args:
1452
+ pred_features (dict[str, torch.Tensor]): predictions.
1453
+ y (dict[str, torch.Tensor]): targets.
1454
+
1455
+ Returns:
1456
+ tuple[Any, ...]: loss terms
1457
+ """
1458
+ mse_loss_avg, cos_loss_avg, l1_loss_avg = 0, 0, 0
1459
+ mse_losses_per_model = {}
1460
+ cos_losses_per_model = {}
1461
+ l1_losses_per_model = {}
1462
+
1463
+ for t in pred_features:
1464
+ pred = pred_features[t]
1465
+ target = y[t]
1466
+
1467
+ # mse loss
1468
+ mse_loss = self.mse_loss(pred, target)
1469
+ weight = self.target_loss_weights if self.target_loss_weights else 1.0 / len(pred_features)
1470
+
1471
+ # l1 loss
1472
+ l1_loss = self.l1_loss(pred, target)
1473
+
1474
+ # cos loss
1475
+ pred_norm = F.normalize(pred.flatten(start_dim=1), dim=1, p=2)
1476
+ target_norm = F.normalize(target.flatten(start_dim=1), dim=1, p=2)
1477
+ target = self.cos_target.repeat(pred.size(0)).to(pred.device)
1478
+ cos_loss = self.cos_loss(pred_norm, target_norm, target)
1479
+
1480
+ mse_loss_avg += mse_loss * weight
1481
+ cos_loss_avg += cos_loss / len(pred_features) # balance cos by default for meaningful eval
1482
+ l1_loss_avg += l1_loss * weight
1483
+
1484
+ mse_losses_per_model[t] = mse_loss.item()
1485
+ cos_losses_per_model[t] = cos_loss.item()
1486
+ l1_losses_per_model[t] = l1_loss.item()
1487
+
1488
+ return {
1489
+ "mse_loss": mse_loss_avg,
1490
+ "cos_loss": cos_loss_avg,
1491
+ "l1_loss": l1_loss_avg,
1492
+ "mse_losses_per_model": mse_losses_per_model,
1493
+ "cos_losses_per_model": cos_losses_per_model,
1494
+ "l1_losses_per_model": l1_losses_per_model,
1495
+ }