aravindhv10 commited on
Commit
951254d
·
1 Parent(s): 3d207b0

Added MVANet ComfyUI plugin

Browse files
.gitignore CHANGED
@@ -1,12 +1,21 @@
1
- **/__pycache__
 
 
 
 
 
 
 
2
  data/
3
- log/
4
- pretrain_model/
5
- git_add.txt
6
- rm.txt
7
- main.org
8
- demo/demo_lip.png
9
- demo/lip-visualization.jpg
10
- demo/demo_pascal.png
11
  demo/demo_atr.png
12
  demo/demo.jpg
 
 
 
 
 
 
 
 
 
 
 
1
+ /ComfyUI_MVANet/download.sh
2
+ /ComfyUI_MVANet/MVANet_inference.class.py
3
+ /ComfyUI_MVANet/MVANet_inference.execute.py
4
+ /ComfyUI_MVANet/MVANet_inference.function.py
5
+ /ComfyUI_MVANet/MVANet_inference.import.py
6
+ /ComfyUI_MVANet/MVANet_inference.run.sh
7
+ /ComfyUI_MVANet/MVANet_inference.unify.sh
8
+ /ComfyUI_MVANet/.#README.org
9
  data/
 
 
 
 
 
 
 
 
10
  demo/demo_atr.png
11
  demo/demo.jpg
