Spaces:
Runtime error
Runtime error
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .idea/.gitignore +8 -0
- .idea/deployment.xml +15 -0
- .idea/inspectionProfiles/Project_Default.xml +23 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/selfmask_demo.iml +8 -0
- .idea/sonarlint/issuestore/index.pb +0 -0
- .idea/webServers.xml +14 -0
- __pycache__/bilateral_solver.cpython-38.pyc +0 -0
- __pycache__/utils.cpython-38.pyc +0 -0
- app.py +134 -0
- bilateral_solver.py +206 -0
- duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml +56 -0
- networks/__init__.py +0 -0
- networks/__pycache__/__init__.cpython-38.pyc +0 -0
- networks/__pycache__/timm_deit.cpython-38.pyc +0 -0
- networks/__pycache__/timm_vit.cpython-38.pyc +0 -0
- networks/__pycache__/vision_transformer.cpython-38.pyc +0 -0
- networks/maskformer/__pycache__/maskformer.cpython-38.pyc +0 -0
- networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc +0 -0
- networks/maskformer/maskformer.py +267 -0
- networks/maskformer/positional_embedding.py +48 -0
- networks/maskformer/transformer_decoder.py +376 -0
- networks/module_helper.py +176 -0
- networks/resnet.py +60 -0
- networks/resnet_backbone.py +194 -0
- networks/resnet_models.py +273 -0
- networks/timm_deit.py +254 -0
- networks/timm_vit.py +819 -0
- networks/vision_transformer.py +569 -0
- resources/.DS_Store +0 -0
- resources/0053.jpg +0 -0
- resources/0236.jpg +0 -0
- resources/0239.jpg +0 -0
- resources/0403.jpg +0 -0
- resources/0412.jpg +0 -0
- resources/ILSVRC2012_test_00005309.jpg +0 -0
- resources/ILSVRC2012_test_00012622.jpg +0 -0
- resources/ILSVRC2012_test_00022698.jpg +0 -0
- resources/ILSVRC2012_test_00040725.jpg +0 -0
- resources/ILSVRC2012_test_00075738.jpg +0 -0
- resources/ILSVRC2012_test_00080683.jpg +0 -0
- resources/ILSVRC2012_test_00085874.jpg +0 -0
- resources/im052.jpg +0 -0
- resources/sun_ainjbonxmervsvpv.jpg +0 -0
- resources/sun_alfntqzssslakmss.jpg +0 -0
- resources/sun_amnrcxhisjfrliwa.jpg +0 -0
- resources/sun_bvyxpvkouzlfwwod.jpg +0 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
.idea/deployment.xml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="PublishConfigData" autoUpload="Always" serverName="mydev" remoteFilesAllowedToDisappearOnAutoupload="false">
|
4 |
+
<serverData>
|
5 |
+
<paths name="mydev">
|
6 |
+
<serverdata>
|
7 |
+
<mappings>
|
8 |
+
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
9 |
+
</mappings>
|
10 |
+
</serverdata>
|
11 |
+
</paths>
|
12 |
+
</serverData>
|
13 |
+
<option name="myAutoUpload" value="ALWAYS" />
|
14 |
+
</component>
|
15 |
+
</project>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
5 |
+
<option name="ignoredPackages">
|
6 |
+
<value>
|
7 |
+
<list size="10">
|
8 |
+
<item index="0" class="java.lang.String" itemvalue="prettytable" />
|
9 |
+
<item index="1" class="java.lang.String" itemvalue="interrogate" />
|
10 |
+
<item index="2" class="java.lang.String" itemvalue="pytest" />
|
11 |
+
<item index="3" class="java.lang.String" itemvalue="yapf" />
|
12 |
+
<item index="4" class="java.lang.String" itemvalue="cityscapesscripts" />
|
13 |
+
<item index="5" class="java.lang.String" itemvalue="Wand" />
|
14 |
+
<item index="6" class="java.lang.String" itemvalue="isort" />
|
15 |
+
<item index="7" class="java.lang.String" itemvalue="xdoctest" />
|
16 |
+
<item index="8" class="java.lang.String" itemvalue="codecov" />
|
17 |
+
<item index="9" class="java.lang.String" itemvalue="flake8" />
|
18 |
+
</list>
|
19 |
+
</value>
|
20 |
+
</option>
|
21 |
+
</inspection_tool>
|
22 |
+
</profile>
|
23 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pytorch)" project-jdk-type="Python SDK" />
|
4 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/selfmask_demo.iml" filepath="$PROJECT_DIR$/.idea/selfmask_demo.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/selfmask_demo.iml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="inheritedJdk" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
</module>
|
.idea/sonarlint/issuestore/index.pb
ADDED
File without changes
|
.idea/webServers.xml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="WebServers">
|
4 |
+
<option name="servers">
|
5 |
+
<webServer id="12e2cf4d-3b81-4241-9665-54a333f70567" name="mydev">
|
6 |
+
<fileTransfer rootFolder="/users/gyungin/selfmask_demo" accessType="SFTP" host="mydev" port="22" sshConfigId="3e23a652-ab3c-4dc2-a117-84c2bf217891" sshConfig="gyungin@mydev:22 password">
|
7 |
+
<advancedOptions>
|
8 |
+
<advancedOptions dataProtectionLevel="Private" passiveMode="true" shareSSLContext="true" />
|
9 |
+
</advancedOptions>
|
10 |
+
</fileTransfer>
|
11 |
+
</webServer>
|
12 |
+
</option>
|
13 |
+
</component>
|
14 |
+
</project>
|
__pycache__/bilateral_solver.cpython-38.pyc
ADDED
Binary file (6.76 kB). View file
|
|
__pycache__/utils.cpython-38.pyc
ADDED
Binary file (2.9 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser, Namespace
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
import yaml
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchvision.transforms.functional import to_tensor, normalize, resize
|
10 |
+
import gradio as gr
|
11 |
+
from utils import get_model
|
12 |
+
from bilateral_solver import bilateral_solver_output
|
13 |
+
import os
|
14 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
15 |
+
|
16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
state_dict: dict = torch.hub.load_state_dict_from_url(
|
18 |
+
"https://github.com/NoelShin/selfmask/releases/download/v1.0.0/selfmask_nq20.pt",
|
19 |
+
map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
+
)["model"]
|
21 |
+
|
22 |
+
parser = ArgumentParser("SelfMask demo")
|
23 |
+
parser.add_argument(
|
24 |
+
"--config",
|
25 |
+
type=str,
|
26 |
+
default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml"
|
27 |
+
)
|
28 |
+
|
29 |
+
# parser.add_argument(
|
30 |
+
# "--p_state_dict",
|
31 |
+
# type=str,
|
32 |
+
# default="/users/gyungin/selfmask_bak/ckpt/nq20_ndl6_bc_sr10100_duts_pm_all_k2,3,4_md_seed0_final/eval/hku_is/best_model.pt",
|
33 |
+
# )
|
34 |
+
#
|
35 |
+
# parser.add_argument(
|
36 |
+
# "--dataset_name", '-dn', type=str, default="duts",
|
37 |
+
# choices=["dut_omron", "duts", "ecssd"]
|
38 |
+
# )
|
39 |
+
|
40 |
+
# independent variables
|
41 |
+
# parser.add_argument("--use_gpu", type=bool, default=True)
|
42 |
+
# parser.add_argument('--seed', default=0, type=int)
|
43 |
+
# parser.add_argument("--dir_root", type=str, default="..")
|
44 |
+
# parser.add_argument("--gpu_id", type=int, default=2)
|
45 |
+
# parser.add_argument("--suffix", type=str, default='')
|
46 |
+
args: Namespace = parser.parse_args()
|
47 |
+
base_args = yaml.safe_load(open(f"{args.config}", 'r'))
|
48 |
+
base_args.pop("dataset_name")
|
49 |
+
args: dict = vars(args)
|
50 |
+
args.update(base_args)
|
51 |
+
args: Namespace = Namespace(**args)
|
52 |
+
|
53 |
+
model = get_model(arch="maskformer", configs=args).to(device)
|
54 |
+
model.load_state_dict(state_dict)
|
55 |
+
model.eval()
|
56 |
+
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def main(
|
60 |
+
image: Image.Image,
|
61 |
+
size: int = 384,
|
62 |
+
max_size: int = 512,
|
63 |
+
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
|
64 |
+
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
|
65 |
+
):
|
66 |
+
pil_image: Image.Image = resize(image, size=size, max_size=max_size)
|
67 |
+
image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
|
68 |
+
dict_outputs = model(image[None].to(device))
|
69 |
+
|
70 |
+
batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1]
|
71 |
+
batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1]
|
72 |
+
|
73 |
+
if len(batch_pred_masks.shape) == 5:
|
74 |
+
# b x n_layers x n_queries x h x w -> b x n_queries x h x w
|
75 |
+
batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer
|
76 |
+
|
77 |
+
if batch_objectness is not None:
|
78 |
+
# b x n_layers x n_queries x 1 -> b x n_queries x 1
|
79 |
+
batch_objectness = batch_objectness[:, -1, ...]
|
80 |
+
|
81 |
+
# resize prediction to original resolution
|
82 |
+
# note: upsampling by 4 and cutting the padded region allows for a better result
|
83 |
+
H, W = image.shape[-2:]
|
84 |
+
batch_pred_masks = F.interpolate(
|
85 |
+
batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False
|
86 |
+
)[..., :H, :W]
|
87 |
+
|
88 |
+
# iterate over batch dimension
|
89 |
+
for batch_index, pred_masks in enumerate(batch_pred_masks):
|
90 |
+
# n_queries x 1 -> n_queries
|
91 |
+
objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1)
|
92 |
+
ranks = torch.argsort(objectness, descending=True) # n_queries
|
93 |
+
pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W
|
94 |
+
pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
|
95 |
+
|
96 |
+
pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64
|
97 |
+
pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
|
98 |
+
|
99 |
+
attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
|
100 |
+
super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
|
101 |
+
return super_imposed_img
|
102 |
+
# return pred_mask_bi
|
103 |
+
|
104 |
+
demo = gr.Interface(
|
105 |
+
fn=main,
|
106 |
+
inputs=gr.inputs.Image(type="pil"),
|
107 |
+
outputs="image",
|
108 |
+
examples=[f"resources/{fname}.jpg" for fname in [
|
109 |
+
"0053",
|
110 |
+
"0236",
|
111 |
+
"0239",
|
112 |
+
"0403",
|
113 |
+
"0412",
|
114 |
+
"ILSVRC2012_test_00005309",
|
115 |
+
"ILSVRC2012_test_00012622",
|
116 |
+
"ILSVRC2012_test_00022698",
|
117 |
+
"ILSVRC2012_test_00040725",
|
118 |
+
"ILSVRC2012_test_00075738",
|
119 |
+
"ILSVRC2012_test_00080683",
|
120 |
+
"ILSVRC2012_test_00085874",
|
121 |
+
"im052",
|
122 |
+
"sun_ainjbonxmervsvpv",
|
123 |
+
"sun_alfntqzssslakmss",
|
124 |
+
"sun_amnrcxhisjfrliwa",
|
125 |
+
"sun_bvyxpvkouzlfwwod"
|
126 |
+
]],
|
127 |
+
title="Unsupervised Salient Object Detection with Spectral Cluster Voting",
|
128 |
+
allow_flagging="never",
|
129 |
+
analytics_enabled=False
|
130 |
+
)
|
131 |
+
|
132 |
+
demo.launch(
|
133 |
+
# share=True
|
134 |
+
)
|
bilateral_solver.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.sparse import diags
|
2 |
+
from scipy.sparse.linalg import cg
|
3 |
+
from scipy.sparse import csr_matrix
|
4 |
+
import numpy as np
|
5 |
+
from skimage.io import imread
|
6 |
+
from scipy import ndimage
|
7 |
+
import torch
|
8 |
+
import PIL.Image as Image
|
9 |
+
import os
|
10 |
+
from argparse import ArgumentParser, Namespace
|
11 |
+
from typing import Dict, Union
|
12 |
+
from collections import defaultdict
|
13 |
+
import yaml
|
14 |
+
import ujson as json
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
|
21 |
+
RGB_TO_YUV = np.array([
|
22 |
+
[0.299, 0.587, 0.114],
|
23 |
+
[-0.168736, -0.331264, 0.5],
|
24 |
+
[0.5, -0.418688, -0.081312]])
|
25 |
+
YUV_TO_RGB = np.array([
|
26 |
+
[1.0, 0.0, 1.402],
|
27 |
+
[1.0, -0.34414, -0.71414],
|
28 |
+
[1.0, 1.772, 0.0]])
|
29 |
+
YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1)
|
30 |
+
MAX_VAL = 255.0
|
31 |
+
|
32 |
+
|
33 |
+
def rgb2yuv(im):
|
34 |
+
return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET
|
35 |
+
|
36 |
+
|
37 |
+
def yuv2rgb(im):
|
38 |
+
return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1]))
|
39 |
+
|
40 |
+
|
41 |
+
def get_valid_idx(valid, candidates):
|
42 |
+
"""Find which values are present in a list and where they are located"""
|
43 |
+
locs = np.searchsorted(valid, candidates)
|
44 |
+
# Handle edge case where the candidate is larger than all valid values
|
45 |
+
locs = np.clip(locs, 0, len(valid) - 1)
|
46 |
+
# Identify which values are actually present
|
47 |
+
valid_idx = np.flatnonzero(valid[locs] == candidates)
|
48 |
+
locs = locs[valid_idx]
|
49 |
+
return valid_idx, locs
|
50 |
+
|
51 |
+
|
52 |
+
class BilateralGrid(object):
|
53 |
+
def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8):
|
54 |
+
im_yuv = rgb2yuv(im)
|
55 |
+
# Compute 5-dimensional XYLUV bilateral-space coordinates
|
56 |
+
Iy, Ix = np.mgrid[:im.shape[0], :im.shape[1]]
|
57 |
+
x_coords = (Ix / sigma_spatial).astype(int)
|
58 |
+
y_coords = (Iy / sigma_spatial).astype(int)
|
59 |
+
luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int)
|
60 |
+
chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int)
|
61 |
+
coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords))
|
62 |
+
coords_flat = coords.reshape(-1, coords.shape[-1])
|
63 |
+
self.npixels, self.dim = coords_flat.shape
|
64 |
+
# Hacky "hash vector" for coordinates,
|
65 |
+
# Requires all scaled coordinates be < MAX_VAL
|
66 |
+
self.hash_vec = (MAX_VAL ** np.arange(self.dim))
|
67 |
+
# Construct S and B matrix
|
68 |
+
self._compute_factorization(coords_flat)
|
69 |
+
|
70 |
+
def _compute_factorization(self, coords_flat):
|
71 |
+
# Hash each coordinate in grid to a unique value
|
72 |
+
hashed_coords = self._hash_coords(coords_flat)
|
73 |
+
unique_hashes, unique_idx, idx = \
|
74 |
+
np.unique(hashed_coords, return_index=True, return_inverse=True)
|
75 |
+
# Identify unique set of vertices
|
76 |
+
unique_coords = coords_flat[unique_idx]
|
77 |
+
self.nvertices = len(unique_coords)
|
78 |
+
# Construct sparse splat matrix that maps from pixels to vertices
|
79 |
+
self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels))))
|
80 |
+
# Construct sparse blur matrices.
|
81 |
+
# Note that these represent [1 0 1] blurs, excluding the central element
|
82 |
+
self.blurs = []
|
83 |
+
for d in range(self.dim):
|
84 |
+
blur = 0.0
|
85 |
+
for offset in (-1, 1):
|
86 |
+
offset_vec = np.zeros((1, self.dim))
|
87 |
+
offset_vec[:, d] = offset
|
88 |
+
neighbor_hash = self._hash_coords(unique_coords + offset_vec)
|
89 |
+
valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash)
|
90 |
+
blur = blur + csr_matrix((np.ones((len(valid_coord),)),
|
91 |
+
(valid_coord, idx)),
|
92 |
+
shape=(self.nvertices, self.nvertices))
|
93 |
+
self.blurs.append(blur)
|
94 |
+
|
95 |
+
def _hash_coords(self, coord):
|
96 |
+
"""Hacky function to turn a coordinate into a unique value"""
|
97 |
+
return np.dot(coord.reshape(-1, self.dim), self.hash_vec)
|
98 |
+
|
99 |
+
def splat(self, x):
|
100 |
+
return self.S.dot(x)
|
101 |
+
|
102 |
+
def slice(self, y):
|
103 |
+
return self.S.T.dot(y)
|
104 |
+
|
105 |
+
def blur(self, x):
|
106 |
+
"""Blur a bilateral-space vector with a 1 2 1 kernel in each dimension"""
|
107 |
+
assert x.shape[0] == self.nvertices
|
108 |
+
out = 2 * self.dim * x
|
109 |
+
for blur in self.blurs:
|
110 |
+
out = out + blur.dot(x)
|
111 |
+
return out
|
112 |
+
|
113 |
+
def filter(self, x):
|
114 |
+
"""Apply bilateral filter to an input x"""
|
115 |
+
return self.slice(self.blur(self.splat(x))) / \
|
116 |
+
self.slice(self.blur(self.splat(np.ones_like(x))))
|
117 |
+
|
118 |
+
|
119 |
+
def bistochastize(grid, maxiter=10):
|
120 |
+
"""Compute diagonal matrices to bistochastize a bilateral grid"""
|
121 |
+
m = grid.splat(np.ones(grid.npixels))
|
122 |
+
n = np.ones(grid.nvertices)
|
123 |
+
for i in range(maxiter):
|
124 |
+
n = np.sqrt(n * m / grid.blur(n))
|
125 |
+
# Correct m to satisfy the assumption of bistochastization regardless
|
126 |
+
# of how many iterations have been run.
|
127 |
+
m = n * grid.blur(n)
|
128 |
+
Dm = diags(m, 0)
|
129 |
+
Dn = diags(n, 0)
|
130 |
+
return Dn, Dm
|
131 |
+
|
132 |
+
|
133 |
+
class BilateralSolver(object):
|
134 |
+
def __init__(self, grid, params):
|
135 |
+
self.grid = grid
|
136 |
+
self.params = params
|
137 |
+
self.Dn, self.Dm = bistochastize(grid)
|
138 |
+
|
139 |
+
def solve(self, x, w):
|
140 |
+
# Check that w is a vector or a nx1 matrix
|
141 |
+
if w.ndim == 2:
|
142 |
+
assert (w.shape[1] == 1)
|
143 |
+
elif w.dim == 1:
|
144 |
+
w = w.reshape(w.shape[0], 1)
|
145 |
+
A_smooth = (self.Dm - self.Dn.dot(self.grid.blur(self.Dn)))
|
146 |
+
w_splat = self.grid.splat(w)
|
147 |
+
A_data = diags(w_splat[:, 0], 0)
|
148 |
+
A = self.params["lam"] * A_smooth + A_data
|
149 |
+
xw = x * w
|
150 |
+
b = self.grid.splat(xw)
|
151 |
+
# Use simple Jacobi preconditioner
|
152 |
+
A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"])
|
153 |
+
M = diags(1 / A_diag, 0)
|
154 |
+
# Flat initialization
|
155 |
+
y0 = self.grid.splat(xw) / w_splat
|
156 |
+
yhat = np.empty_like(y0)
|
157 |
+
for d in range(x.shape[-1]):
|
158 |
+
yhat[..., d], info = cg(A, b[..., d], x0=y0[..., d], M=M, maxiter=self.params["cg_maxiter"],
|
159 |
+
tol=self.params["cg_tol"])
|
160 |
+
xhat = self.grid.slice(yhat)
|
161 |
+
return xhat
|
162 |
+
|
163 |
+
|
164 |
+
def bilateral_solver_output(
|
165 |
+
img: Image.Image,
|
166 |
+
target: np.ndarray,
|
167 |
+
sigma_spatial=16,
|
168 |
+
sigma_luma=16,
|
169 |
+
sigma_chroma=8
|
170 |
+
):
|
171 |
+
reference = np.array(img)
|
172 |
+
h, w = target.shape
|
173 |
+
confidence = np.ones((h, w)) * 0.999
|
174 |
+
|
175 |
+
grid_params = {
|
176 |
+
'sigma_luma': sigma_luma, # Brightness bandwidth
|
177 |
+
'sigma_chroma': sigma_chroma, # Color bandwidth
|
178 |
+
'sigma_spatial': sigma_spatial # Spatial bandwidth
|
179 |
+
}
|
180 |
+
|
181 |
+
bs_params = {
|
182 |
+
'lam': 256, # The strength of the smoothness parameter
|
183 |
+
'A_diag_min': 1e-5, # Clamp the diagonal of the A diagonal in the Jacobi preconditioner.
|
184 |
+
'cg_tol': 1e-5, # The tolerance on the convergence in PCG
|
185 |
+
'cg_maxiter': 25 # The number of PCG iterations
|
186 |
+
}
|
187 |
+
|
188 |
+
grid = BilateralGrid(reference, **grid_params)
|
189 |
+
|
190 |
+
t = target.reshape(-1, 1).astype(np.double)
|
191 |
+
c = confidence.reshape(-1, 1).astype(np.double)
|
192 |
+
|
193 |
+
## output solver, which is a soft value
|
194 |
+
output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w))
|
195 |
+
|
196 |
+
binary_solver = ndimage.binary_fill_holes(output_solver > 0.5)
|
197 |
+
labeled, nr_objects = ndimage.label(binary_solver)
|
198 |
+
|
199 |
+
nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
|
200 |
+
pixel_order = np.argsort(nb_pixel)
|
201 |
+
try:
|
202 |
+
binary_solver = labeled == pixel_order[-2]
|
203 |
+
except:
|
204 |
+
binary_solver = np.ones((h, w), dtype=bool)
|
205 |
+
|
206 |
+
return output_solver, binary_solver
|
duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# augmentations
|
2 |
+
use_copy_paste: false
|
3 |
+
scale_range: [ 0.1, 1.0 ]
|
4 |
+
repeat_image: false
|
5 |
+
|
6 |
+
# base directories
|
7 |
+
dir_ckpt: "/users/gyungin/selfmask/ckpt" # "/work/gyungin/selfmask/ckpt"
|
8 |
+
dir_dataset: "/scratch/shared/beegfs/gyungin/datasets"
|
9 |
+
|
10 |
+
# clustering
|
11 |
+
k: [2, 3, 4]
|
12 |
+
clustering_mode: "spectral"
|
13 |
+
use_gpu: true # if you want to use gpu-accelerated code for clustering
|
14 |
+
scale_factor: 2 # "how much you want to upsample encoder features before clustering"
|
15 |
+
|
16 |
+
# dataset
|
17 |
+
dataset_name: "duts"
|
18 |
+
use_pseudo_masks: true
|
19 |
+
train_image_size: 224
|
20 |
+
eval_image_size: 224
|
21 |
+
n_percent: 100
|
22 |
+
n_copy_pastes: null
|
23 |
+
pseudo_masks_fp: "/users/gyungin/selfmask/datasets/swav_mocov2_dino_p16_k234.json"
|
24 |
+
|
25 |
+
# dataloader:
|
26 |
+
batch_size: 8
|
27 |
+
num_workers: 4
|
28 |
+
pin_memory: true
|
29 |
+
|
30 |
+
# networks
|
31 |
+
abs_2d_pe_init: false
|
32 |
+
arch: "vit_small"
|
33 |
+
lateral_connection: false
|
34 |
+
learnable_pixel_decoder: false # if False, use the bilinear interpolation
|
35 |
+
use_binary_classifier: true # if True, use a binary classifier to get an objectness for each query from transformer decoder
|
36 |
+
n_decoder_layers: 6
|
37 |
+
n_queries: 20
|
38 |
+
num_layers: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
39 |
+
patch_size: 8
|
40 |
+
training_method: "dino" # "supervised", "deit", "dino", "mocov2", "swav"
|
41 |
+
|
42 |
+
# objective
|
43 |
+
loss_every_decoder_layer: true
|
44 |
+
weight_dice_loss: 1.0
|
45 |
+
weight_focal_loss: 0.0
|
46 |
+
|
47 |
+
# optimizer
|
48 |
+
lr: 0.000006 # default: 0.00006
|
49 |
+
lr_warmup_duration: 0 # 5
|
50 |
+
momentum: 0.9
|
51 |
+
n_epochs: 12
|
52 |
+
weight_decay: 0.01
|
53 |
+
optimizer_type: "adamw"
|
54 |
+
|
55 |
+
# validation
|
56 |
+
benchmarks: null
|
networks/__init__.py
ADDED
File without changes
|
networks/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (146 Bytes). View file
|
|
networks/__pycache__/timm_deit.cpython-38.pyc
ADDED
Binary file (7.08 kB). View file
|
|
networks/__pycache__/timm_vit.cpython-38.pyc
ADDED
Binary file (27.7 kB). View file
|
|
networks/__pycache__/vision_transformer.cpython-38.pyc
ADDED
Binary file (15.8 kB). View file
|
|
networks/maskformer/__pycache__/maskformer.cpython-38.pyc
ADDED
Binary file (8.51 kB). View file
|
|
networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc
ADDED
Binary file (8.83 kB). View file
|
|
networks/maskformer/maskformer.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
from math import sqrt, log
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from networks.maskformer.transformer_decoder import TransformerDecoderLayer, TransformerDecoder
|
8 |
+
from utils import get_model
|
9 |
+
|
10 |
+
|
11 |
+
class MaskFormer(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
n_queries: int = 100,
|
15 |
+
arch: str = "vit_small",
|
16 |
+
patch_size: int = 8,
|
17 |
+
training_method: str = "dino",
|
18 |
+
n_decoder_layers: int = 6,
|
19 |
+
normalize_before: bool = False,
|
20 |
+
return_intermediate: bool = False,
|
21 |
+
learnable_pixel_decoder: bool = False,
|
22 |
+
lateral_connection: bool = False,
|
23 |
+
scale_factor: int = 2,
|
24 |
+
abs_2d_pe_init: bool = False,
|
25 |
+
use_binary_classifier: bool = False
|
26 |
+
):
|
27 |
+
"""Define a encoder and decoder along with queries to be learned through the decoder."""
|
28 |
+
super(MaskFormer, self).__init__()
|
29 |
+
|
30 |
+
if arch == "vit_small":
|
31 |
+
self.encoder = get_model(arch=arch, patch_size=patch_size, training_method=training_method)
|
32 |
+
n_dims: int = self.encoder.n_embs
|
33 |
+
n_heads: int = self.encoder.n_heads
|
34 |
+
mlp_ratio: int = self.encoder.mlp_ratio
|
35 |
+
else:
|
36 |
+
self.encoder = get_model(arch=arch, training_method=training_method)
|
37 |
+
n_dims_resnet: int = self.encoder.n_embs
|
38 |
+
n_dims: int = 384
|
39 |
+
n_heads: int = 6
|
40 |
+
mlp_ratio: int = 4
|
41 |
+
self.linear_layer = nn.Conv2d(n_dims_resnet, n_dims, kernel_size=1)
|
42 |
+
|
43 |
+
decoder_layer = TransformerDecoderLayer(
|
44 |
+
n_dims, n_heads, n_dims * mlp_ratio, 0., activation="relu", normalize_before=normalize_before
|
45 |
+
)
|
46 |
+
self.decoder = TransformerDecoder(
|
47 |
+
decoder_layer,
|
48 |
+
n_decoder_layers,
|
49 |
+
norm=nn.LayerNorm(n_dims),
|
50 |
+
return_intermediate=return_intermediate
|
51 |
+
)
|
52 |
+
|
53 |
+
self.query_embed = nn.Embedding(n_queries, n_dims).weight # initialized with gaussian(0, 1)
|
54 |
+
|
55 |
+
if use_binary_classifier:
|
56 |
+
# self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
|
57 |
+
# self.linear_classifier = nn.Linear(n_dims, 1)
|
58 |
+
self.ffn = MLP(n_dims, n_dims, 1, num_layers=3)
|
59 |
+
# self.norm = nn.LayerNorm(n_dims)
|
60 |
+
else:
|
61 |
+
# self.ffn = None
|
62 |
+
# self.linear_classifier = None
|
63 |
+
# self.norm = None
|
64 |
+
self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
|
65 |
+
self.linear_classifier = nn.Linear(n_dims, 2)
|
66 |
+
self.norm = nn.LayerNorm(n_dims)
|
67 |
+
|
68 |
+
self.arch = arch
|
69 |
+
self.use_binary_classifier = use_binary_classifier
|
70 |
+
self.lateral_connection = lateral_connection
|
71 |
+
self.learnable_pixel_decoder = learnable_pixel_decoder
|
72 |
+
self.scale_factor = scale_factor
|
73 |
+
|
74 |
+
# copy-pasted from https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
|
75 |
+
@staticmethod
|
76 |
+
def positional_encoding_2d(n_dims: int, height: int, width: int):
|
77 |
+
"""
|
78 |
+
:param n_dims: dimension of the model
|
79 |
+
:param height: height of the positions
|
80 |
+
:param width: width of the positions
|
81 |
+
:return: d_model*height*width position matrix
|
82 |
+
"""
|
83 |
+
if n_dims % 4 != 0:
|
84 |
+
raise ValueError("Cannot use sin/cos positional encoding with "
|
85 |
+
"odd dimension (got dim={:d})".format(n_dims))
|
86 |
+
pe = torch.zeros(n_dims, height, width)
|
87 |
+
# Each dimension use half of d_model
|
88 |
+
d_model = int(n_dims / 2)
|
89 |
+
div_term = torch.exp(torch.arange(0., d_model, 2) * -(log(10000.0) / d_model))
|
90 |
+
pos_w = torch.arange(0., width).unsqueeze(1)
|
91 |
+
pos_h = torch.arange(0., height).unsqueeze(1)
|
92 |
+
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
93 |
+
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
94 |
+
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
95 |
+
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
96 |
+
|
97 |
+
return pe
|
98 |
+
|
99 |
+
def forward_encoder(self, x: torch.Tensor):
|
100 |
+
"""
|
101 |
+
:param x: b x c x h x w
|
102 |
+
:return patch_tokens: b x depth x hw x n_dims
|
103 |
+
"""
|
104 |
+
if self.arch == "vit_small":
|
105 |
+
encoder_outputs: Dict[str, torch.Tensor] = self.encoder(x) # [:, 1:, :]
|
106 |
+
all_patch_tokens: List[torch.Tensor] = list()
|
107 |
+
for layer_name in [f"layer{num_layer}" for num_layer in range(1, self.encoder.depth + 1)]:
|
108 |
+
patch_tokens: torch.Tensor = encoder_outputs[layer_name][:, 1:, :] # b x hw x n_dims
|
109 |
+
all_patch_tokens.append(patch_tokens)
|
110 |
+
|
111 |
+
all_patch_tokens: torch.Tensor = torch.stack(all_patch_tokens, dim=0) # depth x b x hw x n_dims
|
112 |
+
all_patch_tokens = all_patch_tokens.permute(1, 0, 3, 2) # b x depth x n_dims x hw
|
113 |
+
return all_patch_tokens
|
114 |
+
else:
|
115 |
+
encoder_outputs = self.linear_layer(self.encoder(x)[-1]) # b x n_dims x h x w
|
116 |
+
return encoder_outputs
|
117 |
+
|
118 |
+
def forward_transformer_decoder(self, patch_tokens: torch.Tensor, skip_decoder: bool = False) -> torch.Tensor:
|
119 |
+
"""Forward transformer decoder given patch tokens from the encoder's last layer.
|
120 |
+
:param patch_tokens: b x n_dims x hw -> hw x b x n_dims
|
121 |
+
:param skip_decoder: if True, skip the decoder and produce mask predictions directly by matrix multiplication
|
122 |
+
between learnable queries and encoder features (i.e., patch tokens). This is for the purpose of an overfitting
|
123 |
+
experiment.
|
124 |
+
:return queries: n_queries x b x n_dims -> b x n_queries x n_dims or b x n_layers x n_queries x n_dims
|
125 |
+
"""
|
126 |
+
b = patch_tokens.shape[0]
|
127 |
+
patch_tokens = patch_tokens.permute(2, 0, 1) # b x n_dims x hw -> hw x b x n_dims
|
128 |
+
|
129 |
+
# n_queries x n_dims -> n_queries x b x n_dims
|
130 |
+
queries: torch.Tensor = self.query_embed.unsqueeze(1).repeat(1, b, 1)
|
131 |
+
queries: torch.Tensor = self.decoder.forward(
|
132 |
+
tgt=torch.zeros_like(queries),
|
133 |
+
memory=patch_tokens,
|
134 |
+
query_pos=queries
|
135 |
+
).squeeze(dim=0)
|
136 |
+
|
137 |
+
if len(queries.shape) == 3:
|
138 |
+
queries: torch.Tensor = queries.permute(1, 0, 2) # n_queries x b x n_dims -> b x n_queries x n_dims
|
139 |
+
elif len(queries.shape) == 4:
|
140 |
+
# n_layers x n_queries x b x n_dims -> b x n_layers x n_queries x n_dims
|
141 |
+
queries: torch.Tensor = queries.permute(2, 0, 1, 3)
|
142 |
+
return queries
|
143 |
+
|
144 |
+
def forward_pixel_decoder(self, patch_tokens: torch.Tensor, input_size=None):
|
145 |
+
""" Upsample patch tokens by self.scale_factor and produce mask predictions
|
146 |
+
:param patch_tokens: b (x depth) x n_dims x hw -> b (x depth) x n_dims x h x w
|
147 |
+
:param queries: b x n_queries x n_dims
|
148 |
+
:return mask_predictions: b x n_queries x h x w
|
149 |
+
"""
|
150 |
+
|
151 |
+
if input_size is None:
|
152 |
+
# assume square shape features
|
153 |
+
hw = patch_tokens.shape[-1]
|
154 |
+
h = w = int(sqrt(hw))
|
155 |
+
else:
|
156 |
+
# arbitrary shape features
|
157 |
+
h, w = input_size
|
158 |
+
patch_tokens = patch_tokens.view(*patch_tokens.shape[:-1], h, w)
|
159 |
+
|
160 |
+
assert len(patch_tokens.shape) == 4
|
161 |
+
patch_tokens = F.interpolate(patch_tokens, scale_factor=self.scale_factor, mode="bilinear")
|
162 |
+
return patch_tokens
|
163 |
+
|
164 |
+
def forward(self, x, encoder_only=False, skip_decoder: bool = False):
|
165 |
+
"""
|
166 |
+
x: b x c x h x w
|
167 |
+
patch_tokens: b x n_patches x n_dims -> n_patches x b x n_dims
|
168 |
+
query_emb: n_queries x n_dims -> n_queries x b x n_dims
|
169 |
+
"""
|
170 |
+
dict_outputs: dict = dict()
|
171 |
+
|
172 |
+
# b x depth x n_dims x hw (vit) or b x n_dims x h x w (resnet50)
|
173 |
+
features: torch.Tensor = self.forward_encoder(x)
|
174 |
+
|
175 |
+
if self.arch == "vit_small":
|
176 |
+
# extract the last layer for decoder input
|
177 |
+
last_layer_features: torch.Tensor = features[:, -1, ...] # b x n_dims x hw
|
178 |
+
else:
|
179 |
+
# transform the shape of the features to the one compatible with transformer decoder
|
180 |
+
b, n_dims, h, w = features.shape
|
181 |
+
last_layer_features: torch.Tensor = features.view(b, n_dims, h * w) # b x n_dims x hw
|
182 |
+
|
183 |
+
if encoder_only:
|
184 |
+
_h, _w = self.encoder.make_input_divisible(x).shape[-2:]
|
185 |
+
_h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
|
186 |
+
|
187 |
+
b, n_dims, hw = last_layer_features.shape
|
188 |
+
dict_outputs.update({"patch_tokens": last_layer_features.view(b, _h, _w, n_dims)})
|
189 |
+
return dict_outputs
|
190 |
+
|
191 |
+
# transformer decoder forward
|
192 |
+
queries: torch.Tensor = self.forward_transformer_decoder(
|
193 |
+
last_layer_features,
|
194 |
+
skip_decoder=skip_decoder
|
195 |
+
) # b x n_queries x n_dims or b x n_layers x n_queries x n_dims
|
196 |
+
|
197 |
+
# pixel decoder forward (upsampling the patch tokens by self.scale_factor)
|
198 |
+
if self.arch == "vit_small":
|
199 |
+
_h, _w = self.encoder.make_input_divisible(x).shape[-2:]
|
200 |
+
_h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
|
201 |
+
else:
|
202 |
+
_h, _w = h, w
|
203 |
+
features: torch.Tensor = self.forward_pixel_decoder(
|
204 |
+
patch_tokens=features if self.lateral_connection else last_layer_features,
|
205 |
+
input_size=(_h, _w)
|
206 |
+
) # b x n_dims x h x w
|
207 |
+
|
208 |
+
# queries: b x n_queries x n_dims or b x n_layers x n_queries x n_dims
|
209 |
+
# features: b x n_dims x h x w
|
210 |
+
# mask_pred: b x n_queries x h x w or b x n_layers x n_queries x h x w
|
211 |
+
if len(queries.shape) == 3:
|
212 |
+
mask_pred = torch.einsum("bqn,bnhw->bqhw", queries, features)
|
213 |
+
else:
|
214 |
+
if self.use_binary_classifier:
|
215 |
+
mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", queries, features))
|
216 |
+
else:
|
217 |
+
mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", self.ffn(queries), features))
|
218 |
+
|
219 |
+
if self.use_binary_classifier:
|
220 |
+
# queries: b x n_layers x n_queries x n_dims -> n_layers x b x n_queries x n_dims
|
221 |
+
queries = queries.permute(1, 0, 2, 3)
|
222 |
+
objectness: List[torch.Tensor] = list()
|
223 |
+
for n_layer, queries_per_layer in enumerate(queries): # queries_per_layer: b x n_queries x n_dims
|
224 |
+
# objectness_per_layer = self.linear_classifier(
|
225 |
+
# self.ffn(self.norm(queries_per_layer))
|
226 |
+
# ) # b x n_queries x 1
|
227 |
+
objectness_per_layer = self.ffn(queries_per_layer) # b x n_queries x 1
|
228 |
+
objectness.append(objectness_per_layer)
|
229 |
+
# n_layers x b x n_queries x 1 -> # b x n_layers x n_queries x 1
|
230 |
+
objectness: torch.Tensor = torch.stack(objectness).permute(1, 0, 2, 3)
|
231 |
+
dict_outputs.update({
|
232 |
+
"objectness": torch.sigmoid(objectness),
|
233 |
+
"mask_pred": mask_pred
|
234 |
+
})
|
235 |
+
|
236 |
+
return dict_outputs
|
237 |
+
|
238 |
+
|
239 |
+
class MLP(nn.Module):
|
240 |
+
"""Very simple multi-layer perceptron (also called FFN)"""
|
241 |
+
|
242 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
243 |
+
super().__init__()
|
244 |
+
self.num_layers = num_layers
|
245 |
+
h = [hidden_dim] * (num_layers - 1)
|
246 |
+
self.layers = nn.ModuleList(
|
247 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
248 |
+
)
|
249 |
+
|
250 |
+
def forward(self, x):
|
251 |
+
for i, layer in enumerate(self.layers):
|
252 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
253 |
+
return x
|
254 |
+
|
255 |
+
|
256 |
+
class UpsampleBlock(nn.Module):
|
257 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, n_groups=32, scale_factor=2):
|
258 |
+
super(UpsampleBlock, self).__init__()
|
259 |
+
self.block = nn.Sequential(
|
260 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
|
261 |
+
nn.GroupNorm(n_groups, out_channels),
|
262 |
+
nn.ReLU()
|
263 |
+
)
|
264 |
+
self.scale_factor = scale_factor
|
265 |
+
|
266 |
+
def forward(self, x):
|
267 |
+
return F.interpolate(self.block(x), scale_factor=self.scale_factor, mode="bilinear")
|
networks/maskformer/positional_embedding.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
|
3 |
+
"""
|
4 |
+
Various positional encodings for the transformer.
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
class PositionEmbeddingSine(nn.Module):
|
13 |
+
"""
|
14 |
+
This is a more standard version of the position embedding, very similar to the one
|
15 |
+
used by the Attention is all you need paper, generalized to work on images.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
19 |
+
super().__init__()
|
20 |
+
self.num_pos_feats = num_pos_feats
|
21 |
+
self.temperature = temperature
|
22 |
+
self.normalize = normalize
|
23 |
+
if scale is not None and normalize is False:
|
24 |
+
raise ValueError("normalize should be True if scale is passed")
|
25 |
+
if scale is None:
|
26 |
+
scale = 2 * math.pi
|
27 |
+
self.scale = scale
|
28 |
+
|
29 |
+
def forward(self, x, mask=None):
|
30 |
+
if mask is None:
|
31 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
32 |
+
not_mask = ~mask
|
33 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
34 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
35 |
+
if self.normalize:
|
36 |
+
eps = 1e-6
|
37 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
38 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
39 |
+
|
40 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
41 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
42 |
+
|
43 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
44 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
45 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
46 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
47 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
48 |
+
return pos
|
networks/maskformer/transformer_decoder.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
|
3 |
+
"""
|
4 |
+
Transformer class.
|
5 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
6 |
+
* positional encodings are passed in MHattention
|
7 |
+
* extra LN at the end of encoder is removed
|
8 |
+
* decoder returns a stack of activations from all decoding layers
|
9 |
+
"""
|
10 |
+
import copy
|
11 |
+
from typing import List, Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch import Tensor, nn
|
16 |
+
|
17 |
+
|
18 |
+
class Transformer(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
d_model=512,
|
22 |
+
nhead=8,
|
23 |
+
num_encoder_layers=6,
|
24 |
+
num_decoder_layers=6,
|
25 |
+
dim_feedforward=2048,
|
26 |
+
dropout=0.1,
|
27 |
+
activation="relu", # noel - dino used GeLU
|
28 |
+
normalize_before=False,
|
29 |
+
return_intermediate_dec=False,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
encoder_layer = TransformerEncoderLayer(
|
34 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
35 |
+
)
|
36 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
37 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
38 |
+
|
39 |
+
decoder_layer = TransformerDecoderLayer(
|
40 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
41 |
+
)
|
42 |
+
decoder_norm = nn.LayerNorm(d_model)
|
43 |
+
self.decoder = TransformerDecoder(
|
44 |
+
decoder_layer,
|
45 |
+
num_decoder_layers,
|
46 |
+
decoder_norm,
|
47 |
+
return_intermediate=return_intermediate_dec,
|
48 |
+
)
|
49 |
+
|
50 |
+
self._reset_parameters()
|
51 |
+
|
52 |
+
self.d_model = d_model
|
53 |
+
self.nhead = nhead
|
54 |
+
|
55 |
+
def _reset_parameters(self):
|
56 |
+
for p in self.parameters():
|
57 |
+
if p.dim() > 1:
|
58 |
+
nn.init.xavier_uniform_(p)
|
59 |
+
|
60 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
61 |
+
# flatten NxCxHxW to HWxNxC
|
62 |
+
bs, c, h, w = src.shape
|
63 |
+
src = src.flatten(2).permute(2, 0, 1)
|
64 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
65 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
66 |
+
if mask is not None:
|
67 |
+
mask = mask.flatten(1)
|
68 |
+
|
69 |
+
tgt = torch.zeros_like(query_embed)
|
70 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
71 |
+
hs = self.decoder(
|
72 |
+
tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
|
73 |
+
)
|
74 |
+
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
|
75 |
+
|
76 |
+
|
77 |
+
class TransformerEncoder(nn.Module):
|
78 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
79 |
+
super().__init__()
|
80 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
81 |
+
self.num_layers = num_layers
|
82 |
+
self.norm = norm
|
83 |
+
|
84 |
+
def forward(
|
85 |
+
self,
|
86 |
+
src,
|
87 |
+
mask: Optional[Tensor] = None,
|
88 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
89 |
+
pos: Optional[Tensor] = None,
|
90 |
+
):
|
91 |
+
output = src
|
92 |
+
|
93 |
+
for layer in self.layers:
|
94 |
+
output = layer(
|
95 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
|
96 |
+
)
|
97 |
+
|
98 |
+
if self.norm is not None:
|
99 |
+
output = self.norm(output)
|
100 |
+
|
101 |
+
return output
|
102 |
+
|
103 |
+
|
104 |
+
class TransformerDecoder(nn.Module):
|
105 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
106 |
+
super().__init__()
|
107 |
+
self.layers: nn.ModuleList = _get_clones(decoder_layer, num_layers)
|
108 |
+
self.num_layers: int = num_layers
|
109 |
+
self.norm = norm
|
110 |
+
self.return_intermediate: bool = return_intermediate
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
tgt,
|
115 |
+
memory,
|
116 |
+
tgt_mask: Optional[Tensor] = None,
|
117 |
+
memory_mask: Optional[Tensor] = None,
|
118 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
119 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
120 |
+
pos: Optional[Tensor] = None,
|
121 |
+
query_pos: Optional[Tensor] = None,
|
122 |
+
):
|
123 |
+
output = tgt
|
124 |
+
|
125 |
+
intermediate = []
|
126 |
+
|
127 |
+
for layer in self.layers:
|
128 |
+
output = layer(
|
129 |
+
output,
|
130 |
+
memory,
|
131 |
+
tgt_mask=tgt_mask,
|
132 |
+
memory_mask=memory_mask,
|
133 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
134 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
135 |
+
pos=pos,
|
136 |
+
query_pos=query_pos,
|
137 |
+
)
|
138 |
+
if self.return_intermediate:
|
139 |
+
intermediate.append(self.norm(output))
|
140 |
+
|
141 |
+
if self.norm is not None:
|
142 |
+
output = self.norm(output)
|
143 |
+
if self.return_intermediate:
|
144 |
+
intermediate.pop()
|
145 |
+
intermediate.append(output)
|
146 |
+
|
147 |
+
if self.return_intermediate:
|
148 |
+
return torch.stack(intermediate)
|
149 |
+
|
150 |
+
return output.unsqueeze(0)
|
151 |
+
|
152 |
+
|
153 |
+
class TransformerEncoderLayer(nn.Module):
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
d_model,
|
157 |
+
nhead,
|
158 |
+
dim_feedforward=2048,
|
159 |
+
dropout=0.1,
|
160 |
+
activation="relu",
|
161 |
+
normalize_before=False,
|
162 |
+
):
|
163 |
+
super().__init__()
|
164 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
165 |
+
# Implementation of Feedforward model
|
166 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
167 |
+
self.dropout = nn.Dropout(dropout)
|
168 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
169 |
+
|
170 |
+
self.norm1 = nn.LayerNorm(d_model)
|
171 |
+
self.norm2 = nn.LayerNorm(d_model)
|
172 |
+
self.dropout1 = nn.Dropout(dropout)
|
173 |
+
self.dropout2 = nn.Dropout(dropout)
|
174 |
+
|
175 |
+
self.activation = _get_activation_fn(activation)
|
176 |
+
self.normalize_before = normalize_before
|
177 |
+
|
178 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
179 |
+
return tensor if pos is None else tensor + pos
|
180 |
+
|
181 |
+
def forward_post(
|
182 |
+
self,
|
183 |
+
src,
|
184 |
+
src_mask: Optional[Tensor] = None,
|
185 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
186 |
+
pos: Optional[Tensor] = None,
|
187 |
+
):
|
188 |
+
q = k = self.with_pos_embed(src, pos)
|
189 |
+
src2 = self.self_attn(
|
190 |
+
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
191 |
+
)[0]
|
192 |
+
src = src + self.dropout1(src2)
|
193 |
+
src = self.norm1(src)
|
194 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
195 |
+
src = src + self.dropout2(src2)
|
196 |
+
src = self.norm2(src)
|
197 |
+
return src
|
198 |
+
|
199 |
+
def forward_pre(
|
200 |
+
self,
|
201 |
+
src,
|
202 |
+
src_mask: Optional[Tensor] = None,
|
203 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
204 |
+
pos: Optional[Tensor] = None,
|
205 |
+
):
|
206 |
+
src2 = self.norm1(src)
|
207 |
+
q = k = self.with_pos_embed(src2, pos)
|
208 |
+
src2 = self.self_attn(
|
209 |
+
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
210 |
+
)[0]
|
211 |
+
src = src + self.dropout1(src2)
|
212 |
+
src2 = self.norm2(src)
|
213 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
214 |
+
src = src + self.dropout2(src2)
|
215 |
+
return src
|
216 |
+
|
217 |
+
def forward(
|
218 |
+
self,
|
219 |
+
src,
|
220 |
+
src_mask: Optional[Tensor] = None,
|
221 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
222 |
+
pos: Optional[Tensor] = None,
|
223 |
+
):
|
224 |
+
if self.normalize_before:
|
225 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
226 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
227 |
+
|
228 |
+
|
229 |
+
class TransformerDecoderLayer(nn.Module):
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
d_model,
|
233 |
+
nhead,
|
234 |
+
dim_feedforward=2048,
|
235 |
+
dropout=0.1,
|
236 |
+
activation="relu",
|
237 |
+
normalize_before=False,
|
238 |
+
):
|
239 |
+
super().__init__()
|
240 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
241 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
242 |
+
# Implementation of Feedforward model
|
243 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
244 |
+
self.dropout = nn.Dropout(dropout)
|
245 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
246 |
+
|
247 |
+
self.norm1 = nn.LayerNorm(d_model)
|
248 |
+
self.norm2 = nn.LayerNorm(d_model)
|
249 |
+
self.norm3 = nn.LayerNorm(d_model)
|
250 |
+
self.dropout1 = nn.Dropout(dropout)
|
251 |
+
self.dropout2 = nn.Dropout(dropout)
|
252 |
+
self.dropout3 = nn.Dropout(dropout)
|
253 |
+
|
254 |
+
self.activation = _get_activation_fn(activation)
|
255 |
+
self.normalize_before = normalize_before
|
256 |
+
|
257 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
258 |
+
return tensor if pos is None else tensor + pos
|
259 |
+
|
260 |
+
def forward_post(
|
261 |
+
self,
|
262 |
+
tgt,
|
263 |
+
memory,
|
264 |
+
tgt_mask: Optional[Tensor] = None,
|
265 |
+
memory_mask: Optional[Tensor] = None,
|
266 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
267 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
268 |
+
pos: Optional[Tensor] = None,
|
269 |
+
query_pos: Optional[Tensor] = None,
|
270 |
+
):
|
271 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
272 |
+
|
273 |
+
tgt2 = self.self_attn(
|
274 |
+
q,
|
275 |
+
k,
|
276 |
+
value=tgt,
|
277 |
+
attn_mask=tgt_mask,
|
278 |
+
key_padding_mask=tgt_key_padding_mask
|
279 |
+
)[0]
|
280 |
+
tgt = tgt + self.dropout1(tgt2)
|
281 |
+
tgt = self.norm1(tgt)
|
282 |
+
|
283 |
+
tgt2 = self.multihead_attn(
|
284 |
+
query=self.with_pos_embed(tgt, query_pos),
|
285 |
+
key=self.with_pos_embed(memory, pos),
|
286 |
+
value=memory,
|
287 |
+
attn_mask=memory_mask,
|
288 |
+
key_padding_mask=memory_key_padding_mask,
|
289 |
+
)[0]
|
290 |
+
tgt = tgt + self.dropout2(tgt2)
|
291 |
+
tgt = self.norm2(tgt)
|
292 |
+
|
293 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
294 |
+
tgt = tgt + self.dropout3(tgt2)
|
295 |
+
tgt = self.norm3(tgt)
|
296 |
+
|
297 |
+
return tgt
|
298 |
+
|
299 |
+
def forward_pre(
|
300 |
+
self,
|
301 |
+
tgt,
|
302 |
+
memory,
|
303 |
+
tgt_mask: Optional[Tensor] = None,
|
304 |
+
memory_mask: Optional[Tensor] = None,
|
305 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
306 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
307 |
+
pos: Optional[Tensor] = None,
|
308 |
+
query_pos: Optional[Tensor] = None,
|
309 |
+
):
|
310 |
+
tgt2 = self.norm1(tgt)
|
311 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
312 |
+
tgt2 = self.self_attn(
|
313 |
+
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
314 |
+
)[0]
|
315 |
+
tgt = tgt + self.dropout1(tgt2)
|
316 |
+
tgt2 = self.norm2(tgt)
|
317 |
+
tgt2 = self.multihead_attn(
|
318 |
+
query=self.with_pos_embed(tgt2, query_pos),
|
319 |
+
key=self.with_pos_embed(memory, pos),
|
320 |
+
value=memory,
|
321 |
+
attn_mask=memory_mask,
|
322 |
+
key_padding_mask=memory_key_padding_mask,
|
323 |
+
)[0]
|
324 |
+
tgt = tgt + self.dropout2(tgt2)
|
325 |
+
tgt2 = self.norm3(tgt)
|
326 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
327 |
+
tgt = tgt + self.dropout3(tgt2)
|
328 |
+
return tgt
|
329 |
+
|
330 |
+
def forward(
|
331 |
+
self,
|
332 |
+
tgt,
|
333 |
+
memory,
|
334 |
+
tgt_mask: Optional[Tensor] = None,
|
335 |
+
memory_mask: Optional[Tensor] = None,
|
336 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
337 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
338 |
+
pos: Optional[Tensor] = None,
|
339 |
+
query_pos: Optional[Tensor] = None,
|
340 |
+
):
|
341 |
+
if self.normalize_before:
|
342 |
+
return self.forward_pre(
|
343 |
+
tgt,
|
344 |
+
memory,
|
345 |
+
tgt_mask,
|
346 |
+
memory_mask,
|
347 |
+
tgt_key_padding_mask,
|
348 |
+
memory_key_padding_mask,
|
349 |
+
pos,
|
350 |
+
query_pos,
|
351 |
+
)
|
352 |
+
return self.forward_post(
|
353 |
+
tgt,
|
354 |
+
memory,
|
355 |
+
tgt_mask,
|
356 |
+
memory_mask,
|
357 |
+
tgt_key_padding_mask,
|
358 |
+
memory_key_padding_mask,
|
359 |
+
pos,
|
360 |
+
query_pos,
|
361 |
+
)
|
362 |
+
|
363 |
+
|
364 |
+
def _get_clones(module, N):
|
365 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
366 |
+
|
367 |
+
|
368 |
+
def _get_activation_fn(activation):
|
369 |
+
"""Return an activation function given a string"""
|
370 |
+
if activation == "relu":
|
371 |
+
return F.relu
|
372 |
+
if activation == "gelu":
|
373 |
+
return F.gelu
|
374 |
+
if activation == "glu":
|
375 |
+
return F.glu
|
376 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
networks/module_helper.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Author: Donny You (youansheng@gmail.com)
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
try:
|
10 |
+
from urllib import urlretrieve
|
11 |
+
except ImportError:
|
12 |
+
from urllib.request import urlretrieve
|
13 |
+
|
14 |
+
|
15 |
+
class FixedBatchNorm(nn.BatchNorm2d):
|
16 |
+
def forward(self, input):
|
17 |
+
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps)
|
18 |
+
|
19 |
+
|
20 |
+
class ModuleHelper(object):
|
21 |
+
@staticmethod
|
22 |
+
def BNReLU(num_features, norm_type=None, **kwargs):
|
23 |
+
if norm_type == 'batchnorm':
|
24 |
+
return nn.Sequential(
|
25 |
+
nn.BatchNorm2d(num_features, **kwargs),
|
26 |
+
nn.ReLU()
|
27 |
+
)
|
28 |
+
elif norm_type == 'encsync_batchnorm':
|
29 |
+
from encoding.nn import BatchNorm2d
|
30 |
+
return nn.Sequential(
|
31 |
+
BatchNorm2d(num_features, **kwargs),
|
32 |
+
nn.ReLU()
|
33 |
+
)
|
34 |
+
elif norm_type == 'instancenorm':
|
35 |
+
return nn.Sequential(
|
36 |
+
nn.InstanceNorm2d(num_features, **kwargs),
|
37 |
+
nn.ReLU()
|
38 |
+
)
|
39 |
+
elif norm_type == 'fixed_batchnorm':
|
40 |
+
return nn.Sequential(
|
41 |
+
FixedBatchNorm(num_features, **kwargs),
|
42 |
+
nn.ReLU()
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
raise ValueError('Not support BN type: {}.'.format(norm_type))
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def BatchNorm3d(norm_type=None, ret_cls=False):
|
49 |
+
if norm_type == 'batchnorm':
|
50 |
+
return nn.BatchNorm3d
|
51 |
+
elif norm_type == 'encsync_batchnorm':
|
52 |
+
from encoding.nn import BatchNorm3d
|
53 |
+
return BatchNorm3d
|
54 |
+
elif norm_type == 'instancenorm':
|
55 |
+
return nn.InstanceNorm3d
|
56 |
+
else:
|
57 |
+
raise ValueError('Not support BN type: {}.'.format(norm_type))
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def BatchNorm2d(norm_type=None, ret_cls=False):
|
61 |
+
if norm_type == 'batchnorm':
|
62 |
+
return nn.BatchNorm2d
|
63 |
+
elif norm_type == 'encsync_batchnorm':
|
64 |
+
from encoding.nn import BatchNorm2d
|
65 |
+
return BatchNorm2d
|
66 |
+
|
67 |
+
elif norm_type == 'instancenorm':
|
68 |
+
return nn.InstanceNorm2d
|
69 |
+
else:
|
70 |
+
raise ValueError('Not support BN type: {}.'.format(norm_type))
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def BatchNorm1d(norm_type=None, ret_cls=False):
|
74 |
+
if norm_type == 'batchnorm':
|
75 |
+
return nn.BatchNorm1d
|
76 |
+
elif norm_type == 'encsync_batchnorm':
|
77 |
+
from encoding.nn import BatchNorm1d
|
78 |
+
return BatchNorm1d
|
79 |
+
elif norm_type == 'instancenorm':
|
80 |
+
return nn.InstanceNorm1d
|
81 |
+
else:
|
82 |
+
raise ValueError('Not support BN type: {}.'.format(norm_type))
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def load_model(model, pretrained=None, all_match=True, map_location='cpu'):
|
86 |
+
if pretrained is None:
|
87 |
+
return model
|
88 |
+
|
89 |
+
if not os.path.exists(pretrained):
|
90 |
+
pretrained = pretrained.replace("..", "/home/gishin-temp/projects/open_set/segmentation")
|
91 |
+
if os.path.exists(pretrained):
|
92 |
+
pass
|
93 |
+
else:
|
94 |
+
raise FileNotFoundError('{} not exists.'.format(pretrained))
|
95 |
+
|
96 |
+
print('Loading pretrained model:{}'.format(pretrained))
|
97 |
+
if all_match:
|
98 |
+
pretrained_dict = torch.load(pretrained, map_location=map_location)
|
99 |
+
model_dict = model.state_dict()
|
100 |
+
load_dict = dict()
|
101 |
+
for k, v in pretrained_dict.items():
|
102 |
+
if 'prefix.{}'.format(k) in model_dict:
|
103 |
+
load_dict['prefix.{}'.format(k)] = v
|
104 |
+
else:
|
105 |
+
load_dict[k] = v
|
106 |
+
model.load_state_dict(load_dict)
|
107 |
+
|
108 |
+
else:
|
109 |
+
pretrained_dict = torch.load(pretrained)
|
110 |
+
model_dict = model.state_dict()
|
111 |
+
load_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
112 |
+
print('Matched Keys: {}'.format(load_dict.keys()))
|
113 |
+
model_dict.update(load_dict)
|
114 |
+
model.load_state_dict(model_dict)
|
115 |
+
|
116 |
+
return model
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def load_url(url, map_location=None):
|
120 |
+
model_dir = os.path.join('~', '.TorchCV', 'model')
|
121 |
+
if not os.path.exists(model_dir):
|
122 |
+
os.makedirs(model_dir)
|
123 |
+
|
124 |
+
filename = url.split('/')[-1]
|
125 |
+
cached_file = os.path.join(model_dir, filename)
|
126 |
+
if not os.path.exists(cached_file):
|
127 |
+
print('Downloading: "{}" to {}\n'.format(url, cached_file))
|
128 |
+
urlretrieve(url, cached_file)
|
129 |
+
|
130 |
+
print('Loading pretrained model:{}'.format(cached_file))
|
131 |
+
return torch.load(cached_file, map_location=map_location)
|
132 |
+
|
133 |
+
@staticmethod
|
134 |
+
def constant_init(module, val, bias=0):
|
135 |
+
nn.init.constant_(module.weight, val)
|
136 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
137 |
+
nn.init.constant_(module.bias, bias)
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def xavier_init(module, gain=1, bias=0, distribution='normal'):
|
141 |
+
assert distribution in ['uniform', 'normal']
|
142 |
+
if distribution == 'uniform':
|
143 |
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
144 |
+
else:
|
145 |
+
nn.init.xavier_normal_(module.weight, gain=gain)
|
146 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
147 |
+
nn.init.constant_(module.bias, bias)
|
148 |
+
|
149 |
+
@staticmethod
|
150 |
+
def normal_init(module, mean=0, std=1, bias=0):
|
151 |
+
nn.init.normal_(module.weight, mean, std)
|
152 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
153 |
+
nn.init.constant_(module.bias, bias)
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def uniform_init(module, a=0, b=1, bias=0):
|
157 |
+
nn.init.uniform_(module.weight, a, b)
|
158 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
159 |
+
nn.init.constant_(module.bias, bias)
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def kaiming_init(module,
|
163 |
+
mode='fan_in',
|
164 |
+
nonlinearity='leaky_relu',
|
165 |
+
bias=0,
|
166 |
+
distribution='normal'):
|
167 |
+
assert distribution in ['uniform', 'normal']
|
168 |
+
if distribution == 'uniform':
|
169 |
+
nn.init.kaiming_uniform_(
|
170 |
+
module.weight, mode=mode, nonlinearity=nonlinearity)
|
171 |
+
else:
|
172 |
+
nn.init.kaiming_normal_(
|
173 |
+
module.weight, mode=mode, nonlinearity=nonlinearity)
|
174 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
175 |
+
nn.init.constant_(module.bias, bias)
|
176 |
+
|
networks/resnet.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from .resnet_backbone import ResNetBackbone
|
6 |
+
|
7 |
+
|
8 |
+
class ResNet50(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
weight_type: str = "supervised",
|
12 |
+
use_dilated_resnet: bool = True
|
13 |
+
):
|
14 |
+
super(ResNet50, self).__init__()
|
15 |
+
self.network = ResNetBackbone(backbone=f"resnet50{'_dilated8' if use_dilated_resnet else ''}", pretrained=None)
|
16 |
+
self.n_embs = self.network.num_features
|
17 |
+
self.use_dilated_resnet = use_dilated_resnet
|
18 |
+
self._load_pretrained(weight_type)
|
19 |
+
|
20 |
+
def _load_pretrained(self, training_method: str) -> None:
|
21 |
+
curr_state_dict = self.network.state_dict()
|
22 |
+
if training_method == "mocov2":
|
23 |
+
state_dict = torch.load("/users/gyungin/sos/networks/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"]
|
24 |
+
|
25 |
+
for k in list(state_dict.keys()):
|
26 |
+
if any([k.find(w) != -1 for w in ("fc.0", "fc.2")]):
|
27 |
+
state_dict.pop(k)
|
28 |
+
|
29 |
+
elif training_method == "swav":
|
30 |
+
state_dict = torch.load("/users/gyungin/sos/networks/pretrained/swav_800ep_pretrain.pth.tar")
|
31 |
+
for k in list(state_dict.keys()):
|
32 |
+
if any([k.find(w) != -1 for w in ("projection_head", "prototypes")]):
|
33 |
+
state_dict.pop(k)
|
34 |
+
|
35 |
+
elif training_method == "supervised":
|
36 |
+
# Note - pytorch resnet50 model doesn't have num_batches_tracked layers. Need to know why.
|
37 |
+
# for k in list(curr_state_dict.keys()):
|
38 |
+
# if k.find("num_batches_tracked") != -1:
|
39 |
+
# curr_state_dict.pop(k)
|
40 |
+
# state_dict = torch.load("../networks/pretrained/resnet50-pytorch.pth")
|
41 |
+
|
42 |
+
from torchvision.models.resnet import resnet50
|
43 |
+
resnet50_supervised = resnet50(True, True)
|
44 |
+
state_dict = resnet50_supervised.state_dict()
|
45 |
+
for k in list(state_dict.keys()):
|
46 |
+
if any([k.find(w) != -1 for w in ("fc.weight", "fc.bias")]):
|
47 |
+
state_dict.pop(k)
|
48 |
+
|
49 |
+
assert len(curr_state_dict) == len(state_dict), f"# layers are different: {len(curr_state_dict)} != {len(state_dict)}"
|
50 |
+
for k_curr, k in zip(curr_state_dict.keys(), state_dict.keys()):
|
51 |
+
curr_state_dict[k_curr].copy_(state_dict[k])
|
52 |
+
print(f"ResNet50{' (dilated)' if self.use_dilated_resnet else ''} intialised with {training_method} weights is loaded.")
|
53 |
+
return
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return self.network(x)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
resnet = ResNet50("mocov2")
|
networks/resnet_backbone.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Author: Donny You(youansheng@gmail.com)
|
4 |
+
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
from networks.resnet_models import *
|
8 |
+
|
9 |
+
|
10 |
+
class NormalResnetBackbone(nn.Module):
|
11 |
+
def __init__(self, orig_resnet):
|
12 |
+
super(NormalResnetBackbone, self).__init__()
|
13 |
+
|
14 |
+
self.num_features = 2048
|
15 |
+
# take pretrained resnet, except AvgPool and FC
|
16 |
+
self.prefix = orig_resnet.prefix
|
17 |
+
self.maxpool = orig_resnet.maxpool
|
18 |
+
self.layer1 = orig_resnet.layer1
|
19 |
+
self.layer2 = orig_resnet.layer2
|
20 |
+
self.layer3 = orig_resnet.layer3
|
21 |
+
self.layer4 = orig_resnet.layer4
|
22 |
+
|
23 |
+
def get_num_features(self):
|
24 |
+
return self.num_features
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
tuple_features = list()
|
28 |
+
x = self.prefix(x)
|
29 |
+
x = self.maxpool(x)
|
30 |
+
x = self.layer1(x)
|
31 |
+
tuple_features.append(x)
|
32 |
+
x = self.layer2(x)
|
33 |
+
tuple_features.append(x)
|
34 |
+
x = self.layer3(x)
|
35 |
+
tuple_features.append(x)
|
36 |
+
x = self.layer4(x)
|
37 |
+
tuple_features.append(x)
|
38 |
+
|
39 |
+
return tuple_features
|
40 |
+
|
41 |
+
|
42 |
+
class DilatedResnetBackbone(nn.Module):
|
43 |
+
def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)):
|
44 |
+
super(DilatedResnetBackbone, self).__init__()
|
45 |
+
|
46 |
+
self.num_features = 2048
|
47 |
+
from functools import partial
|
48 |
+
|
49 |
+
if dilate_scale == 8:
|
50 |
+
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
|
51 |
+
if multi_grid is None:
|
52 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
|
53 |
+
else:
|
54 |
+
for i, r in enumerate(multi_grid):
|
55 |
+
orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r)))
|
56 |
+
|
57 |
+
elif dilate_scale == 16:
|
58 |
+
if multi_grid is None:
|
59 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
|
60 |
+
else:
|
61 |
+
for i, r in enumerate(multi_grid):
|
62 |
+
orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r)))
|
63 |
+
|
64 |
+
# Take pretrained resnet, except AvgPool and FC
|
65 |
+
self.prefix = orig_resnet.prefix
|
66 |
+
self.maxpool = orig_resnet.maxpool
|
67 |
+
self.layer1 = orig_resnet.layer1
|
68 |
+
self.layer2 = orig_resnet.layer2
|
69 |
+
self.layer3 = orig_resnet.layer3
|
70 |
+
self.layer4 = orig_resnet.layer4
|
71 |
+
|
72 |
+
def _nostride_dilate(self, m, dilate):
|
73 |
+
classname = m.__class__.__name__
|
74 |
+
if classname.find('Conv') != -1:
|
75 |
+
# the convolution with stride
|
76 |
+
if m.stride == (2, 2):
|
77 |
+
m.stride = (1, 1)
|
78 |
+
if m.kernel_size == (3, 3):
|
79 |
+
m.dilation = (dilate // 2, dilate // 2)
|
80 |
+
m.padding = (dilate // 2, dilate // 2)
|
81 |
+
# other convoluions
|
82 |
+
else:
|
83 |
+
if m.kernel_size == (3, 3):
|
84 |
+
m.dilation = (dilate, dilate)
|
85 |
+
m.padding = (dilate, dilate)
|
86 |
+
|
87 |
+
def get_num_features(self):
|
88 |
+
return self.num_features
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
tuple_features = list()
|
92 |
+
|
93 |
+
x = self.prefix(x)
|
94 |
+
x = self.maxpool(x)
|
95 |
+
|
96 |
+
x = self.layer1(x)
|
97 |
+
tuple_features.append(x)
|
98 |
+
x = self.layer2(x)
|
99 |
+
tuple_features.append(x)
|
100 |
+
x = self.layer3(x)
|
101 |
+
tuple_features.append(x)
|
102 |
+
x = self.layer4(x)
|
103 |
+
tuple_features.append(x)
|
104 |
+
|
105 |
+
return tuple_features
|
106 |
+
|
107 |
+
|
108 |
+
def ResNetBackbone(backbone=None, width_multiplier=1.0, pretrained=None, multi_grid=None, norm_type='batchnorm'):
|
109 |
+
arch = backbone
|
110 |
+
|
111 |
+
if arch == 'resnet18':
|
112 |
+
orig_resnet = resnet18(pretrained=pretrained)
|
113 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
114 |
+
arch_net.num_features = 512
|
115 |
+
|
116 |
+
elif arch == 'resnet18_dilated8':
|
117 |
+
orig_resnet = resnet18(pretrained=pretrained)
|
118 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
119 |
+
arch_net.num_features = 512
|
120 |
+
|
121 |
+
elif arch == 'resnet34':
|
122 |
+
orig_resnet = resnet34(pretrained=pretrained)
|
123 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
124 |
+
arch_net.num_features = 512
|
125 |
+
|
126 |
+
elif arch == 'resnet34_dilated8':
|
127 |
+
orig_resnet = resnet34(pretrained=pretrained)
|
128 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
129 |
+
arch_net.num_features = 512
|
130 |
+
|
131 |
+
elif arch == 'resnet34_dilated16':
|
132 |
+
orig_resnet = resnet34(pretrained=pretrained)
|
133 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
|
134 |
+
arch_net.num_features = 512
|
135 |
+
|
136 |
+
elif arch == 'resnet50':
|
137 |
+
orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
|
138 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
139 |
+
|
140 |
+
elif arch == 'resnet50_dilated8':
|
141 |
+
orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
|
142 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
143 |
+
|
144 |
+
elif arch == 'resnet50_dilated16':
|
145 |
+
orig_resnet = resnet50(pretrained=pretrained)
|
146 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
|
147 |
+
|
148 |
+
elif arch == 'deepbase_resnet50':
|
149 |
+
if pretrained:
|
150 |
+
pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
|
151 |
+
orig_resnet = deepbase_resnet50(pretrained=pretrained)
|
152 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
153 |
+
|
154 |
+
elif arch == 'deepbase_resnet50_dilated8':
|
155 |
+
if pretrained:
|
156 |
+
pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
|
157 |
+
# pretrained = "/home/gishin/Projects/DeepLearning/Oxford/cct/models/backbones/pretrained/3x3resnet50-imagenet.pth"
|
158 |
+
orig_resnet = deepbase_resnet50(pretrained=pretrained)
|
159 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
160 |
+
|
161 |
+
elif arch == 'deepbase_resnet50_dilated16':
|
162 |
+
orig_resnet = deepbase_resnet50(pretrained=pretrained)
|
163 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
|
164 |
+
|
165 |
+
elif arch == 'resnet101':
|
166 |
+
orig_resnet = resnet101(pretrained=pretrained)
|
167 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
168 |
+
|
169 |
+
elif arch == 'resnet101_dilated8':
|
170 |
+
orig_resnet = resnet101(pretrained=pretrained)
|
171 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
172 |
+
|
173 |
+
elif arch == 'resnet101_dilated16':
|
174 |
+
orig_resnet = resnet101(pretrained=pretrained)
|
175 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
|
176 |
+
|
177 |
+
elif arch == 'deepbase_resnet101':
|
178 |
+
orig_resnet = deepbase_resnet101(pretrained=pretrained)
|
179 |
+
arch_net = NormalResnetBackbone(orig_resnet)
|
180 |
+
|
181 |
+
elif arch == 'deepbase_resnet101_dilated8':
|
182 |
+
if pretrained:
|
183 |
+
pretrained = 'backbones/backbones/pretrained/3x3resnet101-imagenet.pth'
|
184 |
+
orig_resnet = deepbase_resnet101(pretrained=pretrained)
|
185 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
|
186 |
+
|
187 |
+
elif arch == 'deepbase_resnet101_dilated16':
|
188 |
+
orig_resnet = deepbase_resnet101(pretrained=pretrained)
|
189 |
+
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
|
190 |
+
|
191 |
+
else:
|
192 |
+
raise Exception('Architecture undefined!')
|
193 |
+
|
194 |
+
return arch_net
|
networks/resnet_models.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Author: Donny You(youansheng@gmail.com)
|
4 |
+
import math
|
5 |
+
import torch.nn as nn
|
6 |
+
from collections import OrderedDict
|
7 |
+
from .module_helper import ModuleHelper
|
8 |
+
|
9 |
+
|
10 |
+
model_urls = {
|
11 |
+
'resnet18': 'https://download.pytorch.org/backbones/resnet18-5c106cde.pth',
|
12 |
+
'resnet34': 'https://download.pytorch.org/backbones/resnet34-333f7ec4.pth',
|
13 |
+
'resnet50': 'https://download.pytorch.org/backbones/resnet50-19c8e357.pth',
|
14 |
+
'resnet101': 'https://download.pytorch.org/backbones/resnet101-5d3b4d8f.pth',
|
15 |
+
'resnet152': 'https://download.pytorch.org/backbones/resnet152-b121ed2d.pth'
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
20 |
+
"3x3 convolution with padding"
|
21 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
22 |
+
padding=1, bias=False)
|
23 |
+
|
24 |
+
|
25 |
+
class BasicBlock(nn.Module):
|
26 |
+
expansion = 1
|
27 |
+
|
28 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
|
29 |
+
super(BasicBlock, self).__init__()
|
30 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
31 |
+
self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
|
32 |
+
self.relu = nn.ReLU(inplace=True)
|
33 |
+
self.conv2 = conv3x3(planes, planes)
|
34 |
+
self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
|
35 |
+
self.downsample = downsample
|
36 |
+
self.stride = stride
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
residual = x
|
40 |
+
|
41 |
+
out = self.conv1(x)
|
42 |
+
out = self.bn1(out)
|
43 |
+
out = self.relu(out)
|
44 |
+
|
45 |
+
out = self.conv2(out)
|
46 |
+
out = self.bn2(out)
|
47 |
+
|
48 |
+
if self.downsample is not None:
|
49 |
+
residual = self.downsample(x)
|
50 |
+
|
51 |
+
out += residual
|
52 |
+
out = self.relu(out)
|
53 |
+
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class Bottleneck(nn.Module):
|
58 |
+
expansion = 4
|
59 |
+
|
60 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
|
61 |
+
super(Bottleneck, self).__init__()
|
62 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
63 |
+
self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
|
64 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
65 |
+
padding=1, bias=False)
|
66 |
+
self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
|
67 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
68 |
+
self.bn3 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes * 4)
|
69 |
+
self.relu = nn.ReLU(inplace=True)
|
70 |
+
self.downsample = downsample
|
71 |
+
self.stride = stride
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
residual = x
|
75 |
+
|
76 |
+
out = self.conv1(x)
|
77 |
+
out = self.bn1(out)
|
78 |
+
out = self.relu(out)
|
79 |
+
|
80 |
+
out = self.conv2(out)
|
81 |
+
out = self.bn2(out)
|
82 |
+
out = self.relu(out)
|
83 |
+
|
84 |
+
out = self.conv3(out)
|
85 |
+
out = self.bn3(out)
|
86 |
+
|
87 |
+
if self.downsample is not None:
|
88 |
+
residual = self.downsample(x)
|
89 |
+
|
90 |
+
out += residual
|
91 |
+
out = self.relu(out)
|
92 |
+
|
93 |
+
return out
|
94 |
+
|
95 |
+
|
96 |
+
class ResNet(nn.Module):
|
97 |
+
def __init__(self, block, layers, width_multiplier=1.0, num_classes=1000, deep_base=False, norm_type=None):
|
98 |
+
super(ResNet, self).__init__()
|
99 |
+
self.inplanes = 128 if deep_base else int(64 * width_multiplier)
|
100 |
+
self.width_multiplier = width_multiplier
|
101 |
+
if deep_base:
|
102 |
+
self.prefix = nn.Sequential(OrderedDict([
|
103 |
+
('conv1', nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)),
|
104 |
+
('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
|
105 |
+
('relu1', nn.ReLU(inplace=False)),
|
106 |
+
('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)),
|
107 |
+
('bn2', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
|
108 |
+
('relu2', nn.ReLU(inplace=False)),
|
109 |
+
('conv3', nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)),
|
110 |
+
('bn3', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
|
111 |
+
('relu3', nn.ReLU(inplace=False))]
|
112 |
+
))
|
113 |
+
else:
|
114 |
+
self.prefix = nn.Sequential(OrderedDict([
|
115 |
+
('conv1', nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
|
116 |
+
('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
|
117 |
+
('relu', nn.ReLU(inplace=False))]
|
118 |
+
))
|
119 |
+
|
120 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) # change.
|
121 |
+
|
122 |
+
self.layer1 = self._make_layer(block, int(64 * width_multiplier), layers[0], norm_type=norm_type)
|
123 |
+
self.layer2 = self._make_layer(block, int(128 * width_multiplier), layers[1], stride=2, norm_type=norm_type)
|
124 |
+
self.layer3 = self._make_layer(block, int(256 * width_multiplier), layers[2], stride=2, norm_type=norm_type)
|
125 |
+
self.layer4 = self._make_layer(block, int(512 * width_multiplier), layers[3], stride=2, norm_type=norm_type)
|
126 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
127 |
+
self.fc = nn.Linear(int(512 * block.expansion * width_multiplier), num_classes)
|
128 |
+
|
129 |
+
for m in self.modules():
|
130 |
+
if isinstance(m, nn.Conv2d):
|
131 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
132 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
133 |
+
elif isinstance(m, ModuleHelper.BatchNorm2d(norm_type=norm_type, ret_cls=True)):
|
134 |
+
m.weight.data.fill_(1)
|
135 |
+
m.bias.data.zero_()
|
136 |
+
|
137 |
+
def _make_layer(self, block, planes, blocks, stride=1, norm_type=None):
|
138 |
+
downsample = None
|
139 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
140 |
+
downsample = nn.Sequential(
|
141 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
142 |
+
kernel_size=1, stride=stride, bias=False),
|
143 |
+
ModuleHelper.BatchNorm2d(norm_type=norm_type)(int(planes * block.expansion * self.width_multiplier)),
|
144 |
+
)
|
145 |
+
|
146 |
+
layers = []
|
147 |
+
layers.append(block(self.inplanes, planes,
|
148 |
+
stride, downsample, norm_type=norm_type))
|
149 |
+
|
150 |
+
self.inplanes = planes * block.expansion
|
151 |
+
for i in range(1, blocks):
|
152 |
+
layers.append(block(self.inplanes, planes, norm_type=norm_type))
|
153 |
+
|
154 |
+
return nn.Sequential(*layers)
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
x = self.prefix(x)
|
158 |
+
x = self.maxpool(x)
|
159 |
+
|
160 |
+
x = self.layer1(x)
|
161 |
+
x = self.layer2(x)
|
162 |
+
x = self.layer3(x)
|
163 |
+
x = self.layer4(x)
|
164 |
+
|
165 |
+
x = self.avgpool(x)
|
166 |
+
x = x.view(x.size(0), -1)
|
167 |
+
x = self.fc(x)
|
168 |
+
|
169 |
+
return x
|
170 |
+
|
171 |
+
|
172 |
+
def resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
173 |
+
"""Constructs a ResNet-18 model.
|
174 |
+
Args:
|
175 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
176 |
+
norm_type (str): choose norm type
|
177 |
+
"""
|
178 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=False, norm_type=norm_type)
|
179 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
180 |
+
return model
|
181 |
+
|
182 |
+
|
183 |
+
def deepbase_resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
184 |
+
"""Constructs a ResNet-18 model.
|
185 |
+
Args:
|
186 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
187 |
+
"""
|
188 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=True, norm_type=norm_type)
|
189 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
190 |
+
return model
|
191 |
+
|
192 |
+
|
193 |
+
def resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
194 |
+
"""Constructs a ResNet-34 model.
|
195 |
+
Args:
|
196 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
197 |
+
"""
|
198 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
|
199 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
200 |
+
return model
|
201 |
+
|
202 |
+
|
203 |
+
def deepbase_resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
204 |
+
"""Constructs a ResNet-34 model.
|
205 |
+
Args:
|
206 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
207 |
+
"""
|
208 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
|
209 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
210 |
+
return model
|
211 |
+
|
212 |
+
|
213 |
+
def resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
214 |
+
"""Constructs a ResNet-50 model.
|
215 |
+
Args:
|
216 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
217 |
+
"""
|
218 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type,
|
219 |
+
width_multiplier=kwargs["width_multiplier"])
|
220 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
221 |
+
return model
|
222 |
+
|
223 |
+
|
224 |
+
def deepbase_resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
225 |
+
"""Constructs a ResNet-50 model.
|
226 |
+
Args:
|
227 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
228 |
+
"""
|
229 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
|
230 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
231 |
+
return model
|
232 |
+
|
233 |
+
|
234 |
+
def resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
235 |
+
"""Constructs a ResNet-101 model.
|
236 |
+
Args:
|
237 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
238 |
+
"""
|
239 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
|
240 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
241 |
+
return model
|
242 |
+
|
243 |
+
|
244 |
+
def deepbase_resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
245 |
+
"""Constructs a ResNet-101 model.
|
246 |
+
Args:
|
247 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
248 |
+
"""
|
249 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
|
250 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
251 |
+
return model
|
252 |
+
|
253 |
+
|
254 |
+
def resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
255 |
+
"""Constructs a ResNet-152 model.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
259 |
+
"""
|
260 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
|
261 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
262 |
+
return model
|
263 |
+
|
264 |
+
|
265 |
+
def deepbase_resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
|
266 |
+
"""Constructs a ResNet-152 model.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
pretrained (bool): If True, returns a model pre-trained on Places
|
270 |
+
"""
|
271 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
|
272 |
+
model = ModuleHelper.load_model(model, pretrained=pretrained)
|
273 |
+
return model
|
networks/timm_deit.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
from networks.timm_vit import VisionTransformer, _cfg
|
9 |
+
from timm.models.registry import register_model
|
10 |
+
from timm.models.layers import trunc_normal_
|
11 |
+
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
|
15 |
+
'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
|
16 |
+
'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
|
17 |
+
'deit_base_distilled_patch16_384',
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
class DistilledVisionTransformer(VisionTransformer):
|
22 |
+
def __init__(self, *args, **kwargs):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
25 |
+
num_patches = self.patch_embed.num_patches
|
26 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
|
27 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
28 |
+
|
29 |
+
trunc_normal_(self.dist_token, std=.02)
|
30 |
+
trunc_normal_(self.pos_embed, std=.02)
|
31 |
+
self.head_dist.apply(self._init_weights)
|
32 |
+
|
33 |
+
def forward_features(self, x):
|
34 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
35 |
+
# with slight modifications to add the dist_token
|
36 |
+
B = x.shape[0]
|
37 |
+
x = self.patch_embed(x)
|
38 |
+
|
39 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
40 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
41 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
42 |
+
|
43 |
+
x = x + self.pos_embed
|
44 |
+
x = self.pos_drop(x)
|
45 |
+
|
46 |
+
for blk in self.blocks:
|
47 |
+
x = blk(x)
|
48 |
+
|
49 |
+
x = self.norm(x)
|
50 |
+
return x[:, 0], x[:, 1]
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x, x_dist = self.forward_features(x)
|
54 |
+
x = self.head(x)
|
55 |
+
x_dist = self.head_dist(x_dist)
|
56 |
+
if self.training:
|
57 |
+
return x, x_dist
|
58 |
+
else:
|
59 |
+
# during inference, return the average of both classifier predictions
|
60 |
+
return (x + x_dist) / 2
|
61 |
+
|
62 |
+
def interpolate_pos_encoding(self, x, pos_embed):
|
63 |
+
"""Interpolate the learnable positional encoding to match the number of patches.
|
64 |
+
|
65 |
+
x: B x (1 + 1 + N patches) x dim_embedding
|
66 |
+
pos_embed: B x (1 + 1 + N patches) x dim_embedding
|
67 |
+
|
68 |
+
return interpolated positional embedding
|
69 |
+
"""
|
70 |
+
|
71 |
+
npatch = x.shape[1] - 2 # (H // patch_size * W // patch_size)
|
72 |
+
N = pos_embed.shape[1] - 2 # 784 (= 28 x 28)
|
73 |
+
|
74 |
+
if npatch == N:
|
75 |
+
return pos_embed
|
76 |
+
|
77 |
+
class_emb, distil_token, pos_embed = pos_embed[:, 0], pos_embed[:, 1], pos_embed[:, 2:] # a learnable CLS token, learnable position embeddings
|
78 |
+
|
79 |
+
dim = x.shape[-1] # dimension of embeddings
|
80 |
+
pos_embed = nn.functional.interpolate(
|
81 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
|
82 |
+
scale_factor=math.sqrt(npatch / N) + 1e-5, # noel: this can be a float, but the output shape will be integer.
|
83 |
+
recompute_scale_factor=True,
|
84 |
+
mode='bicubic'
|
85 |
+
)
|
86 |
+
# print("pos_embed", pos_embed.shape, npatch, N, math.sqrt(npatch/N), math.sqrt(npatch/N) * int(math.sqrt(N)))
|
87 |
+
# exit(12)
|
88 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
89 |
+
pos_embed = torch.cat((class_emb.unsqueeze(0), distil_token.unsqueeze(0), pos_embed), dim=1)
|
90 |
+
return pos_embed
|
91 |
+
|
92 |
+
def get_tokens(
|
93 |
+
self,
|
94 |
+
x,
|
95 |
+
layers: list,
|
96 |
+
patch_tokens: bool = False,
|
97 |
+
norm: bool = True,
|
98 |
+
input_tokens: bool = False,
|
99 |
+
post_pe: bool = False
|
100 |
+
):
|
101 |
+
"""Return intermediate tokens."""
|
102 |
+
list_tokens: list = []
|
103 |
+
|
104 |
+
B = x.shape[0]
|
105 |
+
x = self.patch_embed(x)
|
106 |
+
|
107 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
108 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
109 |
+
|
110 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
111 |
+
|
112 |
+
if input_tokens:
|
113 |
+
list_tokens.append(x)
|
114 |
+
|
115 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
116 |
+
x = x + pos_embed
|
117 |
+
|
118 |
+
if post_pe:
|
119 |
+
list_tokens.append(x)
|
120 |
+
|
121 |
+
x = self.pos_drop(x)
|
122 |
+
|
123 |
+
for i, blk in enumerate(self.blocks):
|
124 |
+
x = blk(x) # B x # patches x dim
|
125 |
+
if layers is None or i in layers:
|
126 |
+
list_tokens.append(self.norm(x) if norm else x)
|
127 |
+
|
128 |
+
tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
|
129 |
+
|
130 |
+
if not patch_tokens:
|
131 |
+
return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
|
132 |
+
|
133 |
+
else:
|
134 |
+
return torch.cat((tokens[:, :, 0, :].unsqueeze(dim=2), tokens[:, :, 2:, :]), dim=2) # exclude distil token.
|
135 |
+
|
136 |
+
|
137 |
+
@register_model
|
138 |
+
def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
139 |
+
model = VisionTransformer(
|
140 |
+
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
141 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
142 |
+
model.default_cfg = _cfg()
|
143 |
+
if pretrained:
|
144 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
145 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
|
146 |
+
map_location="cpu", check_hash=True
|
147 |
+
)
|
148 |
+
model.load_state_dict(checkpoint["model"])
|
149 |
+
return model
|
150 |
+
|
151 |
+
|
152 |
+
@register_model
|
153 |
+
def deit_small_patch16_224(pretrained=False, **kwargs):
|
154 |
+
model = VisionTransformer(
|
155 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
156 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
157 |
+
model.default_cfg = _cfg()
|
158 |
+
if pretrained:
|
159 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
160 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
|
161 |
+
map_location="cpu", check_hash=True
|
162 |
+
)
|
163 |
+
model.load_state_dict(checkpoint["model"])
|
164 |
+
return model
|
165 |
+
|
166 |
+
|
167 |
+
@register_model
|
168 |
+
def deit_base_patch16_224(pretrained=False, **kwargs):
|
169 |
+
model = VisionTransformer(
|
170 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
171 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
172 |
+
model.default_cfg = _cfg()
|
173 |
+
if pretrained:
|
174 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
175 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
176 |
+
map_location="cpu", check_hash=True
|
177 |
+
)
|
178 |
+
model.load_state_dict(checkpoint["model"])
|
179 |
+
return model
|
180 |
+
|
181 |
+
|
182 |
+
@register_model
|
183 |
+
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
184 |
+
model = DistilledVisionTransformer(
|
185 |
+
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
186 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
187 |
+
model.default_cfg = _cfg()
|
188 |
+
if pretrained:
|
189 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
190 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
|
191 |
+
map_location="cpu", check_hash=True
|
192 |
+
)
|
193 |
+
model.load_state_dict(checkpoint["model"])
|
194 |
+
return model
|
195 |
+
|
196 |
+
|
197 |
+
@register_model
|
198 |
+
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
199 |
+
model = DistilledVisionTransformer(
|
200 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
201 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
202 |
+
model.default_cfg = _cfg()
|
203 |
+
if pretrained:
|
204 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
205 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
|
206 |
+
map_location="cpu", check_hash=True
|
207 |
+
)
|
208 |
+
model.load_state_dict(checkpoint["model"])
|
209 |
+
return model
|
210 |
+
|
211 |
+
|
212 |
+
@register_model
|
213 |
+
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
214 |
+
model = DistilledVisionTransformer(
|
215 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
216 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
217 |
+
model.default_cfg = _cfg()
|
218 |
+
if pretrained:
|
219 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
220 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
|
221 |
+
map_location="cpu", check_hash=True
|
222 |
+
)
|
223 |
+
model.load_state_dict(checkpoint["model"])
|
224 |
+
return model
|
225 |
+
|
226 |
+
|
227 |
+
@register_model
|
228 |
+
def deit_base_patch16_384(pretrained=False, **kwargs):
|
229 |
+
model = VisionTransformer(
|
230 |
+
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
231 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
232 |
+
model.default_cfg = _cfg()
|
233 |
+
if pretrained:
|
234 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
235 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
|
236 |
+
map_location="cpu", check_hash=True
|
237 |
+
)
|
238 |
+
model.load_state_dict(checkpoint["model"])
|
239 |
+
return model
|
240 |
+
|
241 |
+
|
242 |
+
@register_model
|
243 |
+
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
244 |
+
model = DistilledVisionTransformer(
|
245 |
+
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
246 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
247 |
+
model.default_cfg = _cfg()
|
248 |
+
if pretrained:
|
249 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
250 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
|
251 |
+
map_location="cpu", check_hash=True
|
252 |
+
)
|
253 |
+
model.load_state_dict(checkpoint["model"])
|
254 |
+
return model
|
networks/timm_vit.py
ADDED
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Vision Transformer (ViT) in PyTorch
|
2 |
+
|
3 |
+
A PyTorch implement of Vision Transformers as described in
|
4 |
+
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
5 |
+
|
6 |
+
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
7 |
+
|
8 |
+
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
9 |
+
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
10 |
+
|
11 |
+
Acknowledgments:
|
12 |
+
* The paper authors for releasing code and weights, thanks!
|
13 |
+
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
14 |
+
for some einops/einsum fun
|
15 |
+
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
16 |
+
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
17 |
+
|
18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
19 |
+
"""
|
20 |
+
import math
|
21 |
+
import logging
|
22 |
+
from functools import partial
|
23 |
+
from collections import OrderedDict
|
24 |
+
from copy import deepcopy
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
|
30 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
31 |
+
from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
|
32 |
+
from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
|
33 |
+
from timm.models.registry import register_model
|
34 |
+
|
35 |
+
_logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
def _cfg(url='', **kwargs):
|
39 |
+
return {
|
40 |
+
'url': url,
|
41 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
42 |
+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
43 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
44 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
45 |
+
**kwargs
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
default_cfgs = {
|
50 |
+
# patch models (my experiments)
|
51 |
+
'vit_small_patch16_224': _cfg(
|
52 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
53 |
+
),
|
54 |
+
|
55 |
+
# patch models (weights ported from official Google JAX impl)
|
56 |
+
'vit_base_patch16_224': _cfg(
|
57 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
58 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
59 |
+
),
|
60 |
+
'vit_base_patch32_224': _cfg(
|
61 |
+
url='', # no official model weights for this combo, only for in21k
|
62 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
63 |
+
'vit_base_patch16_384': _cfg(
|
64 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
65 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
66 |
+
'vit_base_patch32_384': _cfg(
|
67 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
|
68 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
69 |
+
'vit_large_patch16_224': _cfg(
|
70 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
71 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
72 |
+
'vit_large_patch32_224': _cfg(
|
73 |
+
url='', # no official model weights for this combo, only for in21k
|
74 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
75 |
+
'vit_large_patch16_384': _cfg(
|
76 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
77 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
78 |
+
'vit_large_patch32_384': _cfg(
|
79 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
80 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
81 |
+
|
82 |
+
# patch models, imagenet21k (weights ported from official Google JAX impl)
|
83 |
+
'vit_base_patch16_224_in21k': _cfg(
|
84 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
|
85 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
86 |
+
'vit_base_patch32_224_in21k': _cfg(
|
87 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
|
88 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
89 |
+
'vit_large_patch16_224_in21k': _cfg(
|
90 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
|
91 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
92 |
+
'vit_large_patch32_224_in21k': _cfg(
|
93 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
94 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
95 |
+
'vit_huge_patch14_224_in21k': _cfg(
|
96 |
+
hf_hub='timm/vit_huge_patch14_224_in21k',
|
97 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
98 |
+
|
99 |
+
# deit models (FB weights)
|
100 |
+
'vit_deit_tiny_patch16_224': _cfg(
|
101 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
102 |
+
'vit_deit_small_patch16_224': _cfg(
|
103 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
104 |
+
'vit_deit_base_patch16_224': _cfg(
|
105 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
|
106 |
+
'vit_deit_base_patch16_384': _cfg(
|
107 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
108 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
109 |
+
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
110 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
111 |
+
classifier=('head', 'head_dist')),
|
112 |
+
'vit_deit_small_distilled_patch16_224': _cfg(
|
113 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
114 |
+
classifier=('head', 'head_dist')),
|
115 |
+
'vit_deit_base_distilled_patch16_224': _cfg(
|
116 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
117 |
+
classifier=('head', 'head_dist')),
|
118 |
+
'vit_deit_base_distilled_patch16_384': _cfg(
|
119 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
120 |
+
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
|
121 |
+
|
122 |
+
# ViT ImageNet-21K-P pretraining
|
123 |
+
'vit_base_patch16_224_miil_in21k': _cfg(
|
124 |
+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
|
125 |
+
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
|
126 |
+
),
|
127 |
+
'vit_base_patch16_224_miil': _cfg(
|
128 |
+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
|
129 |
+
'/vit_base_patch16_224_1k_miil_84_4.pth',
|
130 |
+
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
|
131 |
+
),
|
132 |
+
}
|
133 |
+
|
134 |
+
|
135 |
+
class Attention(nn.Module):
|
136 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
137 |
+
super().__init__()
|
138 |
+
self.num_heads = num_heads
|
139 |
+
head_dim = dim // num_heads
|
140 |
+
self.scale = qk_scale or head_dim ** -0.5
|
141 |
+
|
142 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
143 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
144 |
+
self.proj = nn.Linear(dim, dim)
|
145 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
B, N, C = x.shape
|
149 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
150 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
151 |
+
|
152 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
153 |
+
attn = attn.softmax(dim=-1)
|
154 |
+
attn = self.attn_drop(attn)
|
155 |
+
|
156 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
157 |
+
x = self.proj(x)
|
158 |
+
x = self.proj_drop(x)
|
159 |
+
return x
|
160 |
+
|
161 |
+
|
162 |
+
class Block(nn.Module):
|
163 |
+
|
164 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
165 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
166 |
+
super().__init__()
|
167 |
+
self.norm1 = norm_layer(dim)
|
168 |
+
self.attn = Attention(
|
169 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
170 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
171 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
172 |
+
self.norm2 = norm_layer(dim)
|
173 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
174 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
178 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
179 |
+
return x
|
180 |
+
|
181 |
+
|
182 |
+
class VisionTransformer(nn.Module):
|
183 |
+
""" Vision Transformer
|
184 |
+
|
185 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
186 |
+
- https://arxiv.org/abs/2010.11929
|
187 |
+
|
188 |
+
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
189 |
+
- https://arxiv.org/abs/2012.12877
|
190 |
+
"""
|
191 |
+
|
192 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
193 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
|
194 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
|
195 |
+
act_layer=None, weight_init='',
|
196 |
+
# noel
|
197 |
+
img_size_eval: int = 224):
|
198 |
+
"""
|
199 |
+
Args:
|
200 |
+
img_size (int, tuple): input image size
|
201 |
+
patch_size (int, tuple): patch size
|
202 |
+
in_chans (int): number of input channels
|
203 |
+
num_classes (int): number of classes for classification head
|
204 |
+
embed_dim (int): embedding dimension
|
205 |
+
depth (int): depth of transformer
|
206 |
+
num_heads (int): number of attention heads
|
207 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
208 |
+
qkv_bias (bool): enable bias for qkv if True
|
209 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
210 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
211 |
+
distilled (bool): model includes a distillation token and head as in DeiT models
|
212 |
+
drop_rate (float): dropout rate
|
213 |
+
attn_drop_rate (float): attention dropout rate
|
214 |
+
drop_path_rate (float): stochastic depth rate
|
215 |
+
embed_layer (nn.Module): patch embedding layer
|
216 |
+
norm_layer: (nn.Module): normalization layer
|
217 |
+
weight_init: (str): weight init scheme
|
218 |
+
"""
|
219 |
+
super().__init__()
|
220 |
+
self.num_classes = num_classes
|
221 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
222 |
+
self.num_tokens = 2 if distilled else 1
|
223 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
224 |
+
act_layer = act_layer or nn.GELU
|
225 |
+
|
226 |
+
self.patch_embed = embed_layer(
|
227 |
+
img_size=img_size,
|
228 |
+
patch_size=patch_size,
|
229 |
+
in_chans=in_chans,
|
230 |
+
embed_dim=embed_dim
|
231 |
+
)
|
232 |
+
num_patches = self.patch_embed.num_patches
|
233 |
+
|
234 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
235 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
236 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
237 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
238 |
+
|
239 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
240 |
+
self.blocks = nn.Sequential(*[
|
241 |
+
Block(
|
242 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
243 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
244 |
+
for i in range(depth)])
|
245 |
+
self.norm = norm_layer(embed_dim)
|
246 |
+
|
247 |
+
# Representation layer
|
248 |
+
if representation_size and not distilled:
|
249 |
+
self.num_features = representation_size
|
250 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
251 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
252 |
+
('act', nn.Tanh())
|
253 |
+
]))
|
254 |
+
else:
|
255 |
+
self.pre_logits = nn.Identity()
|
256 |
+
|
257 |
+
# Classifier head(s)
|
258 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
259 |
+
self.head_dist = None
|
260 |
+
if distilled:
|
261 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
262 |
+
|
263 |
+
# Weight init
|
264 |
+
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
|
265 |
+
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
|
266 |
+
trunc_normal_(self.pos_embed, std=.02)
|
267 |
+
if self.dist_token is not None:
|
268 |
+
trunc_normal_(self.dist_token, std=.02)
|
269 |
+
if weight_init.startswith('jax'):
|
270 |
+
# leave cls token as zeros to match jax impl
|
271 |
+
for n, m in self.named_modules():
|
272 |
+
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
|
273 |
+
else:
|
274 |
+
trunc_normal_(self.cls_token, std=.02)
|
275 |
+
self.apply(_init_vit_weights)
|
276 |
+
|
277 |
+
# noel
|
278 |
+
self.depth = depth
|
279 |
+
self.distilled = distilled
|
280 |
+
self.patch_size = patch_size
|
281 |
+
self.patch_embed.img_size = (img_size_eval, img_size_eval)
|
282 |
+
|
283 |
+
def _init_weights(self, m):
|
284 |
+
# this fn left here for compat with downstream users
|
285 |
+
_init_vit_weights(m)
|
286 |
+
|
287 |
+
@torch.jit.ignore
|
288 |
+
def no_weight_decay(self):
|
289 |
+
return {'pos_embed', 'cls_token', 'dist_token'}
|
290 |
+
|
291 |
+
def get_classifier(self):
|
292 |
+
if self.dist_token is None:
|
293 |
+
return self.head
|
294 |
+
else:
|
295 |
+
return self.head, self.head_dist
|
296 |
+
|
297 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
298 |
+
self.num_classes = num_classes
|
299 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
300 |
+
if self.num_tokens == 2:
|
301 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
302 |
+
|
303 |
+
def forward_features(self, x):
|
304 |
+
x = self.patch_embed(x)
|
305 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
306 |
+
if self.dist_token is None:
|
307 |
+
x = torch.cat((cls_token, x), dim=1)
|
308 |
+
else:
|
309 |
+
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
310 |
+
x = self.pos_drop(x + self.pos_embed)
|
311 |
+
x = self.blocks(x)
|
312 |
+
x = self.norm(x)
|
313 |
+
if self.dist_token is None:
|
314 |
+
return self.pre_logits(x[:, 0])
|
315 |
+
else:
|
316 |
+
return x[:, 0], x[:, 1]
|
317 |
+
|
318 |
+
# def forward(self, x):
|
319 |
+
# x = self.forward_features(x)
|
320 |
+
# if self.head_dist is not None:
|
321 |
+
# x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
|
322 |
+
# if self.training and not torch.jit.is_scripting():
|
323 |
+
# # during inference, return the average of both classifier predictions
|
324 |
+
# return x, x_dist
|
325 |
+
# else:
|
326 |
+
# return (x + x_dist) / 2
|
327 |
+
# else:
|
328 |
+
# x = self.head(x)
|
329 |
+
# return x
|
330 |
+
|
331 |
+
# noel - start
|
332 |
+
def make_square(self, x: torch.Tensor):
|
333 |
+
"""Pad some pixels to make the input size divisible by the patch size."""
|
334 |
+
B, _, H_0, W_0 = x.shape
|
335 |
+
pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size
|
336 |
+
pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size
|
337 |
+
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=x.mean())
|
338 |
+
|
339 |
+
H_p, W_p = H_0 + pad_h, W_0 + pad_w
|
340 |
+
x = nn.functional.pad(x, (0, H_p - W_p, 0, 0) if H_p > W_p else (0, 0, 0, W_p - H_p), value=x.mean())
|
341 |
+
return x
|
342 |
+
|
343 |
+
def interpolate_pos_encoding(self, x, pos_embed, size):
|
344 |
+
"""Interpolate the learnable positional encoding to match the number of patches.
|
345 |
+
|
346 |
+
x: B x (1 + N patches) x dim_embedding
|
347 |
+
pos_embed: B x (1 + N patches) x dim_embedding
|
348 |
+
|
349 |
+
return interpolated positional embedding
|
350 |
+
"""
|
351 |
+
npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
|
352 |
+
N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
|
353 |
+
if npatch == N:
|
354 |
+
return pos_embed
|
355 |
+
class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
|
356 |
+
|
357 |
+
dim = x.shape[-1] # dimension of embeddings
|
358 |
+
pos_embed = nn.functional.interpolate(
|
359 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
|
360 |
+
size=size,
|
361 |
+
mode='bicubic',
|
362 |
+
align_corners=False
|
363 |
+
)
|
364 |
+
|
365 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
366 |
+
pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
367 |
+
return pos_embed
|
368 |
+
|
369 |
+
# def interpolate_pos_encoding(self, x, pos_embed):
|
370 |
+
# """Interpolate the learnable positional encoding to match the number of patches.
|
371 |
+
#
|
372 |
+
# x: B x (1 + N patches) x dim_embedding
|
373 |
+
# pos_embed: B x (1 + N patches) x dim_embedding
|
374 |
+
#
|
375 |
+
# return interpolated positional embedding
|
376 |
+
# """
|
377 |
+
# npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
|
378 |
+
# N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
|
379 |
+
# if npatch == N:
|
380 |
+
# return pos_embed
|
381 |
+
# class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
|
382 |
+
#
|
383 |
+
# dim = x.shape[-1] # dimension of embeddings
|
384 |
+
# pos_embed = nn.functional.interpolate(
|
385 |
+
# pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
|
386 |
+
# scale_factor=math.sqrt(npatch / N) + 1e-5, # noel: this can be a float, but the output shape will be integer.
|
387 |
+
# recompute_scale_factor=True,
|
388 |
+
# mode='bicubic',
|
389 |
+
# align_corners=False
|
390 |
+
# )
|
391 |
+
#
|
392 |
+
# pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
393 |
+
# pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
394 |
+
# return pos_embed
|
395 |
+
|
396 |
+
def prepare_tokens(self, x):
|
397 |
+
B, nc, h, w = x.shape
|
398 |
+
patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
|
399 |
+
x = self.patch_embed(x) # patch linear embedding
|
400 |
+
|
401 |
+
# add the [CLS] token to the embed patch tokens
|
402 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
403 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
404 |
+
|
405 |
+
# add positional encoding to each token
|
406 |
+
x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w))
|
407 |
+
return self.pos_drop(x)
|
408 |
+
|
409 |
+
def get_tokens(
|
410 |
+
self,
|
411 |
+
x,
|
412 |
+
layers: list,
|
413 |
+
patch_tokens: bool = False,
|
414 |
+
norm: bool = True,
|
415 |
+
input_tokens: bool = False,
|
416 |
+
post_pe: bool = False
|
417 |
+
):
|
418 |
+
"""Return intermediate tokens."""
|
419 |
+
list_tokens: list = []
|
420 |
+
|
421 |
+
B = x.shape[0]
|
422 |
+
x = self.patch_embed(x)
|
423 |
+
|
424 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
425 |
+
|
426 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
427 |
+
|
428 |
+
if input_tokens:
|
429 |
+
list_tokens.append(x)
|
430 |
+
|
431 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
432 |
+
x = x + pos_embed
|
433 |
+
|
434 |
+
if post_pe:
|
435 |
+
list_tokens.append(x)
|
436 |
+
|
437 |
+
x = self.pos_drop(x)
|
438 |
+
|
439 |
+
for i, blk in enumerate(self.blocks):
|
440 |
+
x = blk(x) # B x # patches x dim
|
441 |
+
if layers is None or i in layers:
|
442 |
+
list_tokens.append(self.norm(x) if norm else x)
|
443 |
+
|
444 |
+
tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
|
445 |
+
|
446 |
+
if not patch_tokens:
|
447 |
+
return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
|
448 |
+
|
449 |
+
else:
|
450 |
+
return tokens
|
451 |
+
|
452 |
+
def forward(self, x, layer: str = None):
|
453 |
+
x = self.prepare_tokens(x)
|
454 |
+
|
455 |
+
features: dict = {}
|
456 |
+
for i, blk in enumerate(self.blocks):
|
457 |
+
x = blk(x)
|
458 |
+
features[f"layer{i + 1}"] = self.norm(x)
|
459 |
+
|
460 |
+
if layer is not None:
|
461 |
+
return features[layer]
|
462 |
+
else:
|
463 |
+
return features["layer12"]
|
464 |
+
# noel - end
|
465 |
+
|
466 |
+
|
467 |
+
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
|
468 |
+
""" ViT weight initialization
|
469 |
+
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
470 |
+
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
471 |
+
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
472 |
+
"""
|
473 |
+
if isinstance(m, nn.Linear):
|
474 |
+
if n.startswith('head'):
|
475 |
+
nn.init.zeros_(m.weight)
|
476 |
+
nn.init.constant_(m.bias, head_bias)
|
477 |
+
elif n.startswith('pre_logits'):
|
478 |
+
lecun_normal_(m.weight)
|
479 |
+
nn.init.zeros_(m.bias)
|
480 |
+
else:
|
481 |
+
if jax_impl:
|
482 |
+
nn.init.xavier_uniform_(m.weight)
|
483 |
+
if m.bias is not None:
|
484 |
+
if 'mlp' in n:
|
485 |
+
nn.init.normal_(m.bias, std=1e-6)
|
486 |
+
else:
|
487 |
+
nn.init.zeros_(m.bias)
|
488 |
+
else:
|
489 |
+
trunc_normal_(m.weight, std=.02)
|
490 |
+
if m.bias is not None:
|
491 |
+
nn.init.zeros_(m.bias)
|
492 |
+
elif jax_impl and isinstance(m, nn.Conv2d):
|
493 |
+
# NOTE conv was left to pytorch default in my original init
|
494 |
+
lecun_normal_(m.weight)
|
495 |
+
if m.bias is not None:
|
496 |
+
nn.init.zeros_(m.bias)
|
497 |
+
elif isinstance(m, nn.LayerNorm):
|
498 |
+
nn.init.zeros_(m.bias)
|
499 |
+
nn.init.ones_(m.weight)
|
500 |
+
|
501 |
+
|
502 |
+
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
503 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
504 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
505 |
+
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
506 |
+
ntok_new = posemb_new.shape[1]
|
507 |
+
if num_tokens:
|
508 |
+
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
509 |
+
ntok_new -= num_tokens
|
510 |
+
else:
|
511 |
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
512 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
513 |
+
if not len(gs_new): # backwards compatibility
|
514 |
+
gs_new = [int(math.sqrt(ntok_new))] * 2
|
515 |
+
assert len(gs_new) >= 2
|
516 |
+
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
|
517 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
518 |
+
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
|
519 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
520 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
521 |
+
return posemb
|
522 |
+
|
523 |
+
|
524 |
+
def checkpoint_filter_fn(state_dict, model):
|
525 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
526 |
+
out_dict = {}
|
527 |
+
if 'model' in state_dict:
|
528 |
+
# For deit models
|
529 |
+
state_dict = state_dict['model']
|
530 |
+
for k, v in state_dict.items():
|
531 |
+
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
532 |
+
# For old models that I trained prior to conv based patchification
|
533 |
+
O, I, H, W = model.patch_embed.proj.weight.shape
|
534 |
+
v = v.reshape(O, -1, H, W)
|
535 |
+
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
536 |
+
# To resize pos embedding when using model at different size from pretrained weights
|
537 |
+
v = resize_pos_embed(
|
538 |
+
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
539 |
+
out_dict[k] = v
|
540 |
+
return out_dict
|
541 |
+
|
542 |
+
|
543 |
+
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
|
544 |
+
default_cfg = default_cfg or default_cfgs[variant]
|
545 |
+
if kwargs.get('features_only', None):
|
546 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
547 |
+
|
548 |
+
# NOTE this extra code to support handling of repr size for in21k pretrained models
|
549 |
+
default_num_classes = default_cfg['num_classes']
|
550 |
+
num_classes = kwargs.get('num_classes', default_num_classes)
|
551 |
+
repr_size = kwargs.pop('representation_size', None)
|
552 |
+
if repr_size is not None and num_classes != default_num_classes:
|
553 |
+
# Remove representation layer if fine-tuning. This may not always be the desired action,
|
554 |
+
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
|
555 |
+
_logger.warning("Removing representation layer for fine-tuning.")
|
556 |
+
repr_size = None
|
557 |
+
|
558 |
+
model = build_model_with_cfg(
|
559 |
+
VisionTransformer, variant, pretrained,
|
560 |
+
default_cfg=default_cfg,
|
561 |
+
representation_size=repr_size,
|
562 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
563 |
+
**kwargs)
|
564 |
+
return model
|
565 |
+
|
566 |
+
|
567 |
+
@register_model
|
568 |
+
def vit_small_patch16_224(pretrained=False, **kwargs):
|
569 |
+
""" My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
|
570 |
+
NOTE:
|
571 |
+
* this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
|
572 |
+
* this model does not have a bias for QKV (unlike the official ViT and DeiT models)
|
573 |
+
"""
|
574 |
+
model_kwargs = dict(
|
575 |
+
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
|
576 |
+
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
|
577 |
+
if pretrained:
|
578 |
+
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
579 |
+
model_kwargs.setdefault('qk_scale', 768 ** -0.5)
|
580 |
+
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
581 |
+
return model
|
582 |
+
|
583 |
+
|
584 |
+
@register_model
|
585 |
+
def vit_base_patch16_224(pretrained=False, **kwargs):
|
586 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
587 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
588 |
+
"""
|
589 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
590 |
+
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
591 |
+
return model
|
592 |
+
|
593 |
+
|
594 |
+
@register_model
|
595 |
+
def vit_base_patch32_224(pretrained=False, **kwargs):
|
596 |
+
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
597 |
+
"""
|
598 |
+
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
599 |
+
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
|
600 |
+
return model
|
601 |
+
|
602 |
+
|
603 |
+
@register_model
|
604 |
+
def vit_base_patch16_384(pretrained=False, **kwargs):
|
605 |
+
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
606 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
607 |
+
"""
|
608 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
609 |
+
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
610 |
+
return model
|
611 |
+
|
612 |
+
|
613 |
+
@register_model
|
614 |
+
def vit_base_patch32_384(pretrained=False, **kwargs):
|
615 |
+
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
616 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
617 |
+
"""
|
618 |
+
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
619 |
+
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
|
620 |
+
return model
|
621 |
+
|
622 |
+
|
623 |
+
@register_model
|
624 |
+
def vit_large_patch16_224(pretrained=False, **kwargs):
|
625 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
626 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
627 |
+
"""
|
628 |
+
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
629 |
+
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
630 |
+
return model
|
631 |
+
|
632 |
+
|
633 |
+
@register_model
|
634 |
+
def vit_large_patch32_224(pretrained=False, **kwargs):
|
635 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
636 |
+
"""
|
637 |
+
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
638 |
+
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
|
639 |
+
return model
|
640 |
+
|
641 |
+
|
642 |
+
@register_model
|
643 |
+
def vit_large_patch16_384(pretrained=False, **kwargs):
|
644 |
+
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
645 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
646 |
+
"""
|
647 |
+
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
648 |
+
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
649 |
+
return model
|
650 |
+
|
651 |
+
|
652 |
+
@register_model
|
653 |
+
def vit_large_patch32_384(pretrained=False, **kwargs):
|
654 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
655 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
656 |
+
"""
|
657 |
+
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
658 |
+
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
|
659 |
+
return model
|
660 |
+
|
661 |
+
|
662 |
+
@register_model
|
663 |
+
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
664 |
+
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
665 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
666 |
+
"""
|
667 |
+
model_kwargs = dict(
|
668 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
669 |
+
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
670 |
+
return model
|
671 |
+
|
672 |
+
|
673 |
+
@register_model
|
674 |
+
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
675 |
+
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
676 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
677 |
+
"""
|
678 |
+
model_kwargs = dict(
|
679 |
+
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
680 |
+
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
681 |
+
return model
|
682 |
+
|
683 |
+
|
684 |
+
@register_model
|
685 |
+
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
686 |
+
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
687 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
688 |
+
"""
|
689 |
+
model_kwargs = dict(
|
690 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
691 |
+
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
692 |
+
return model
|
693 |
+
|
694 |
+
|
695 |
+
@register_model
|
696 |
+
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
697 |
+
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
698 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
699 |
+
"""
|
700 |
+
model_kwargs = dict(
|
701 |
+
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
702 |
+
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
703 |
+
return model
|
704 |
+
|
705 |
+
|
706 |
+
@register_model
|
707 |
+
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
708 |
+
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
709 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
710 |
+
NOTE: converted weights not currently available, too large for github release hosting.
|
711 |
+
"""
|
712 |
+
model_kwargs = dict(
|
713 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
|
714 |
+
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
|
715 |
+
return model
|
716 |
+
|
717 |
+
|
718 |
+
@register_model
|
719 |
+
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
720 |
+
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
721 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
722 |
+
"""
|
723 |
+
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
724 |
+
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
725 |
+
return model
|
726 |
+
|
727 |
+
|
728 |
+
@register_model
|
729 |
+
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
730 |
+
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
731 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
732 |
+
"""
|
733 |
+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
734 |
+
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
735 |
+
return model
|
736 |
+
|
737 |
+
|
738 |
+
@register_model
|
739 |
+
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
|
740 |
+
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
741 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
742 |
+
"""
|
743 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
744 |
+
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
745 |
+
return model
|
746 |
+
|
747 |
+
|
748 |
+
@register_model
|
749 |
+
def vit_deit_base_patch16_384(pretrained=False, **kwargs):
|
750 |
+
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
751 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
752 |
+
"""
|
753 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
754 |
+
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
755 |
+
return model
|
756 |
+
|
757 |
+
|
758 |
+
@register_model
|
759 |
+
def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
760 |
+
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
761 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
762 |
+
"""
|
763 |
+
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
764 |
+
model = _create_vision_transformer(
|
765 |
+
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
766 |
+
return model
|
767 |
+
|
768 |
+
|
769 |
+
@register_model
|
770 |
+
def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
771 |
+
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
772 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
773 |
+
"""
|
774 |
+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
775 |
+
model = _create_vision_transformer(
|
776 |
+
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
777 |
+
return model
|
778 |
+
|
779 |
+
|
780 |
+
@register_model
|
781 |
+
def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
782 |
+
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
783 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
784 |
+
"""
|
785 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
786 |
+
model = _create_vision_transformer(
|
787 |
+
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
788 |
+
return model
|
789 |
+
|
790 |
+
|
791 |
+
@register_model
|
792 |
+
def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
793 |
+
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
794 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
795 |
+
"""
|
796 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
797 |
+
model = _create_vision_transformer(
|
798 |
+
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
799 |
+
return model
|
800 |
+
|
801 |
+
|
802 |
+
@register_model
|
803 |
+
def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
|
804 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
805 |
+
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
806 |
+
"""
|
807 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
808 |
+
model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
|
809 |
+
return model
|
810 |
+
|
811 |
+
|
812 |
+
@register_model
|
813 |
+
def vit_base_patch16_224_miil(pretrained=False, **kwargs):
|
814 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
815 |
+
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
816 |
+
"""
|
817 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
818 |
+
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
|
819 |
+
return model
|
networks/vision_transformer.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Mostly copy-paste from timm library.
|
4 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
5 |
+
"""
|
6 |
+
from typing import Optional
|
7 |
+
import math
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
|
14 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
15 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
16 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
17 |
+
def norm_cdf(x):
|
18 |
+
# Computes standard normal cumulative distribution function
|
19 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
20 |
+
|
21 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
22 |
+
warnings.warn(
|
23 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.",
|
24 |
+
stacklevel=2
|
25 |
+
)
|
26 |
+
|
27 |
+
with torch.no_grad():
|
28 |
+
# Values are generated by using a truncated uniform distribution and
|
29 |
+
# then using the inverse CDF for the normal distribution.
|
30 |
+
# Get upper and lower cdf values
|
31 |
+
l = norm_cdf((a - mean) / std)
|
32 |
+
u = norm_cdf((b - mean) / std)
|
33 |
+
|
34 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
35 |
+
# [2l-1, 2u-1].
|
36 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
37 |
+
|
38 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
39 |
+
# standard normal
|
40 |
+
tensor.erfinv_()
|
41 |
+
|
42 |
+
# Transform to proper mean, std
|
43 |
+
tensor.mul_(std * math.sqrt(2.))
|
44 |
+
tensor.add_(mean)
|
45 |
+
|
46 |
+
# Clamp to ensure it's in the proper range
|
47 |
+
tensor.clamp_(min=a, max=b)
|
48 |
+
return tensor
|
49 |
+
|
50 |
+
|
51 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
52 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
53 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
54 |
+
|
55 |
+
|
56 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
57 |
+
if drop_prob == 0. or not training:
|
58 |
+
return x
|
59 |
+
keep_prob = 1 - drop_prob
|
60 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
61 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
62 |
+
random_tensor.floor_() # binarize
|
63 |
+
output = x.div(keep_prob) * random_tensor
|
64 |
+
return output
|
65 |
+
|
66 |
+
|
67 |
+
class DropPath(nn.Module):
|
68 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
69 |
+
"""
|
70 |
+
def __init__(self, drop_prob=None):
|
71 |
+
super(DropPath, self).__init__()
|
72 |
+
self.drop_prob = drop_prob
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
return drop_path(x, self.drop_prob, self.training)
|
76 |
+
|
77 |
+
|
78 |
+
class Mlp(nn.Module):
|
79 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
80 |
+
super().__init__()
|
81 |
+
out_features = out_features or in_features
|
82 |
+
hidden_features = hidden_features or in_features
|
83 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
84 |
+
self.act = act_layer()
|
85 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
86 |
+
self.drop = nn.Dropout(drop)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
x = self.fc1(x)
|
90 |
+
x = self.act(x)
|
91 |
+
x = self.drop(x)
|
92 |
+
x = self.fc2(x)
|
93 |
+
x = self.drop(x)
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
class Attention(nn.Module):
|
98 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
99 |
+
super().__init__()
|
100 |
+
self.num_heads = num_heads
|
101 |
+
head_dim = dim // num_heads
|
102 |
+
self.scale = qk_scale or head_dim ** -0.5 # square root of dimension for normalisation
|
103 |
+
|
104 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
105 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
106 |
+
|
107 |
+
self.proj = nn.Linear(dim, dim)
|
108 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
B, N, C = x.shape # B x (cls token + # patch tokens) x dim
|
112 |
+
|
113 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
114 |
+
# qkv: 3 x B x Nh x (cls token + # patch tokens) x (dim // Nh)
|
115 |
+
|
116 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
117 |
+
# q, k, v: B x Nh x (cls token + # patch tokens) x (dim // Nh)
|
118 |
+
|
119 |
+
# q: B x Nh x (cls token + # patch tokens) x (dim // Nh)
|
120 |
+
# k.transpose(-2, -1) = B x Nh x (dim // Nh) x (cls token + # patch tokens)
|
121 |
+
# attn: B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
|
122 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale # @ operator is for matrix multiplication
|
123 |
+
attn = attn.softmax(dim=-1) # B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
|
124 |
+
attn = self.attn_drop(attn)
|
125 |
+
|
126 |
+
# attn = B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
|
127 |
+
# v = B x Nh x (cls token + # patch tokens) x (dim // Nh)
|
128 |
+
# attn @ v = B x Nh x (cls token + # patch tokens) x (dim // Nh)
|
129 |
+
# (attn @ v).transpose(1, 2) = B x (cls token + # patch tokens) x Nh x (dim // Nh)
|
130 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # B x (cls token + # patch tokens) x dim
|
131 |
+
x = self.proj(x) # B x (cls token + # patch tokens) x dim
|
132 |
+
x = self.proj_drop(x)
|
133 |
+
return x, attn
|
134 |
+
|
135 |
+
|
136 |
+
class Block(nn.Module):
|
137 |
+
def __init__(self,
|
138 |
+
dim, num_heads,
|
139 |
+
mlp_ratio=4.,
|
140 |
+
qkv_bias=False,
|
141 |
+
qk_scale=None,
|
142 |
+
drop=0.,
|
143 |
+
attn_drop=0.,
|
144 |
+
drop_path=0.,
|
145 |
+
act_layer=nn.GELU,
|
146 |
+
norm_layer=nn.LayerNorm):
|
147 |
+
super().__init__()
|
148 |
+
self.norm1 = norm_layer(dim)
|
149 |
+
self.attn = Attention(
|
150 |
+
dim,
|
151 |
+
num_heads=num_heads,
|
152 |
+
qkv_bias=qkv_bias,
|
153 |
+
qk_scale=qk_scale,
|
154 |
+
attn_drop=attn_drop,
|
155 |
+
proj_drop=drop
|
156 |
+
)
|
157 |
+
|
158 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
159 |
+
|
160 |
+
self.norm2 = norm_layer(dim)
|
161 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
162 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
163 |
+
|
164 |
+
def forward(self, x, return_attention=False):
|
165 |
+
y, attn = self.attn(self.norm1(x))
|
166 |
+
if return_attention:
|
167 |
+
return attn
|
168 |
+
x = x + self.drop_path(y)
|
169 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class PatchEmbed(nn.Module):
|
174 |
+
""" Image to Patch Embedding"""
|
175 |
+
def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768):
|
176 |
+
super().__init__()
|
177 |
+
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
178 |
+
self.img_size = img_size
|
179 |
+
self.patch_size = patch_size
|
180 |
+
self.num_patches = num_patches
|
181 |
+
|
182 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
183 |
+
|
184 |
+
def forward(self, x):
|
185 |
+
B, C, H, W = x.shape
|
186 |
+
x = self.proj(x)
|
187 |
+
x = x.flatten(2).transpose(1, 2) # B x (P_H * P_W) x C
|
188 |
+
return x
|
189 |
+
|
190 |
+
|
191 |
+
class VisionTransformer(nn.Module):
|
192 |
+
""" Vision Transformer """
|
193 |
+
def __init__(self,
|
194 |
+
img_size=(224, 224),
|
195 |
+
patch_size=16,
|
196 |
+
in_chans=3,
|
197 |
+
num_classes=0,
|
198 |
+
embed_dim=768,
|
199 |
+
depth=12,
|
200 |
+
num_heads=12,
|
201 |
+
mlp_ratio=4.,
|
202 |
+
qkv_bias=False,
|
203 |
+
qk_scale=None,
|
204 |
+
drop_rate=0.,
|
205 |
+
attn_drop_rate=0.,
|
206 |
+
drop_path_rate=0.,
|
207 |
+
norm_layer=nn.LayerNorm):
|
208 |
+
super().__init__()
|
209 |
+
self.num_features = self.embed_dim = embed_dim
|
210 |
+
|
211 |
+
self.patch_embed = PatchEmbed(
|
212 |
+
img_size=(224, 224), # noel: this is to load pretrained model.
|
213 |
+
patch_size=patch_size,
|
214 |
+
in_chans=in_chans,
|
215 |
+
embed_dim=embed_dim
|
216 |
+
)
|
217 |
+
num_patches = self.patch_embed.num_patches
|
218 |
+
|
219 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
220 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
221 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
222 |
+
|
223 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
224 |
+
self.blocks = nn.ModuleList([
|
225 |
+
Block(
|
226 |
+
dim=embed_dim,
|
227 |
+
num_heads=num_heads,
|
228 |
+
mlp_ratio=mlp_ratio,
|
229 |
+
qkv_bias=qkv_bias,
|
230 |
+
qk_scale=qk_scale,
|
231 |
+
drop=drop_rate,
|
232 |
+
attn_drop=attn_drop_rate,
|
233 |
+
drop_path=dpr[i],
|
234 |
+
norm_layer=norm_layer
|
235 |
+
) for i in range(depth)])
|
236 |
+
self.norm = norm_layer(embed_dim)
|
237 |
+
|
238 |
+
# Classifier head
|
239 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
240 |
+
|
241 |
+
trunc_normal_(self.pos_embed, std=.02)
|
242 |
+
trunc_normal_(self.cls_token, std=.02)
|
243 |
+
self.apply(self._init_weights)
|
244 |
+
|
245 |
+
self.depth = depth
|
246 |
+
self.embed_dim = self.n_embs = embed_dim
|
247 |
+
self.mlp_ratio = mlp_ratio
|
248 |
+
self.n_heads = num_heads
|
249 |
+
self.patch_size = patch_size
|
250 |
+
|
251 |
+
def _init_weights(self, m):
|
252 |
+
if isinstance(m, nn.Linear):
|
253 |
+
trunc_normal_(m.weight, std=.02)
|
254 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
255 |
+
nn.init.constant_(m.bias, 0)
|
256 |
+
elif isinstance(m, nn.LayerNorm):
|
257 |
+
nn.init.constant_(m.bias, 0)
|
258 |
+
nn.init.constant_(m.weight, 1.0)
|
259 |
+
|
260 |
+
def make_input_divisible(self, x: torch.Tensor) -> torch.Tensor:
|
261 |
+
"""Pad some pixels to make the input size divisible by the patch size."""
|
262 |
+
B, _, H_0, W_0 = x.shape
|
263 |
+
pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size
|
264 |
+
pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size
|
265 |
+
|
266 |
+
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0)
|
267 |
+
return x
|
268 |
+
|
269 |
+
def prepare_tokens(self, x):
|
270 |
+
B, nc, h, w = x.shape
|
271 |
+
x: torch.Tensor = self.make_input_divisible(x)
|
272 |
+
patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
|
273 |
+
|
274 |
+
x = self.patch_embed(x) # patch linear embedding
|
275 |
+
|
276 |
+
# add positional encoding to each token
|
277 |
+
# add the [CLS] token to the embed patch tokens
|
278 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
279 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
280 |
+
x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w))
|
281 |
+
return self.pos_drop(x)
|
282 |
+
|
283 |
+
@staticmethod
|
284 |
+
def split_token(x, token_type: str):
|
285 |
+
if token_type == "cls":
|
286 |
+
return x[:, 0, :]
|
287 |
+
elif token_type == "patch":
|
288 |
+
return x[:, 1:, :]
|
289 |
+
else:
|
290 |
+
return x
|
291 |
+
|
292 |
+
# noel
|
293 |
+
def forward(self, x, layer: Optional[str] = None):
|
294 |
+
x: torch.Tensor = self.prepare_tokens(x)
|
295 |
+
|
296 |
+
features: dict = {}
|
297 |
+
for i, blk in enumerate(self.blocks):
|
298 |
+
x = blk(x)
|
299 |
+
features[f"layer{i + 1}"] = self.norm(x)
|
300 |
+
|
301 |
+
if layer is not None:
|
302 |
+
return features[layer]
|
303 |
+
else:
|
304 |
+
return features
|
305 |
+
|
306 |
+
# noel - for DINO's visual
|
307 |
+
def get_last_selfattention(self, x):
|
308 |
+
x = self.prepare_tokens(x)
|
309 |
+
for i, blk in enumerate(self.blocks):
|
310 |
+
if i < len(self.blocks) - 1:
|
311 |
+
x = blk(x)
|
312 |
+
else:
|
313 |
+
# return attention of the last block
|
314 |
+
return blk(x, return_attention=True)
|
315 |
+
|
316 |
+
def get_tokens(
|
317 |
+
self,
|
318 |
+
x,
|
319 |
+
layers: list,
|
320 |
+
patch_tokens: bool = False,
|
321 |
+
norm: bool = True,
|
322 |
+
input_tokens: bool = False,
|
323 |
+
post_pe: bool = False
|
324 |
+
):
|
325 |
+
"""Return intermediate tokens."""
|
326 |
+
list_tokens: list = []
|
327 |
+
|
328 |
+
B = x.shape[0]
|
329 |
+
x = self.patch_embed(x)
|
330 |
+
|
331 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
332 |
+
|
333 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
334 |
+
|
335 |
+
if input_tokens:
|
336 |
+
list_tokens.append(x)
|
337 |
+
|
338 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
339 |
+
x = x + pos_embed
|
340 |
+
|
341 |
+
if post_pe:
|
342 |
+
list_tokens.append(x)
|
343 |
+
|
344 |
+
x = self.pos_drop(x)
|
345 |
+
|
346 |
+
for i, blk in enumerate(self.blocks):
|
347 |
+
x = blk(x) # B x # patches x dim
|
348 |
+
if layers is None or i in layers:
|
349 |
+
list_tokens.append(self.norm(x) if norm else x)
|
350 |
+
|
351 |
+
tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
|
352 |
+
|
353 |
+
if not patch_tokens:
|
354 |
+
return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
|
355 |
+
|
356 |
+
else:
|
357 |
+
return tokens
|
358 |
+
|
359 |
+
def forward_features(self, x):
|
360 |
+
B = x.shape[0]
|
361 |
+
x = self.patch_embed(x)
|
362 |
+
|
363 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
364 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
365 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
366 |
+
x = x + pos_embed
|
367 |
+
x = self.pos_drop(x)
|
368 |
+
|
369 |
+
for blk in self.blocks:
|
370 |
+
x = blk(x)
|
371 |
+
|
372 |
+
if self.norm is not None:
|
373 |
+
x = self.norm(x)
|
374 |
+
|
375 |
+
return x[:, 0]
|
376 |
+
|
377 |
+
def interpolate_pos_encoding(self, x, pos_embed, size):
|
378 |
+
"""Interpolate the learnable positional encoding to match the number of patches.
|
379 |
+
|
380 |
+
x: B x (1 + N patches) x dim_embedding
|
381 |
+
pos_embed: B x (1 + N patches) x dim_embedding
|
382 |
+
|
383 |
+
return interpolated positional embedding
|
384 |
+
"""
|
385 |
+
npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
|
386 |
+
N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
|
387 |
+
if npatch == N:
|
388 |
+
return pos_embed
|
389 |
+
class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
|
390 |
+
|
391 |
+
dim = x.shape[-1] # dimension of embeddings
|
392 |
+
pos_embed = nn.functional.interpolate(
|
393 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
|
394 |
+
size=size,
|
395 |
+
mode='bicubic',
|
396 |
+
align_corners=False
|
397 |
+
)
|
398 |
+
|
399 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
400 |
+
pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
401 |
+
return pos_embed
|
402 |
+
|
403 |
+
def forward_selfattention(self, x, return_interm_attn=False):
|
404 |
+
B, nc, w, h = x.shape
|
405 |
+
N = self.pos_embed.shape[1] - 1
|
406 |
+
x = self.patch_embed(x)
|
407 |
+
|
408 |
+
# interpolate patch embeddings
|
409 |
+
dim = x.shape[-1]
|
410 |
+
w0 = w // self.patch_embed.patch_size
|
411 |
+
h0 = h // self.patch_embed.patch_size
|
412 |
+
class_pos_embed = self.pos_embed[:, 0]
|
413 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
414 |
+
patch_pos_embed = nn.functional.interpolate(
|
415 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
416 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
417 |
+
mode='bicubic'
|
418 |
+
)
|
419 |
+
if w0 != patch_pos_embed.shape[-2]:
|
420 |
+
helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device)
|
421 |
+
patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2)
|
422 |
+
if h0 != patch_pos_embed.shape[-1]:
|
423 |
+
helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device)
|
424 |
+
pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)
|
425 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
426 |
+
pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
427 |
+
|
428 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # self.cls_token: 1 x 1 x emb_dim -> ?
|
429 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
430 |
+
x = x + pos_embed
|
431 |
+
x = self.pos_drop(x)
|
432 |
+
|
433 |
+
if return_interm_attn:
|
434 |
+
list_attn = []
|
435 |
+
for i, blk in enumerate(self.blocks):
|
436 |
+
attn = blk(x, return_attention=True)
|
437 |
+
x = blk(x)
|
438 |
+
list_attn.append(attn)
|
439 |
+
return torch.cat(list_attn, dim=0)
|
440 |
+
|
441 |
+
else:
|
442 |
+
for i, blk in enumerate(self.blocks):
|
443 |
+
if i < len(self.blocks) - 1:
|
444 |
+
x = blk(x)
|
445 |
+
else:
|
446 |
+
return blk(x, return_attention=True)
|
447 |
+
|
448 |
+
def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
|
449 |
+
B = x.shape[0]
|
450 |
+
x = self.patch_embed(x)
|
451 |
+
|
452 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
453 |
+
|
454 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
455 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
456 |
+
x = x + pos_embed
|
457 |
+
x = self.pos_drop(x)
|
458 |
+
|
459 |
+
# we will return the [CLS] tokens from the `n` last blocks
|
460 |
+
output = []
|
461 |
+
for i, blk in enumerate(self.blocks):
|
462 |
+
x = blk(x)
|
463 |
+
if len(self.blocks) - i <= n:
|
464 |
+
# get only CLS token (B x dim)
|
465 |
+
output.append(self.norm(x)[:, 0])
|
466 |
+
if return_patch_avgpool:
|
467 |
+
x = self.norm(x)
|
468 |
+
# In addition to the [CLS] tokens from the `n` last blocks, we also return
|
469 |
+
# the patch tokens from the last block. This is useful for linear eval.
|
470 |
+
output.append(torch.mean(x[:, 1:], dim=1))
|
471 |
+
return torch.cat(output, dim=-1)
|
472 |
+
|
473 |
+
def return_patch_emb_from_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
|
474 |
+
"""Return intermediate patch embeddings, rather than CLS token, from the last n blocks."""
|
475 |
+
B = x.shape[0]
|
476 |
+
x = self.patch_embed(x)
|
477 |
+
|
478 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
479 |
+
|
480 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
481 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
482 |
+
x = x + pos_embed
|
483 |
+
x = self.pos_drop(x)
|
484 |
+
|
485 |
+
# we will return the [CLS] tokens from the `n` last blocks
|
486 |
+
output = []
|
487 |
+
for i, blk in enumerate(self.blocks):
|
488 |
+
x = blk(x)
|
489 |
+
if len(self.blocks) - i <= n:
|
490 |
+
output.append(self.norm(x)[:, 1:]) # get only CLS token (B x dim)
|
491 |
+
|
492 |
+
if return_patch_avgpool:
|
493 |
+
x = self.norm(x)
|
494 |
+
# In addition to the [CLS] tokens from the `n` last blocks, we also return
|
495 |
+
# the patch tokens from the last block. This is useful for linear eval.
|
496 |
+
output.append(torch.mean(x[:, 1:], dim=1))
|
497 |
+
return torch.stack(output, dim=-1) # B x n_patches x dim x n
|
498 |
+
|
499 |
+
|
500 |
+
def deit_tiny(patch_size=16, **kwargs):
|
501 |
+
model = VisionTransformer(
|
502 |
+
patch_size=patch_size,
|
503 |
+
embed_dim=192,
|
504 |
+
depth=12,
|
505 |
+
num_heads=3,
|
506 |
+
mlp_ratio=4,
|
507 |
+
qkv_bias=True,
|
508 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
509 |
+
**kwargs)
|
510 |
+
return model
|
511 |
+
|
512 |
+
|
513 |
+
def deit_small(patch_size=16, **kwargs):
|
514 |
+
depth = kwargs.pop("depth") if "depth" in kwargs else 12
|
515 |
+
model = VisionTransformer(
|
516 |
+
patch_size=patch_size,
|
517 |
+
embed_dim=384,
|
518 |
+
depth=depth,
|
519 |
+
num_heads=6,
|
520 |
+
mlp_ratio=4,
|
521 |
+
qkv_bias=True,
|
522 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
523 |
+
**kwargs
|
524 |
+
)
|
525 |
+
return model
|
526 |
+
|
527 |
+
|
528 |
+
def vit_base(patch_size=16, **kwargs):
|
529 |
+
model = VisionTransformer(
|
530 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
531 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
532 |
+
return model
|
533 |
+
|
534 |
+
|
535 |
+
class DINOHead(nn.Module):
|
536 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
|
537 |
+
super().__init__()
|
538 |
+
nlayers = max(nlayers, 1)
|
539 |
+
if nlayers == 1:
|
540 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
541 |
+
else:
|
542 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
543 |
+
if use_bn:
|
544 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
545 |
+
layers.append(nn.GELU())
|
546 |
+
for _ in range(nlayers - 2):
|
547 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
548 |
+
if use_bn:
|
549 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
550 |
+
layers.append(nn.GELU())
|
551 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
552 |
+
self.mlp = nn.Sequential(*layers)
|
553 |
+
self.apply(self._init_weights)
|
554 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
555 |
+
self.last_layer.weight_g.data.fill_(1)
|
556 |
+
if norm_last_layer:
|
557 |
+
self.last_layer.weight_g.requires_grad = False
|
558 |
+
|
559 |
+
def _init_weights(self, m):
|
560 |
+
if isinstance(m, nn.Linear):
|
561 |
+
trunc_normal_(m.weight, std=.02)
|
562 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
563 |
+
nn.init.constant_(m.bias, 0)
|
564 |
+
|
565 |
+
def forward(self, x):
|
566 |
+
x = self.mlp(x)
|
567 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
568 |
+
x = self.last_layer(x)
|
569 |
+
return x
|
resources/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
resources/0053.jpg
ADDED
resources/0236.jpg
ADDED
resources/0239.jpg
ADDED
resources/0403.jpg
ADDED
resources/0412.jpg
ADDED
resources/ILSVRC2012_test_00005309.jpg
ADDED
resources/ILSVRC2012_test_00012622.jpg
ADDED
resources/ILSVRC2012_test_00022698.jpg
ADDED
resources/ILSVRC2012_test_00040725.jpg
ADDED
resources/ILSVRC2012_test_00075738.jpg
ADDED
resources/ILSVRC2012_test_00080683.jpg
ADDED
resources/ILSVRC2012_test_00085874.jpg
ADDED
resources/im052.jpg
ADDED
resources/sun_ainjbonxmervsvpv.jpg
ADDED
resources/sun_alfntqzssslakmss.jpg
ADDED
resources/sun_amnrcxhisjfrliwa.jpg
ADDED
resources/sun_bvyxpvkouzlfwwod.jpg
ADDED