12
+ demo/demo_lip.png
13
+ demo/demo_pascal.png
14
+ demo/lip-visualization.jpg
15
+ /git_add.txt
16
+ log/
17
+ /main.org
18
+ pretrain_model/
19
+ **/__pycache__
20
+ /rm.txt
21
+ /waste.txt
ComfyUI_MVANet/MVANet_inference.py ADDED
@@ -0,0 +1,1548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import os
3
+ import sys
4
+
5
+ HOME_DIR = os.environ.get('HOME', '/root')
6
+ MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
7
+ finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
8
+ pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
9
+
10
+ import math
11
+ import numpy as np
12
+ import cv2
13
+ import wget
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint as checkpoint
19
+ from torch.autograd import Variable
20
+ from torch import nn
21
+ from torchvision import transforms
22
+
23
+ from einops import rearrange
24
+
25
+ from timm.models import load_checkpoint
26
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27
+
28
+ torch_device = 'cuda'
29
+ torch_dtype = torch.float16
30
+
31
+
32
+ def check_mkdir(dir_name):
33
+ if not os.path.isdir(dir_name):
34
+ os.makedirs(dir_name)
35
+
36
+
37
+ def SwinT(pretrained=True):
38
+ model = SwinTransformer(embed_dim=96,
39
+ depths=[2, 2, 6, 2],
40
+ num_heads=[3, 6, 12, 24],
41
+ window_size=7)
42
+ if pretrained is True:
43
+ model.load_state_dict(torch.load(
44
+ 'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
45
+ map_location='cpu')['model'],
46
+ strict=False)
47
+
48
+ return model
49
+
50
+
51
+ def SwinS(pretrained=True):
52
+ model = SwinTransformer(embed_dim=96,
53
+ depths=[2, 2, 18, 2],
54
+ num_heads=[3, 6, 12, 24],
55
+ window_size=7)
56
+ if pretrained is True:
57
+ model.load_state_dict(torch.load(
58
+ 'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
59
+ map_location='cpu')['model'],
60
+ strict=False)
61
+
62
+ return model
63
+
64
+
65
+ def SwinB(pretrained=True):
66
+ model = SwinTransformer(embed_dim=128,
67
+ depths=[2, 2, 18, 2],
68
+ num_heads=[4, 8, 16, 32],
69
+ window_size=12)
70
+ if pretrained is True:
71
+ import os
72
+ model.load_state_dict(torch.load(pretrained_SwinB_model_path,
73
+ map_location='cpu')['model'],
74
+ strict=False)
75
+ return model
76
+
77
+
78
+ def SwinL(pretrained=True):
79
+ model = SwinTransformer(embed_dim=192,
80
+ depths=[2, 2, 18, 2],
81
+ num_heads=[6, 12, 24, 48],
82
+ window_size=12)
83
+ if pretrained is True:
84
+ model.load_state_dict(torch.load(
85
+ 'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
86
+ map_location='cpu')['model'],
87
+ strict=False)
88
+
89
+ return model
90
+
91
+
92
+ def get_activation_fn(activation):
93
+ """Return an activation function given a string"""
94
+ if activation == "relu":
95
+ return F.relu
96
+ if activation == "gelu":
97
+ return F.gelu
98
+ if activation == "glu":
99
+ return F.glu
100
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
101
+
102
+
103
+ def make_cbr(in_dim, out_dim):
104
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
105
+ nn.BatchNorm2d(out_dim), nn.PReLU())
106
+
107
+
108
+ def make_cbg(in_dim, out_dim):
109
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
110
+ nn.BatchNorm2d(out_dim), nn.GELU())
111
+
112
+
113
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
114
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
115
+
116
+
117
+ def resize_as(x, y, interpolation='bilinear'):
118
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
119
+
120
+
121
+ def image2patches(x):
122
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
123
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
124
+ return x
125
+
126
+
127
+ def patches2image(x):
128
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
129
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
130
+ return x
131
+
132
+
133
+ def window_partition(x, window_size):
134
+ """
135
+ Args:
136
+ x: (B, H, W, C)
137
+ window_size (int): window size
138
+
139
+ Returns:
140
+ windows: (num_windows*B, window_size, window_size, C)
141
+ """
142
+ B, H, W, C = x.shape
143
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
144
+ C)
145
+ windows = x.permute(0, 1, 3, 2, 4,
146
+ 5).contiguous().view(-1, window_size, window_size, C)
147
+ return windows
148
+
149
+
150
+ def window_reverse(windows, window_size, H, W):
151
+ """
152
+ Args:
153
+ windows: (num_windows*B, window_size, window_size, C)
154
+ window_size (int): Window size
155
+ H (int): Height of image
156
+ W (int): Width of image
157
+
158
+ Returns:
159
+ x: (B, H, W, C)
160
+ """
161
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
162
+ x = windows.view(B, H // window_size, W // window_size, window_size,
163
+ window_size, -1)
164
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
165
+ return x
166
+
167
+
168
+ def mkdir_safe(out_path):
169
+ if type(out_path) == str:
170
+ if len(out_path) > 0:
171
+ if not os.path.exists(out_path):
172
+ os.mkdir(out_path)
173
+
174
+
175
+ def get_model_path():
176
+ import folder_paths
177
+ from folder_paths import models_dir
178
+
179
+ path_file_model = models_dir
180
+ mkdir_safe(out_path=path_file_model)
181
+
182
+ path_file_model = os.path.join(path_file_model, 'MVANet')
183
+ mkdir_safe(out_path=path_file_model)
184
+
185
+ path_file_model = os.path.join(path_file_model, 'Model_80.pth')
186
+
187
+ return path_file_model
188
+
189
+
190
+ def download_model(path):
191
+ if not os.path.exists(path):
192
+ wget.download(
193
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/Model_80.pth',
194
+ out=path)
195
+
196
+
197
+ def load_model(model_checkpoint_path):
198
+ download_model(path=model_checkpoint_path)
199
+ torch.cuda.set_device(0)
200
+
201
+ net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
202
+
203
+ pretrained_dict = torch.load(finetuned_MVANet_model_path,
204
+ map_location=torch_device)
205
+
206
+ model_dict = net.state_dict()
207
+ pretrained_dict = {
208
+ k: v
209
+ for k, v in pretrained_dict.items() if k in model_dict
210
+ }
211
+ model_dict.update(pretrained_dict)
212
+ net.load_state_dict(model_dict)
213
+ net = net.to(dtype=torch_dtype, device=torch_device)
214
+ net.eval()
215
+ return net
216
+
217
+
218
+ def do_infer_tensor2tensor(img, net):
219
+
220
+ img_transform = transforms.Compose(
221
+ [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
222
+
223
+ h_, w_ = img.shape[1], img.shape[2]
224
+
225
+ with torch.no_grad():
226
+
227
+ img = rearrange(img, 'B H W C -> B C H W')
228
+
229
+ img_resize = torch.nn.functional.interpolate(input=img,
230
+ size=(1024, 1024),
231
+ mode='bicubic',
232
+ antialias=True)
233
+
234
+ img_var = img_transform(img_resize)
235
+ img_var = Variable(img_var)
236
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
237
+
238
+ mask = []
239
+
240
+ mask.append(net(img_var))
241
+
242
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
243
+ prediction = prediction.sigmoid()
244
+
245
+ prediction = torch.nn.functional.interpolate(input=prediction,
246
+ size=(h_, w_),
247
+ mode='bicubic',
248
+ antialias=True)
249
+
250
+ prediction = prediction.squeeze(0)
251
+ prediction = prediction.clamp(0, 1)
252
+ prediction = prediction.detach()
253
+ prediction = prediction.to(dtype=torch.float32, device='cpu')
254
+
255
+ return prediction
256
+
257
+
258
+ class Mlp(nn.Module):
259
+ """ Multilayer perceptron."""
260
+
261
+ def __init__(self,
262
+ in_features,
263
+ hidden_features=None,
264
+ out_features=None,
265
+ act_layer=nn.GELU,
266
+ drop=0.):
267
+ super().__init__()
268
+ out_features = out_features or in_features
269
+ hidden_features = hidden_features or in_features
270
+ self.fc1 = nn.Linear(in_features, hidden_features)
271
+ self.act = act_layer()
272
+ self.fc2 = nn.Linear(hidden_features, out_features)
273
+ self.drop = nn.Dropout(drop)
274
+
275
+ def forward(self, x):
276
+ x = self.fc1(x)
277
+ x = self.act(x)
278
+ x = self.drop(x)
279
+ x = self.fc2(x)
280
+ x = self.drop(x)
281
+ return x
282
+
283
+
284
+ class WindowAttention(nn.Module):
285
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
286
+ It supports both of shifted and non-shifted window.
287
+
288
+ Args:
289
+ dim (int): Number of input channels.
290
+ window_size (tuple[int]): The height and width of the window.
291
+ num_heads (int): Number of attention heads.
292
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
293
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
294
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
295
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
296
+ """
297
+
298
+ def __init__(self,
299
+ dim,
300
+ window_size,
301
+ num_heads,
302
+ qkv_bias=True,
303
+ qk_scale=None,
304
+ attn_drop=0.,
305
+ proj_drop=0.):
306
+
307
+ super().__init__()
308
+ self.dim = dim
309
+ self.window_size = window_size # Wh, Ww
310
+ self.num_heads = num_heads
311
+ head_dim = dim // num_heads
312
+ self.scale = qk_scale or head_dim**-0.5
313
+
314
+ # define a parameter table of relative position bias
315
+ self.relative_position_bias_table = nn.Parameter(
316
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
317
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
318
+
319
+ # get pair-wise relative position index for each token inside the window
320
+ coords_h = torch.arange(self.window_size[0])
321
+ coords_w = torch.arange(self.window_size[1])
322
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
323
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
324
+ relative_coords = coords_flatten[:, :,
325
+ None] - coords_flatten[:,
326
+ None, :] # 2, Wh*Ww, Wh*Ww
327
+ relative_coords = relative_coords.permute(
328
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
329
+ relative_coords[:, :,
330
+ 0] += self.window_size[0] - 1 # shift to start from 0
331
+ relative_coords[:, :, 1] += self.window_size[1] - 1
332
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
333
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
334
+ self.register_buffer("relative_position_index",
335
+ relative_position_index)
336
+
337
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
338
+ self.attn_drop = nn.Dropout(attn_drop)
339
+ self.proj = nn.Linear(dim, dim)
340
+ self.proj_drop = nn.Dropout(proj_drop)
341
+
342
+ trunc_normal_(self.relative_position_bias_table, std=.02)
343
+ self.softmax = nn.Softmax(dim=-1)
344
+
345
+ def forward(self, x, mask=None):
346
+ """ Forward function.
347
+
348
+ Args:
349
+ x: input features with shape of (num_windows*B, N, C)
350
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
351
+ """
352
+ x = x.to(dtype=torch_dtype, device=torch_device)
353
+ B_, N, C = x.shape
354
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
355
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
356
+ q, k, v = qkv[0], qkv[1], qkv[
357
+ 2] # make torchscript happy (cannot use tensor as tuple)
358
+
359
+ q = q * self.scale
360
+ attn = (q @ k.transpose(-2, -1))
361
+
362
+ relative_position_bias = self.relative_position_bias_table[
363
+ self.relative_position_index.view(-1)].view(
364
+ self.window_size[0] * self.window_size[1],
365
+ self.window_size[0] * self.window_size[1],
366
+ -1) # Wh*Ww,Wh*Ww,nH
367
+ relative_position_bias = relative_position_bias.permute(
368
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
369
+ attn = attn + relative_position_bias.unsqueeze(0)
370
+
371
+ if mask is not None:
372
+ nW = mask.shape[0]
373
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
374
+ N) + mask.unsqueeze(1).unsqueeze(0)
375
+ attn = attn.view(-1, self.num_heads, N, N)
376
+ attn = self.softmax(attn)
377
+ else:
378
+ attn = self.softmax(attn)
379
+
380
+ attn = self.attn_drop(attn)
381
+ attn = attn.to(dtype=torch_dtype, device=torch_device)
382
+ v = v.to(dtype=torch_dtype, device=torch_device)
383
+
384
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
385
+ x = self.proj(x)
386
+ x = self.proj_drop(x)
387
+ return x
388
+
389
+
390
+ class SwinTransformerBlock(nn.Module):
391
+ """ Swin Transformer Block.
392
+
393
+ Args:
394
+ dim (int): Number of input channels.
395
+ num_heads (int): Number of attention heads.
396
+ window_size (int): Window size.
397
+ shift_size (int): Shift size for SW-MSA.
398
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
399
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
400
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
401
+ drop (float, optional): Dropout rate. Default: 0.0
402
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
403
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
404
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
405
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
406
+ """
407
+
408
+ def __init__(self,
409
+ dim,
410
+ num_heads,
411
+ window_size=7,
412
+ shift_size=0,
413
+ mlp_ratio=4.,
414
+ qkv_bias=True,
415
+ qk_scale=None,
416
+ drop=0.,
417
+ attn_drop=0.,
418
+ drop_path=0.,
419
+ act_layer=nn.GELU,
420
+ norm_layer=nn.LayerNorm):
421
+ super().__init__()
422
+ self.dim = dim
423
+ self.num_heads = num_heads
424
+ self.window_size = window_size
425
+ self.shift_size = shift_size
426
+ self.mlp_ratio = mlp_ratio
427
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
428
+
429
+ self.norm1 = norm_layer(dim)
430
+ self.attn = WindowAttention(dim,
431
+ window_size=to_2tuple(self.window_size),
432
+ num_heads=num_heads,
433
+ qkv_bias=qkv_bias,
434
+ qk_scale=qk_scale,
435
+ attn_drop=attn_drop,
436
+ proj_drop=drop)
437
+
438
+ self.drop_path = DropPath(
439
+ drop_path) if drop_path > 0. else nn.Identity()
440
+ self.norm2 = norm_layer(dim)
441
+ mlp_hidden_dim = int(dim * mlp_ratio)
442
+ self.mlp = Mlp(in_features=dim,
443
+ hidden_features=mlp_hidden_dim,
444
+ act_layer=act_layer,
445
+ drop=drop)
446
+
447
+ self.H = None
448
+ self.W = None
449
+
450
+ def forward(self, x, mask_matrix):
451
+ """ Forward function.
452
+
453
+ Args:
454
+ x: Input feature, tensor size (B, H*W, C).
455
+ H, W: Spatial resolution of the input feature.
456
+ mask_matrix: Attention mask for cyclic shift.
457
+ """
458
+ B, L, C = x.shape
459
+ H, W = self.H, self.W
460
+ assert L == H * W, "input feature has wrong size"
461
+
462
+ shortcut = x
463
+ x = self.norm1(x)
464
+ x = x.view(B, H, W, C)
465
+
466
+ # pad feature maps to multiples of window size
467
+ pad_l = pad_t = 0
468
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
469
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
470
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
471
+ _, Hp, Wp, _ = x.shape
472
+
473
+ # cyclic shift
474
+ if self.shift_size > 0:
475
+ shifted_x = torch.roll(x,
476
+ shifts=(-self.shift_size, -self.shift_size),
477
+ dims=(1, 2))
478
+ attn_mask = mask_matrix
479
+ else:
480
+ shifted_x = x
481
+ attn_mask = None
482
+
483
+ # partition windows
484
+ x_windows = window_partition(
485
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
486
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
487
+ C) # nW*B, window_size*window_size, C
488
+
489
+ # W-MSA/SW-MSA
490
+ attn_windows = self.attn(
491
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
492
+
493
+ # merge windows
494
+ attn_windows = attn_windows.view(-1, self.window_size,
495
+ self.window_size, C)
496
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
497
+ Wp) # B H' W' C
498
+
499
+ # reverse cyclic shift
500
+ if self.shift_size > 0:
501
+ x = torch.roll(shifted_x,
502
+ shifts=(self.shift_size, self.shift_size),
503
+ dims=(1, 2))
504
+ else:
505
+ x = shifted_x
506
+
507
+ if pad_r > 0 or pad_b > 0:
508
+ x = x[:, :H, :W, :].contiguous()
509
+
510
+ x = x.view(B, H * W, C)
511
+
512
+ # FFN
513
+ x = shortcut + self.drop_path(x)
514
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
515
+
516
+ return x
517
+
518
+
519
+ class PatchMerging(nn.Module):
520
+ """ Patch Merging Layer
521
+
522
+ Args:
523
+ dim (int): Number of input channels.
524
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
525
+ """
526
+
527
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
528
+ super().__init__()
529
+ self.dim = dim
530
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
531
+ self.norm = norm_layer(4 * dim)
532
+
533
+ def forward(self, x, H, W):
534
+ """ Forward function.
535
+
536
+ Args:
537
+ x: Input feature, tensor size (B, H*W, C).
538
+ H, W: Spatial resolution of the input feature.
539
+ """
540
+ B, L, C = x.shape
541
+ assert L == H * W, "input feature has wrong size"
542
+
543
+ x = x.view(B, H, W, C)
544
+
545
+ # padding
546
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
547
+ if pad_input:
548
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
549
+
550
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
551
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
552
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
553
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
554
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
555
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
556
+
557
+ x = self.norm(x)
558
+ x = self.reduction(x)
559
+
560
+ return x
561
+
562
+
563
+ class BasicLayer(nn.Module):
564
+ """ A basic Swin Transformer layer for one stage.
565
+
566
+ Args:
567
+ dim (int): Number of feature channels
568
+ depth (int): Depths of this stage.
569
+ num_heads (int): Number of attention head.
570
+ window_size (int): Local window size. Default: 7.
571
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
572
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
573
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
574
+ drop (float, optional): Dropout rate. Default: 0.0
575
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
576
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
577
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
578
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
579
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
580
+ """
581
+
582
+ def __init__(self,
583
+ dim,
584
+ depth,
585
+ num_heads,
586
+ window_size=7,
587
+ mlp_ratio=4.,
588
+ qkv_bias=True,
589
+ qk_scale=None,
590
+ drop=0.,
591
+ attn_drop=0.,
592
+ drop_path=0.,
593
+ norm_layer=nn.LayerNorm,
594
+ downsample=None,
595
+ use_checkpoint=False):
596
+ super().__init__()
597
+ self.window_size = window_size
598
+ self.shift_size = window_size // 2
599
+ self.depth = depth
600
+ self.use_checkpoint = use_checkpoint
601
+
602
+ # build blocks
603
+ self.blocks = nn.ModuleList([
604
+ SwinTransformerBlock(dim=dim,
605
+ num_heads=num_heads,
606
+ window_size=window_size,
607
+ shift_size=0 if
608
+ (i % 2 == 0) else window_size // 2,
609
+ mlp_ratio=mlp_ratio,
610
+ qkv_bias=qkv_bias,
611
+ qk_scale=qk_scale,
612
+ drop=drop,
613
+ attn_drop=attn_drop,
614
+ drop_path=drop_path[i] if isinstance(
615
+ drop_path, list) else drop_path,
616
+ norm_layer=norm_layer) for i in range(depth)
617
+ ])
618
+
619
+ # patch merging layer
620
+ if downsample is not None:
621
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
622
+ else:
623
+ self.downsample = None
624
+
625
+ def forward(self, x, H, W):
626
+ """ Forward function.
627
+
628
+ Args:
629
+ x: Input feature, tensor size (B, H*W, C).
630
+ H, W: Spatial resolution of the input feature.
631
+ """
632
+
633
+ # calculate attention mask for SW-MSA
634
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
635
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
636
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
637
+ h_slices = (slice(0, -self.window_size),
638
+ slice(-self.window_size,
639
+ -self.shift_size), slice(-self.shift_size, None))
640
+ w_slices = (slice(0, -self.window_size),
641
+ slice(-self.window_size,
642
+ -self.shift_size), slice(-self.shift_size, None))
643
+ cnt = 0
644
+ for h in h_slices:
645
+ for w in w_slices:
646
+ img_mask[:, h, w, :] = cnt
647
+ cnt += 1
648
+
649
+ mask_windows = window_partition(
650
+ img_mask, self.window_size) # nW, window_size, window_size, 1
651
+ mask_windows = mask_windows.view(-1,
652
+ self.window_size * self.window_size)
653
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
654
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
655
+ float(-100.0)).masked_fill(
656
+ attn_mask == 0, float(0.0))
657
+
658
+ for blk in self.blocks:
659
+ blk.H, blk.W = H, W
660
+ if self.use_checkpoint:
661
+ x = checkpoint.checkpoint(blk, x, attn_mask)
662
+ else:
663
+ x = blk(x, attn_mask)
664
+ if self.downsample is not None:
665
+ x_down = self.downsample(x, H, W)
666
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
667
+ return x, H, W, x_down, Wh, Ww
668
+ else:
669
+ return x, H, W, x, H, W
670
+
671
+
672
+ class PatchEmbed(nn.Module):
673
+ """ Image to Patch Embedding
674
+
675
+ Args:
676
+ patch_size (int): Patch token size. Default: 4.
677
+ in_chans (int): Number of input image channels. Default: 3.
678
+ embed_dim (int): Number of linear projection output channels. Default: 96.
679
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
680
+ """
681
+
682
+ def __init__(self,
683
+ patch_size=4,
684
+ in_chans=3,
685
+ embed_dim=96,
686
+ norm_layer=None):
687
+ super().__init__()
688
+ patch_size = to_2tuple(patch_size)
689
+ self.patch_size = patch_size
690
+
691
+ self.in_chans = in_chans
692
+ self.embed_dim = embed_dim
693
+
694
+ self.proj = nn.Conv2d(in_chans,
695
+ embed_dim,
696
+ kernel_size=patch_size,
697
+ stride=patch_size)
698
+ if norm_layer is not None:
699
+ self.norm = norm_layer(embed_dim)
700
+ else:
701
+ self.norm = None
702
+
703
+ def forward(self, x):
704
+ """Forward function."""
705
+ # padding
706
+ _, _, H, W = x.size()
707
+ if W % self.patch_size[1] != 0:
708
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
709
+ if H % self.patch_size[0] != 0:
710
+ x = F.pad(x,
711
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
712
+
713
+ x = self.proj(x) # B C Wh Ww
714
+ if self.norm is not None:
715
+ Wh, Ww = x.size(2), x.size(3)
716
+ x = x.flatten(2).transpose(1, 2)
717
+ x = self.norm(x)
718
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
719
+
720
+ return x
721
+
722
+
723
+ class SwinTransformer(nn.Module):
724
+ """ Swin Transformer backbone.
725
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
726
+ https://arxiv.org/pdf/2103.14030
727
+
728
+ Args:
729
+ pretrain_img_size (int): Input image size for training the pretrained model,
730
+ used in absolute postion embedding. Default 224.
731
+ patch_size (int | tuple(int)): Patch size. Default: 4.
732
+ in_chans (int): Number of input image channels. Default: 3.
733
+ embed_dim (int): Number of linear projection output channels. Default: 96.
734
+ depths (tuple[int]): Depths of each Swin Transformer stage.
735
+ num_heads (tuple[int]): Number of attention head of each stage.
736
+ window_size (int): Window size. Default: 7.
737
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
738
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
739
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
740
+ drop_rate (float): Dropout rate.
741
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
742
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
743
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
744
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
745
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
746
+ out_indices (Sequence[int]): Output from which stages.
747
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
748
+ -1 means not freezing any parameters.
749
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
750
+ """
751
+
752
+ def __init__(self,
753
+ pretrain_img_size=224,
754
+ patch_size=4,
755
+ in_chans=3,
756
+ embed_dim=96,
757
+ depths=[2, 2, 6, 2],
758
+ num_heads=[3, 6, 12, 24],
759
+ window_size=7,
760
+ mlp_ratio=4.,
761
+ qkv_bias=True,
762
+ qk_scale=None,
763
+ drop_rate=0.,
764
+ attn_drop_rate=0.,
765
+ drop_path_rate=0.2,
766
+ norm_layer=nn.LayerNorm,
767
+ ape=False,
768
+ patch_norm=True,
769
+ out_indices=(0, 1, 2, 3),
770
+ frozen_stages=-1,
771
+ use_checkpoint=False):
772
+ super().__init__()
773
+
774
+ self.pretrain_img_size = pretrain_img_size
775
+ self.num_layers = len(depths)
776
+ self.embed_dim = embed_dim
777
+ self.ape = ape
778
+ self.patch_norm = patch_norm
779
+ self.out_indices = out_indices
780
+ self.frozen_stages = frozen_stages
781
+
782
+ # split image into non-overlapping patches
783
+ self.patch_embed = PatchEmbed(
784
+ patch_size=patch_size,
785
+ in_chans=in_chans,
786
+ embed_dim=embed_dim,
787
+ norm_layer=norm_layer if self.patch_norm else None)
788
+
789
+ # absolute position embedding
790
+ if self.ape:
791
+ pretrain_img_size = to_2tuple(pretrain_img_size)
792
+ patch_size = to_2tuple(patch_size)
793
+ patches_resolution = [
794
+ pretrain_img_size[0] // patch_size[0],
795
+ pretrain_img_size[1] // patch_size[1]
796
+ ]
797
+
798
+ self.absolute_pos_embed = nn.Parameter(
799
+ torch.zeros(1, embed_dim, patches_resolution[0],
800
+ patches_resolution[1]))
801
+ trunc_normal_(self.absolute_pos_embed, std=.02)
802
+
803
+ self.pos_drop = nn.Dropout(p=drop_rate)
804
+
805
+ # stochastic depth
806
+ dpr = [
807
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
808
+ ] # stochastic depth decay rule
809
+
810
+ # build layers
811
+ self.layers = nn.ModuleList()
812
+ for i_layer in range(self.num_layers):
813
+ layer = BasicLayer(
814
+ dim=int(embed_dim * 2**i_layer),
815
+ depth=depths[i_layer],
816
+ num_heads=num_heads[i_layer],
817
+ window_size=window_size,
818
+ mlp_ratio=mlp_ratio,
819
+ qkv_bias=qkv_bias,
820
+ qk_scale=qk_scale,
821
+ drop=drop_rate,
822
+ attn_drop=attn_drop_rate,
823
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
824
+ norm_layer=norm_layer,
825
+ downsample=PatchMerging if
826
+ (i_layer < self.num_layers - 1) else None,
827
+ use_checkpoint=use_checkpoint)
828
+ self.layers.append(layer)
829
+
830
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
831
+ self.num_features = num_features
832
+
833
+ # add a norm layer for each output
834
+ for i_layer in out_indices:
835
+ layer = norm_layer(num_features[i_layer])
836
+ layer_name = f'norm{i_layer}'
837
+ self.add_module(layer_name, layer)
838
+
839
+ self._freeze_stages()
840
+
841
+ def _freeze_stages(self):
842
+ if self.frozen_stages >= 0:
843
+ self.patch_embed.eval()
844
+ for param in self.patch_embed.parameters():
845
+ param.requires_grad = False
846
+
847
+ if self.frozen_stages >= 1 and self.ape:
848
+ self.absolute_pos_embed.requires_grad = False
849
+
850
+ if self.frozen_stages >= 2:
851
+ self.pos_drop.eval()
852
+ for i in range(0, self.frozen_stages - 1):
853
+ m = self.layers[i]
854
+ m.eval()
855
+ for param in m.parameters():
856
+ param.requires_grad = False
857
+
858
+ def init_weights(self, pretrained=None):
859
+ """Initialize the weights in backbone.
860
+
861
+ Args:
862
+ pretrained (str, optional): Path to pre-trained weights.
863
+ Defaults to None.
864
+ """
865
+
866
+ def _init_weights(m):
867
+ if isinstance(m, nn.Linear):
868
+ trunc_normal_(m.weight, std=.02)
869
+ if isinstance(m, nn.Linear) and m.bias is not None:
870
+ nn.init.constant_(m.bias, 0)
871
+ elif isinstance(m, nn.LayerNorm):
872
+ nn.init.constant_(m.bias, 0)
873
+ nn.init.constant_(m.weight, 1.0)
874
+
875
+ if isinstance(pretrained, str):
876
+ self.apply(_init_weights)
877
+ load_checkpoint(self, pretrained, strict=False, logger=None)
878
+ elif pretrained is None:
879
+ self.apply(_init_weights)
880
+ else:
881
+ raise TypeError('pretrained must be a str or None')
882
+
883
+ def forward(self, x):
884
+ x = self.patch_embed(x)
885
+
886
+ Wh, Ww = x.size(2), x.size(3)
887
+ if self.ape:
888
+ # interpolate the position embedding to the corresponding size
889
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
890
+ size=(Wh, Ww),
891
+ mode='bicubic')
892
+ x = (x + absolute_pos_embed) # B Wh*Ww C
893
+
894
+ outs = [x.contiguous()]
895
+ x = x.flatten(2).transpose(1, 2)
896
+ x = self.pos_drop(x)
897
+ for i in range(self.num_layers):
898
+ layer = self.layers[i]
899
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
900
+
901
+ if i in self.out_indices:
902
+ norm_layer = getattr(self, f'norm{i}')
903
+ x_out = norm_layer(x_out)
904
+
905
+ out = x_out.view(-1, H, W,
906
+ self.num_features[i]).permute(0, 3, 1,
907
+ 2).contiguous()
908
+ outs.append(out)
909
+
910
+ return tuple(outs)
911
+
912
+ def train(self, mode=True):
913
+ """Convert the model into training mode while keep layers freezed."""
914
+ super(SwinTransformer, self).train(mode)
915
+ self._freeze_stages()
916
+
917
+
918
+ class PositionEmbeddingSine:
919
+
920
+ def __init__(self,
921
+ num_pos_feats=64,
922
+ temperature=10000,
923
+ normalize=False,
924
+ scale=None):
925
+ super().__init__()
926
+ self.num_pos_feats = num_pos_feats
927
+ self.temperature = temperature
928
+ self.normalize = normalize
929
+ if scale is not None and normalize is False:
930
+ raise ValueError("normalize should be True if scale is passed")
931
+ if scale is None:
932
+ scale = 2 * math.pi
933
+ self.scale = scale
934
+ self.dim_t = torch.arange(0,
935
+ self.num_pos_feats,
936
+ dtype=torch_dtype,
937
+ device=torch_device)
938
+
939
+ def __call__(self, b, h, w):
940
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
941
+ assert mask is not None
942
+ not_mask = ~mask
943
+ y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
944
+ x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
945
+ if self.normalize:
946
+ eps = 1e-6
947
+ y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
948
+ self.scale).to(device=torch_device, dtype=torch_dtype)
949
+ x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
950
+ self.scale).to(device=torch_device, dtype=torch_dtype)
951
+
952
+ dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
953
+
954
+ pos_x = x_embed[:, :, :, None] / dim_t
955
+ pos_y = y_embed[:, :, :, None] / dim_t
956
+ pos_x = torch.stack(
957
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
958
+ dim=4).flatten(3)
959
+ pos_y = torch.stack(
960
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
961
+ dim=4).flatten(3)
962
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
963
+
964
+
965
+ class MCLM(nn.Module):
966
+
967
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
968
+ super(MCLM, self).__init__()
969
+ self.attention = nn.ModuleList([
970
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
971
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
972
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
973
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
974
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
975
+ ])
976
+
977
+ self.linear1 = nn.Linear(d_model, d_model * 2)
978
+ self.linear2 = nn.Linear(d_model * 2, d_model)
979
+ self.linear3 = nn.Linear(d_model, d_model * 2)
980
+ self.linear4 = nn.Linear(d_model * 2, d_model)
981
+ self.norm1 = nn.LayerNorm(d_model)
982
+ self.norm2 = nn.LayerNorm(d_model)
983
+ self.dropout = nn.Dropout(0.1)
984
+ self.dropout1 = nn.Dropout(0.1)
985
+ self.dropout2 = nn.Dropout(0.1)
986
+ self.activation = get_activation_fn('relu')
987
+ self.pool_ratios = pool_ratios
988
+ self.p_poses = []
989
+ self.g_pos = None
990
+ self.positional_encoding = PositionEmbeddingSine(
991
+ num_pos_feats=d_model // 2, normalize=True)
992
+
993
+ def forward(self, l, g):
994
+ """
995
+ l: 4,c,h,w
996
+ g: 1,c,h,w
997
+ """
998
+ b, c, h, w = l.size()
999
+ # 4,c,h,w -> 1,c,2h,2w
1000
+ concated_locs = rearrange(l,
1001
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1002
+ hg=2,
1003
+ wg=2)
1004
+
1005
+ pools = []
1006
+ for pool_ratio in self.pool_ratios:
1007
+ # b,c,h,w
1008
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1009
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1010
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1011
+ if self.g_pos is None:
1012
+ pos_emb = self.positional_encoding(pool.shape[0],
1013
+ pool.shape[2],
1014
+ pool.shape[3])
1015
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1016
+ self.p_poses.append(pos_emb)
1017
+ pools = torch.cat(pools, 0)
1018
+ if self.g_pos is None:
1019
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1020
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
1021
+ g.shape[3])
1022
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1023
+
1024
+ # attention between glb (q) & multisensory concated-locs (k,v)
1025
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1026
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1027
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1028
+ g_hw_b_c = self.norm1(g_hw_b_c)
1029
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1030
+ self.linear2(
1031
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1032
+ g_hw_b_c = self.norm2(g_hw_b_c)
1033
+
1034
+ # attention between origin locs (q) & freashed glb (k,v)
1035
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1036
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1037
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1038
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1039
+ ng=2,
1040
+ nw=2)
1041
+ outputs_re = []
1042
+ for i, (_l, _g) in enumerate(
1043
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1044
+ outputs_re.append(self.attention[i + 1](_l, _g,
1045
+ _g)[0]) # (h w) 1 c
1046
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1047
+
1048
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1049
+ l_hw_b_c = self.norm1(l_hw_b_c)
1050
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1051
+ self.linear4(
1052
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1053
+ l_hw_b_c = self.norm2(l_hw_b_c)
1054
+
1055
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1056
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1057
+
1058
+
1059
+ class inf_MCLM(nn.Module):
1060
+
1061
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
1062
+ super(inf_MCLM, self).__init__()
1063
+ self.attention = nn.ModuleList([
1064
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1065
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1066
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1067
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1068
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1069
+ ])
1070
+
1071
+ self.linear1 = nn.Linear(d_model, d_model * 2)
1072
+ self.linear2 = nn.Linear(d_model * 2, d_model)
1073
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1074
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1075
+ self.norm1 = nn.LayerNorm(d_model)
1076
+ self.norm2 = nn.LayerNorm(d_model)
1077
+ self.dropout = nn.Dropout(0.1)
1078
+ self.dropout1 = nn.Dropout(0.1)
1079
+ self.dropout2 = nn.Dropout(0.1)
1080
+ self.activation = get_activation_fn('relu')
1081
+ self.pool_ratios = pool_ratios
1082
+ self.p_poses = []
1083
+ self.g_pos = None
1084
+ self.positional_encoding = PositionEmbeddingSine(
1085
+ num_pos_feats=d_model // 2, normalize=True)
1086
+
1087
+ def forward(self, l, g):
1088
+ """
1089
+ l: 4,c,h,w
1090
+ g: 1,c,h,w
1091
+ """
1092
+ b, c, h, w = l.size()
1093
+ # 4,c,h,w -> 1,c,2h,2w
1094
+ concated_locs = rearrange(l,
1095
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1096
+ hg=2,
1097
+ wg=2)
1098
+ self.p_poses = []
1099
+ pools = []
1100
+ for pool_ratio in self.pool_ratios:
1101
+ # b,c,h,w
1102
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1103
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1104
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1105
+ # if self.g_pos is None:
1106
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
1107
+ pool.shape[3])
1108
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1109
+ self.p_poses.append(pos_emb)
1110
+ pools = torch.cat(pools, 0)
1111
+ # if self.g_pos is None:
1112
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1113
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
1114
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1115
+
1116
+ # attention between glb (q) & multisensory concated-locs (k,v)
1117
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1118
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1119
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1120
+ g_hw_b_c = self.norm1(g_hw_b_c)
1121
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1122
+ self.linear2(
1123
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1124
+ g_hw_b_c = self.norm2(g_hw_b_c)
1125
+
1126
+ # attention between origin locs (q) & freashed glb (k,v)
1127
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1128
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1129
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1130
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1131
+ ng=2,
1132
+ nw=2)
1133
+ outputs_re = []
1134
+ for i, (_l, _g) in enumerate(
1135
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1136
+ outputs_re.append(self.attention[i + 1](_l, _g,
1137
+ _g)[0]) # (h w) 1 c
1138
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1139
+
1140
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1141
+ l_hw_b_c = self.norm1(l_hw_b_c)
1142
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1143
+ self.linear4(
1144
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1145
+ l_hw_b_c = self.norm2(l_hw_b_c)
1146
+
1147
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1148
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1149
+
1150
+
1151
+ class MCRM(nn.Module):
1152
+
1153
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1154
+ super(MCRM, self).__init__()
1155
+ self.attention = nn.ModuleList([
1156
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1157
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1158
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1159
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1160
+ ])
1161
+
1162
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1163
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1164
+ self.norm1 = nn.LayerNorm(d_model)
1165
+ self.norm2 = nn.LayerNorm(d_model)
1166
+ self.dropout = nn.Dropout(0.1)
1167
+ self.dropout1 = nn.Dropout(0.1)
1168
+ self.dropout2 = nn.Dropout(0.1)
1169
+ self.sigmoid = nn.Sigmoid()
1170
+ self.activation = get_activation_fn('relu')
1171
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1172
+ self.pool_ratios = pool_ratios
1173
+ self.positional_encoding = PositionEmbeddingSine(
1174
+ num_pos_feats=d_model // 2, normalize=True)
1175
+
1176
+ def forward(self, x):
1177
+ b, c, h, w = x.size()
1178
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1179
+ # b(4),c,h,w
1180
+ patched_glb = rearrange(glb,
1181
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1182
+ hg=2,
1183
+ wg=2)
1184
+
1185
+ # generate token attention map
1186
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1187
+ token_attention_map = F.interpolate(token_attention_map,
1188
+ size=patches2image(loc).shape[-2:],
1189
+ mode='nearest')
1190
+ loc = loc * rearrange(token_attention_map,
1191
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1192
+ hg=2,
1193
+ wg=2)
1194
+ pools = []
1195
+ for pool_ratio in self.pool_ratios:
1196
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1197
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1198
+ pools.append(rearrange(pool,
1199
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1200
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1201
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1202
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1203
+ outputs = []
1204
+ for i, q in enumerate(
1205
+ loc_.unbind(dim=0)): # traverse all local patches
1206
+ # np*hw,1,c
1207
+ v = pools[i]
1208
+ k = v
1209
+ outputs.append(self.attention[i](q, k, v)[0])
1210
+ outputs = torch.cat(outputs, 1)
1211
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1212
+ src = self.norm1(src)
1213
+ src = src + self.dropout2(
1214
+ self.linear4(
1215
+ self.dropout(self.activation(self.linear3(src)).clone())))
1216
+ src = self.norm2(src)
1217
+
1218
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1219
+ glb = glb + F.interpolate(patches2image(src),
1220
+ size=glb.shape[-2:],
1221
+ mode='nearest') # freshed glb
1222
+ return torch.cat((src, glb), 0), token_attention_map
1223
+
1224
+
1225
+ class inf_MCRM(nn.Module):
1226
+
1227
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1228
+ super(inf_MCRM, self).__init__()
1229
+ self.attention = nn.ModuleList([
1230
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1231
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1232
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1233
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1234
+ ])
1235
+
1236
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1237
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1238
+ self.norm1 = nn.LayerNorm(d_model)
1239
+ self.norm2 = nn.LayerNorm(d_model)
1240
+ self.dropout = nn.Dropout(0.1)
1241
+ self.dropout1 = nn.Dropout(0.1)
1242
+ self.dropout2 = nn.Dropout(0.1)
1243
+ self.sigmoid = nn.Sigmoid()
1244
+ self.activation = get_activation_fn('relu')
1245
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1246
+ self.pool_ratios = pool_ratios
1247
+ self.positional_encoding = PositionEmbeddingSine(
1248
+ num_pos_feats=d_model // 2, normalize=True)
1249
+
1250
+ def forward(self, x):
1251
+ b, c, h, w = x.size()
1252
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1253
+ # b(4),c,h,w
1254
+ patched_glb = rearrange(glb,
1255
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1256
+ hg=2,
1257
+ wg=2)
1258
+
1259
+ # generate token attention map
1260
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1261
+ token_attention_map = F.interpolate(token_attention_map,
1262
+ size=patches2image(loc).shape[-2:],
1263
+ mode='nearest')
1264
+ loc = loc * rearrange(token_attention_map,
1265
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1266
+ hg=2,
1267
+ wg=2)
1268
+ pools = []
1269
+ for pool_ratio in self.pool_ratios:
1270
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1271
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1272
+ pools.append(rearrange(pool,
1273
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1274
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1275
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1276
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1277
+ outputs = []
1278
+ for i, q in enumerate(
1279
+ loc_.unbind(dim=0)): # traverse all local patches
1280
+ # np*hw,1,c
1281
+ v = pools[i]
1282
+ k = v
1283
+ outputs.append(self.attention[i](q, k, v)[0])
1284
+ outputs = torch.cat(outputs, 1)
1285
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1286
+ src = self.norm1(src)
1287
+ src = src + self.dropout2(
1288
+ self.linear4(
1289
+ self.dropout(self.activation(self.linear3(src)).clone())))
1290
+ src = self.norm2(src)
1291
+
1292
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1293
+ glb = glb + F.interpolate(patches2image(src),
1294
+ size=glb.shape[-2:],
1295
+ mode='nearest') # freshed glb
1296
+ return torch.cat((src, glb), 0)
1297
+
1298
+
1299
+ # model for single-scale training
1300
+ class MVANet(nn.Module):
1301
+
1302
+ def __init__(self):
1303
+ super().__init__()
1304
+ self.backbone = SwinB(pretrained=True)
1305
+ emb_dim = 128
1306
+ self.sideout5 = nn.Sequential(
1307
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1308
+ self.sideout4 = nn.Sequential(
1309
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1310
+ self.sideout3 = nn.Sequential(
1311
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1312
+ self.sideout2 = nn.Sequential(
1313
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1314
+ self.sideout1 = nn.Sequential(
1315
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1316
+
1317
+ self.output5 = make_cbr(1024, emb_dim)
1318
+ self.output4 = make_cbr(512, emb_dim)
1319
+ self.output3 = make_cbr(256, emb_dim)
1320
+ self.output2 = make_cbr(128, emb_dim)
1321
+ self.output1 = make_cbr(128, emb_dim)
1322
+
1323
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
1324
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1325
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1326
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1327
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1328
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
1329
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
1330
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
1331
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
1332
+
1333
+ self.insmask_head = nn.Sequential(
1334
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1335
+ nn.BatchNorm2d(384), nn.PReLU(),
1336
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1337
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1338
+
1339
+ self.shallow = nn.Sequential(
1340
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1341
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1342
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1343
+ self.output = nn.Sequential(
1344
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1345
+
1346
+ for m in self.modules():
1347
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1348
+ m.inplace = True
1349
+
1350
+ def forward(self, x):
1351
+ x = x.to(dtype=torch_dtype, device=torch_device)
1352
+ shallow = self.shallow(x)
1353
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1354
+ loc = image2patches(x)
1355
+ input = torch.cat((loc, glb), dim=0)
1356
+ feature = self.backbone(input)
1357
+ e5 = self.output5(feature[4]) # (5,128,16,16)
1358
+ e4 = self.output4(feature[3]) # (5,128,32,32)
1359
+ e3 = self.output3(feature[2]) # (5,128,64,64)
1360
+ e2 = self.output2(feature[1]) # (5,128,128,128)
1361
+ e1 = self.output1(feature[0]) # (5,128,128,128)
1362
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1363
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
1364
+
1365
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
1366
+ e4 = self.conv4(e4)
1367
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
1368
+ e3 = self.conv3(e3)
1369
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
1370
+ e2 = self.conv2(e2)
1371
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
1372
+ e1 = self.conv1(e1)
1373
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1374
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
1375
+ # add glb feat in
1376
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1377
+ # merge
1378
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
1379
+ # shallow feature merge
1380
+ final_output = final_output + resize_as(shallow, final_output)
1381
+ final_output = self.upsample1(rescale_to(final_output))
1382
+ final_output = rescale_to(final_output +
1383
+ resize_as(shallow, final_output))
1384
+ final_output = self.upsample2(final_output)
1385
+ final_output = self.output(final_output)
1386
+ ####
1387
+ sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
1388
+ sideout4 = self.sideout4(e4)
1389
+ sideout3 = self.sideout3(e3)
1390
+ sideout2 = self.sideout2(e2)
1391
+ sideout1 = self.sideout1(e1)
1392
+ #######glb_sideouts ######
1393
+ glb5 = self.sideout5(glb_e5)
1394
+ glb4 = sideout4[-1, :, :, :].unsqueeze(0)
1395
+ glb3 = sideout3[-1, :, :, :].unsqueeze(0)
1396
+ glb2 = sideout2[-1, :, :, :].unsqueeze(0)
1397
+ glb1 = sideout1[-1, :, :, :].unsqueeze(0)
1398
+ ####### concat 4 to 1 #######
1399
+ sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
1400
+ device=torch_device)
1401
+ sideout2 = patches2image(sideout2[:-1]).to(
1402
+ dtype=torch_dtype,
1403
+ device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
1404
+ sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
1405
+ device=torch_device)
1406
+ sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
1407
+ device=torch_device)
1408
+ sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
1409
+ device=torch_device)
1410
+ if self.training:
1411
+ return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
1412
+ else:
1413
+ return final_output
1414
+
1415
+
1416
+ # model for multi-scale testing
1417
+ class inf_MVANet(nn.Module):
1418
+
1419
+ def __init__(self):
1420
+ super().__init__()
1421
+ # self.backbone = SwinB(pretrained=True)
1422
+ self.backbone = SwinB(pretrained=False)
1423
+
1424
+ emb_dim = 128
1425
+ self.output5 = make_cbr(1024, emb_dim)
1426
+ self.output4 = make_cbr(512, emb_dim)
1427
+ self.output3 = make_cbr(256, emb_dim)
1428
+ self.output2 = make_cbr(128, emb_dim)
1429
+ self.output1 = make_cbr(128, emb_dim)
1430
+
1431
+ self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
1432
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1433
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1434
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1435
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1436
+ self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1437
+ self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1438
+ self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1439
+ self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1440
+
1441
+ self.insmask_head = nn.Sequential(
1442
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1443
+ nn.BatchNorm2d(384), nn.PReLU(),
1444
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1445
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1446
+
1447
+ self.shallow = nn.Sequential(
1448
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1449
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1450
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1451
+ self.output = nn.Sequential(
1452
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1453
+
1454
+ for m in self.modules():
1455
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1456
+ m.inplace = True
1457
+
1458
+ def forward(self, x):
1459
+ shallow = self.shallow(x)
1460
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1461
+ loc = image2patches(x)
1462
+ input = torch.cat((loc, glb), dim=0)
1463
+ feature = self.backbone(input)
1464
+ e5 = self.output5(feature[4])
1465
+ e4 = self.output4(feature[3])
1466
+ e3 = self.output3(feature[2])
1467
+ e2 = self.output2(feature[1])
1468
+ e1 = self.output1(feature[0])
1469
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1470
+ e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
1471
+
1472
+ e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
1473
+ e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
1474
+ e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
1475
+ e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
1476
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1477
+ # after decoder, concat loc features to a whole one, and merge
1478
+ output1_cat = patches2image(loc_e1)
1479
+ # add glb feat in
1480
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1481
+ # merge
1482
+ final_output = self.insmask_head(output1_cat)
1483
+ # shallow feature merge
1484
+ final_output = final_output + resize_as(shallow, final_output)
1485
+ final_output = self.upsample1(rescale_to(final_output))
1486
+ final_output = rescale_to(final_output +
1487
+ resize_as(shallow, final_output))
1488
+ final_output = self.upsample2(final_output)
1489
+ final_output = self.output(final_output)
1490
+ return final_output
1491
+
1492
+
1493
+ class load_MVANet_Model:
1494
+
1495
+ def __init__(self):
1496
+ pass
1497
+
1498
+ @classmethod
1499
+ def INPUT_TYPES(s):
1500
+ return {
1501
+ "required": {},
1502
+ }
1503
+
1504
+ RETURN_TYPES = ("MVANet_Model", )
1505
+ FUNCTION = "test"
1506
+ CATEGORY = "MVANet"
1507
+
1508
+ def test(self):
1509
+ return (load_model(get_model_path()), )
1510
+
1511
+
1512
+ class run_MVANet_inference:
1513
+
1514
+ def __init__(self):
1515
+ pass
1516
+
1517
+ @classmethod
1518
+ def INPUT_TYPES(s):
1519
+ return {
1520
+ "required": {
1521
+ "image": ("IMAGE", ),
1522
+ "MVANet_Model": ("MVANet_Model", ),
1523
+ },
1524
+ }
1525
+
1526
+ RETURN_TYPES = ("MASK", )
1527
+ FUNCTION = "test"
1528
+ CATEGORY = "MVANet"
1529
+
1530
+ def test(
1531
+ self,
1532
+ image,
1533
+ MVANet_Model,
1534
+ ):
1535
+ ret = do_infer_tensor2tensor(img=image, net=MVANet_Model)
1536
+
1537
+ return (ret, )
1538
+
1539
+
1540
+ NODE_CLASS_MAPPINGS = {
1541
+ "load_MVANet_Model": load_MVANet_Model,
1542
+ "run_MVANet_inference": run_MVANet_inference
1543
+ }
1544
+
1545
+ NODE_DISPLAY_NAME_MAPPINGS = {
1546
+ "load_MVANet_Model": "load MVANet Model",
1547
+ "load_MVANet_Model": "load MVANet Model"
1548
+ }
ComfyUI_MVANet/README.org ADDED
@@ -0,0 +1,1694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * COMMENT Sample
2
+
3
+ ** Shell script to download
4
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
5
+ #+end_src
6
+
7
+ ** MVANet_inference import
8
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
9
+ #+end_src
10
+
11
+ ** MVANet_inference function
12
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
13
+ #+end_src
14
+
15
+ ** MVANet_inference class
16
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
17
+ #+end_src
18
+
19
+ ** MVANet_inference execute
20
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
21
+ #+end_src
22
+
23
+ ** MVANet_inference unify
24
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
25
+ #+end_src
26
+
27
+ * Download the code:
28
+
29
+ ** Function to download
30
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
31
+ get_repo(){
32
+ DIR_REPO="${HOME}/GITHUB/$('echo' "${1}" | 'sed' 's/^git@github.com://g ; s@^https://github.com/@@g ; s@.git$@@g' )"
33
+ DIR_BASE="$('dirname' '--' "${DIR_REPO}")"
34
+ mkdir -pv -- "${DIR_BASE}"
35
+ cd "${DIR_BASE}"
36
+ git clone "${1}"
37
+ cd "${DIR_REPO}"
38
+ git pull
39
+ git submodule update --recursive --init
40
+ }
41
+ #+end_src
42
+
43
+ ** Download
44
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
45
+ get_repo 'https://github.com/qianyu-dlut/MVANet.git'
46
+ #+end_src
47
+
48
+ * Dependencies
49
+ #+begin_src conf :tangle ./requirements.txt
50
+ timm
51
+ einops
52
+ wget
53
+ #+end_src
54
+
55
+ * Python inference
56
+
57
+ ** Important configs
58
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
59
+ import os
60
+ import sys
61
+
62
+ HOME_DIR = os.environ.get('HOME', '/root')
63
+ MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
64
+ finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
65
+ pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
66
+ #+end_src
67
+
68
+ ** MVANet_inference import
69
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
70
+ import math
71
+ import numpy as np
72
+ import cv2
73
+ import wget
74
+
75
+ import torch
76
+ import torch.nn as nn
77
+ import torch.nn.functional as F
78
+ import torch.utils.checkpoint as checkpoint
79
+ from torch.autograd import Variable
80
+ from torch import nn
81
+ from torchvision import transforms
82
+
83
+ from einops import rearrange
84
+
85
+ from timm.models import load_checkpoint
86
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
87
+
88
+ torch_device = 'cuda'
89
+ torch_dtype = torch.float16
90
+ #+end_src
91
+
92
+ ** COMMENT Load image using CV
93
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
94
+ def load_image(input_image_path):
95
+ img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
96
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
97
+ return img
98
+
99
+
100
+ def load_image_torch(input_image_path):
101
+ img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
102
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103
+ img = torch.from_numpy(img)
104
+ img = img.to(dtype=torch.float32)
105
+ img /= 255.0
106
+ img = img.unsqueeze(0)
107
+ return img
108
+
109
+
110
+ def save_mask(output_image_path, mask):
111
+ cv2.imwrite(output_image_path, mask)
112
+
113
+
114
+ def save_mask_torch(output_image_path, mask):
115
+ mask = mask.detach().cpu()
116
+ mask *= 255.0
117
+ mask = mask.clamp(0, 255)
118
+ print(mask.shape)
119
+ mask = mask.squeeze(0)
120
+ mask = mask.to(dtype=torch.uint8)
121
+ print(mask.shape)
122
+ mask = mask.numpy()
123
+ print(mask.shape)
124
+ cv2.imwrite(output_image_path, mask)
125
+ #+end_src
126
+
127
+ ** MVANet_inference function
128
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
129
+ def check_mkdir(dir_name):
130
+ if not os.path.isdir(dir_name):
131
+ os.makedirs(dir_name)
132
+
133
+
134
+ def SwinT(pretrained=True):
135
+ model = SwinTransformer(embed_dim=96,
136
+ depths=[2, 2, 6, 2],
137
+ num_heads=[3, 6, 12, 24],
138
+ window_size=7)
139
+ if pretrained is True:
140
+ model.load_state_dict(torch.load(
141
+ 'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
142
+ map_location='cpu')['model'],
143
+ strict=False)
144
+
145
+ return model
146
+
147
+
148
+ def SwinS(pretrained=True):
149
+ model = SwinTransformer(embed_dim=96,
150
+ depths=[2, 2, 18, 2],
151
+ num_heads=[3, 6, 12, 24],
152
+ window_size=7)
153
+ if pretrained is True:
154
+ model.load_state_dict(torch.load(
155
+ 'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
156
+ map_location='cpu')['model'],
157
+ strict=False)
158
+
159
+ return model
160
+
161
+
162
+ def SwinB(pretrained=True):
163
+ model = SwinTransformer(embed_dim=128,
164
+ depths=[2, 2, 18, 2],
165
+ num_heads=[4, 8, 16, 32],
166
+ window_size=12)
167
+ if pretrained is True:
168
+ import os
169
+ model.load_state_dict(torch.load(pretrained_SwinB_model_path,
170
+ map_location='cpu')['model'],
171
+ strict=False)
172
+ return model
173
+
174
+
175
+ def SwinL(pretrained=True):
176
+ model = SwinTransformer(embed_dim=192,
177
+ depths=[2, 2, 18, 2],
178
+ num_heads=[6, 12, 24, 48],
179
+ window_size=12)
180
+ if pretrained is True:
181
+ model.load_state_dict(torch.load(
182
+ 'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
183
+ map_location='cpu')['model'],
184
+ strict=False)
185
+
186
+ return model
187
+
188
+
189
+ def get_activation_fn(activation):
190
+ """Return an activation function given a string"""
191
+ if activation == "relu":
192
+ return F.relu
193
+ if activation == "gelu":
194
+ return F.gelu
195
+ if activation == "glu":
196
+ return F.glu
197
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
198
+
199
+
200
+ def make_cbr(in_dim, out_dim):
201
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
202
+ nn.BatchNorm2d(out_dim), nn.PReLU())
203
+
204
+
205
+ def make_cbg(in_dim, out_dim):
206
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
207
+ nn.BatchNorm2d(out_dim), nn.GELU())
208
+
209
+
210
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
211
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
212
+
213
+
214
+ def resize_as(x, y, interpolation='bilinear'):
215
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
216
+
217
+
218
+ def image2patches(x):
219
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
220
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
221
+ return x
222
+
223
+
224
+ def patches2image(x):
225
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
226
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
227
+ return x
228
+
229
+
230
+ def window_partition(x, window_size):
231
+ """
232
+ Args:
233
+ x: (B, H, W, C)
234
+ window_size (int): window size
235
+
236
+ Returns:
237
+ windows: (num_windows*B, window_size, window_size, C)
238
+ """
239
+ B, H, W, C = x.shape
240
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
241
+ C)
242
+ windows = x.permute(0, 1, 3, 2, 4,
243
+ 5).contiguous().view(-1, window_size, window_size, C)
244
+ return windows
245
+
246
+
247
+ def window_reverse(windows, window_size, H, W):
248
+ """
249
+ Args:
250
+ windows: (num_windows*B, window_size, window_size, C)
251
+ window_size (int): Window size
252
+ H (int): Height of image
253
+ W (int): Width of image
254
+
255
+ Returns:
256
+ x: (B, H, W, C)
257
+ """
258
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
259
+ x = windows.view(B, H // window_size, W // window_size, window_size,
260
+ window_size, -1)
261
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
262
+ return x
263
+ #+end_src
264
+
265
+ ** MVANet_inference class
266
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
267
+ class Mlp(nn.Module):
268
+ """ Multilayer perceptron."""
269
+
270
+ def __init__(self,
271
+ in_features,
272
+ hidden_features=None,
273
+ out_features=None,
274
+ act_layer=nn.GELU,
275
+ drop=0.):
276
+ super().__init__()
277
+ out_features = out_features or in_features
278
+ hidden_features = hidden_features or in_features
279
+ self.fc1 = nn.Linear(in_features, hidden_features)
280
+ self.act = act_layer()
281
+ self.fc2 = nn.Linear(hidden_features, out_features)
282
+ self.drop = nn.Dropout(drop)
283
+
284
+ def forward(self, x):
285
+ x = self.fc1(x)
286
+ x = self.act(x)
287
+ x = self.drop(x)
288
+ x = self.fc2(x)
289
+ x = self.drop(x)
290
+ return x
291
+
292
+
293
+ class WindowAttention(nn.Module):
294
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
295
+ It supports both of shifted and non-shifted window.
296
+
297
+ Args:
298
+ dim (int): Number of input channels.
299
+ window_size (tuple[int]): The height and width of the window.
300
+ num_heads (int): Number of attention heads.
301
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
302
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
303
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
304
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
305
+ """
306
+
307
+ def __init__(self,
308
+ dim,
309
+ window_size,
310
+ num_heads,
311
+ qkv_bias=True,
312
+ qk_scale=None,
313
+ attn_drop=0.,
314
+ proj_drop=0.):
315
+
316
+ super().__init__()
317
+ self.dim = dim
318
+ self.window_size = window_size # Wh, Ww
319
+ self.num_heads = num_heads
320
+ head_dim = dim // num_heads
321
+ self.scale = qk_scale or head_dim**-0.5
322
+
323
+ # define a parameter table of relative position bias
324
+ self.relative_position_bias_table = nn.Parameter(
325
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
326
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
327
+
328
+ # get pair-wise relative position index for each token inside the window
329
+ coords_h = torch.arange(self.window_size[0])
330
+ coords_w = torch.arange(self.window_size[1])
331
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
332
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
333
+ relative_coords = coords_flatten[:, :,
334
+ None] - coords_flatten[:,
335
+ None, :] # 2, Wh*Ww, Wh*Ww
336
+ relative_coords = relative_coords.permute(
337
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
338
+ relative_coords[:, :,
339
+ 0] += self.window_size[0] - 1 # shift to start from 0
340
+ relative_coords[:, :, 1] += self.window_size[1] - 1
341
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
342
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
343
+ self.register_buffer("relative_position_index",
344
+ relative_position_index)
345
+
346
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
347
+ self.attn_drop = nn.Dropout(attn_drop)
348
+ self.proj = nn.Linear(dim, dim)
349
+ self.proj_drop = nn.Dropout(proj_drop)
350
+
351
+ trunc_normal_(self.relative_position_bias_table, std=.02)
352
+ self.softmax = nn.Softmax(dim=-1)
353
+
354
+ def forward(self, x, mask=None):
355
+ """ Forward function.
356
+
357
+ Args:
358
+ x: input features with shape of (num_windows*B, N, C)
359
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
360
+ """
361
+ x = x.to(dtype=torch_dtype, device=torch_device)
362
+ B_, N, C = x.shape
363
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
364
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
365
+ q, k, v = qkv[0], qkv[1], qkv[
366
+ 2] # make torchscript happy (cannot use tensor as tuple)
367
+
368
+ q = q * self.scale
369
+ attn = (q @ k.transpose(-2, -1))
370
+
371
+ relative_position_bias = self.relative_position_bias_table[
372
+ self.relative_position_index.view(-1)].view(
373
+ self.window_size[0] * self.window_size[1],
374
+ self.window_size[0] * self.window_size[1],
375
+ -1) # Wh*Ww,Wh*Ww,nH
376
+ relative_position_bias = relative_position_bias.permute(
377
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
378
+ attn = attn + relative_position_bias.unsqueeze(0)
379
+
380
+ if mask is not None:
381
+ nW = mask.shape[0]
382
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
383
+ N) + mask.unsqueeze(1).unsqueeze(0)
384
+ attn = attn.view(-1, self.num_heads, N, N)
385
+ attn = self.softmax(attn)
386
+ else:
387
+ attn = self.softmax(attn)
388
+
389
+ attn = self.attn_drop(attn)
390
+ attn = attn.to(dtype=torch_dtype, device=torch_device)
391
+ v = v.to(dtype=torch_dtype, device=torch_device)
392
+
393
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
394
+ x = self.proj(x)
395
+ x = self.proj_drop(x)
396
+ return x
397
+
398
+
399
+ class SwinTransformerBlock(nn.Module):
400
+ """ Swin Transformer Block.
401
+
402
+ Args:
403
+ dim (int): Number of input channels.
404
+ num_heads (int): Number of attention heads.
405
+ window_size (int): Window size.
406
+ shift_size (int): Shift size for SW-MSA.
407
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
408
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
409
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
410
+ drop (float, optional): Dropout rate. Default: 0.0
411
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
412
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
413
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
414
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
415
+ """
416
+
417
+ def __init__(self,
418
+ dim,
419
+ num_heads,
420
+ window_size=7,
421
+ shift_size=0,
422
+ mlp_ratio=4.,
423
+ qkv_bias=True,
424
+ qk_scale=None,
425
+ drop=0.,
426
+ attn_drop=0.,
427
+ drop_path=0.,
428
+ act_layer=nn.GELU,
429
+ norm_layer=nn.LayerNorm):
430
+ super().__init__()
431
+ self.dim = dim
432
+ self.num_heads = num_heads
433
+ self.window_size = window_size
434
+ self.shift_size = shift_size
435
+ self.mlp_ratio = mlp_ratio
436
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
437
+
438
+ self.norm1 = norm_layer(dim)
439
+ self.attn = WindowAttention(dim,
440
+ window_size=to_2tuple(self.window_size),
441
+ num_heads=num_heads,
442
+ qkv_bias=qkv_bias,
443
+ qk_scale=qk_scale,
444
+ attn_drop=attn_drop,
445
+ proj_drop=drop)
446
+
447
+ self.drop_path = DropPath(
448
+ drop_path) if drop_path > 0. else nn.Identity()
449
+ self.norm2 = norm_layer(dim)
450
+ mlp_hidden_dim = int(dim * mlp_ratio)
451
+ self.mlp = Mlp(in_features=dim,
452
+ hidden_features=mlp_hidden_dim,
453
+ act_layer=act_layer,
454
+ drop=drop)
455
+
456
+ self.H = None
457
+ self.W = None
458
+
459
+ def forward(self, x, mask_matrix):
460
+ """ Forward function.
461
+
462
+ Args:
463
+ x: Input feature, tensor size (B, H*W, C).
464
+ H, W: Spatial resolution of the input feature.
465
+ mask_matrix: Attention mask for cyclic shift.
466
+ """
467
+ B, L, C = x.shape
468
+ H, W = self.H, self.W
469
+ assert L == H * W, "input feature has wrong size"
470
+
471
+ shortcut = x
472
+ x = self.norm1(x)
473
+ x = x.view(B, H, W, C)
474
+
475
+ # pad feature maps to multiples of window size
476
+ pad_l = pad_t = 0
477
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
478
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
479
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
480
+ _, Hp, Wp, _ = x.shape
481
+
482
+ # cyclic shift
483
+ if self.shift_size > 0:
484
+ shifted_x = torch.roll(x,
485
+ shifts=(-self.shift_size, -self.shift_size),
486
+ dims=(1, 2))
487
+ attn_mask = mask_matrix
488
+ else:
489
+ shifted_x = x
490
+ attn_mask = None
491
+
492
+ # partition windows
493
+ x_windows = window_partition(
494
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
495
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
496
+ C) # nW*B, window_size*window_size, C
497
+
498
+ # W-MSA/SW-MSA
499
+ attn_windows = self.attn(
500
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
501
+
502
+ # merge windows
503
+ attn_windows = attn_windows.view(-1, self.window_size,
504
+ self.window_size, C)
505
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
506
+ Wp) # B H' W' C
507
+
508
+ # reverse cyclic shift
509
+ if self.shift_size > 0:
510
+ x = torch.roll(shifted_x,
511
+ shifts=(self.shift_size, self.shift_size),
512
+ dims=(1, 2))
513
+ else:
514
+ x = shifted_x
515
+
516
+ if pad_r > 0 or pad_b > 0:
517
+ x = x[:, :H, :W, :].contiguous()
518
+
519
+ x = x.view(B, H * W, C)
520
+
521
+ # FFN
522
+ x = shortcut + self.drop_path(x)
523
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
524
+
525
+ return x
526
+
527
+
528
+ class PatchMerging(nn.Module):
529
+ """ Patch Merging Layer
530
+
531
+ Args:
532
+ dim (int): Number of input channels.
533
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
534
+ """
535
+
536
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
537
+ super().__init__()
538
+ self.dim = dim
539
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
540
+ self.norm = norm_layer(4 * dim)
541
+
542
+ def forward(self, x, H, W):
543
+ """ Forward function.
544
+
545
+ Args:
546
+ x: Input feature, tensor size (B, H*W, C).
547
+ H, W: Spatial resolution of the input feature.
548
+ """
549
+ B, L, C = x.shape
550
+ assert L == H * W, "input feature has wrong size"
551
+
552
+ x = x.view(B, H, W, C)
553
+
554
+ # padding
555
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
556
+ if pad_input:
557
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
558
+
559
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
560
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
561
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
562
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
563
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
564
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
565
+
566
+ x = self.norm(x)
567
+ x = self.reduction(x)
568
+
569
+ return x
570
+
571
+
572
+ class BasicLayer(nn.Module):
573
+ """ A basic Swin Transformer layer for one stage.
574
+
575
+ Args:
576
+ dim (int): Number of feature channels
577
+ depth (int): Depths of this stage.
578
+ num_heads (int): Number of attention head.
579
+ window_size (int): Local window size. Default: 7.
580
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
581
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
582
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
583
+ drop (float, optional): Dropout rate. Default: 0.0
584
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
585
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
586
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
587
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
588
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
589
+ """
590
+
591
+ def __init__(self,
592
+ dim,
593
+ depth,
594
+ num_heads,
595
+ window_size=7,
596
+ mlp_ratio=4.,
597
+ qkv_bias=True,
598
+ qk_scale=None,
599
+ drop=0.,
600
+ attn_drop=0.,
601
+ drop_path=0.,
602
+ norm_layer=nn.LayerNorm,
603
+ downsample=None,
604
+ use_checkpoint=False):
605
+ super().__init__()
606
+ self.window_size = window_size
607
+ self.shift_size = window_size // 2
608
+ self.depth = depth
609
+ self.use_checkpoint = use_checkpoint
610
+
611
+ # build blocks
612
+ self.blocks = nn.ModuleList([
613
+ SwinTransformerBlock(dim=dim,
614
+ num_heads=num_heads,
615
+ window_size=window_size,
616
+ shift_size=0 if
617
+ (i % 2 == 0) else window_size // 2,
618
+ mlp_ratio=mlp_ratio,
619
+ qkv_bias=qkv_bias,
620
+ qk_scale=qk_scale,
621
+ drop=drop,
622
+ attn_drop=attn_drop,
623
+ drop_path=drop_path[i] if isinstance(
624
+ drop_path, list) else drop_path,
625
+ norm_layer=norm_layer) for i in range(depth)
626
+ ])
627
+
628
+ # patch merging layer
629
+ if downsample is not None:
630
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
631
+ else:
632
+ self.downsample = None
633
+
634
+ def forward(self, x, H, W):
635
+ """ Forward function.
636
+
637
+ Args:
638
+ x: Input feature, tensor size (B, H*W, C).
639
+ H, W: Spatial resolution of the input feature.
640
+ """
641
+
642
+ # calculate attention mask for SW-MSA
643
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
644
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
645
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
646
+ h_slices = (slice(0, -self.window_size),
647
+ slice(-self.window_size,
648
+ -self.shift_size), slice(-self.shift_size, None))
649
+ w_slices = (slice(0, -self.window_size),
650
+ slice(-self.window_size,
651
+ -self.shift_size), slice(-self.shift_size, None))
652
+ cnt = 0
653
+ for h in h_slices:
654
+ for w in w_slices:
655
+ img_mask[:, h, w, :] = cnt
656
+ cnt += 1
657
+
658
+ mask_windows = window_partition(
659
+ img_mask, self.window_size) # nW, window_size, window_size, 1
660
+ mask_windows = mask_windows.view(-1,
661
+ self.window_size * self.window_size)
662
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
663
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
664
+ float(-100.0)).masked_fill(
665
+ attn_mask == 0, float(0.0))
666
+
667
+ for blk in self.blocks:
668
+ blk.H, blk.W = H, W
669
+ if self.use_checkpoint:
670
+ x = checkpoint.checkpoint(blk, x, attn_mask)
671
+ else:
672
+ x = blk(x, attn_mask)
673
+ if self.downsample is not None:
674
+ x_down = self.downsample(x, H, W)
675
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
676
+ return x, H, W, x_down, Wh, Ww
677
+ else:
678
+ return x, H, W, x, H, W
679
+
680
+
681
+ class PatchEmbed(nn.Module):
682
+ """ Image to Patch Embedding
683
+
684
+ Args:
685
+ patch_size (int): Patch token size. Default: 4.
686
+ in_chans (int): Number of input image channels. Default: 3.
687
+ embed_dim (int): Number of linear projection output channels. Default: 96.
688
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
689
+ """
690
+
691
+ def __init__(self,
692
+ patch_size=4,
693
+ in_chans=3,
694
+ embed_dim=96,
695
+ norm_layer=None):
696
+ super().__init__()
697
+ patch_size = to_2tuple(patch_size)
698
+ self.patch_size = patch_size
699
+
700
+ self.in_chans = in_chans
701
+ self.embed_dim = embed_dim
702
+
703
+ self.proj = nn.Conv2d(in_chans,
704
+ embed_dim,
705
+ kernel_size=patch_size,
706
+ stride=patch_size)
707
+ if norm_layer is not None:
708
+ self.norm = norm_layer(embed_dim)
709
+ else:
710
+ self.norm = None
711
+
712
+ def forward(self, x):
713
+ """Forward function."""
714
+ # padding
715
+ _, _, H, W = x.size()
716
+ if W % self.patch_size[1] != 0:
717
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
718
+ if H % self.patch_size[0] != 0:
719
+ x = F.pad(x,
720
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
721
+
722
+ x = self.proj(x) # B C Wh Ww
723
+ if self.norm is not None:
724
+ Wh, Ww = x.size(2), x.size(3)
725
+ x = x.flatten(2).transpose(1, 2)
726
+ x = self.norm(x)
727
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
728
+
729
+ return x
730
+
731
+
732
+ class SwinTransformer(nn.Module):
733
+ """ Swin Transformer backbone.
734
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
735
+ https://arxiv.org/pdf/2103.14030
736
+
737
+ Args:
738
+ pretrain_img_size (int): Input image size for training the pretrained model,
739
+ used in absolute postion embedding. Default 224.
740
+ patch_size (int | tuple(int)): Patch size. Default: 4.
741
+ in_chans (int): Number of input image channels. Default: 3.
742
+ embed_dim (int): Number of linear projection output channels. Default: 96.
743
+ depths (tuple[int]): Depths of each Swin Transformer stage.
744
+ num_heads (tuple[int]): Number of attention head of each stage.
745
+ window_size (int): Window size. Default: 7.
746
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
747
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
748
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
749
+ drop_rate (float): Dropout rate.
750
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
751
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
752
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
753
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
754
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
755
+ out_indices (Sequence[int]): Output from which stages.
756
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
757
+ -1 means not freezing any parameters.
758
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
759
+ """
760
+
761
+ def __init__(self,
762
+ pretrain_img_size=224,
763
+ patch_size=4,
764
+ in_chans=3,
765
+ embed_dim=96,
766
+ depths=[2, 2, 6, 2],
767
+ num_heads=[3, 6, 12, 24],
768
+ window_size=7,
769
+ mlp_ratio=4.,
770
+ qkv_bias=True,
771
+ qk_scale=None,
772
+ drop_rate=0.,
773
+ attn_drop_rate=0.,
774
+ drop_path_rate=0.2,
775
+ norm_layer=nn.LayerNorm,
776
+ ape=False,
777
+ patch_norm=True,
778
+ out_indices=(0, 1, 2, 3),
779
+ frozen_stages=-1,
780
+ use_checkpoint=False):
781
+ super().__init__()
782
+
783
+ self.pretrain_img_size = pretrain_img_size
784
+ self.num_layers = len(depths)
785
+ self.embed_dim = embed_dim
786
+ self.ape = ape
787
+ self.patch_norm = patch_norm
788
+ self.out_indices = out_indices
789
+ self.frozen_stages = frozen_stages
790
+
791
+ # split image into non-overlapping patches
792
+ self.patch_embed = PatchEmbed(
793
+ patch_size=patch_size,
794
+ in_chans=in_chans,
795
+ embed_dim=embed_dim,
796
+ norm_layer=norm_layer if self.patch_norm else None)
797
+
798
+ # absolute position embedding
799
+ if self.ape:
800
+ pretrain_img_size = to_2tuple(pretrain_img_size)
801
+ patch_size = to_2tuple(patch_size)
802
+ patches_resolution = [
803
+ pretrain_img_size[0] // patch_size[0],
804
+ pretrain_img_size[1] // patch_size[1]
805
+ ]
806
+
807
+ self.absolute_pos_embed = nn.Parameter(
808
+ torch.zeros(1, embed_dim, patches_resolution[0],
809
+ patches_resolution[1]))
810
+ trunc_normal_(self.absolute_pos_embed, std=.02)
811
+
812
+ self.pos_drop = nn.Dropout(p=drop_rate)
813
+
814
+ # stochastic depth
815
+ dpr = [
816
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
817
+ ] # stochastic depth decay rule
818
+
819
+ # build layers
820
+ self.layers = nn.ModuleList()
821
+ for i_layer in range(self.num_layers):
822
+ layer = BasicLayer(
823
+ dim=int(embed_dim * 2**i_layer),
824
+ depth=depths[i_layer],
825
+ num_heads=num_heads[i_layer],
826
+ window_size=window_size,
827
+ mlp_ratio=mlp_ratio,
828
+ qkv_bias=qkv_bias,
829
+ qk_scale=qk_scale,
830
+ drop=drop_rate,
831
+ attn_drop=attn_drop_rate,
832
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
833
+ norm_layer=norm_layer,
834
+ downsample=PatchMerging if
835
+ (i_layer < self.num_layers - 1) else None,
836
+ use_checkpoint=use_checkpoint)
837
+ self.layers.append(layer)
838
+
839
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
840
+ self.num_features = num_features
841
+
842
+ # add a norm layer for each output
843
+ for i_layer in out_indices:
844
+ layer = norm_layer(num_features[i_layer])
845
+ layer_name = f'norm{i_layer}'
846
+ self.add_module(layer_name, layer)
847
+
848
+ self._freeze_stages()
849
+
850
+ def _freeze_stages(self):
851
+ if self.frozen_stages >= 0:
852
+ self.patch_embed.eval()
853
+ for param in self.patch_embed.parameters():
854
+ param.requires_grad = False
855
+
856
+ if self.frozen_stages >= 1 and self.ape:
857
+ self.absolute_pos_embed.requires_grad = False
858
+
859
+ if self.frozen_stages >= 2:
860
+ self.pos_drop.eval()
861
+ for i in range(0, self.frozen_stages - 1):
862
+ m = self.layers[i]
863
+ m.eval()
864
+ for param in m.parameters():
865
+ param.requires_grad = False
866
+
867
+ def init_weights(self, pretrained=None):
868
+ """Initialize the weights in backbone.
869
+
870
+ Args:
871
+ pretrained (str, optional): Path to pre-trained weights.
872
+ Defaults to None.
873
+ """
874
+
875
+ def _init_weights(m):
876
+ if isinstance(m, nn.Linear):
877
+ trunc_normal_(m.weight, std=.02)
878
+ if isinstance(m, nn.Linear) and m.bias is not None:
879
+ nn.init.constant_(m.bias, 0)
880
+ elif isinstance(m, nn.LayerNorm):
881
+ nn.init.constant_(m.bias, 0)
882
+ nn.init.constant_(m.weight, 1.0)
883
+
884
+ if isinstance(pretrained, str):
885
+ self.apply(_init_weights)
886
+ load_checkpoint(self, pretrained, strict=False, logger=None)
887
+ elif pretrained is None:
888
+ self.apply(_init_weights)
889
+ else:
890
+ raise TypeError('pretrained must be a str or None')
891
+
892
+ def forward(self, x):
893
+ x = self.patch_embed(x)
894
+
895
+ Wh, Ww = x.size(2), x.size(3)
896
+ if self.ape:
897
+ # interpolate the position embedding to the corresponding size
898
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
899
+ size=(Wh, Ww),
900
+ mode='bicubic')
901
+ x = (x + absolute_pos_embed) # B Wh*Ww C
902
+
903
+ outs = [x.contiguous()]
904
+ x = x.flatten(2).transpose(1, 2)
905
+ x = self.pos_drop(x)
906
+ for i in range(self.num_layers):
907
+ layer = self.layers[i]
908
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
909
+
910
+ if i in self.out_indices:
911
+ norm_layer = getattr(self, f'norm{i}')
912
+ x_out = norm_layer(x_out)
913
+
914
+ out = x_out.view(-1, H, W,
915
+ self.num_features[i]).permute(0, 3, 1,
916
+ 2).contiguous()
917
+ outs.append(out)
918
+
919
+ return tuple(outs)
920
+
921
+ def train(self, mode=True):
922
+ """Convert the model into training mode while keep layers freezed."""
923
+ super(SwinTransformer, self).train(mode)
924
+ self._freeze_stages()
925
+
926
+
927
+ class PositionEmbeddingSine:
928
+
929
+ def __init__(self,
930
+ num_pos_feats=64,
931
+ temperature=10000,
932
+ normalize=False,
933
+ scale=None):
934
+ super().__init__()
935
+ self.num_pos_feats = num_pos_feats
936
+ self.temperature = temperature
937
+ self.normalize = normalize
938
+ if scale is not None and normalize is False:
939
+ raise ValueError("normalize should be True if scale is passed")
940
+ if scale is None:
941
+ scale = 2 * math.pi
942
+ self.scale = scale
943
+ self.dim_t = torch.arange(0,
944
+ self.num_pos_feats,
945
+ dtype=torch_dtype,
946
+ device=torch_device)
947
+
948
+ def __call__(self, b, h, w):
949
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
950
+ assert mask is not None
951
+ not_mask = ~mask
952
+ y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
953
+ x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
954
+ if self.normalize:
955
+ eps = 1e-6
956
+ y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
957
+ self.scale).to(device=torch_device, dtype=torch_dtype)
958
+ x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
959
+ self.scale).to(device=torch_device, dtype=torch_dtype)
960
+
961
+ dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
962
+
963
+ pos_x = x_embed[:, :, :, None] / dim_t
964
+ pos_y = y_embed[:, :, :, None] / dim_t
965
+ pos_x = torch.stack(
966
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
967
+ dim=4).flatten(3)
968
+ pos_y = torch.stack(
969
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
970
+ dim=4).flatten(3)
971
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
972
+
973
+
974
+ class MCLM(nn.Module):
975
+
976
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
977
+ super(MCLM, self).__init__()
978
+ self.attention = nn.ModuleList([
979
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
980
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
981
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
982
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
983
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
984
+ ])
985
+
986
+ self.linear1 = nn.Linear(d_model, d_model * 2)
987
+ self.linear2 = nn.Linear(d_model * 2, d_model)
988
+ self.linear3 = nn.Linear(d_model, d_model * 2)
989
+ self.linear4 = nn.Linear(d_model * 2, d_model)
990
+ self.norm1 = nn.LayerNorm(d_model)
991
+ self.norm2 = nn.LayerNorm(d_model)
992
+ self.dropout = nn.Dropout(0.1)
993
+ self.dropout1 = nn.Dropout(0.1)
994
+ self.dropout2 = nn.Dropout(0.1)
995
+ self.activation = get_activation_fn('relu')
996
+ self.pool_ratios = pool_ratios
997
+ self.p_poses = []
998
+ self.g_pos = None
999
+ self.positional_encoding = PositionEmbeddingSine(
1000
+ num_pos_feats=d_model // 2, normalize=True)
1001
+
1002
+ def forward(self, l, g):
1003
+ """
1004
+ l: 4,c,h,w
1005
+ g: 1,c,h,w
1006
+ """
1007
+ b, c, h, w = l.size()
1008
+ # 4,c,h,w -> 1,c,2h,2w
1009
+ concated_locs = rearrange(l,
1010
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1011
+ hg=2,
1012
+ wg=2)
1013
+
1014
+ pools = []
1015
+ for pool_ratio in self.pool_ratios:
1016
+ # b,c,h,w
1017
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1018
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1019
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1020
+ if self.g_pos is None:
1021
+ pos_emb = self.positional_encoding(pool.shape[0],
1022
+ pool.shape[2],
1023
+ pool.shape[3])
1024
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1025
+ self.p_poses.append(pos_emb)
1026
+ pools = torch.cat(pools, 0)
1027
+ if self.g_pos is None:
1028
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1029
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
1030
+ g.shape[3])
1031
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1032
+
1033
+ # attention between glb (q) & multisensory concated-locs (k,v)
1034
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1035
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1036
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1037
+ g_hw_b_c = self.norm1(g_hw_b_c)
1038
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1039
+ self.linear2(
1040
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1041
+ g_hw_b_c = self.norm2(g_hw_b_c)
1042
+
1043
+ # attention between origin locs (q) & freashed glb (k,v)
1044
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1045
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1046
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1047
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1048
+ ng=2,
1049
+ nw=2)
1050
+ outputs_re = []
1051
+ for i, (_l, _g) in enumerate(
1052
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1053
+ outputs_re.append(self.attention[i + 1](_l, _g,
1054
+ _g)[0]) # (h w) 1 c
1055
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1056
+
1057
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1058
+ l_hw_b_c = self.norm1(l_hw_b_c)
1059
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1060
+ self.linear4(
1061
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1062
+ l_hw_b_c = self.norm2(l_hw_b_c)
1063
+
1064
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1065
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1066
+
1067
+
1068
+ class inf_MCLM(nn.Module):
1069
+
1070
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
1071
+ super(inf_MCLM, self).__init__()
1072
+ self.attention = nn.ModuleList([
1073
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1074
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1075
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1076
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1077
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1078
+ ])
1079
+
1080
+ self.linear1 = nn.Linear(d_model, d_model * 2)
1081
+ self.linear2 = nn.Linear(d_model * 2, d_model)
1082
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1083
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1084
+ self.norm1 = nn.LayerNorm(d_model)
1085
+ self.norm2 = nn.LayerNorm(d_model)
1086
+ self.dropout = nn.Dropout(0.1)
1087
+ self.dropout1 = nn.Dropout(0.1)
1088
+ self.dropout2 = nn.Dropout(0.1)
1089
+ self.activation = get_activation_fn('relu')
1090
+ self.pool_ratios = pool_ratios
1091
+ self.p_poses = []
1092
+ self.g_pos = None
1093
+ self.positional_encoding = PositionEmbeddingSine(
1094
+ num_pos_feats=d_model // 2, normalize=True)
1095
+
1096
+ def forward(self, l, g):
1097
+ """
1098
+ l: 4,c,h,w
1099
+ g: 1,c,h,w
1100
+ """
1101
+ b, c, h, w = l.size()
1102
+ # 4,c,h,w -> 1,c,2h,2w
1103
+ concated_locs = rearrange(l,
1104
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1105
+ hg=2,
1106
+ wg=2)
1107
+ self.p_poses = []
1108
+ pools = []
1109
+ for pool_ratio in self.pool_ratios:
1110
+ # b,c,h,w
1111
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1112
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1113
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1114
+ # if self.g_pos is None:
1115
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
1116
+ pool.shape[3])
1117
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1118
+ self.p_poses.append(pos_emb)
1119
+ pools = torch.cat(pools, 0)
1120
+ # if self.g_pos is None:
1121
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1122
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
1123
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1124
+
1125
+ # attention between glb (q) & multisensory concated-locs (k,v)
1126
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1127
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1128
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1129
+ g_hw_b_c = self.norm1(g_hw_b_c)
1130
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1131
+ self.linear2(
1132
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1133
+ g_hw_b_c = self.norm2(g_hw_b_c)
1134
+
1135
+ # attention between origin locs (q) & freashed glb (k,v)
1136
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1137
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1138
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1139
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1140
+ ng=2,
1141
+ nw=2)
1142
+ outputs_re = []
1143
+ for i, (_l, _g) in enumerate(
1144
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1145
+ outputs_re.append(self.attention[i + 1](_l, _g,
1146
+ _g)[0]) # (h w) 1 c
1147
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1148
+
1149
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1150
+ l_hw_b_c = self.norm1(l_hw_b_c)
1151
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1152
+ self.linear4(
1153
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1154
+ l_hw_b_c = self.norm2(l_hw_b_c)
1155
+
1156
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1157
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1158
+
1159
+
1160
+ class MCRM(nn.Module):
1161
+
1162
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1163
+ super(MCRM, self).__init__()
1164
+ self.attention = nn.ModuleList([
1165
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1166
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1167
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1168
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1169
+ ])
1170
+
1171
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1172
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1173
+ self.norm1 = nn.LayerNorm(d_model)
1174
+ self.norm2 = nn.LayerNorm(d_model)
1175
+ self.dropout = nn.Dropout(0.1)
1176
+ self.dropout1 = nn.Dropout(0.1)
1177
+ self.dropout2 = nn.Dropout(0.1)
1178
+ self.sigmoid = nn.Sigmoid()
1179
+ self.activation = get_activation_fn('relu')
1180
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1181
+ self.pool_ratios = pool_ratios
1182
+ self.positional_encoding = PositionEmbeddingSine(
1183
+ num_pos_feats=d_model // 2, normalize=True)
1184
+
1185
+ def forward(self, x):
1186
+ b, c, h, w = x.size()
1187
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1188
+ # b(4),c,h,w
1189
+ patched_glb = rearrange(glb,
1190
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1191
+ hg=2,
1192
+ wg=2)
1193
+
1194
+ # generate token attention map
1195
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1196
+ token_attention_map = F.interpolate(token_attention_map,
1197
+ size=patches2image(loc).shape[-2:],
1198
+ mode='nearest')
1199
+ loc = loc * rearrange(token_attention_map,
1200
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1201
+ hg=2,
1202
+ wg=2)
1203
+ pools = []
1204
+ for pool_ratio in self.pool_ratios:
1205
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1206
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1207
+ pools.append(rearrange(pool,
1208
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1209
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1210
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1211
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1212
+ outputs = []
1213
+ for i, q in enumerate(
1214
+ loc_.unbind(dim=0)): # traverse all local patches
1215
+ # np*hw,1,c
1216
+ v = pools[i]
1217
+ k = v
1218
+ outputs.append(self.attention[i](q, k, v)[0])
1219
+ outputs = torch.cat(outputs, 1)
1220
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1221
+ src = self.norm1(src)
1222
+ src = src + self.dropout2(
1223
+ self.linear4(
1224
+ self.dropout(self.activation(self.linear3(src)).clone())))
1225
+ src = self.norm2(src)
1226
+
1227
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1228
+ glb = glb + F.interpolate(patches2image(src),
1229
+ size=glb.shape[-2:],
1230
+ mode='nearest') # freshed glb
1231
+ return torch.cat((src, glb), 0), token_attention_map
1232
+
1233
+
1234
+ class inf_MCRM(nn.Module):
1235
+
1236
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1237
+ super(inf_MCRM, self).__init__()
1238
+ self.attention = nn.ModuleList([
1239
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1240
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1241
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1242
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1243
+ ])
1244
+
1245
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1246
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1247
+ self.norm1 = nn.LayerNorm(d_model)
1248
+ self.norm2 = nn.LayerNorm(d_model)
1249
+ self.dropout = nn.Dropout(0.1)
1250
+ self.dropout1 = nn.Dropout(0.1)
1251
+ self.dropout2 = nn.Dropout(0.1)
1252
+ self.sigmoid = nn.Sigmoid()
1253
+ self.activation = get_activation_fn('relu')
1254
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1255
+ self.pool_ratios = pool_ratios
1256
+ self.positional_encoding = PositionEmbeddingSine(
1257
+ num_pos_feats=d_model // 2, normalize=True)
1258
+
1259
+ def forward(self, x):
1260
+ b, c, h, w = x.size()
1261
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1262
+ # b(4),c,h,w
1263
+ patched_glb = rearrange(glb,
1264
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1265
+ hg=2,
1266
+ wg=2)
1267
+
1268
+ # generate token attention map
1269
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1270
+ token_attention_map = F.interpolate(token_attention_map,
1271
+ size=patches2image(loc).shape[-2:],
1272
+ mode='nearest')
1273
+ loc = loc * rearrange(token_attention_map,
1274
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1275
+ hg=2,
1276
+ wg=2)
1277
+ pools = []
1278
+ for pool_ratio in self.pool_ratios:
1279
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1280
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1281
+ pools.append(rearrange(pool,
1282
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1283
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1284
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1285
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1286
+ outputs = []
1287
+ for i, q in enumerate(
1288
+ loc_.unbind(dim=0)): # traverse all local patches
1289
+ # np*hw,1,c
1290
+ v = pools[i]
1291
+ k = v
1292
+ outputs.append(self.attention[i](q, k, v)[0])
1293
+ outputs = torch.cat(outputs, 1)
1294
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1295
+ src = self.norm1(src)
1296
+ src = src + self.dropout2(
1297
+ self.linear4(
1298
+ self.dropout(self.activation(self.linear3(src)).clone())))
1299
+ src = self.norm2(src)
1300
+
1301
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1302
+ glb = glb + F.interpolate(patches2image(src),
1303
+ size=glb.shape[-2:],
1304
+ mode='nearest') # freshed glb
1305
+ return torch.cat((src, glb), 0)
1306
+
1307
+
1308
+ # model for single-scale training
1309
+ class MVANet(nn.Module):
1310
+
1311
+ def __init__(self):
1312
+ super().__init__()
1313
+ self.backbone = SwinB(pretrained=True)
1314
+ emb_dim = 128
1315
+ self.sideout5 = nn.Sequential(
1316
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1317
+ self.sideout4 = nn.Sequential(
1318
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1319
+ self.sideout3 = nn.Sequential(
1320
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1321
+ self.sideout2 = nn.Sequential(
1322
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1323
+ self.sideout1 = nn.Sequential(
1324
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1325
+
1326
+ self.output5 = make_cbr(1024, emb_dim)
1327
+ self.output4 = make_cbr(512, emb_dim)
1328
+ self.output3 = make_cbr(256, emb_dim)
1329
+ self.output2 = make_cbr(128, emb_dim)
1330
+ self.output1 = make_cbr(128, emb_dim)
1331
+
1332
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
1333
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1334
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1335
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1336
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1337
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
1338
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
1339
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
1340
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
1341
+
1342
+ self.insmask_head = nn.Sequential(
1343
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1344
+ nn.BatchNorm2d(384), nn.PReLU(),
1345
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1346
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1347
+
1348
+ self.shallow = nn.Sequential(
1349
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1350
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1351
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1352
+ self.output = nn.Sequential(
1353
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1354
+
1355
+ for m in self.modules():
1356
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1357
+ m.inplace = True
1358
+
1359
+ def forward(self, x):
1360
+ x = x.to(dtype=torch_dtype, device=torch_device)
1361
+ shallow = self.shallow(x)
1362
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1363
+ loc = image2patches(x)
1364
+ input = torch.cat((loc, glb), dim=0)
1365
+ feature = self.backbone(input)
1366
+ e5 = self.output5(feature[4]) # (5,128,16,16)
1367
+ e4 = self.output4(feature[3]) # (5,128,32,32)
1368
+ e3 = self.output3(feature[2]) # (5,128,64,64)
1369
+ e2 = self.output2(feature[1]) # (5,128,128,128)
1370
+ e1 = self.output1(feature[0]) # (5,128,128,128)
1371
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1372
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
1373
+
1374
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
1375
+ e4 = self.conv4(e4)
1376
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
1377
+ e3 = self.conv3(e3)
1378
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
1379
+ e2 = self.conv2(e2)
1380
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
1381
+ e1 = self.conv1(e1)
1382
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1383
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
1384
+ # add glb feat in
1385
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1386
+ # merge
1387
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
1388
+ # shallow feature merge
1389
+ final_output = final_output + resize_as(shallow, final_output)
1390
+ final_output = self.upsample1(rescale_to(final_output))
1391
+ final_output = rescale_to(final_output +
1392
+ resize_as(shallow, final_output))
1393
+ final_output = self.upsample2(final_output)
1394
+ final_output = self.output(final_output)
1395
+ ####
1396
+ sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
1397
+ sideout4 = self.sideout4(e4)
1398
+ sideout3 = self.sideout3(e3)
1399
+ sideout2 = self.sideout2(e2)
1400
+ sideout1 = self.sideout1(e1)
1401
+ #######glb_sideouts ######
1402
+ glb5 = self.sideout5(glb_e5)
1403
+ glb4 = sideout4[-1, :, :, :].unsqueeze(0)
1404
+ glb3 = sideout3[-1, :, :, :].unsqueeze(0)
1405
+ glb2 = sideout2[-1, :, :, :].unsqueeze(0)
1406
+ glb1 = sideout1[-1, :, :, :].unsqueeze(0)
1407
+ ####### concat 4 to 1 #######
1408
+ sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
1409
+ device=torch_device)
1410
+ sideout2 = patches2image(sideout2[:-1]).to(
1411
+ dtype=torch_dtype,
1412
+ device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
1413
+ sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
1414
+ device=torch_device)
1415
+ sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
1416
+ device=torch_device)
1417
+ sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
1418
+ device=torch_device)
1419
+ if self.training:
1420
+ return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
1421
+ else:
1422
+ return final_output
1423
+
1424
+
1425
+ # model for multi-scale testing
1426
+ class inf_MVANet(nn.Module):
1427
+
1428
+ def __init__(self):
1429
+ super().__init__()
1430
+ # self.backbone = SwinB(pretrained=True)
1431
+ self.backbone = SwinB(pretrained=False)
1432
+
1433
+ emb_dim = 128
1434
+ self.output5 = make_cbr(1024, emb_dim)
1435
+ self.output4 = make_cbr(512, emb_dim)
1436
+ self.output3 = make_cbr(256, emb_dim)
1437
+ self.output2 = make_cbr(128, emb_dim)
1438
+ self.output1 = make_cbr(128, emb_dim)
1439
+
1440
+ self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
1441
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1442
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1443
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1444
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1445
+ self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1446
+ self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1447
+ self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1448
+ self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1449
+
1450
+ self.insmask_head = nn.Sequential(
1451
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1452
+ nn.BatchNorm2d(384), nn.PReLU(),
1453
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1454
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1455
+
1456
+ self.shallow = nn.Sequential(
1457
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1458
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1459
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1460
+ self.output = nn.Sequential(
1461
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1462
+
1463
+ for m in self.modules():
1464
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1465
+ m.inplace = True
1466
+
1467
+ def forward(self, x):
1468
+ shallow = self.shallow(x)
1469
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1470
+ loc = image2patches(x)
1471
+ input = torch.cat((loc, glb), dim=0)
1472
+ feature = self.backbone(input)
1473
+ e5 = self.output5(feature[4])
1474
+ e4 = self.output4(feature[3])
1475
+ e3 = self.output3(feature[2])
1476
+ e2 = self.output2(feature[1])
1477
+ e1 = self.output1(feature[0])
1478
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1479
+ e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
1480
+
1481
+ e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
1482
+ e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
1483
+ e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
1484
+ e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
1485
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1486
+ # after decoder, concat loc features to a whole one, and merge
1487
+ output1_cat = patches2image(loc_e1)
1488
+ # add glb feat in
1489
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1490
+ # merge
1491
+ final_output = self.insmask_head(output1_cat)
1492
+ # shallow feature merge
1493
+ final_output = final_output + resize_as(shallow, final_output)
1494
+ final_output = self.upsample1(rescale_to(final_output))
1495
+ final_output = rescale_to(final_output +
1496
+ resize_as(shallow, final_output))
1497
+ final_output = self.upsample2(final_output)
1498
+ final_output = self.output(final_output)
1499
+ return final_output
1500
+ #+end_src
1501
+
1502
+ ** Function to load model
1503
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1504
+ def mkdir_safe(out_path):
1505
+ if type(out_path) == str:
1506
+ if len(out_path) > 0:
1507
+ if not os.path.exists(out_path):
1508
+ os.mkdir(out_path)
1509
+
1510
+
1511
+ def get_model_path():
1512
+ import folder_paths
1513
+ from folder_paths import models_dir
1514
+
1515
+ path_file_model = models_dir
1516
+ mkdir_safe(out_path=path_file_model)
1517
+
1518
+ path_file_model = os.path.join(path_file_model, 'MVANet')
1519
+ mkdir_safe(out_path=path_file_model)
1520
+
1521
+ path_file_model = os.path.join(path_file_model, 'Model_80.pth')
1522
+
1523
+ return path_file_model
1524
+
1525
+
1526
+ def download_model(path):
1527
+ if not os.path.exists(path):
1528
+ wget.download(
1529
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/Model_80.pth',
1530
+ out=path)
1531
+
1532
+
1533
+ def load_model(model_checkpoint_path):
1534
+ download_model(path=model_checkpoint_path)
1535
+ torch.cuda.set_device(0)
1536
+
1537
+ net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
1538
+
1539
+ pretrained_dict = torch.load(finetuned_MVANet_model_path,
1540
+ map_location=torch_device)
1541
+
1542
+ model_dict = net.state_dict()
1543
+ pretrained_dict = {
1544
+ k: v
1545
+ for k, v in pretrained_dict.items() if k in model_dict
1546
+ }
1547
+ model_dict.update(pretrained_dict)
1548
+ net.load_state_dict(model_dict)
1549
+ net = net.to(dtype=torch_dtype, device=torch_device)
1550
+ net.eval()
1551
+ return net
1552
+ #+end_src
1553
+
1554
+ ** Function for modular inference CV
1555
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1556
+ def do_infer_tensor2tensor(img, net):
1557
+
1558
+ img_transform = transforms.Compose(
1559
+ [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
1560
+
1561
+ h_, w_ = img.shape[1], img.shape[2]
1562
+
1563
+ with torch.no_grad():
1564
+
1565
+ img = rearrange(img, 'B H W C -> B C H W')
1566
+
1567
+ img_resize = torch.nn.functional.interpolate(input=img,
1568
+ size=(1024, 1024),
1569
+ mode='bicubic',
1570
+ antialias=True)
1571
+
1572
+ img_var = img_transform(img_resize)
1573
+ img_var = Variable(img_var)
1574
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1575
+
1576
+ mask = []
1577
+
1578
+ mask.append(net(img_var))
1579
+
1580
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1581
+ prediction = prediction.sigmoid()
1582
+
1583
+ prediction = torch.nn.functional.interpolate(input=prediction,
1584
+ size=(h_, w_),
1585
+ mode='bicubic',
1586
+ antialias=True)
1587
+
1588
+ prediction = prediction.squeeze(0)
1589
+ prediction = prediction.clamp(0, 1)
1590
+ prediction = prediction.detach()
1591
+ prediction = prediction.to(dtype=torch.float32, device='cpu')
1592
+
1593
+ return prediction
1594
+ #+end_src
1595
+
1596
+ ** Comfyui wrapper classes
1597
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
1598
+ class load_MVANet_Model:
1599
+
1600
+ def __init__(self):
1601
+ pass
1602
+
1603
+ @classmethod
1604
+ def INPUT_TYPES(s):
1605
+ return {
1606
+ "required": {},
1607
+ }
1608
+
1609
+ RETURN_TYPES = ("MVANet_Model", )
1610
+ FUNCTION = "test"
1611
+ CATEGORY = "MVANet"
1612
+
1613
+ def test(self):
1614
+ return (load_model(get_model_path()), )
1615
+
1616
+
1617
+ class run_MVANet_inference:
1618
+
1619
+ def __init__(self):
1620
+ pass
1621
+
1622
+ @classmethod
1623
+ def INPUT_TYPES(s):
1624
+ return {
1625
+ "required": {
1626
+ "image": ("IMAGE", ),
1627
+ "MVANet_Model": ("MVANet_Model", ),
1628
+ },
1629
+ }
1630
+
1631
+ RETURN_TYPES = ("MASK", )
1632
+ FUNCTION = "test"
1633
+ CATEGORY = "MVANet"
1634
+
1635
+ def test(
1636
+ self,
1637
+ image,
1638
+ MVANet_Model,
1639
+ ):
1640
+ ret = do_infer_tensor2tensor(img=image, net=MVANet_Model)
1641
+
1642
+ return (ret, )
1643
+ #+end_src
1644
+
1645
+ ** MVANet_inference execute
1646
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
1647
+ NODE_CLASS_MAPPINGS = {
1648
+ "load_MVANet_Model": load_MVANet_Model,
1649
+ "run_MVANet_inference": run_MVANet_inference
1650
+ }
1651
+
1652
+ NODE_DISPLAY_NAME_MAPPINGS = {
1653
+ "load_MVANet_Model": "load MVANet Model",
1654
+ "load_MVANet_Model": "load MVANet Model"
1655
+ }
1656
+ #+end_src
1657
+
1658
+ ** MVANet_inference unify
1659
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
1660
+ . "${HOME}/dbnew.sh"
1661
+
1662
+ (
1663
+ echo '#!/usr/bin/python3'
1664
+ cat \
1665
+ './MVANet_inference.import.py' \
1666
+ './MVANet_inference.function.py' \
1667
+ './MVANet_inference.class.py' \
1668
+ './MVANet_inference.execute.py' \
1669
+ | expand | yapf3 \
1670
+ | grep -v '#!/usr/bin/python3' \
1671
+ ;
1672
+ ) > './MVANet_inference.py' \
1673
+ ;
1674
+
1675
+ cp './MVANet_inference.py' '__init__.py'
1676
+ #+end_src
1677
+
1678
+ * WORK SPACE
1679
+
1680
+ ** elisp
1681
+ #+begin_src elisp
1682
+ (save-buffer)
1683
+ (org-babel-tangle)
1684
+ (shell-command "./MVANet_inference.unify.sh")
1685
+ #+end_src
1686
+
1687
+ #+RESULTS:
1688
+ : 0
1689
+
1690
+ ** sh
1691
+ #+begin_src sh :shebang #!/bin/sh :results output
1692
+ realpath .
1693
+ cd /home/asd/GITHUB/aravind-h-v/dreambooth_experiments/MVANet
1694
+ #+end_src
ComfyUI_MVANet/__init__.py ADDED
@@ -0,0 +1,1548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import os
3
+ import sys
4
+
5
+ HOME_DIR = os.environ.get('HOME', '/root')
6
+ MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
7
+ finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
8
+ pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
9
+
10
+ import math
11
+ import numpy as np
12
+ import cv2
13
+ import wget
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint as checkpoint
19
+ from torch.autograd import Variable
20
+ from torch import nn
21
+ from torchvision import transforms
22
+
23
+ from einops import rearrange
24
+
25
+ from timm.models import load_checkpoint
26
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27
+
28
+ torch_device = 'cuda'
29
+ torch_dtype = torch.float16
30
+
31
+
32
+ def check_mkdir(dir_name):
33
+ if not os.path.isdir(dir_name):
34
+ os.makedirs(dir_name)
35
+
36
+
37
+ def SwinT(pretrained=True):
38
+ model = SwinTransformer(embed_dim=96,
39
+ depths=[2, 2, 6, 2],
40
+ num_heads=[3, 6, 12, 24],
41
+ window_size=7)
42
+ if pretrained is True:
43
+ model.load_state_dict(torch.load(
44
+ 'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
45
+ map_location='cpu')['model'],
46
+ strict=False)
47
+
48
+ return model
49
+
50
+
51
+ def SwinS(pretrained=True):
52
+ model = SwinTransformer(embed_dim=96,
53
+ depths=[2, 2, 18, 2],
54
+ num_heads=[3, 6, 12, 24],
55
+ window_size=7)
56
+ if pretrained is True:
57
+ model.load_state_dict(torch.load(
58
+ 'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
59
+ map_location='cpu')['model'],
60
+ strict=False)
61
+
62
+ return model
63
+
64
+
65
+ def SwinB(pretrained=True):
66
+ model = SwinTransformer(embed_dim=128,
67
+ depths=[2, 2, 18, 2],
68
+ num_heads=[4, 8, 16, 32],
69
+ window_size=12)
70
+ if pretrained is True:
71
+ import os
72
+ model.load_state_dict(torch.load(pretrained_SwinB_model_path,
73
+ map_location='cpu')['model'],
74
+ strict=False)
75
+ return model
76
+
77
+
78
+ def SwinL(pretrained=True):
79
+ model = SwinTransformer(embed_dim=192,
80
+ depths=[2, 2, 18, 2],
81
+ num_heads=[6, 12, 24, 48],
82
+ window_size=12)
83
+ if pretrained is True:
84
+ model.load_state_dict(torch.load(
85
+ 'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
86
+ map_location='cpu')['model'],
87
+ strict=False)
88
+
89
+ return model
90
+
91
+
92
+ def get_activation_fn(activation):
93
+ """Return an activation function given a string"""
94
+ if activation == "relu":
95
+ return F.relu
96
+ if activation == "gelu":
97
+ return F.gelu
98
+ if activation == "glu":
99
+ return F.glu
100
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
101
+
102
+
103
+ def make_cbr(in_dim, out_dim):
104
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
105
+ nn.BatchNorm2d(out_dim), nn.PReLU())
106
+
107
+
108
+ def make_cbg(in_dim, out_dim):
109
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
110
+ nn.BatchNorm2d(out_dim), nn.GELU())
111
+
112
+
113
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
114
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
115
+
116
+
117
+ def resize_as(x, y, interpolation='bilinear'):
118
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
119
+
120
+
121
+ def image2patches(x):
122
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
123
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
124
+ return x
125
+
126
+
127
+ def patches2image(x):
128
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
129
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
130
+ return x
131
+
132
+
133
+ def window_partition(x, window_size):
134
+ """
135
+ Args:
136
+ x: (B, H, W, C)
137
+ window_size (int): window size
138
+
139
+ Returns:
140
+ windows: (num_windows*B, window_size, window_size, C)
141
+ """
142
+ B, H, W, C = x.shape
143
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
144
+ C)
145
+ windows = x.permute(0, 1, 3, 2, 4,
146
+ 5).contiguous().view(-1, window_size, window_size, C)
147
+ return windows
148
+
149
+
150
+ def window_reverse(windows, window_size, H, W):
151
+ """
152
+ Args:
153
+ windows: (num_windows*B, window_size, window_size, C)
154
+ window_size (int): Window size
155
+ H (int): Height of image
156
+ W (int): Width of image
157
+
158
+ Returns:
159
+ x: (B, H, W, C)
160
+ """
161
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
162
+ x = windows.view(B, H // window_size, W // window_size, window_size,
163
+ window_size, -1)
164
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
165
+ return x
166
+
167
+
168
+ def mkdir_safe(out_path):
169
+ if type(out_path) == str:
170
+ if len(out_path) > 0:
171
+ if not os.path.exists(out_path):
172
+ os.mkdir(out_path)
173
+
174
+
175
+ def get_model_path():
176
+ import folder_paths
177
+ from folder_paths import models_dir
178
+
179
+ path_file_model = models_dir
180
+ mkdir_safe(out_path=path_file_model)
181
+
182
+ path_file_model = os.path.join(path_file_model, 'MVANet')
183
+ mkdir_safe(out_path=path_file_model)
184
+
185
+ path_file_model = os.path.join(path_file_model, 'Model_80.pth')
186
+
187
+ return path_file_model
188
+
189
+
190
+ def download_model(path):
191
+ if not os.path.exists(path):
192
+ wget.download(
193
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/Model_80.pth',
194
+ out=path)
195
+
196
+
197
+ def load_model(model_checkpoint_path):
198
+ download_model(path=model_checkpoint_path)
199
+ torch.cuda.set_device(0)
200
+
201
+ net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
202
+
203
+ pretrained_dict = torch.load(finetuned_MVANet_model_path,
204
+ map_location=torch_device)
205
+
206
+ model_dict = net.state_dict()
207
+ pretrained_dict = {
208
+ k: v
209
+ for k, v in pretrained_dict.items() if k in model_dict
210
+ }
211
+ model_dict.update(pretrained_dict)
212
+ net.load_state_dict(model_dict)
213
+ net = net.to(dtype=torch_dtype, device=torch_device)
214
+ net.eval()
215
+ return net
216
+
217
+
218
+ def do_infer_tensor2tensor(img, net):
219
+
220
+ img_transform = transforms.Compose(
221
+ [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
222
+
223
+ h_, w_ = img.shape[1], img.shape[2]
224
+
225
+ with torch.no_grad():
226
+
227
+ img = rearrange(img, 'B H W C -> B C H W')
228
+
229
+ img_resize = torch.nn.functional.interpolate(input=img,
230
+ size=(1024, 1024),
231
+ mode='bicubic',
232
+ antialias=True)
233
+
234
+ img_var = img_transform(img_resize)
235
+ img_var = Variable(img_var)
236
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
237
+
238
+ mask = []
239
+
240
+ mask.append(net(img_var))
241
+
242
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
243
+ prediction = prediction.sigmoid()
244
+
245
+ prediction = torch.nn.functional.interpolate(input=prediction,
246
+ size=(h_, w_),
247
+ mode='bicubic',
248
+ antialias=True)
249
+
250
+ prediction = prediction.squeeze(0)
251
+ prediction = prediction.clamp(0, 1)
252
+ prediction = prediction.detach()
253
+ prediction = prediction.to(dtype=torch.float32, device='cpu')
254
+
255
+ return prediction
256
+
257
+
258
+ class Mlp(nn.Module):
259
+ """ Multilayer perceptron."""
260
+
261
+ def __init__(self,
262
+ in_features,
263
+ hidden_features=None,
264
+ out_features=None,
265
+ act_layer=nn.GELU,
266
+ drop=0.):
267
+ super().__init__()
268
+ out_features = out_features or in_features
269
+ hidden_features = hidden_features or in_features
270
+ self.fc1 = nn.Linear(in_features, hidden_features)
271
+ self.act = act_layer()
272
+ self.fc2 = nn.Linear(hidden_features, out_features)
273
+ self.drop = nn.Dropout(drop)
274
+
275
+ def forward(self, x):
276
+ x = self.fc1(x)
277
+ x = self.act(x)
278
+ x = self.drop(x)
279
+ x = self.fc2(x)
280
+ x = self.drop(x)
281
+ return x
282
+
283
+
284
+ class WindowAttention(nn.Module):
285
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
286
+ It supports both of shifted and non-shifted window.
287
+
288
+ Args:
289
+ dim (int): Number of input channels.
290
+ window_size (tuple[int]): The height and width of the window.
291
+ num_heads (int): Number of attention heads.
292
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
293
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
294
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
295
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
296
+ """
297
+
298
+ def __init__(self,
299
+ dim,
300
+ window_size,
301
+ num_heads,
302
+ qkv_bias=True,
303
+ qk_scale=None,
304
+ attn_drop=0.,
305
+ proj_drop=0.):
306
+
307
+ super().__init__()
308
+ self.dim = dim
309
+ self.window_size = window_size # Wh, Ww
310
+ self.num_heads = num_heads
311
+ head_dim = dim // num_heads
312
+ self.scale = qk_scale or head_dim**-0.5
313
+
314
+ # define a parameter table of relative position bias
315
+ self.relative_position_bias_table = nn.Parameter(
316
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
317
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
318
+
319
+ # get pair-wise relative position index for each token inside the window
320
+ coords_h = torch.arange(self.window_size[0])
321
+ coords_w = torch.arange(self.window_size[1])
322
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
323
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
324
+ relative_coords = coords_flatten[:, :,
325
+ None] - coords_flatten[:,
326
+ None, :] # 2, Wh*Ww, Wh*Ww
327
+ relative_coords = relative_coords.permute(
328
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
329
+ relative_coords[:, :,
330
+ 0] += self.window_size[0] - 1 # shift to start from 0
331
+ relative_coords[:, :, 1] += self.window_size[1] - 1
332
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
333
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
334
+ self.register_buffer("relative_position_index",
335
+ relative_position_index)
336
+
337
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
338
+ self.attn_drop = nn.Dropout(attn_drop)
339
+ self.proj = nn.Linear(dim, dim)
340
+ self.proj_drop = nn.Dropout(proj_drop)
341
+
342
+ trunc_normal_(self.relative_position_bias_table, std=.02)
343
+ self.softmax = nn.Softmax(dim=-1)
344
+
345
+ def forward(self, x, mask=None):
346
+ """ Forward function.
347
+
348
+ Args:
349
+ x: input features with shape of (num_windows*B, N, C)
350
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
351
+ """
352
+ x = x.to(dtype=torch_dtype, device=torch_device)
353
+ B_, N, C = x.shape
354
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
355
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
356
+ q, k, v = qkv[0], qkv[1], qkv[
357
+ 2] # make torchscript happy (cannot use tensor as tuple)
358
+
359
+ q = q * self.scale
360
+ attn = (q @ k.transpose(-2, -1))
361
+
362
+ relative_position_bias = self.relative_position_bias_table[
363
+ self.relative_position_index.view(-1)].view(
364
+ self.window_size[0] * self.window_size[1],
365
+ self.window_size[0] * self.window_size[1],
366
+ -1) # Wh*Ww,Wh*Ww,nH
367
+ relative_position_bias = relative_position_bias.permute(
368
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
369
+ attn = attn + relative_position_bias.unsqueeze(0)
370
+
371
+ if mask is not None:
372
+ nW = mask.shape[0]
373
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
374
+ N) + mask.unsqueeze(1).unsqueeze(0)
375
+ attn = attn.view(-1, self.num_heads, N, N)
376
+ attn = self.softmax(attn)
377
+ else:
378
+ attn = self.softmax(attn)
379
+
380
+ attn = self.attn_drop(attn)
381
+ attn = attn.to(dtype=torch_dtype, device=torch_device)
382
+ v = v.to(dtype=torch_dtype, device=torch_device)
383
+
384
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
385
+ x = self.proj(x)
386
+ x = self.proj_drop(x)
387
+ return x
388
+
389
+
390
+ class SwinTransformerBlock(nn.Module):
391
+ """ Swin Transformer Block.
392
+
393
+ Args:
394
+ dim (int): Number of input channels.
395
+ num_heads (int): Number of attention heads.
396
+ window_size (int): Window size.
397
+ shift_size (int): Shift size for SW-MSA.
398
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
399
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
400
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
401
+ drop (float, optional): Dropout rate. Default: 0.0
402
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
403
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
404
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
405
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
406
+ """
407
+
408
+ def __init__(self,
409
+ dim,
410
+ num_heads,
411
+ window_size=7,
412
+ shift_size=0,
413
+ mlp_ratio=4.,
414
+ qkv_bias=True,
415
+ qk_scale=None,
416
+ drop=0.,
417
+ attn_drop=0.,
418
+ drop_path=0.,
419
+ act_layer=nn.GELU,
420
+ norm_layer=nn.LayerNorm):
421
+ super().__init__()
422
+ self.dim = dim
423
+ self.num_heads = num_heads
424
+ self.window_size = window_size
425
+ self.shift_size = shift_size
426
+ self.mlp_ratio = mlp_ratio
427
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
428
+
429
+ self.norm1 = norm_layer(dim)
430
+ self.attn = WindowAttention(dim,
431
+ window_size=to_2tuple(self.window_size),
432
+ num_heads=num_heads,
433
+ qkv_bias=qkv_bias,
434
+ qk_scale=qk_scale,
435
+ attn_drop=attn_drop,
436
+ proj_drop=drop)
437
+
438
+ self.drop_path = DropPath(
439
+ drop_path) if drop_path > 0. else nn.Identity()
440
+ self.norm2 = norm_layer(dim)
441
+ mlp_hidden_dim = int(dim * mlp_ratio)
442
+ self.mlp = Mlp(in_features=dim,
443
+ hidden_features=mlp_hidden_dim,
444
+ act_layer=act_layer,
445
+ drop=drop)
446
+
447
+ self.H = None
448
+ self.W = None
449
+
450
+ def forward(self, x, mask_matrix):
451
+ """ Forward function.
452
+
453
+ Args:
454
+ x: Input feature, tensor size (B, H*W, C).
455
+ H, W: Spatial resolution of the input feature.
456
+ mask_matrix: Attention mask for cyclic shift.
457
+ """
458
+ B, L, C = x.shape
459
+ H, W = self.H, self.W
460
+ assert L == H * W, "input feature has wrong size"
461
+
462
+ shortcut = x
463
+ x = self.norm1(x)
464
+ x = x.view(B, H, W, C)
465
+
466
+ # pad feature maps to multiples of window size
467
+ pad_l = pad_t = 0
468
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
469
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
470
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
471
+ _, Hp, Wp, _ = x.shape
472
+
473
+ # cyclic shift
474
+ if self.shift_size > 0:
475
+ shifted_x = torch.roll(x,
476
+ shifts=(-self.shift_size, -self.shift_size),
477
+ dims=(1, 2))
478
+ attn_mask = mask_matrix
479
+ else:
480
+ shifted_x = x
481
+ attn_mask = None
482
+
483
+ # partition windows
484
+ x_windows = window_partition(
485
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
486
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
487
+ C) # nW*B, window_size*window_size, C
488
+
489
+ # W-MSA/SW-MSA
490
+ attn_windows = self.attn(
491
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
492
+
493
+ # merge windows
494
+ attn_windows = attn_windows.view(-1, self.window_size,
495
+ self.window_size, C)
496
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
497
+ Wp) # B H' W' C
498
+
499
+ # reverse cyclic shift
500
+ if self.shift_size > 0:
501
+ x = torch.roll(shifted_x,
502
+ shifts=(self.shift_size, self.shift_size),
503
+ dims=(1, 2))
504
+ else:
505
+ x = shifted_x
506
+
507
+ if pad_r > 0 or pad_b > 0:
508
+ x = x[:, :H, :W, :].contiguous()
509
+
510
+ x = x.view(B, H * W, C)
511
+
512
+ # FFN
513
+ x = shortcut + self.drop_path(x)
514
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
515
+
516
+ return x
517
+
518
+
519
+ class PatchMerging(nn.Module):
520
+ """ Patch Merging Layer
521
+
522
+ Args:
523
+ dim (int): Number of input channels.
524
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
525
+ """
526
+
527
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
528
+ super().__init__()
529
+ self.dim = dim
530
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
531
+ self.norm = norm_layer(4 * dim)
532
+
533
+ def forward(self, x, H, W):
534
+ """ Forward function.
535
+
536
+ Args:
537
+ x: Input feature, tensor size (B, H*W, C).
538
+ H, W: Spatial resolution of the input feature.
539
+ """
540
+ B, L, C = x.shape
541
+ assert L == H * W, "input feature has wrong size"
542
+
543
+ x = x.view(B, H, W, C)
544
+
545
+ # padding
546
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
547
+ if pad_input:
548
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
549
+
550
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
551
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
552
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
553
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
554
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
555
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
556
+
557
+ x = self.norm(x)
558
+ x = self.reduction(x)
559
+
560
+ return x
561
+
562
+
563
+ class BasicLayer(nn.Module):
564
+ """ A basic Swin Transformer layer for one stage.
565
+
566
+ Args:
567
+ dim (int): Number of feature channels
568
+ depth (int): Depths of this stage.
569
+ num_heads (int): Number of attention head.
570
+ window_size (int): Local window size. Default: 7.
571
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
572
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
573
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
574
+ drop (float, optional): Dropout rate. Default: 0.0
575
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
576
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
577
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
578
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
579
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
580
+ """
581
+
582
+ def __init__(self,
583
+ dim,
584
+ depth,
585
+ num_heads,
586
+ window_size=7,
587
+ mlp_ratio=4.,
588
+ qkv_bias=True,
589
+ qk_scale=None,
590
+ drop=0.,
591
+ attn_drop=0.,
592
+ drop_path=0.,
593
+ norm_layer=nn.LayerNorm,
594
+ downsample=None,
595
+ use_checkpoint=False):
596
+ super().__init__()
597
+ self.window_size = window_size
598
+ self.shift_size = window_size // 2
599
+ self.depth = depth
600
+ self.use_checkpoint = use_checkpoint
601
+
602
+ # build blocks
603
+ self.blocks = nn.ModuleList([
604
+ SwinTransformerBlock(dim=dim,
605
+ num_heads=num_heads,
606
+ window_size=window_size,
607
+ shift_size=0 if
608
+ (i % 2 == 0) else window_size // 2,
609
+ mlp_ratio=mlp_ratio,
610
+ qkv_bias=qkv_bias,
611
+ qk_scale=qk_scale,
612
+ drop=drop,
613
+ attn_drop=attn_drop,
614
+ drop_path=drop_path[i] if isinstance(
615
+ drop_path, list) else drop_path,
616
+ norm_layer=norm_layer) for i in range(depth)
617
+ ])
618
+
619
+ # patch merging layer
620
+ if downsample is not None:
621
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
622
+ else:
623
+ self.downsample = None
624
+
625
+ def forward(self, x, H, W):
626
+ """ Forward function.
627
+
628
+ Args:
629
+ x: Input feature, tensor size (B, H*W, C).
630
+ H, W: Spatial resolution of the input feature.
631
+ """
632
+
633
+ # calculate attention mask for SW-MSA
634
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
635
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
636
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
637
+ h_slices = (slice(0, -self.window_size),
638
+ slice(-self.window_size,
639
+ -self.shift_size), slice(-self.shift_size, None))
640
+ w_slices = (slice(0, -self.window_size),
641
+ slice(-self.window_size,
642
+ -self.shift_size), slice(-self.shift_size, None))
643
+ cnt = 0
644
+ for h in h_slices:
645
+ for w in w_slices:
646
+ img_mask[:, h, w, :] = cnt
647
+ cnt += 1
648
+
649
+ mask_windows = window_partition(
650
+ img_mask, self.window_size) # nW, window_size, window_size, 1
651
+ mask_windows = mask_windows.view(-1,
652
+ self.window_size * self.window_size)
653
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
654
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
655
+ float(-100.0)).masked_fill(
656
+ attn_mask == 0, float(0.0))
657
+
658
+ for blk in self.blocks:
659
+ blk.H, blk.W = H, W
660
+ if self.use_checkpoint:
661
+ x = checkpoint.checkpoint(blk, x, attn_mask)
662
+ else:
663
+ x = blk(x, attn_mask)
664
+ if self.downsample is not None:
665
+ x_down = self.downsample(x, H, W)
666
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
667
+ return x, H, W, x_down, Wh, Ww
668
+ else:
669
+ return x, H, W, x, H, W
670
+
671
+
672
+ class PatchEmbed(nn.Module):
673
+ """ Image to Patch Embedding
674
+
675
+ Args:
676
+ patch_size (int): Patch token size. Default: 4.
677
+ in_chans (int): Number of input image channels. Default: 3.
678
+ embed_dim (int): Number of linear projection output channels. Default: 96.
679
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
680
+ """
681
+
682
+ def __init__(self,
683
+ patch_size=4,
684
+ in_chans=3,
685
+ embed_dim=96,
686
+ norm_layer=None):
687
+ super().__init__()
688
+ patch_size = to_2tuple(patch_size)
689
+ self.patch_size = patch_size
690
+
691
+ self.in_chans = in_chans
692
+ self.embed_dim = embed_dim
693
+
694
+ self.proj = nn.Conv2d(in_chans,
695
+ embed_dim,
696
+ kernel_size=patch_size,
697
+ stride=patch_size)
698
+ if norm_layer is not None:
699
+ self.norm = norm_layer(embed_dim)
700
+ else:
701
+ self.norm = None
702
+
703
+ def forward(self, x):
704
+ """Forward function."""
705
+ # padding
706
+ _, _, H, W = x.size()
707
+ if W % self.patch_size[1] != 0:
708
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
709
+ if H % self.patch_size[0] != 0:
710
+ x = F.pad(x,
711
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
712
+
713
+ x = self.proj(x) # B C Wh Ww
714
+ if self.norm is not None:
715
+ Wh, Ww = x.size(2), x.size(3)
716
+ x = x.flatten(2).transpose(1, 2)
717
+ x = self.norm(x)
718
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
719
+
720
+ return x
721
+
722
+
723
+ class SwinTransformer(nn.Module):
724
+ """ Swin Transformer backbone.
725
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
726
+ https://arxiv.org/pdf/2103.14030
727
+
728
+ Args:
729
+ pretrain_img_size (int): Input image size for training the pretrained model,
730
+ used in absolute postion embedding. Default 224.
731
+ patch_size (int | tuple(int)): Patch size. Default: 4.
732
+ in_chans (int): Number of input image channels. Default: 3.
733
+ embed_dim (int): Number of linear projection output channels. Default: 96.
734
+ depths (tuple[int]): Depths of each Swin Transformer stage.
735
+ num_heads (tuple[int]): Number of attention head of each stage.
736
+ window_size (int): Window size. Default: 7.
737
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
738
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
739
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
740
+ drop_rate (float): Dropout rate.
741
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
742
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
743
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
744
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
745
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
746
+ out_indices (Sequence[int]): Output from which stages.
747
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
748
+ -1 means not freezing any parameters.
749
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
750
+ """
751
+
752
+ def __init__(self,
753
+ pretrain_img_size=224,
754
+ patch_size=4,
755
+ in_chans=3,
756
+ embed_dim=96,
757
+ depths=[2, 2, 6, 2],
758
+ num_heads=[3, 6, 12, 24],
759
+ window_size=7,
760
+ mlp_ratio=4.,
761
+ qkv_bias=True,
762
+ qk_scale=None,
763
+ drop_rate=0.,
764
+ attn_drop_rate=0.,
765
+ drop_path_rate=0.2,
766
+ norm_layer=nn.LayerNorm,
767
+ ape=False,
768
+ patch_norm=True,
769
+ out_indices=(0, 1, 2, 3),
770
+ frozen_stages=-1,
771
+ use_checkpoint=False):
772
+ super().__init__()
773
+
774
+ self.pretrain_img_size = pretrain_img_size
775
+ self.num_layers = len(depths)
776
+ self.embed_dim = embed_dim
777
+ self.ape = ape
778
+ self.patch_norm = patch_norm
779
+ self.out_indices = out_indices
780
+ self.frozen_stages = frozen_stages
781
+
782
+ # split image into non-overlapping patches
783
+ self.patch_embed = PatchEmbed(
784
+ patch_size=patch_size,
785
+ in_chans=in_chans,
786
+ embed_dim=embed_dim,
787
+ norm_layer=norm_layer if self.patch_norm else None)
788
+
789
+ # absolute position embedding
790
+ if self.ape:
791
+ pretrain_img_size = to_2tuple(pretrain_img_size)
792
+ patch_size = to_2tuple(patch_size)
793
+ patches_resolution = [
794
+ pretrain_img_size[0] // patch_size[0],
795
+ pretrain_img_size[1] // patch_size[1]
796
+ ]
797
+
798
+ self.absolute_pos_embed = nn.Parameter(
799
+ torch.zeros(1, embed_dim, patches_resolution[0],
800
+ patches_resolution[1]))
801
+ trunc_normal_(self.absolute_pos_embed, std=.02)
802
+
803
+ self.pos_drop = nn.Dropout(p=drop_rate)
804
+
805
+ # stochastic depth
806
+ dpr = [
807
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
808
+ ] # stochastic depth decay rule
809
+
810
+ # build layers
811
+ self.layers = nn.ModuleList()
812
+ for i_layer in range(self.num_layers):
813
+ layer = BasicLayer(
814
+ dim=int(embed_dim * 2**i_layer),
815
+ depth=depths[i_layer],
816
+ num_heads=num_heads[i_layer],
817
+ window_size=window_size,
818
+ mlp_ratio=mlp_ratio,
819
+ qkv_bias=qkv_bias,
820
+ qk_scale=qk_scale,
821
+ drop=drop_rate,
822
+ attn_drop=attn_drop_rate,
823
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
824
+ norm_layer=norm_layer,
825
+ downsample=PatchMerging if
826
+ (i_layer < self.num_layers - 1) else None,
827
+ use_checkpoint=use_checkpoint)
828
+ self.layers.append(layer)
829
+
830
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
831
+ self.num_features = num_features
832
+
833
+ # add a norm layer for each output
834
+ for i_layer in out_indices:
835
+ layer = norm_layer(num_features[i_layer])
836
+ layer_name = f'norm{i_layer}'
837
+ self.add_module(layer_name, layer)
838
+
839
+ self._freeze_stages()
840
+
841
+ def _freeze_stages(self):
842
+ if self.frozen_stages >= 0:
843
+ self.patch_embed.eval()
844
+ for param in self.patch_embed.parameters():
845
+ param.requires_grad = False
846
+
847
+ if self.frozen_stages >= 1 and self.ape:
848
+ self.absolute_pos_embed.requires_grad = False
849
+
850
+ if self.frozen_stages >= 2:
851
+ self.pos_drop.eval()
852
+ for i in range(0, self.frozen_stages - 1):
853
+ m = self.layers[i]
854
+ m.eval()
855
+ for param in m.parameters():
856
+ param.requires_grad = False
857
+
858
+ def init_weights(self, pretrained=None):
859
+ """Initialize the weights in backbone.
860
+
861
+ Args:
862
+ pretrained (str, optional): Path to pre-trained weights.
863
+ Defaults to None.
864
+ """
865
+
866
+ def _init_weights(m):
867
+ if isinstance(m, nn.Linear):
868
+ trunc_normal_(m.weight, std=.02)
869
+ if isinstance(m, nn.Linear) and m.bias is not None:
870
+ nn.init.constant_(m.bias, 0)
871
+ elif isinstance(m, nn.LayerNorm):
872
+ nn.init.constant_(m.bias, 0)
873
+ nn.init.constant_(m.weight, 1.0)
874
+
875
+ if isinstance(pretrained, str):
876
+ self.apply(_init_weights)
877
+ load_checkpoint(self, pretrained, strict=False, logger=None)
878
+ elif pretrained is None:
879
+ self.apply(_init_weights)
880
+ else:
881
+ raise TypeError('pretrained must be a str or None')
882
+
883
+ def forward(self, x):
884
+ x = self.patch_embed(x)
885
+
886
+ Wh, Ww = x.size(2), x.size(3)
887
+ if self.ape:
888
+ # interpolate the position embedding to the corresponding size
889
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
890
+ size=(Wh, Ww),
891
+ mode='bicubic')
892
+ x = (x + absolute_pos_embed) # B Wh*Ww C
893
+
894
+ outs = [x.contiguous()]
895
+ x = x.flatten(2).transpose(1, 2)
896
+ x = self.pos_drop(x)
897
+ for i in range(self.num_layers):
898
+ layer = self.layers[i]
899
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
900
+
901
+ if i in self.out_indices:
902
+ norm_layer = getattr(self, f'norm{i}')
903
+ x_out = norm_layer(x_out)
904
+
905
+ out = x_out.view(-1, H, W,
906
+ self.num_features[i]).permute(0, 3, 1,
907
+ 2).contiguous()
908
+ outs.append(out)
909
+
910
+ return tuple(outs)
911
+
912
+ def train(self, mode=True):
913
+ """Convert the model into training mode while keep layers freezed."""
914
+ super(SwinTransformer, self).train(mode)
915
+ self._freeze_stages()
916
+
917
+
918
+ class PositionEmbeddingSine:
919
+
920
+ def __init__(self,
921
+ num_pos_feats=64,
922
+ temperature=10000,
923
+ normalize=False,
924
+ scale=None):
925
+ super().__init__()
926
+ self.num_pos_feats = num_pos_feats
927
+ self.temperature = temperature
928
+ self.normalize = normalize
929
+ if scale is not None and normalize is False:
930
+ raise ValueError("normalize should be True if scale is passed")
931
+ if scale is None:
932
+ scale = 2 * math.pi
933
+ self.scale = scale
934
+ self.dim_t = torch.arange(0,
935
+ self.num_pos_feats,
936
+ dtype=torch_dtype,
937
+ device=torch_device)
938
+
939
+ def __call__(self, b, h, w):
940
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
941
+ assert mask is not None
942
+ not_mask = ~mask
943
+ y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
944
+ x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
945
+ if self.normalize:
946
+ eps = 1e-6
947
+ y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
948
+ self.scale).to(device=torch_device, dtype=torch_dtype)
949
+ x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
950
+ self.scale).to(device=torch_device, dtype=torch_dtype)
951
+
952
+ dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
953
+
954
+ pos_x = x_embed[:, :, :, None] / dim_t
955
+ pos_y = y_embed[:, :, :, None] / dim_t
956
+ pos_x = torch.stack(
957
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
958
+ dim=4).flatten(3)
959
+ pos_y = torch.stack(
960
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
961
+ dim=4).flatten(3)
962
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
963
+
964
+
965
+ class MCLM(nn.Module):
966
+
967
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
968
+ super(MCLM, self).__init__()
969
+ self.attention = nn.ModuleList([
970
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
971
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
972
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
973
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
974
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
975
+ ])
976
+
977
+ self.linear1 = nn.Linear(d_model, d_model * 2)
978
+ self.linear2 = nn.Linear(d_model * 2, d_model)
979
+ self.linear3 = nn.Linear(d_model, d_model * 2)
980
+ self.linear4 = nn.Linear(d_model * 2, d_model)
981
+ self.norm1 = nn.LayerNorm(d_model)
982
+ self.norm2 = nn.LayerNorm(d_model)
983
+ self.dropout = nn.Dropout(0.1)
984
+ self.dropout1 = nn.Dropout(0.1)
985
+ self.dropout2 = nn.Dropout(0.1)
986
+ self.activation = get_activation_fn('relu')
987
+ self.pool_ratios = pool_ratios
988
+ self.p_poses = []
989
+ self.g_pos = None
990
+ self.positional_encoding = PositionEmbeddingSine(
991
+ num_pos_feats=d_model // 2, normalize=True)
992
+
993
+ def forward(self, l, g):
994
+ """
995
+ l: 4,c,h,w
996
+ g: 1,c,h,w
997
+ """
998
+ b, c, h, w = l.size()
999
+ # 4,c,h,w -> 1,c,2h,2w
1000
+ concated_locs = rearrange(l,
1001
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1002
+ hg=2,
1003
+ wg=2)
1004
+
1005
+ pools = []
1006
+ for pool_ratio in self.pool_ratios:
1007
+ # b,c,h,w
1008
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1009
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1010
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1011
+ if self.g_pos is None:
1012
+ pos_emb = self.positional_encoding(pool.shape[0],
1013
+ pool.shape[2],
1014
+ pool.shape[3])
1015
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1016
+ self.p_poses.append(pos_emb)
1017
+ pools = torch.cat(pools, 0)
1018
+ if self.g_pos is None:
1019
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1020
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
1021
+ g.shape[3])
1022
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1023
+
1024
+ # attention between glb (q) & multisensory concated-locs (k,v)
1025
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1026
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1027
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1028
+ g_hw_b_c = self.norm1(g_hw_b_c)
1029
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1030
+ self.linear2(
1031
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1032
+ g_hw_b_c = self.norm2(g_hw_b_c)
1033
+
1034
+ # attention between origin locs (q) & freashed glb (k,v)
1035
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1036
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1037
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1038
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1039
+ ng=2,
1040
+ nw=2)
1041
+ outputs_re = []
1042
+ for i, (_l, _g) in enumerate(
1043
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1044
+ outputs_re.append(self.attention[i + 1](_l, _g,
1045
+ _g)[0]) # (h w) 1 c
1046
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1047
+
1048
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1049
+ l_hw_b_c = self.norm1(l_hw_b_c)
1050
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1051
+ self.linear4(
1052
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1053
+ l_hw_b_c = self.norm2(l_hw_b_c)
1054
+
1055
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1056
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1057
+
1058
+
1059
+ class inf_MCLM(nn.Module):
1060
+
1061
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
1062
+ super(inf_MCLM, self).__init__()
1063
+ self.attention = nn.ModuleList([
1064
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1065
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1066
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1067
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1068
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1069
+ ])
1070
+
1071
+ self.linear1 = nn.Linear(d_model, d_model * 2)
1072
+ self.linear2 = nn.Linear(d_model * 2, d_model)
1073
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1074
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1075
+ self.norm1 = nn.LayerNorm(d_model)
1076
+ self.norm2 = nn.LayerNorm(d_model)
1077
+ self.dropout = nn.Dropout(0.1)
1078
+ self.dropout1 = nn.Dropout(0.1)
1079
+ self.dropout2 = nn.Dropout(0.1)
1080
+ self.activation = get_activation_fn('relu')
1081
+ self.pool_ratios = pool_ratios
1082
+ self.p_poses = []
1083
+ self.g_pos = None
1084
+ self.positional_encoding = PositionEmbeddingSine(
1085
+ num_pos_feats=d_model // 2, normalize=True)
1086
+
1087
+ def forward(self, l, g):
1088
+ """
1089
+ l: 4,c,h,w
1090
+ g: 1,c,h,w
1091
+ """
1092
+ b, c, h, w = l.size()
1093
+ # 4,c,h,w -> 1,c,2h,2w
1094
+ concated_locs = rearrange(l,
1095
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1096
+ hg=2,
1097
+ wg=2)
1098
+ self.p_poses = []
1099
+ pools = []
1100
+ for pool_ratio in self.pool_ratios:
1101
+ # b,c,h,w
1102
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1103
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1104
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1105
+ # if self.g_pos is None:
1106
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
1107
+ pool.shape[3])
1108
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1109
+ self.p_poses.append(pos_emb)
1110
+ pools = torch.cat(pools, 0)
1111
+ # if self.g_pos is None:
1112
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1113
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
1114
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1115
+
1116
+ # attention between glb (q) & multisensory concated-locs (k,v)
1117
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1118
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1119
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1120
+ g_hw_b_c = self.norm1(g_hw_b_c)
1121
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1122
+ self.linear2(
1123
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1124
+ g_hw_b_c = self.norm2(g_hw_b_c)
1125
+
1126
+ # attention between origin locs (q) & freashed glb (k,v)
1127
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1128
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1129
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1130
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1131
+ ng=2,
1132
+ nw=2)
1133
+ outputs_re = []
1134
+ for i, (_l, _g) in enumerate(
1135
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1136
+ outputs_re.append(self.attention[i + 1](_l, _g,
1137
+ _g)[0]) # (h w) 1 c
1138
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1139
+
1140
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1141
+ l_hw_b_c = self.norm1(l_hw_b_c)
1142
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1143
+ self.linear4(
1144
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1145
+ l_hw_b_c = self.norm2(l_hw_b_c)
1146
+
1147
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1148
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1149
+
1150
+
1151
+ class MCRM(nn.Module):
1152
+
1153
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1154
+ super(MCRM, self).__init__()
1155
+ self.attention = nn.ModuleList([
1156
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1157
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1158
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1159
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1160
+ ])
1161
+
1162
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1163
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1164
+ self.norm1 = nn.LayerNorm(d_model)
1165
+ self.norm2 = nn.LayerNorm(d_model)
1166
+ self.dropout = nn.Dropout(0.1)
1167
+ self.dropout1 = nn.Dropout(0.1)
1168
+ self.dropout2 = nn.Dropout(0.1)
1169
+ self.sigmoid = nn.Sigmoid()
1170
+ self.activation = get_activation_fn('relu')
1171
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1172
+ self.pool_ratios = pool_ratios
1173
+ self.positional_encoding = PositionEmbeddingSine(
1174
+ num_pos_feats=d_model // 2, normalize=True)
1175
+
1176
+ def forward(self, x):
1177
+ b, c, h, w = x.size()
1178
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1179
+ # b(4),c,h,w
1180
+ patched_glb = rearrange(glb,
1181
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1182
+ hg=2,
1183
+ wg=2)
1184
+
1185
+ # generate token attention map
1186
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1187
+ token_attention_map = F.interpolate(token_attention_map,
1188
+ size=patches2image(loc).shape[-2:],
1189
+ mode='nearest')
1190
+ loc = loc * rearrange(token_attention_map,
1191
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1192
+ hg=2,
1193
+ wg=2)
1194
+ pools = []
1195
+ for pool_ratio in self.pool_ratios:
1196
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1197
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1198
+ pools.append(rearrange(pool,
1199
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1200
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1201
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1202
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1203
+ outputs = []
1204
+ for i, q in enumerate(
1205
+ loc_.unbind(dim=0)): # traverse all local patches
1206
+ # np*hw,1,c
1207
+ v = pools[i]
1208
+ k = v
1209
+ outputs.append(self.attention[i](q, k, v)[0])
1210
+ outputs = torch.cat(outputs, 1)
1211
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1212
+ src = self.norm1(src)
1213
+ src = src + self.dropout2(
1214
+ self.linear4(
1215
+ self.dropout(self.activation(self.linear3(src)).clone())))
1216
+ src = self.norm2(src)
1217
+
1218
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1219
+ glb = glb + F.interpolate(patches2image(src),
1220
+ size=glb.shape[-2:],
1221
+ mode='nearest') # freshed glb
1222
+ return torch.cat((src, glb), 0), token_attention_map
1223
+
1224
+
1225
+ class inf_MCRM(nn.Module):
1226
+
1227
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1228
+ super(inf_MCRM, self).__init__()
1229
+ self.attention = nn.ModuleList([
1230
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1231
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1232
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1233
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1234
+ ])
1235
+
1236
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1237
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1238
+ self.norm1 = nn.LayerNorm(d_model)
1239
+ self.norm2 = nn.LayerNorm(d_model)
1240
+ self.dropout = nn.Dropout(0.1)
1241
+ self.dropout1 = nn.Dropout(0.1)
1242
+ self.dropout2 = nn.Dropout(0.1)
1243
+ self.sigmoid = nn.Sigmoid()
1244
+ self.activation = get_activation_fn('relu')
1245
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1246
+ self.pool_ratios = pool_ratios
1247
+ self.positional_encoding = PositionEmbeddingSine(
1248
+ num_pos_feats=d_model // 2, normalize=True)
1249
+
1250
+ def forward(self, x):
1251
+ b, c, h, w = x.size()
1252
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1253
+ # b(4),c,h,w
1254
+ patched_glb = rearrange(glb,
1255
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1256
+ hg=2,
1257
+ wg=2)
1258
+
1259
+ # generate token attention map
1260
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1261
+ token_attention_map = F.interpolate(token_attention_map,
1262
+ size=patches2image(loc).shape[-2:],
1263
+ mode='nearest')
1264
+ loc = loc * rearrange(token_attention_map,
1265
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1266
+ hg=2,
1267
+ wg=2)
1268
+ pools = []
1269
+ for pool_ratio in self.pool_ratios:
1270
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1271
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1272
+ pools.append(rearrange(pool,
1273
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1274
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1275
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1276
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1277
+ outputs = []
1278
+ for i, q in enumerate(
1279
+ loc_.unbind(dim=0)): # traverse all local patches
1280
+ # np*hw,1,c
1281
+ v = pools[i]
1282
+ k = v
1283
+ outputs.append(self.attention[i](q, k, v)[0])
1284
+ outputs = torch.cat(outputs, 1)
1285
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1286
+ src = self.norm1(src)
1287
+ src = src + self.dropout2(
1288
+ self.linear4(
1289
+ self.dropout(self.activation(self.linear3(src)).clone())))
1290
+ src = self.norm2(src)
1291
+
1292
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1293
+ glb = glb + F.interpolate(patches2image(src),
1294
+ size=glb.shape[-2:],
1295
+ mode='nearest') # freshed glb
1296
+ return torch.cat((src, glb), 0)
1297
+
1298
+
1299
+ # model for single-scale training
1300
+ class MVANet(nn.Module):
1301
+
1302
+ def __init__(self):
1303
+ super().__init__()
1304
+ self.backbone = SwinB(pretrained=True)
1305
+ emb_dim = 128
1306
+ self.sideout5 = nn.Sequential(
1307
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1308
+ self.sideout4 = nn.Sequential(
1309
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1310
+ self.sideout3 = nn.Sequential(
1311
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1312
+ self.sideout2 = nn.Sequential(
1313
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1314
+ self.sideout1 = nn.Sequential(
1315
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1316
+
1317
+ self.output5 = make_cbr(1024, emb_dim)
1318
+ self.output4 = make_cbr(512, emb_dim)
1319
+ self.output3 = make_cbr(256, emb_dim)
1320
+ self.output2 = make_cbr(128, emb_dim)
1321
+ self.output1 = make_cbr(128, emb_dim)
1322
+
1323
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
1324
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1325
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1326
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1327
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1328
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
1329
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
1330
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
1331
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
1332
+
1333
+ self.insmask_head = nn.Sequential(
1334
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1335
+ nn.BatchNorm2d(384), nn.PReLU(),
1336
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1337
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1338
+
1339
+ self.shallow = nn.Sequential(
1340
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1341
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1342
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1343
+ self.output = nn.Sequential(
1344
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1345
+
1346
+ for m in self.modules():
1347
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1348
+ m.inplace = True
1349
+
1350
+ def forward(self, x):
1351
+ x = x.to(dtype=torch_dtype, device=torch_device)
1352
+ shallow = self.shallow(x)
1353
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1354
+ loc = image2patches(x)
1355
+ input = torch.cat((loc, glb), dim=0)
1356
+ feature = self.backbone(input)
1357
+ e5 = self.output5(feature[4]) # (5,128,16,16)
1358
+ e4 = self.output4(feature[3]) # (5,128,32,32)
1359
+ e3 = self.output3(feature[2]) # (5,128,64,64)
1360
+ e2 = self.output2(feature[1]) # (5,128,128,128)
1361
+ e1 = self.output1(feature[0]) # (5,128,128,128)
1362
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1363
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
1364
+
1365
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
1366
+ e4 = self.conv4(e4)
1367
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
1368
+ e3 = self.conv3(e3)
1369
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
1370
+ e2 = self.conv2(e2)
1371
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
1372
+ e1 = self.conv1(e1)
1373
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1374
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
1375
+ # add glb feat in
1376
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1377
+ # merge
1378
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
1379
+ # shallow feature merge
1380
+ final_output = final_output + resize_as(shallow, final_output)
1381
+ final_output = self.upsample1(rescale_to(final_output))
1382
+ final_output = rescale_to(final_output +
1383
+ resize_as(shallow, final_output))
1384
+ final_output = self.upsample2(final_output)
1385
+ final_output = self.output(final_output)
1386
+ ####
1387
+ sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
1388
+ sideout4 = self.sideout4(e4)
1389
+ sideout3 = self.sideout3(e3)
1390
+ sideout2 = self.sideout2(e2)
1391
+ sideout1 = self.sideout1(e1)
1392
+ #######glb_sideouts ######
1393
+ glb5 = self.sideout5(glb_e5)
1394
+ glb4 = sideout4[-1, :, :, :].unsqueeze(0)
1395
+ glb3 = sideout3[-1, :, :, :].unsqueeze(0)
1396
+ glb2 = sideout2[-1, :, :, :].unsqueeze(0)
1397
+ glb1 = sideout1[-1, :, :, :].unsqueeze(0)
1398
+ ####### concat 4 to 1 #######
1399
+ sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
1400
+ device=torch_device)
1401
+ sideout2 = patches2image(sideout2[:-1]).to(
1402
+ dtype=torch_dtype,
1403
+ device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
1404
+ sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
1405
+ device=torch_device)
1406
+ sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
1407
+ device=torch_device)
1408
+ sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
1409
+ device=torch_device)
1410
+ if self.training:
1411
+ return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
1412
+ else:
1413
+ return final_output
1414
+
1415
+
1416
+ # model for multi-scale testing
1417
+ class inf_MVANet(nn.Module):
1418
+
1419
+ def __init__(self):
1420
+ super().__init__()
1421
+ # self.backbone = SwinB(pretrained=True)
1422
+ self.backbone = SwinB(pretrained=False)
1423
+
1424
+ emb_dim = 128
1425
+ self.output5 = make_cbr(1024, emb_dim)
1426
+ self.output4 = make_cbr(512, emb_dim)
1427
+ self.output3 = make_cbr(256, emb_dim)
1428
+ self.output2 = make_cbr(128, emb_dim)
1429
+ self.output1 = make_cbr(128, emb_dim)
1430
+
1431
+ self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
1432
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1433
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1434
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1435
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1436
+ self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1437
+ self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1438
+ self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1439
+ self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1440
+
1441
+ self.insmask_head = nn.Sequential(
1442
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1443
+ nn.BatchNorm2d(384), nn.PReLU(),
1444
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1445
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1446
+
1447
+ self.shallow = nn.Sequential(
1448
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1449
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1450
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1451
+ self.output = nn.Sequential(
1452
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1453
+
1454
+ for m in self.modules():
1455
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1456
+ m.inplace = True
1457
+
1458
+ def forward(self, x):
1459
+ shallow = self.shallow(x)
1460
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1461
+ loc = image2patches(x)
1462
+ input = torch.cat((loc, glb), dim=0)
1463
+ feature = self.backbone(input)
1464
+ e5 = self.output5(feature[4])
1465
+ e4 = self.output4(feature[3])
1466
+ e3 = self.output3(feature[2])
1467
+ e2 = self.output2(feature[1])
1468
+ e1 = self.output1(feature[0])
1469
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1470
+ e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
1471
+
1472
+ e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
1473
+ e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
1474
+ e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
1475
+ e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
1476
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1477
+ # after decoder, concat loc features to a whole one, and merge
1478
+ output1_cat = patches2image(loc_e1)
1479
+ # add glb feat in
1480
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1481
+ # merge
1482
+ final_output = self.insmask_head(output1_cat)
1483
+ # shallow feature merge
1484
+ final_output = final_output + resize_as(shallow, final_output)
1485
+ final_output = self.upsample1(rescale_to(final_output))
1486
+ final_output = rescale_to(final_output +
1487
+ resize_as(shallow, final_output))
1488
+ final_output = self.upsample2(final_output)
1489
+ final_output = self.output(final_output)
1490
+ return final_output
1491
+
1492
+
1493
+ class load_MVANet_Model:
1494
+
1495
+ def __init__(self):
1496
+ pass
1497
+
1498
+ @classmethod
1499
+ def INPUT_TYPES(s):
1500
+ return {
1501
+ "required": {},
1502
+ }
1503
+
1504
+ RETURN_TYPES = ("MVANet_Model", )
1505
+ FUNCTION = "test"
1506
+ CATEGORY = "MVANet"
1507
+
1508
+ def test(self):
1509
+ return (load_model(get_model_path()), )
1510
+
1511
+
1512
+ class run_MVANet_inference:
1513
+
1514
+ def __init__(self):
1515
+ pass
1516
+
1517
+ @classmethod
1518
+ def INPUT_TYPES(s):
1519
+ return {
1520
+ "required": {
1521
+ "image": ("IMAGE", ),
1522
+ "MVANet_Model": ("MVANet_Model", ),
1523
+ },
1524
+ }
1525
+
1526
+ RETURN_TYPES = ("MASK", )
1527
+ FUNCTION = "test"
1528
+ CATEGORY = "MVANet"
1529
+
1530
+ def test(
1531
+ self,
1532
+ image,
1533
+ MVANet_Model,
1534
+ ):
1535
+ ret = do_infer_tensor2tensor(img=image, net=MVANet_Model)
1536
+
1537
+ return (ret, )
1538
+
1539
+
1540
+ NODE_CLASS_MAPPINGS = {
1541
+ "load_MVANet_Model": load_MVANet_Model,
1542
+ "run_MVANet_inference": run_MVANet_inference
1543
+ }
1544
+
1545
+ NODE_DISPLAY_NAME_MAPPINGS = {
1546
+ "load_MVANet_Model": "load MVANet Model",
1547
+ "load_MVANet_Model": "load MVANet Model"
1548
+ }
ComfyUI_MVANet/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ timm
2
+ einops
3
+ wget