aravindhv10 commited on
Commit
0010613
·
1 Parent(s): 0be46a0

Routine updates

Browse files
Files changed (2) hide show
  1. MVANet_Inference/README.org +2179 -0
  2. main.org +7 -29
MVANet_Inference/README.org ADDED
@@ -0,0 +1,2179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * COMMENT Sample
2
+
3
+ ** Shell script to download
4
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
5
+ #+end_src
6
+
7
+ ** MVANet_inference import
8
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
9
+ #+end_src
10
+
11
+ ** MVANet_inference function
12
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
13
+ #+end_src
14
+
15
+ ** MVANet_inference class
16
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
17
+ #+end_src
18
+
19
+ ** MVANet_inference execute
20
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
21
+ #+end_src
22
+
23
+ ** MVANet_inference unify
24
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
25
+ #+end_src
26
+
27
+ ** MVANet_inference run
28
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.run.sh
29
+ #+end_src
30
+
31
+ * Download the code:
32
+
33
+ ** Function to download
34
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
35
+ get_repo(){
36
+ DIR_REPO="${HOME}/GITHUB/$('echo' "${1}" | 'sed' 's/^git@github.com://g ; s@^https://github.com/@@g ; s@.git$@@g' )"
37
+ DIR_BASE="$('dirname' '--' "${DIR_REPO}")"
38
+ mkdir -pv -- "${DIR_BASE}"
39
+ cd "${DIR_BASE}"
40
+ git clone "${1}"
41
+ cd "${DIR_REPO}"
42
+ git pull
43
+ git submodule update --recursive --init
44
+ }
45
+ #+end_src
46
+
47
+ ** Download
48
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
49
+ get_repo 'https://github.com/qianyu-dlut/MVANet.git'
50
+ #+end_src
51
+
52
+ * Dependencies
53
+ pip3 install mmdet==2.23.0
54
+ pip3 install mmcv==1.4.8
55
+ pip3 install ttach
56
+
57
+ * Python inference
58
+
59
+ ** Important configs
60
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
61
+ import os
62
+ import sys
63
+
64
+ HOME_DIR = os.environ.get('HOME', '/root')
65
+ MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
66
+ finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
67
+ pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
68
+ #+end_src
69
+
70
+ ** MVANet_inference import
71
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
72
+ import math
73
+ import numpy as np
74
+ from PIL import Image
75
+ import time
76
+ # import ttach as tta
77
+ import cv2
78
+
79
+ import torch
80
+ import torch.nn as nn
81
+ import torch.nn.functional as F
82
+ import torch.utils.checkpoint as checkpoint
83
+ from torch.autograd import Variable
84
+ from torch import nn
85
+ from torchvision import transforms
86
+
87
+ from einops import rearrange
88
+
89
+ from timm.models import load_checkpoint
90
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
91
+ #+end_src
92
+
93
+ ** Load image using CV
94
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
95
+ def load_image(input_image_path):
96
+ img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
97
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
98
+ return img
99
+
100
+
101
+ def load_image_torch(input_image_path):
102
+ img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
103
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
104
+ img = torch.from_numpy(img)
105
+ img = img.to(dtype=torch.float32)
106
+ img /= 255.0
107
+ img = img.unsqueeze(0)
108
+ return img
109
+
110
+
111
+ def save_mask(output_image_path, mask):
112
+ cv2.imwrite(output_image_path, mask)
113
+
114
+
115
+ def save_mask_torch(output_image_path, mask):
116
+ mask = mask.detach().cpu()
117
+ mask *= 255.0
118
+ mask = mask.clamp(0, 255)
119
+ print(mask.shape)
120
+ mask = mask.squeeze(0)
121
+ mask = mask.to(dtype=torch.uint8)
122
+ print(mask.shape)
123
+ mask = mask.numpy()
124
+ print(mask.shape)
125
+ cv2.imwrite(output_image_path, mask)
126
+ #+end_src
127
+
128
+ ** Device configs
129
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
130
+ torch_device = 'cuda'
131
+ torch_dtype = torch.float16
132
+ #+end_src
133
+ to(dtype=torch_dtype, device=torch_device)
134
+
135
+ ** MVANet_inference function
136
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
137
+ def check_mkdir(dir_name):
138
+ if not os.path.isdir(dir_name):
139
+ os.makedirs(dir_name)
140
+
141
+
142
+ def SwinT(pretrained=True):
143
+ model = SwinTransformer(embed_dim=96,
144
+ depths=[2, 2, 6, 2],
145
+ num_heads=[3, 6, 12, 24],
146
+ window_size=7)
147
+ if pretrained is True:
148
+ model.load_state_dict(torch.load(
149
+ 'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
150
+ map_location='cpu')['model'],
151
+ strict=False)
152
+
153
+ return model
154
+
155
+
156
+ def SwinS(pretrained=True):
157
+ model = SwinTransformer(embed_dim=96,
158
+ depths=[2, 2, 18, 2],
159
+ num_heads=[3, 6, 12, 24],
160
+ window_size=7)
161
+ if pretrained is True:
162
+ model.load_state_dict(torch.load(
163
+ 'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
164
+ map_location='cpu')['model'],
165
+ strict=False)
166
+
167
+ return model
168
+
169
+
170
+ def SwinB(pretrained=True):
171
+ model = SwinTransformer(embed_dim=128,
172
+ depths=[2, 2, 18, 2],
173
+ num_heads=[4, 8, 16, 32],
174
+ window_size=12)
175
+ if pretrained is True:
176
+ import os
177
+ model.load_state_dict(torch.load(pretrained_SwinB_model_path,
178
+ map_location='cpu')['model'],
179
+ strict=False)
180
+ return model
181
+
182
+
183
+ def SwinL(pretrained=True):
184
+ model = SwinTransformer(embed_dim=192,
185
+ depths=[2, 2, 18, 2],
186
+ num_heads=[6, 12, 24, 48],
187
+ window_size=12)
188
+ if pretrained is True:
189
+ model.load_state_dict(torch.load(
190
+ 'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
191
+ map_location='cpu')['model'],
192
+ strict=False)
193
+
194
+ return model
195
+
196
+
197
+ def get_activation_fn(activation):
198
+ """Return an activation function given a string"""
199
+ if activation == "relu":
200
+ return F.relu
201
+ if activation == "gelu":
202
+ return F.gelu
203
+ if activation == "glu":
204
+ return F.glu
205
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
206
+
207
+
208
+ def make_cbr(in_dim, out_dim):
209
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
210
+ nn.BatchNorm2d(out_dim), nn.PReLU())
211
+
212
+
213
+ def make_cbg(in_dim, out_dim):
214
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
215
+ nn.BatchNorm2d(out_dim), nn.GELU())
216
+
217
+
218
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
219
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
220
+
221
+
222
+ def resize_as(x, y, interpolation='bilinear'):
223
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
224
+
225
+
226
+ def image2patches(x):
227
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
228
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
229
+ return x
230
+
231
+
232
+ def patches2image(x):
233
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
234
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
235
+ return x
236
+
237
+
238
+ def window_partition(x, window_size):
239
+ """
240
+ Args:
241
+ x: (B, H, W, C)
242
+ window_size (int): window size
243
+
244
+ Returns:
245
+ windows: (num_windows*B, window_size, window_size, C)
246
+ """
247
+ B, H, W, C = x.shape
248
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
249
+ C)
250
+ windows = x.permute(0, 1, 3, 2, 4,
251
+ 5).contiguous().view(-1, window_size, window_size, C)
252
+ return windows
253
+
254
+
255
+ def window_reverse(windows, window_size, H, W):
256
+ """
257
+ Args:
258
+ windows: (num_windows*B, window_size, window_size, C)
259
+ window_size (int): Window size
260
+ H (int): Height of image
261
+ W (int): Width of image
262
+
263
+ Returns:
264
+ x: (B, H, W, C)
265
+ """
266
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
267
+ x = windows.view(B, H // window_size, W // window_size, window_size,
268
+ window_size, -1)
269
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
270
+ return x
271
+ #+end_src
272
+
273
+ ** MVANet_inference class
274
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
275
+ class Mlp(nn.Module):
276
+ """ Multilayer perceptron."""
277
+
278
+ def __init__(self,
279
+ in_features,
280
+ hidden_features=None,
281
+ out_features=None,
282
+ act_layer=nn.GELU,
283
+ drop=0.):
284
+ super().__init__()
285
+ out_features = out_features or in_features
286
+ hidden_features = hidden_features or in_features
287
+ self.fc1 = nn.Linear(in_features, hidden_features)
288
+ self.act = act_layer()
289
+ self.fc2 = nn.Linear(hidden_features, out_features)
290
+ self.drop = nn.Dropout(drop)
291
+
292
+ def forward(self, x):
293
+ x = self.fc1(x)
294
+ x = self.act(x)
295
+ x = self.drop(x)
296
+ x = self.fc2(x)
297
+ x = self.drop(x)
298
+ return x
299
+
300
+
301
+ class WindowAttention(nn.Module):
302
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
303
+ It supports both of shifted and non-shifted window.
304
+
305
+ Args:
306
+ dim (int): Number of input channels.
307
+ window_size (tuple[int]): The height and width of the window.
308
+ num_heads (int): Number of attention heads.
309
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
310
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
311
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
312
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
313
+ """
314
+
315
+ def __init__(self,
316
+ dim,
317
+ window_size,
318
+ num_heads,
319
+ qkv_bias=True,
320
+ qk_scale=None,
321
+ attn_drop=0.,
322
+ proj_drop=0.):
323
+
324
+ super().__init__()
325
+ self.dim = dim
326
+ self.window_size = window_size # Wh, Ww
327
+ self.num_heads = num_heads
328
+ head_dim = dim // num_heads
329
+ self.scale = qk_scale or head_dim**-0.5
330
+
331
+ # define a parameter table of relative position bias
332
+ self.relative_position_bias_table = nn.Parameter(
333
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
334
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
335
+
336
+ # get pair-wise relative position index for each token inside the window
337
+ coords_h = torch.arange(self.window_size[0])
338
+ coords_w = torch.arange(self.window_size[1])
339
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
340
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
341
+ relative_coords = coords_flatten[:, :,
342
+ None] - coords_flatten[:,
343
+ None, :] # 2, Wh*Ww, Wh*Ww
344
+ relative_coords = relative_coords.permute(
345
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
346
+ relative_coords[:, :,
347
+ 0] += self.window_size[0] - 1 # shift to start from 0
348
+ relative_coords[:, :, 1] += self.window_size[1] - 1
349
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
350
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
351
+ self.register_buffer("relative_position_index",
352
+ relative_position_index)
353
+
354
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
355
+ self.attn_drop = nn.Dropout(attn_drop)
356
+ self.proj = nn.Linear(dim, dim)
357
+ self.proj_drop = nn.Dropout(proj_drop)
358
+
359
+ trunc_normal_(self.relative_position_bias_table, std=.02)
360
+ self.softmax = nn.Softmax(dim=-1)
361
+
362
+ def forward(self, x, mask=None):
363
+ """ Forward function.
364
+
365
+ Args:
366
+ x: input features with shape of (num_windows*B, N, C)
367
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
368
+ """
369
+ x = x.to(dtype=torch_dtype, device=torch_device)
370
+ B_, N, C = x.shape
371
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
372
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
373
+ q, k, v = qkv[0], qkv[1], qkv[
374
+ 2] # make torchscript happy (cannot use tensor as tuple)
375
+
376
+ q = q * self.scale
377
+ attn = (q @ k.transpose(-2, -1))
378
+
379
+ relative_position_bias = self.relative_position_bias_table[
380
+ self.relative_position_index.view(-1)].view(
381
+ self.window_size[0] * self.window_size[1],
382
+ self.window_size[0] * self.window_size[1],
383
+ -1) # Wh*Ww,Wh*Ww,nH
384
+ relative_position_bias = relative_position_bias.permute(
385
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
386
+ attn = attn + relative_position_bias.unsqueeze(0)
387
+
388
+ if mask is not None:
389
+ nW = mask.shape[0]
390
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
391
+ N) + mask.unsqueeze(1).unsqueeze(0)
392
+ attn = attn.view(-1, self.num_heads, N, N)
393
+ attn = self.softmax(attn)
394
+ else:
395
+ attn = self.softmax(attn)
396
+
397
+ attn = self.attn_drop(attn)
398
+ attn = attn.to(dtype=torch_dtype, device=torch_device)
399
+ v = v.to(dtype=torch_dtype, device=torch_device)
400
+
401
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
402
+ x = self.proj(x)
403
+ x = self.proj_drop(x)
404
+ return x
405
+
406
+
407
+ class SwinTransformerBlock(nn.Module):
408
+ """ Swin Transformer Block.
409
+
410
+ Args:
411
+ dim (int): Number of input channels.
412
+ num_heads (int): Number of attention heads.
413
+ window_size (int): Window size.
414
+ shift_size (int): Shift size for SW-MSA.
415
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
416
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
417
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
418
+ drop (float, optional): Dropout rate. Default: 0.0
419
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
420
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
421
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
422
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
423
+ """
424
+
425
+ def __init__(self,
426
+ dim,
427
+ num_heads,
428
+ window_size=7,
429
+ shift_size=0,
430
+ mlp_ratio=4.,
431
+ qkv_bias=True,
432
+ qk_scale=None,
433
+ drop=0.,
434
+ attn_drop=0.,
435
+ drop_path=0.,
436
+ act_layer=nn.GELU,
437
+ norm_layer=nn.LayerNorm):
438
+ super().__init__()
439
+ self.dim = dim
440
+ self.num_heads = num_heads
441
+ self.window_size = window_size
442
+ self.shift_size = shift_size
443
+ self.mlp_ratio = mlp_ratio
444
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
445
+
446
+ self.norm1 = norm_layer(dim)
447
+ self.attn = WindowAttention(dim,
448
+ window_size=to_2tuple(self.window_size),
449
+ num_heads=num_heads,
450
+ qkv_bias=qkv_bias,
451
+ qk_scale=qk_scale,
452
+ attn_drop=attn_drop,
453
+ proj_drop=drop)
454
+
455
+ self.drop_path = DropPath(
456
+ drop_path) if drop_path > 0. else nn.Identity()
457
+ self.norm2 = norm_layer(dim)
458
+ mlp_hidden_dim = int(dim * mlp_ratio)
459
+ self.mlp = Mlp(in_features=dim,
460
+ hidden_features=mlp_hidden_dim,
461
+ act_layer=act_layer,
462
+ drop=drop)
463
+
464
+ self.H = None
465
+ self.W = None
466
+
467
+ def forward(self, x, mask_matrix):
468
+ """ Forward function.
469
+
470
+ Args:
471
+ x: Input feature, tensor size (B, H*W, C).
472
+ H, W: Spatial resolution of the input feature.
473
+ mask_matrix: Attention mask for cyclic shift.
474
+ """
475
+ B, L, C = x.shape
476
+ H, W = self.H, self.W
477
+ assert L == H * W, "input feature has wrong size"
478
+
479
+ shortcut = x
480
+ x = self.norm1(x)
481
+ x = x.view(B, H, W, C)
482
+
483
+ # pad feature maps to multiples of window size
484
+ pad_l = pad_t = 0
485
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
486
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
487
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
488
+ _, Hp, Wp, _ = x.shape
489
+
490
+ # cyclic shift
491
+ if self.shift_size > 0:
492
+ shifted_x = torch.roll(x,
493
+ shifts=(-self.shift_size, -self.shift_size),
494
+ dims=(1, 2))
495
+ attn_mask = mask_matrix
496
+ else:
497
+ shifted_x = x
498
+ attn_mask = None
499
+
500
+ # partition windows
501
+ x_windows = window_partition(
502
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
503
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
504
+ C) # nW*B, window_size*window_size, C
505
+
506
+ # W-MSA/SW-MSA
507
+ attn_windows = self.attn(
508
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
509
+
510
+ # merge windows
511
+ attn_windows = attn_windows.view(-1, self.window_size,
512
+ self.window_size, C)
513
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
514
+ Wp) # B H' W' C
515
+
516
+ # reverse cyclic shift
517
+ if self.shift_size > 0:
518
+ x = torch.roll(shifted_x,
519
+ shifts=(self.shift_size, self.shift_size),
520
+ dims=(1, 2))
521
+ else:
522
+ x = shifted_x
523
+
524
+ if pad_r > 0 or pad_b > 0:
525
+ x = x[:, :H, :W, :].contiguous()
526
+
527
+ x = x.view(B, H * W, C)
528
+
529
+ # FFN
530
+ x = shortcut + self.drop_path(x)
531
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
532
+
533
+ return x
534
+
535
+
536
+ class PatchMerging(nn.Module):
537
+ """ Patch Merging Layer
538
+
539
+ Args:
540
+ dim (int): Number of input channels.
541
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
542
+ """
543
+
544
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
545
+ super().__init__()
546
+ self.dim = dim
547
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
548
+ self.norm = norm_layer(4 * dim)
549
+
550
+ def forward(self, x, H, W):
551
+ """ Forward function.
552
+
553
+ Args:
554
+ x: Input feature, tensor size (B, H*W, C).
555
+ H, W: Spatial resolution of the input feature.
556
+ """
557
+ B, L, C = x.shape
558
+ assert L == H * W, "input feature has wrong size"
559
+
560
+ x = x.view(B, H, W, C)
561
+
562
+ # padding
563
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
564
+ if pad_input:
565
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
566
+
567
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
568
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
569
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
570
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
571
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
572
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
573
+
574
+ x = self.norm(x)
575
+ x = self.reduction(x)
576
+
577
+ return x
578
+
579
+
580
+ class BasicLayer(nn.Module):
581
+ """ A basic Swin Transformer layer for one stage.
582
+
583
+ Args:
584
+ dim (int): Number of feature channels
585
+ depth (int): Depths of this stage.
586
+ num_heads (int): Number of attention head.
587
+ window_size (int): Local window size. Default: 7.
588
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
589
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
590
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
591
+ drop (float, optional): Dropout rate. Default: 0.0
592
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
593
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
594
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
595
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
596
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
597
+ """
598
+
599
+ def __init__(self,
600
+ dim,
601
+ depth,
602
+ num_heads,
603
+ window_size=7,
604
+ mlp_ratio=4.,
605
+ qkv_bias=True,
606
+ qk_scale=None,
607
+ drop=0.,
608
+ attn_drop=0.,
609
+ drop_path=0.,
610
+ norm_layer=nn.LayerNorm,
611
+ downsample=None,
612
+ use_checkpoint=False):
613
+ super().__init__()
614
+ self.window_size = window_size
615
+ self.shift_size = window_size // 2
616
+ self.depth = depth
617
+ self.use_checkpoint = use_checkpoint
618
+
619
+ # build blocks
620
+ self.blocks = nn.ModuleList([
621
+ SwinTransformerBlock(dim=dim,
622
+ num_heads=num_heads,
623
+ window_size=window_size,
624
+ shift_size=0 if
625
+ (i % 2 == 0) else window_size // 2,
626
+ mlp_ratio=mlp_ratio,
627
+ qkv_bias=qkv_bias,
628
+ qk_scale=qk_scale,
629
+ drop=drop,
630
+ attn_drop=attn_drop,
631
+ drop_path=drop_path[i] if isinstance(
632
+ drop_path, list) else drop_path,
633
+ norm_layer=norm_layer) for i in range(depth)
634
+ ])
635
+
636
+ # patch merging layer
637
+ if downsample is not None:
638
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
639
+ else:
640
+ self.downsample = None
641
+
642
+ def forward(self, x, H, W):
643
+ """ Forward function.
644
+
645
+ Args:
646
+ x: Input feature, tensor size (B, H*W, C).
647
+ H, W: Spatial resolution of the input feature.
648
+ """
649
+
650
+ # calculate attention mask for SW-MSA
651
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
652
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
653
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
654
+ h_slices = (slice(0, -self.window_size),
655
+ slice(-self.window_size,
656
+ -self.shift_size), slice(-self.shift_size, None))
657
+ w_slices = (slice(0, -self.window_size),
658
+ slice(-self.window_size,
659
+ -self.shift_size), slice(-self.shift_size, None))
660
+ cnt = 0
661
+ for h in h_slices:
662
+ for w in w_slices:
663
+ img_mask[:, h, w, :] = cnt
664
+ cnt += 1
665
+
666
+ mask_windows = window_partition(
667
+ img_mask, self.window_size) # nW, window_size, window_size, 1
668
+ mask_windows = mask_windows.view(-1,
669
+ self.window_size * self.window_size)
670
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
671
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
672
+ float(-100.0)).masked_fill(
673
+ attn_mask == 0, float(0.0))
674
+
675
+ for blk in self.blocks:
676
+ blk.H, blk.W = H, W
677
+ if self.use_checkpoint:
678
+ x = checkpoint.checkpoint(blk, x, attn_mask)
679
+ else:
680
+ x = blk(x, attn_mask)
681
+ if self.downsample is not None:
682
+ x_down = self.downsample(x, H, W)
683
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
684
+ return x, H, W, x_down, Wh, Ww
685
+ else:
686
+ return x, H, W, x, H, W
687
+
688
+
689
+ class PatchEmbed(nn.Module):
690
+ """ Image to Patch Embedding
691
+
692
+ Args:
693
+ patch_size (int): Patch token size. Default: 4.
694
+ in_chans (int): Number of input image channels. Default: 3.
695
+ embed_dim (int): Number of linear projection output channels. Default: 96.
696
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
697
+ """
698
+
699
+ def __init__(self,
700
+ patch_size=4,
701
+ in_chans=3,
702
+ embed_dim=96,
703
+ norm_layer=None):
704
+ super().__init__()
705
+ patch_size = to_2tuple(patch_size)
706
+ self.patch_size = patch_size
707
+
708
+ self.in_chans = in_chans
709
+ self.embed_dim = embed_dim
710
+
711
+ self.proj = nn.Conv2d(in_chans,
712
+ embed_dim,
713
+ kernel_size=patch_size,
714
+ stride=patch_size)
715
+ if norm_layer is not None:
716
+ self.norm = norm_layer(embed_dim)
717
+ else:
718
+ self.norm = None
719
+
720
+ def forward(self, x):
721
+ """Forward function."""
722
+ # padding
723
+ _, _, H, W = x.size()
724
+ if W % self.patch_size[1] != 0:
725
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
726
+ if H % self.patch_size[0] != 0:
727
+ x = F.pad(x,
728
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
729
+
730
+ x = self.proj(x) # B C Wh Ww
731
+ if self.norm is not None:
732
+ Wh, Ww = x.size(2), x.size(3)
733
+ x = x.flatten(2).transpose(1, 2)
734
+ x = self.norm(x)
735
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
736
+
737
+ return x
738
+
739
+
740
+ class SwinTransformer(nn.Module):
741
+ """ Swin Transformer backbone.
742
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
743
+ https://arxiv.org/pdf/2103.14030
744
+
745
+ Args:
746
+ pretrain_img_size (int): Input image size for training the pretrained model,
747
+ used in absolute postion embedding. Default 224.
748
+ patch_size (int | tuple(int)): Patch size. Default: 4.
749
+ in_chans (int): Number of input image channels. Default: 3.
750
+ embed_dim (int): Number of linear projection output channels. Default: 96.
751
+ depths (tuple[int]): Depths of each Swin Transformer stage.
752
+ num_heads (tuple[int]): Number of attention head of each stage.
753
+ window_size (int): Window size. Default: 7.
754
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
755
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
756
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
757
+ drop_rate (float): Dropout rate.
758
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
759
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
760
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
761
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
762
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
763
+ out_indices (Sequence[int]): Output from which stages.
764
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
765
+ -1 means not freezing any parameters.
766
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
767
+ """
768
+
769
+ def __init__(self,
770
+ pretrain_img_size=224,
771
+ patch_size=4,
772
+ in_chans=3,
773
+ embed_dim=96,
774
+ depths=[2, 2, 6, 2],
775
+ num_heads=[3, 6, 12, 24],
776
+ window_size=7,
777
+ mlp_ratio=4.,
778
+ qkv_bias=True,
779
+ qk_scale=None,
780
+ drop_rate=0.,
781
+ attn_drop_rate=0.,
782
+ drop_path_rate=0.2,
783
+ norm_layer=nn.LayerNorm,
784
+ ape=False,
785
+ patch_norm=True,
786
+ out_indices=(0, 1, 2, 3),
787
+ frozen_stages=-1,
788
+ use_checkpoint=False):
789
+ super().__init__()
790
+
791
+ self.pretrain_img_size = pretrain_img_size
792
+ self.num_layers = len(depths)
793
+ self.embed_dim = embed_dim
794
+ self.ape = ape
795
+ self.patch_norm = patch_norm
796
+ self.out_indices = out_indices
797
+ self.frozen_stages = frozen_stages
798
+
799
+ # split image into non-overlapping patches
800
+ self.patch_embed = PatchEmbed(
801
+ patch_size=patch_size,
802
+ in_chans=in_chans,
803
+ embed_dim=embed_dim,
804
+ norm_layer=norm_layer if self.patch_norm else None)
805
+
806
+ # absolute position embedding
807
+ if self.ape:
808
+ pretrain_img_size = to_2tuple(pretrain_img_size)
809
+ patch_size = to_2tuple(patch_size)
810
+ patches_resolution = [
811
+ pretrain_img_size[0] // patch_size[0],
812
+ pretrain_img_size[1] // patch_size[1]
813
+ ]
814
+
815
+ self.absolute_pos_embed = nn.Parameter(
816
+ torch.zeros(1, embed_dim, patches_resolution[0],
817
+ patches_resolution[1]))
818
+ trunc_normal_(self.absolute_pos_embed, std=.02)
819
+
820
+ self.pos_drop = nn.Dropout(p=drop_rate)
821
+
822
+ # stochastic depth
823
+ dpr = [
824
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
825
+ ] # stochastic depth decay rule
826
+
827
+ # build layers
828
+ self.layers = nn.ModuleList()
829
+ for i_layer in range(self.num_layers):
830
+ layer = BasicLayer(
831
+ dim=int(embed_dim * 2**i_layer),
832
+ depth=depths[i_layer],
833
+ num_heads=num_heads[i_layer],
834
+ window_size=window_size,
835
+ mlp_ratio=mlp_ratio,
836
+ qkv_bias=qkv_bias,
837
+ qk_scale=qk_scale,
838
+ drop=drop_rate,
839
+ attn_drop=attn_drop_rate,
840
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
841
+ norm_layer=norm_layer,
842
+ downsample=PatchMerging if
843
+ (i_layer < self.num_layers - 1) else None,
844
+ use_checkpoint=use_checkpoint)
845
+ self.layers.append(layer)
846
+
847
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
848
+ self.num_features = num_features
849
+
850
+ # add a norm layer for each output
851
+ for i_layer in out_indices:
852
+ layer = norm_layer(num_features[i_layer])
853
+ layer_name = f'norm{i_layer}'
854
+ self.add_module(layer_name, layer)
855
+
856
+ self._freeze_stages()
857
+
858
+ def _freeze_stages(self):
859
+ if self.frozen_stages >= 0:
860
+ self.patch_embed.eval()
861
+ for param in self.patch_embed.parameters():
862
+ param.requires_grad = False
863
+
864
+ if self.frozen_stages >= 1 and self.ape:
865
+ self.absolute_pos_embed.requires_grad = False
866
+
867
+ if self.frozen_stages >= 2:
868
+ self.pos_drop.eval()
869
+ for i in range(0, self.frozen_stages - 1):
870
+ m = self.layers[i]
871
+ m.eval()
872
+ for param in m.parameters():
873
+ param.requires_grad = False
874
+
875
+ def init_weights(self, pretrained=None):
876
+ """Initialize the weights in backbone.
877
+
878
+ Args:
879
+ pretrained (str, optional): Path to pre-trained weights.
880
+ Defaults to None.
881
+ """
882
+
883
+ def _init_weights(m):
884
+ if isinstance(m, nn.Linear):
885
+ trunc_normal_(m.weight, std=.02)
886
+ if isinstance(m, nn.Linear) and m.bias is not None:
887
+ nn.init.constant_(m.bias, 0)
888
+ elif isinstance(m, nn.LayerNorm):
889
+ nn.init.constant_(m.bias, 0)
890
+ nn.init.constant_(m.weight, 1.0)
891
+
892
+ if isinstance(pretrained, str):
893
+ self.apply(_init_weights)
894
+ load_checkpoint(self, pretrained, strict=False, logger=None)
895
+ elif pretrained is None:
896
+ self.apply(_init_weights)
897
+ else:
898
+ raise TypeError('pretrained must be a str or None')
899
+
900
+ def forward(self, x):
901
+ x = self.patch_embed(x)
902
+
903
+ Wh, Ww = x.size(2), x.size(3)
904
+ if self.ape:
905
+ # interpolate the position embedding to the corresponding size
906
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
907
+ size=(Wh, Ww),
908
+ mode='bicubic')
909
+ x = (x + absolute_pos_embed) # B Wh*Ww C
910
+
911
+ outs = [x.contiguous()]
912
+ x = x.flatten(2).transpose(1, 2)
913
+ x = self.pos_drop(x)
914
+ for i in range(self.num_layers):
915
+ layer = self.layers[i]
916
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
917
+
918
+ if i in self.out_indices:
919
+ norm_layer = getattr(self, f'norm{i}')
920
+ x_out = norm_layer(x_out)
921
+
922
+ out = x_out.view(-1, H, W,
923
+ self.num_features[i]).permute(0, 3, 1,
924
+ 2).contiguous()
925
+ outs.append(out)
926
+
927
+ return tuple(outs)
928
+
929
+ def train(self, mode=True):
930
+ """Convert the model into training mode while keep layers freezed."""
931
+ super(SwinTransformer, self).train(mode)
932
+ self._freeze_stages()
933
+
934
+
935
+ class PositionEmbeddingSine:
936
+
937
+ def __init__(self,
938
+ num_pos_feats=64,
939
+ temperature=10000,
940
+ normalize=False,
941
+ scale=None):
942
+ super().__init__()
943
+ self.num_pos_feats = num_pos_feats
944
+ self.temperature = temperature
945
+ self.normalize = normalize
946
+ if scale is not None and normalize is False:
947
+ raise ValueError("normalize should be True if scale is passed")
948
+ if scale is None:
949
+ scale = 2 * math.pi
950
+ self.scale = scale
951
+ self.dim_t = torch.arange(0,
952
+ self.num_pos_feats,
953
+ dtype=torch_dtype,
954
+ device=torch_device)
955
+
956
+ def __call__(self, b, h, w):
957
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
958
+ assert mask is not None
959
+ not_mask = ~mask
960
+ y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
961
+ x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
962
+ if self.normalize:
963
+ eps = 1e-6
964
+ y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
965
+ self.scale).to(device=torch_device, dtype=torch_dtype)
966
+ x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
967
+ self.scale).to(device=torch_device, dtype=torch_dtype)
968
+
969
+ dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
970
+
971
+ pos_x = x_embed[:, :, :, None] / dim_t
972
+ pos_y = y_embed[:, :, :, None] / dim_t
973
+ pos_x = torch.stack(
974
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
975
+ dim=4).flatten(3)
976
+ pos_y = torch.stack(
977
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
978
+ dim=4).flatten(3)
979
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
980
+
981
+
982
+ class MCLM(nn.Module):
983
+
984
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
985
+ super(MCLM, self).__init__()
986
+ self.attention = nn.ModuleList([
987
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
988
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
989
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
990
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
991
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
992
+ ])
993
+
994
+ self.linear1 = nn.Linear(d_model, d_model * 2)
995
+ self.linear2 = nn.Linear(d_model * 2, d_model)
996
+ self.linear3 = nn.Linear(d_model, d_model * 2)
997
+ self.linear4 = nn.Linear(d_model * 2, d_model)
998
+ self.norm1 = nn.LayerNorm(d_model)
999
+ self.norm2 = nn.LayerNorm(d_model)
1000
+ self.dropout = nn.Dropout(0.1)
1001
+ self.dropout1 = nn.Dropout(0.1)
1002
+ self.dropout2 = nn.Dropout(0.1)
1003
+ self.activation = get_activation_fn('relu')
1004
+ self.pool_ratios = pool_ratios
1005
+ self.p_poses = []
1006
+ self.g_pos = None
1007
+ self.positional_encoding = PositionEmbeddingSine(
1008
+ num_pos_feats=d_model // 2, normalize=True)
1009
+
1010
+ def forward(self, l, g):
1011
+ """
1012
+ l: 4,c,h,w
1013
+ g: 1,c,h,w
1014
+ """
1015
+ b, c, h, w = l.size()
1016
+ # 4,c,h,w -> 1,c,2h,2w
1017
+ concated_locs = rearrange(l,
1018
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1019
+ hg=2,
1020
+ wg=2)
1021
+
1022
+ pools = []
1023
+ for pool_ratio in self.pool_ratios:
1024
+ # b,c,h,w
1025
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1026
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1027
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1028
+ if self.g_pos is None:
1029
+ pos_emb = self.positional_encoding(pool.shape[0],
1030
+ pool.shape[2],
1031
+ pool.shape[3])
1032
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1033
+ self.p_poses.append(pos_emb)
1034
+ pools = torch.cat(pools, 0)
1035
+ if self.g_pos is None:
1036
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1037
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
1038
+ g.shape[3])
1039
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1040
+
1041
+ # attention between glb (q) & multisensory concated-locs (k,v)
1042
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1043
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1044
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1045
+ g_hw_b_c = self.norm1(g_hw_b_c)
1046
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1047
+ self.linear2(
1048
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1049
+ g_hw_b_c = self.norm2(g_hw_b_c)
1050
+
1051
+ # attention between origin locs (q) & freashed glb (k,v)
1052
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1053
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1054
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1055
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1056
+ ng=2,
1057
+ nw=2)
1058
+ outputs_re = []
1059
+ for i, (_l, _g) in enumerate(
1060
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1061
+ outputs_re.append(self.attention[i + 1](_l, _g,
1062
+ _g)[0]) # (h w) 1 c
1063
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1064
+
1065
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1066
+ l_hw_b_c = self.norm1(l_hw_b_c)
1067
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1068
+ self.linear4(
1069
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1070
+ l_hw_b_c = self.norm2(l_hw_b_c)
1071
+
1072
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1073
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1074
+
1075
+
1076
+ class inf_MCLM(nn.Module):
1077
+
1078
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
1079
+ super(inf_MCLM, self).__init__()
1080
+ self.attention = nn.ModuleList([
1081
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1082
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1083
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1084
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1085
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1086
+ ])
1087
+
1088
+ self.linear1 = nn.Linear(d_model, d_model * 2)
1089
+ self.linear2 = nn.Linear(d_model * 2, d_model)
1090
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1091
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1092
+ self.norm1 = nn.LayerNorm(d_model)
1093
+ self.norm2 = nn.LayerNorm(d_model)
1094
+ self.dropout = nn.Dropout(0.1)
1095
+ self.dropout1 = nn.Dropout(0.1)
1096
+ self.dropout2 = nn.Dropout(0.1)
1097
+ self.activation = get_activation_fn('relu')
1098
+ self.pool_ratios = pool_ratios
1099
+ self.p_poses = []
1100
+ self.g_pos = None
1101
+ self.positional_encoding = PositionEmbeddingSine(
1102
+ num_pos_feats=d_model // 2, normalize=True)
1103
+
1104
+ def forward(self, l, g):
1105
+ """
1106
+ l: 4,c,h,w
1107
+ g: 1,c,h,w
1108
+ """
1109
+ b, c, h, w = l.size()
1110
+ # 4,c,h,w -> 1,c,2h,2w
1111
+ concated_locs = rearrange(l,
1112
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1113
+ hg=2,
1114
+ wg=2)
1115
+ self.p_poses = []
1116
+ pools = []
1117
+ for pool_ratio in self.pool_ratios:
1118
+ # b,c,h,w
1119
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1120
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1121
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1122
+ # if self.g_pos is None:
1123
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
1124
+ pool.shape[3])
1125
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1126
+ self.p_poses.append(pos_emb)
1127
+ pools = torch.cat(pools, 0)
1128
+ # if self.g_pos is None:
1129
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1130
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
1131
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1132
+
1133
+ # attention between glb (q) & multisensory concated-locs (k,v)
1134
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1135
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1136
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1137
+ g_hw_b_c = self.norm1(g_hw_b_c)
1138
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1139
+ self.linear2(
1140
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1141
+ g_hw_b_c = self.norm2(g_hw_b_c)
1142
+
1143
+ # attention between origin locs (q) & freashed glb (k,v)
1144
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1145
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1146
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1147
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1148
+ ng=2,
1149
+ nw=2)
1150
+ outputs_re = []
1151
+ for i, (_l, _g) in enumerate(
1152
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1153
+ outputs_re.append(self.attention[i + 1](_l, _g,
1154
+ _g)[0]) # (h w) 1 c
1155
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1156
+
1157
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1158
+ l_hw_b_c = self.norm1(l_hw_b_c)
1159
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1160
+ self.linear4(
1161
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1162
+ l_hw_b_c = self.norm2(l_hw_b_c)
1163
+
1164
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1165
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1166
+
1167
+
1168
+ class MCRM(nn.Module):
1169
+
1170
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1171
+ super(MCRM, self).__init__()
1172
+ self.attention = nn.ModuleList([
1173
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1174
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1175
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1176
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1177
+ ])
1178
+
1179
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1180
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1181
+ self.norm1 = nn.LayerNorm(d_model)
1182
+ self.norm2 = nn.LayerNorm(d_model)
1183
+ self.dropout = nn.Dropout(0.1)
1184
+ self.dropout1 = nn.Dropout(0.1)
1185
+ self.dropout2 = nn.Dropout(0.1)
1186
+ self.sigmoid = nn.Sigmoid()
1187
+ self.activation = get_activation_fn('relu')
1188
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1189
+ self.pool_ratios = pool_ratios
1190
+ self.positional_encoding = PositionEmbeddingSine(
1191
+ num_pos_feats=d_model // 2, normalize=True)
1192
+
1193
+ def forward(self, x):
1194
+ b, c, h, w = x.size()
1195
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1196
+ # b(4),c,h,w
1197
+ patched_glb = rearrange(glb,
1198
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1199
+ hg=2,
1200
+ wg=2)
1201
+
1202
+ # generate token attention map
1203
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1204
+ token_attention_map = F.interpolate(token_attention_map,
1205
+ size=patches2image(loc).shape[-2:],
1206
+ mode='nearest')
1207
+ loc = loc * rearrange(token_attention_map,
1208
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1209
+ hg=2,
1210
+ wg=2)
1211
+ pools = []
1212
+ for pool_ratio in self.pool_ratios:
1213
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1214
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1215
+ pools.append(rearrange(pool,
1216
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1217
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1218
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1219
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1220
+ outputs = []
1221
+ for i, q in enumerate(
1222
+ loc_.unbind(dim=0)): # traverse all local patches
1223
+ # np*hw,1,c
1224
+ v = pools[i]
1225
+ k = v
1226
+ outputs.append(self.attention[i](q, k, v)[0])
1227
+ outputs = torch.cat(outputs, 1)
1228
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1229
+ src = self.norm1(src)
1230
+ src = src + self.dropout2(
1231
+ self.linear4(
1232
+ self.dropout(self.activation(self.linear3(src)).clone())))
1233
+ src = self.norm2(src)
1234
+
1235
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1236
+ glb = glb + F.interpolate(patches2image(src),
1237
+ size=glb.shape[-2:],
1238
+ mode='nearest') # freshed glb
1239
+ return torch.cat((src, glb), 0), token_attention_map
1240
+
1241
+
1242
+ class inf_MCRM(nn.Module):
1243
+
1244
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1245
+ super(inf_MCRM, self).__init__()
1246
+ self.attention = nn.ModuleList([
1247
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1248
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1249
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1250
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1251
+ ])
1252
+
1253
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1254
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1255
+ self.norm1 = nn.LayerNorm(d_model)
1256
+ self.norm2 = nn.LayerNorm(d_model)
1257
+ self.dropout = nn.Dropout(0.1)
1258
+ self.dropout1 = nn.Dropout(0.1)
1259
+ self.dropout2 = nn.Dropout(0.1)
1260
+ self.sigmoid = nn.Sigmoid()
1261
+ self.activation = get_activation_fn('relu')
1262
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1263
+ self.pool_ratios = pool_ratios
1264
+ self.positional_encoding = PositionEmbeddingSine(
1265
+ num_pos_feats=d_model // 2, normalize=True)
1266
+
1267
+ def forward(self, x):
1268
+ b, c, h, w = x.size()
1269
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1270
+ # b(4),c,h,w
1271
+ patched_glb = rearrange(glb,
1272
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1273
+ hg=2,
1274
+ wg=2)
1275
+
1276
+ # generate token attention map
1277
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1278
+ token_attention_map = F.interpolate(token_attention_map,
1279
+ size=patches2image(loc).shape[-2:],
1280
+ mode='nearest')
1281
+ loc = loc * rearrange(token_attention_map,
1282
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1283
+ hg=2,
1284
+ wg=2)
1285
+ pools = []
1286
+ for pool_ratio in self.pool_ratios:
1287
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1288
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1289
+ pools.append(rearrange(pool,
1290
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1291
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1292
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1293
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1294
+ outputs = []
1295
+ for i, q in enumerate(
1296
+ loc_.unbind(dim=0)): # traverse all local patches
1297
+ # np*hw,1,c
1298
+ v = pools[i]
1299
+ k = v
1300
+ outputs.append(self.attention[i](q, k, v)[0])
1301
+ outputs = torch.cat(outputs, 1)
1302
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1303
+ src = self.norm1(src)
1304
+ src = src + self.dropout2(
1305
+ self.linear4(
1306
+ self.dropout(self.activation(self.linear3(src)).clone())))
1307
+ src = self.norm2(src)
1308
+
1309
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1310
+ glb = glb + F.interpolate(patches2image(src),
1311
+ size=glb.shape[-2:],
1312
+ mode='nearest') # freshed glb
1313
+ return torch.cat((src, glb), 0)
1314
+
1315
+
1316
+ # model for single-scale training
1317
+ class MVANet(nn.Module):
1318
+
1319
+ def __init__(self):
1320
+ super().__init__()
1321
+ self.backbone = SwinB(pretrained=True)
1322
+ emb_dim = 128
1323
+ self.sideout5 = nn.Sequential(
1324
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1325
+ self.sideout4 = nn.Sequential(
1326
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1327
+ self.sideout3 = nn.Sequential(
1328
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1329
+ self.sideout2 = nn.Sequential(
1330
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1331
+ self.sideout1 = nn.Sequential(
1332
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1333
+
1334
+ self.output5 = make_cbr(1024, emb_dim)
1335
+ self.output4 = make_cbr(512, emb_dim)
1336
+ self.output3 = make_cbr(256, emb_dim)
1337
+ self.output2 = make_cbr(128, emb_dim)
1338
+ self.output1 = make_cbr(128, emb_dim)
1339
+
1340
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
1341
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1342
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1343
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1344
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1345
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
1346
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
1347
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
1348
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
1349
+
1350
+ self.insmask_head = nn.Sequential(
1351
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1352
+ nn.BatchNorm2d(384), nn.PReLU(),
1353
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1354
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1355
+
1356
+ self.shallow = nn.Sequential(
1357
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1358
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1359
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1360
+ self.output = nn.Sequential(
1361
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1362
+
1363
+ for m in self.modules():
1364
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1365
+ m.inplace = True
1366
+
1367
+ def forward(self, x):
1368
+ x = x.to(dtype=torch_dtype, device=torch_device)
1369
+ shallow = self.shallow(x)
1370
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1371
+ loc = image2patches(x)
1372
+ input = torch.cat((loc, glb), dim=0)
1373
+ feature = self.backbone(input)
1374
+ e5 = self.output5(feature[4]) # (5,128,16,16)
1375
+ e4 = self.output4(feature[3]) # (5,128,32,32)
1376
+ e3 = self.output3(feature[2]) # (5,128,64,64)
1377
+ e2 = self.output2(feature[1]) # (5,128,128,128)
1378
+ e1 = self.output1(feature[0]) # (5,128,128,128)
1379
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1380
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
1381
+
1382
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
1383
+ e4 = self.conv4(e4)
1384
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
1385
+ e3 = self.conv3(e3)
1386
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
1387
+ e2 = self.conv2(e2)
1388
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
1389
+ e1 = self.conv1(e1)
1390
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1391
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
1392
+ # add glb feat in
1393
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1394
+ # merge
1395
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
1396
+ # shallow feature merge
1397
+ final_output = final_output + resize_as(shallow, final_output)
1398
+ final_output = self.upsample1(rescale_to(final_output))
1399
+ final_output = rescale_to(final_output +
1400
+ resize_as(shallow, final_output))
1401
+ final_output = self.upsample2(final_output)
1402
+ final_output = self.output(final_output)
1403
+ ####
1404
+ sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
1405
+ sideout4 = self.sideout4(e4)
1406
+ sideout3 = self.sideout3(e3)
1407
+ sideout2 = self.sideout2(e2)
1408
+ sideout1 = self.sideout1(e1)
1409
+ #######glb_sideouts ######
1410
+ glb5 = self.sideout5(glb_e5)
1411
+ glb4 = sideout4[-1, :, :, :].unsqueeze(0)
1412
+ glb3 = sideout3[-1, :, :, :].unsqueeze(0)
1413
+ glb2 = sideout2[-1, :, :, :].unsqueeze(0)
1414
+ glb1 = sideout1[-1, :, :, :].unsqueeze(0)
1415
+ ####### concat 4 to 1 #######
1416
+ sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
1417
+ device=torch_device)
1418
+ sideout2 = patches2image(sideout2[:-1]).to(
1419
+ dtype=torch_dtype,
1420
+ device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
1421
+ sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
1422
+ device=torch_device)
1423
+ sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
1424
+ device=torch_device)
1425
+ sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
1426
+ device=torch_device)
1427
+ if self.training:
1428
+ return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
1429
+ else:
1430
+ return final_output
1431
+
1432
+
1433
+ # model for multi-scale testing
1434
+ class inf_MVANet(nn.Module):
1435
+
1436
+ def __init__(self):
1437
+ super().__init__()
1438
+ # self.backbone = SwinB(pretrained=True)
1439
+ self.backbone = SwinB(pretrained=False)
1440
+
1441
+ emb_dim = 128
1442
+ self.output5 = make_cbr(1024, emb_dim)
1443
+ self.output4 = make_cbr(512, emb_dim)
1444
+ self.output3 = make_cbr(256, emb_dim)
1445
+ self.output2 = make_cbr(128, emb_dim)
1446
+ self.output1 = make_cbr(128, emb_dim)
1447
+
1448
+ self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
1449
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1450
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1451
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1452
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1453
+ self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1454
+ self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1455
+ self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1456
+ self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1457
+
1458
+ self.insmask_head = nn.Sequential(
1459
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1460
+ nn.BatchNorm2d(384), nn.PReLU(),
1461
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1462
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1463
+
1464
+ self.shallow = nn.Sequential(
1465
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1466
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1467
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1468
+ self.output = nn.Sequential(
1469
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1470
+
1471
+ for m in self.modules():
1472
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1473
+ m.inplace = True
1474
+
1475
+ def forward(self, x):
1476
+ shallow = self.shallow(x)
1477
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1478
+ loc = image2patches(x)
1479
+ input = torch.cat((loc, glb), dim=0)
1480
+ feature = self.backbone(input)
1481
+ e5 = self.output5(feature[4])
1482
+ e4 = self.output4(feature[3])
1483
+ e3 = self.output3(feature[2])
1484
+ e2 = self.output2(feature[1])
1485
+ e1 = self.output1(feature[0])
1486
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1487
+ e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
1488
+
1489
+ e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
1490
+ e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
1491
+ e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
1492
+ e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
1493
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1494
+ # after decoder, concat loc features to a whole one, and merge
1495
+ output1_cat = patches2image(loc_e1)
1496
+ # add glb feat in
1497
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1498
+ # merge
1499
+ final_output = self.insmask_head(output1_cat)
1500
+ # shallow feature merge
1501
+ final_output = final_output + resize_as(shallow, final_output)
1502
+ final_output = self.upsample1(rescale_to(final_output))
1503
+ final_output = rescale_to(final_output +
1504
+ resize_as(shallow, final_output))
1505
+ final_output = self.upsample2(final_output)
1506
+ final_output = self.output(final_output)
1507
+ return final_output
1508
+ #+end_src
1509
+
1510
+ ** Function to load model
1511
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1512
+ def load_model(model_checkpoint_path):
1513
+ torch.cuda.set_device(0)
1514
+
1515
+ net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
1516
+
1517
+ pretrained_dict = torch.load(model_checkpoint_path,
1518
+ map_location=torch_device)
1519
+
1520
+ model_dict = net.state_dict()
1521
+ pretrained_dict = {
1522
+ k: v
1523
+ for k, v in pretrained_dict.items() if k in model_dict
1524
+ }
1525
+ model_dict.update(pretrained_dict)
1526
+ net.load_state_dict(model_dict)
1527
+ net = net.to(dtype=torch_dtype, device=torch_device)
1528
+ net.eval()
1529
+ return net
1530
+
1531
+
1532
+ def load_transforms_stripped():
1533
+ img_transform = transforms.Compose([
1534
+ # transforms.ToTensor(),
1535
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
1536
+ ])
1537
+
1538
+ return img_transform
1539
+
1540
+
1541
+ def load_transforms():
1542
+ img_transform = transforms.Compose([
1543
+ # transforms.ToTensor(),
1544
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
1545
+ ])
1546
+
1547
+ depth_transform = transforms.ToTensor()
1548
+ target_transform = transforms.ToTensor()
1549
+ to_pil = transforms.ToPILImage()
1550
+
1551
+ transforms_var = tta.Compose([
1552
+ tta.HorizontalFlip(),
1553
+ tta.Scale(scales=[0.75, 1, 1.25],
1554
+ interpolation='bilinear',
1555
+ align_corners=False),
1556
+ ])
1557
+
1558
+ return (img_transform, depth_transform, target_transform, to_pil,
1559
+ transforms_var)
1560
+ #+end_src
1561
+
1562
+ ** Function for modular inference CV
1563
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1564
+ def do_infer_tensor2tensor(img, net):
1565
+
1566
+ img_transform = transforms.Compose(
1567
+ [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
1568
+
1569
+ h_, w_ = img.shape[1], img.shape[2]
1570
+
1571
+ with torch.no_grad():
1572
+
1573
+ img = rearrange(img, 'B H W C -> B C H W')
1574
+
1575
+ img_resize = torch.nn.functional.interpolate(input=img,
1576
+ size=(1024, 1024),
1577
+ mode='bicubic',
1578
+ antialias=True)
1579
+
1580
+ img_var = img_transform(img_resize)
1581
+ img_var = Variable(img_var)
1582
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1583
+
1584
+ mask = []
1585
+
1586
+ mask.append(net(img_var))
1587
+
1588
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1589
+ prediction = prediction.sigmoid()
1590
+
1591
+ prediction = torch.nn.functional.interpolate(input=prediction,
1592
+ size=(h_, w_),
1593
+ mode='bicubic',
1594
+ antialias=True)
1595
+
1596
+ prediction = prediction.squeeze(0)
1597
+ prediction = prediction.clamp(0, 1)
1598
+
1599
+ return prediction
1600
+
1601
+
1602
+ def do_infer_modular_cv(input_image_path, output_mask_path, net,
1603
+ all_transforms):
1604
+
1605
+ (img_transform, depth_transform, target_transform, to_pil,
1606
+ transforms_var) = all_transforms
1607
+
1608
+ img = load_image_torch(input_image_path)
1609
+
1610
+ h_, w_ = img.shape[1], img.shape[2]
1611
+
1612
+ with torch.no_grad():
1613
+
1614
+ img = rearrange(img, 'B H W C -> B C H W')
1615
+
1616
+ img_resize = torch.nn.functional.interpolate(input=img,
1617
+ size=(1024, 1024),
1618
+ mode='bicubic',
1619
+ antialias=True)
1620
+
1621
+ img_var = img_transform(img_resize)
1622
+ img_var = Variable(img_var)
1623
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1624
+
1625
+ mask = []
1626
+
1627
+ for transformer in transforms_var:
1628
+ rgb_trans = img_var.to(dtype=torch_dtype, device=torch_device)
1629
+ mask.append(net(rgb_trans))
1630
+
1631
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1632
+ prediction = prediction.sigmoid()
1633
+
1634
+ prediction = torch.nn.functional.interpolate(input=prediction,
1635
+ size=(h_, w_),
1636
+ mode='bicubic',
1637
+ antialias=True)
1638
+
1639
+ prediction = prediction.squeeze(0)
1640
+ prediction = prediction.clamp(0, 1)
1641
+
1642
+ save_mask_torch(output_image_path=output_mask_path, mask=prediction)
1643
+
1644
+
1645
+ def do_infer_modular_cv_2(input_image_path, output_mask_path, net,
1646
+ all_transforms):
1647
+
1648
+ (img_transform, depth_transform, target_transform, to_pil,
1649
+ transforms_var) = all_transforms
1650
+
1651
+ img = load_image(input_image_path)
1652
+ w_, h_ = img.shape[0], img.shape[1]
1653
+ img_resize = cv2.resize(img, (1024, 1024), cv2.INTER_CUBIC)
1654
+
1655
+ with torch.no_grad():
1656
+
1657
+ # rgb_png_path = input_image_path
1658
+ # img = Image.open(rgb_png_path).convert('RGB')
1659
+ # w_, h_ = img.size
1660
+
1661
+ # img_resize = img.resize([256 * 4, 256 * 4], Image.BILINEAR)
1662
+
1663
+ # img_var = Variable(img_transform(img_resize).unsqueeze(0)).to(
1664
+ # dtype=torch_dtype, device=torch_device)
1665
+
1666
+ img_resize = torch.from_numpy(img_resize)
1667
+ img_resize = img_resize.to(dtype=torch.float32)
1668
+ img_resize /= 255.0
1669
+ img_resize = rearrange(img_resize, 'H W C -> C H W')
1670
+ img_var = img_transform(img_resize)
1671
+ img_var = img_var.unsqueeze(0)
1672
+ img_var = Variable(img_var)
1673
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1674
+
1675
+ mask = []
1676
+
1677
+ for transformer in transforms_var:
1678
+ rgb_trans = transformer.augment_image(img_var)
1679
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1680
+ model_output = net(rgb_trans)
1681
+ deaug_mask = transformer.deaugment_mask(model_output)
1682
+ mask.append(deaug_mask)
1683
+
1684
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1685
+ prediction = prediction.sigmoid()
1686
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1687
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1688
+ prediction.save(output_mask_path)
1689
+
1690
+
1691
+ def do_infer_modular_cv_3(input_image_path, output_mask_path, net,
1692
+ all_transforms):
1693
+
1694
+ (img_transform, depth_transform, target_transform, to_pil,
1695
+ transforms_var) = all_transforms
1696
+
1697
+ img = load_image(input_image_path)
1698
+ w_, h_ = img.shape[0], img.shape[1]
1699
+
1700
+ with torch.no_grad():
1701
+
1702
+ # rgb_png_path = input_image_path
1703
+ # img = Image.open(rgb_png_path).convert('RGB')
1704
+ # w_, h_ = img.size
1705
+
1706
+ # img_resize = img.resize([256 * 4, 256 * 4], Image.BILINEAR)
1707
+
1708
+ # img_var = Variable(img_transform(img_resize).unsqueeze(0)).to(
1709
+ # dtype=torch_dtype, device=torch_device)
1710
+
1711
+ img_resize = torch.from_numpy(img)
1712
+ img_resize = img_resize.to(dtype=torch.float32)
1713
+ img_resize = rearrange(img_resize, 'H W C -> C H W')
1714
+ img_resize = img_resize.unsqueeze(0)
1715
+
1716
+ img_resize = torch.nn.functional.interpolate(input=img_resize,
1717
+ size=(1024, 1024),
1718
+ mode='bicubic',
1719
+ antialias=True)
1720
+
1721
+ img_resize = img_resize.squeeze(0)
1722
+ img_resize = rearrange(img_resize, 'C H W -> H W C')
1723
+
1724
+ img_resize = img_resize.to(dtype=torch.float32)
1725
+ img_resize /= 255.0
1726
+ img_resize = rearrange(img_resize, 'H W C -> C H W')
1727
+ img_var = img_transform(img_resize)
1728
+ img_var = img_var.unsqueeze(0)
1729
+ img_var = Variable(img_var)
1730
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1731
+
1732
+ mask = []
1733
+
1734
+ for transformer in transforms_var:
1735
+ rgb_trans = transformer.augment_image(img_var)
1736
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1737
+ model_output = net(rgb_trans)
1738
+ deaug_mask = transformer.deaugment_mask(model_output)
1739
+ mask.append(deaug_mask)
1740
+
1741
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1742
+ prediction = prediction.sigmoid()
1743
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1744
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1745
+ prediction.save(output_mask_path)
1746
+
1747
+
1748
+ def do_infer_modular_cv_4(input_image_path, output_mask_path, net,
1749
+ all_transforms):
1750
+
1751
+ (img_transform, depth_transform, target_transform, to_pil,
1752
+ transforms_var) = all_transforms
1753
+
1754
+ img = load_image(input_image_path)
1755
+ w_, h_ = img.shape[0], img.shape[1]
1756
+
1757
+ with torch.no_grad():
1758
+
1759
+ img_resize = torch.from_numpy(img)
1760
+ img_resize = img_resize.to(dtype=torch.float32)
1761
+ img_resize /= 255.0
1762
+ img_resize = img_resize.unsqueeze(0)
1763
+
1764
+ img_resize = rearrange(img_resize, 'B H W C -> B C H W')
1765
+
1766
+ img_resize = torch.nn.functional.interpolate(input=img_resize,
1767
+ size=(1024, 1024),
1768
+ mode='bicubic',
1769
+ antialias=True)
1770
+
1771
+ img_resize = img_resize.squeeze(0)
1772
+ img_var = img_transform(img_resize)
1773
+ img_var = img_var.unsqueeze(0)
1774
+ img_var = Variable(img_var)
1775
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1776
+
1777
+ mask = []
1778
+
1779
+ for transformer in transforms_var:
1780
+ rgb_trans = transformer.augment_image(img_var)
1781
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1782
+ model_output = net(rgb_trans)
1783
+ deaug_mask = transformer.deaugment_mask(model_output)
1784
+ mask.append(deaug_mask)
1785
+
1786
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1787
+ prediction = prediction.sigmoid()
1788
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1789
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1790
+ prediction.save(output_mask_path)
1791
+
1792
+
1793
+ def do_infer_modular_cv_5(input_image_path, output_mask_path, net,
1794
+ all_transforms):
1795
+
1796
+ (img_transform, depth_transform, target_transform, to_pil,
1797
+ transforms_var) = all_transforms
1798
+
1799
+ img = load_image(input_image_path)
1800
+ w_, h_ = img.shape[0], img.shape[1]
1801
+
1802
+ with torch.no_grad():
1803
+
1804
+ img_resize = torch.from_numpy(img)
1805
+ img_resize = img_resize.to(dtype=torch.float32)
1806
+ img_resize /= 255.0
1807
+ img_resize = img_resize.unsqueeze(0)
1808
+
1809
+ img_resize = rearrange(img_resize, 'B H W C -> B C H W')
1810
+
1811
+ img_resize = torch.nn.functional.interpolate(input=img_resize,
1812
+ size=(1024, 1024),
1813
+ mode='bicubic',
1814
+ antialias=True)
1815
+
1816
+ img_var = img_transform(img_resize)
1817
+ img_var = Variable(img_var)
1818
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1819
+
1820
+ mask = []
1821
+
1822
+ for transformer in transforms_var:
1823
+ rgb_trans = transformer.augment_image(img_var)
1824
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1825
+ model_output = net(rgb_trans)
1826
+ deaug_mask = transformer.deaugment_mask(model_output)
1827
+ mask.append(deaug_mask)
1828
+
1829
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1830
+ prediction = prediction.sigmoid()
1831
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1832
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1833
+ prediction.save(output_mask_path)
1834
+
1835
+
1836
+ def do_infer_modular_cv_6(input_image_path, output_mask_path, net,
1837
+ all_transforms):
1838
+
1839
+ (img_transform, depth_transform, target_transform, to_pil,
1840
+ transforms_var) = all_transforms
1841
+
1842
+ img = load_image(input_image_path)
1843
+ w_, h_ = img.shape[0], img.shape[1]
1844
+
1845
+ with torch.no_grad():
1846
+
1847
+ img_resize = torch.from_numpy(img)
1848
+ img_resize = img_resize.to(dtype=torch.float32)
1849
+ img_resize /= 255.0
1850
+ img_resize = img_resize.unsqueeze(0)
1851
+
1852
+ img_resize = rearrange(img_resize, 'B H W C -> B C H W')
1853
+
1854
+ img_resize = torch.nn.functional.interpolate(input=img_resize,
1855
+ size=(1024, 1024),
1856
+ mode='bicubic',
1857
+ antialias=True)
1858
+
1859
+ img_var = img_transform(img_resize)
1860
+ img_var = Variable(img_var)
1861
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1862
+
1863
+ mask = []
1864
+
1865
+ for transformer in transforms_var:
1866
+ rgb_trans = img_var.to(dtype=torch_dtype, device=torch_device)
1867
+ mask.append(net(rgb_trans))
1868
+
1869
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1870
+ prediction = prediction.sigmoid()
1871
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1872
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1873
+ prediction.save(output_mask_path)
1874
+
1875
+
1876
+ def do_infer_modular_cv_7(input_image_path, output_mask_path, net,
1877
+ all_transforms):
1878
+
1879
+ (img_transform, depth_transform, target_transform, to_pil,
1880
+ transforms_var) = all_transforms
1881
+
1882
+ img = load_image_torch(input_image_path)
1883
+
1884
+ h_, w_ = img.shape[1], img.shape[2]
1885
+
1886
+ with torch.no_grad():
1887
+
1888
+ img = rearrange(img, 'B H W C -> B C H W')
1889
+
1890
+ img_resize = torch.nn.functional.interpolate(input=img,
1891
+ size=(1024, 1024),
1892
+ mode='bicubic',
1893
+ antialias=True)
1894
+
1895
+ img_var = img_transform(img_resize)
1896
+ img_var = Variable(img_var)
1897
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1898
+
1899
+ mask = []
1900
+
1901
+ for transformer in transforms_var:
1902
+ rgb_trans = img_var.to(dtype=torch_dtype, device=torch_device)
1903
+ mask.append(net(rgb_trans))
1904
+
1905
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1906
+ prediction = prediction.sigmoid()
1907
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1908
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1909
+ prediction.save(output_mask_path)
1910
+ #+end_src
1911
+
1912
+ ** Function for modular inference
1913
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1914
+ def do_infer_modular(input_image_path, output_mask_path, net, all_transforms):
1915
+ # net = load_model(finetuned_MVANet_model_path)
1916
+
1917
+ (img_transform, depth_transform, target_transform, to_pil,
1918
+ transforms_var) = all_transforms
1919
+
1920
+ with torch.no_grad():
1921
+ rgb_png_path = input_image_path
1922
+ img = Image.open(rgb_png_path).convert('RGB')
1923
+
1924
+ w_, h_ = img.size
1925
+ # img_resize = img.resize([(w_ // 2) * 2, (h_ // 2) * 2], Image.BILINEAR)
1926
+ img_resize = img.resize([256 * 4, 256 * 4], Image.BILINEAR)
1927
+ # img_resize = img
1928
+ img_var = Variable(img_transform(img_resize).unsqueeze(0)).to(
1929
+ dtype=torch_dtype, device=torch_device)
1930
+ mask = []
1931
+ for transformer in transforms_var:
1932
+ rgb_trans = transformer.augment_image(img_var)
1933
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1934
+ model_output = net(rgb_trans)
1935
+ deaug_mask = transformer.deaugment_mask(model_output)
1936
+ mask.append(deaug_mask)
1937
+
1938
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1939
+ prediction = prediction.sigmoid()
1940
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1941
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1942
+ prediction.save(output_mask_path)
1943
+ #+end_src
1944
+
1945
+ ** Function for inference
1946
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1947
+ def do_infer():
1948
+ torch.cuda.set_device(0)
1949
+ args = {'crf_refine': True, 'save_results': True}
1950
+
1951
+ img_transform = transforms.Compose([
1952
+ transforms.ToTensor(),
1953
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
1954
+ ])
1955
+
1956
+ depth_transform = transforms.ToTensor()
1957
+ target_transform = transforms.ToTensor()
1958
+ to_pil = transforms.ToPILImage()
1959
+
1960
+ transforms_var = tta.Compose([
1961
+ tta.HorizontalFlip(),
1962
+ tta.Scale(scales=[0.75, 1, 1.25],
1963
+ interpolation='bilinear',
1964
+ align_corners=False),
1965
+ ])
1966
+
1967
+ net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
1968
+ pretrained_dict = torch.load(finetuned_MVANet_model_path,
1969
+ map_location=torch_device)
1970
+ model_dict = net.state_dict()
1971
+ pretrained_dict = {
1972
+ k: v
1973
+ for k, v in pretrained_dict.items() if k in model_dict
1974
+ }
1975
+ model_dict.update(pretrained_dict)
1976
+ net.load_state_dict(model_dict)
1977
+ net = net.to(dtype=torch_dtype, device=torch_device)
1978
+ net.eval()
1979
+ with torch.no_grad():
1980
+ rgb_png_path = '/home/asd/DATASETS/SD_BG_SWAP_TEST/comfyui_outputs/4/output_fooocus/bgswap-output.png'
1981
+ img = Image.open(rgb_png_path).convert('RGB')
1982
+ w_, h_ = img.size
1983
+ # img_resize = img.resize([(w_ // 2) * 2, (h_ // 2) * 2], Image.BILINEAR)
1984
+ img_resize = img.resize([256 * 4 , 256 * 4 ], Image.BILINEAR)
1985
+ # img_resize = img
1986
+ img_var = Variable(img_transform(img_resize).unsqueeze(0),
1987
+ volatile=True).cuda()
1988
+ mask = []
1989
+ for transformer in transforms_var:
1990
+ rgb_trans = transformer.augment_image(img_var)
1991
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1992
+ model_output = net(rgb_trans)
1993
+ deaug_mask = transformer.deaugment_mask(model_output)
1994
+ mask.append(deaug_mask)
1995
+
1996
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1997
+ prediction = prediction.sigmoid()
1998
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1999
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
2000
+ prediction.save('./tmp.png')
2001
+ #+end_src
2002
+
2003
+ ** MVANet_inference function
2004
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
2005
+ def main(item):
2006
+ net = inf_MVANet().cuda()
2007
+ pretrained_dict = torch.load(os.path.join(ckpt_path, item + '.pth'),
2008
+ map_location='cuda')
2009
+ model_dict = net.state_dict()
2010
+ pretrained_dict = {
2011
+ k: v
2012
+ for k, v in pretrained_dict.items() if k in model_dict
2013
+ }
2014
+ model_dict.update(pretrained_dict)
2015
+ net.load_state_dict(model_dict)
2016
+ net.eval()
2017
+ with torch.no_grad():
2018
+ for name, root in to_test.items():
2019
+ root1 = os.path.join(root, 'images')
2020
+ img_list = [os.path.splitext(f) for f in os.listdir(root1)]
2021
+ for idx, img_name in enumerate(img_list):
2022
+
2023
+ print('predicting for %s: %d / %d' %
2024
+ (name, idx + 1, len(img_list)))
2025
+ rgb_png_path = os.path.join(root, 'images',
2026
+ img_name[0] + '.png')
2027
+ rgb_jpg_path = os.path.join(root, 'images',
2028
+ img_name[0] + '.jpg')
2029
+ if os.path.exists(rgb_png_path):
2030
+ img = Image.open(rgb_png_path).convert('RGB')
2031
+ else:
2032
+ img = Image.open(rgb_jpg_path).convert('RGB')
2033
+ w_, h_ = img.size
2034
+ img_resize = img.resize([1024, 1024], Image.BILINEAR)
2035
+ img_var = Variable(img_transform(img_resize).unsqueeze(0),
2036
+ volatile=True).cuda()
2037
+ mask = []
2038
+ for transformer in transforms_var:
2039
+ rgb_trans = transformer.augment_image(img_var)
2040
+ model_output = net(rgb_trans)
2041
+ deaug_mask = transformer.deaugment_mask(model_output)
2042
+ mask.append(deaug_mask)
2043
+
2044
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
2045
+ prediction = prediction.sigmoid()
2046
+ prediction = to_pil(prediction.data.squeeze(0))
2047
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
2048
+ if args['save_results']:
2049
+ check_mkdir(os.path.join(ckpt_path, item, name))
2050
+ prediction.save(
2051
+ os.path.join(ckpt_path, item, name,
2052
+ img_name[0] + '.png'))
2053
+ #+end_src
2054
+
2055
+ ** MVANet_inference execute
2056
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
2057
+ def do_merge(path_image, path_mask, path_out):
2058
+ image = cv2.imread(path_image, cv2.IMREAD_COLOR)
2059
+ mask = cv2.imread(path_mask, cv2.IMREAD_GRAYSCALE)
2060
+ mask = (mask > 127).astype(dtype=np.uint8) * 255
2061
+ out = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
2062
+ out[:, :, 0:3] = image
2063
+ out[:, :, 3] = mask
2064
+ cv2.imwrite(path_out, out)
2065
+
2066
+
2067
+ if __name__ == '__main__':
2068
+
2069
+ # do_infer_modular_cv(
2070
+ # input_image_path=
2071
+ # '/home/asd/DATASETS/SD_BG_SWAP_TEST/comfyui_outputs/4/output_fooocus/bgswap-output.png',
2072
+ # output_mask_path='./tmp.png',
2073
+ # net=load_model(finetuned_MVANet_model_path),
2074
+ # all_transforms=load_transforms(),
2075
+ # )
2076
+
2077
+ # net = load_model(
2078
+ # HOME_DIR + '/dreambooth_experiments/MVANet/MVANet_cloth_segment_14.pth')
2079
+
2080
+ # net = load_model(
2081
+ # HOME_DIR +
2082
+ # '/dreambooth_experiments/MVANet/new_type_crop_with_midshot.pth')
2083
+
2084
+ # net = load_model('/home/asd/MODEL_CHECKPOINTS/MVANet/SKIN_SEGMENTATION/1/Model_4.pth')
2085
+
2086
+ net = load_model('/home/asd/MODEL_CHECKPOINTS/MVANet/SKIN_SEGMENTATION/3/Model_14.pth')
2087
+
2088
+
2089
+ # net = load_model(HOME_DIR +
2090
+ # '/dreambooth_experiments/MVANet/mvanet_normal_crop_2.pth')
2091
+
2092
+ DATA_DIR_BASE = HOME_DIR + '/DATASETS/cloth_segmentation_test_images.dir/cloth_segmentation_test_images/'
2093
+
2094
+ images = (
2095
+ '1370', '1371', '1372', '1373', '1374', '1375', '1376', '1377', '1378',
2096
+ '1379', '1380', '1381', '1382', '1383', '1384', '1385', '1386', '1387',
2097
+ '1388', '1389', '1390', '1391', '1392', '1393', '1394', '1395', '1396',
2098
+ '1397', '1398', '1399', '1400', '1401', '1402', '1403', '1404', '1405',
2099
+ '1406', '1407', '1408', '1409', '1410', '1411', '1412', '1413', '1414',
2100
+ '1415', '1539', '1541', '1542', '1543', '17320', '4129', '4190',
2101
+ '4191', '4192', '4193', '4202', '4203', '4204', '4207', '4208', '4209',
2102
+ '4210', '4213', '4214', '4221', '4222', '4223', '4224', '4225', '4226',
2103
+ '4227', '4228', '4229', '4230', '4231', '4232', '4233', '4234', '4235',
2104
+ '4236', '4237', '4238', '4239', '4240', '4241', '4242', '4251', '4252',
2105
+ '4253', '4254', '4255', '4256', '4257', '4258', '4259', '4260', '4261',
2106
+ '4262', '4263', '4264', '6581', '6642', '6647', '6656', '6660', '6690',
2107
+ '6696', '6724', '6767', '6771', '6788', '6791', '6807', '6821', '6824',
2108
+ '6833', '6847', '6850', '6879', '6941', '7001', '7070', '7083', '7092',
2109
+ '7093', '7119', '7191', '7220', '7252', '7264', '7276', '7278', '7281',
2110
+ '7290', '7301', '7312', '7340', '7398', '7404', '7412', '7429', '7439',
2111
+ '7478', '7491', '7631', '7687', '7699', '7719', '7770', '7784', '7793',
2112
+ '7811', '7829', '7861', '7864', '7868', '7980', '7987', '7990', '8069',
2113
+ '8083', '8100', '8108', '8227', '8323', '8329', '8358', '8383', '8401',
2114
+ '8415', '8488', '8515', '8518', '8560', '8565', '8595', '8639', '8676',
2115
+ '8690', '8691', '8701', '8703', '8723', '8726', '8756', '8783', '8801',
2116
+ '8820', '8826', '8842', '8865', '8874', '8875', '8882', '8911', '8946',
2117
+ '8947', '8969', '8979', '8983')
2118
+
2119
+ masks = [DATA_DIR_BASE + i + '/garment_mask.png' for i in images]
2120
+ out = [DATA_DIR_BASE + i + '/garment_transparent.png' for i in images]
2121
+
2122
+ images = [DATA_DIR_BASE + i + '/original.jpg' for i in images]
2123
+
2124
+ for i in range(len(images)):
2125
+ image = images[i]
2126
+ image = load_image_torch(image)
2127
+ mask = do_infer_tensor2tensor(image, net)
2128
+ save_mask_torch(output_image_path=masks[i], mask=mask)
2129
+ do_merge(path_image=images[i], path_mask=masks[i], path_out=out[i])
2130
+
2131
+ # img = load_image_torch(
2132
+ # '/home/asd/DATASETS/SD_BG_SWAP_TEST/comfyui_outputs/4/output_fooocus/bgswap-output.png'
2133
+ # )
2134
+ # # all_transforms = load_transforms()
2135
+ # masks = do_infer_tensor2tensor(img, net)
2136
+ # save_mask_torch(output_image_path='./tmp.png', mask=masks)
2137
+ #+end_src
2138
+
2139
+ ** MVANet_inference unify
2140
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
2141
+ . "${HOME}/dbnew.sh"
2142
+
2143
+ (
2144
+ echo '#!/usr/bin/python3'
2145
+ cat \
2146
+ './MVANet_inference.import.py' \
2147
+ './MVANet_inference.function.py' \
2148
+ './MVANet_inference.class.py' \
2149
+ './MVANet_inference.execute.py' \
2150
+ | expand | yapf3 \
2151
+ | grep -v '#!/usr/bin/python3' \
2152
+ ;
2153
+ ) > './MVANet_inference.py' \
2154
+ ;
2155
+ #+end_src
2156
+
2157
+ ** MVANet_inference run
2158
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.run.sh
2159
+ . "${HOME}/dbnew.sh"
2160
+ python3 './MVANet_inference.py'
2161
+ #+end_src
2162
+
2163
+ * WORK SPACE
2164
+
2165
+ ** elisp
2166
+ #+begin_src elisp
2167
+ (save-buffer)
2168
+ (org-babel-tangle)
2169
+ (shell-command "./MVANet_inference.unify.sh")
2170
+ #+end_src
2171
+
2172
+ #+RESULTS:
2173
+ : 0
2174
+
2175
+ ** sh
2176
+ #+begin_src sh :shebang #!/bin/sh :results output
2177
+ realpath .
2178
+ cd /home/asd/GITHUB/aravind-h-v/dreambooth_experiments/MVANet
2179
+ #+end_src
main.org CHANGED
@@ -4,8 +4,9 @@ cd $HOME/HUGGINGFACE/aravindhv10/Self-Correction-Human-Parsing
4
  ** ELISP
5
  #+begin_src elisp
6
  (save-buffer)
 
7
  (org-babel-tangle)
8
- (shell-command "./work.sh")
9
  #+end_src
10
 
11
  #+RESULTS:
@@ -13,38 +14,14 @@ cd $HOME/HUGGINGFACE/aravindhv10/Self-Correction-Human-Parsing
13
 
14
  ** ELISP
15
  #+begin_src elisp
16
- (shell-command "./commit_and_push.sh")
17
  #+end_src
18
 
19
- ** SHELL
20
- #+begin_src sh :shebang #!/bin/sh :results output
21
- git status
22
  #+end_src
23
 
24
- #+RESULTS:
25
- #+begin_example
26
- On branch main
27
- Your branch is up to date with 'origin/main'.
28
-
29
- Changes to be committed:
30
- (use "git restore --staged <file>..." to unstage)
31
- modified: .gitattributes
32
- modified: .gitignore
33
- new file: ComfyUI_AEMatter/AEMatter.run.sh
34
- new file: ComfyUI_MVANet/MVANet_inference.run.sh
35
- new file: ComfyUI_MVANet/download.sh
36
- new file: checkpoints/MVANet/garment.pth
37
- new file: checkpoints/MVANet/skin.pth
38
- new file: demo/demo.jpg
39
- new file: demo/demo_atr.png
40
- new file: demo/demo_lip.png
41
- new file: demo/demo_pascal.png
42
- new file: demo/lip-visualization.jpg
43
- new file: main.org
44
- new file: training_code/MVANet/README.org
45
-
46
- #+end_example
47
-
48
  * Commit and push
49
  #+begin_src sh :shebang #!/bin/sh :results output :tangle ./commit_and_push.sh
50
  git commit -m 'Routine updates'
@@ -608,6 +585,7 @@ Changes to be committed:
608
  utils/soft_dice_loss.py
609
  utils/transforms.py
610
  utils/warmup_scheduler.py
 
611
  #+end_src
612
 
613
  * List of files to remove
 
4
  ** ELISP
5
  #+begin_src elisp
6
  (save-buffer)
7
+ (save-some-buffers)
8
  (org-babel-tangle)
9
+ (shell-command "./work.sh" "output_log_work")
10
  #+end_src
11
 
12
  #+RESULTS:
 
14
 
15
  ** ELISP
16
  #+begin_src elisp
17
+ (shell-command "git status" "output_log_git_status")
18
  #+end_src
19
 
20
+ ** ELISP
21
+ #+begin_src elisp
22
+ (shell-command "./commit_and_push.sh" "output_log_commit_and_push")
23
  #+end_src
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  * Commit and push
26
  #+begin_src sh :shebang #!/bin/sh :results output :tangle ./commit_and_push.sh
27
  git commit -m 'Routine updates'
 
585
  utils/soft_dice_loss.py
586
  utils/transforms.py
587
  utils/warmup_scheduler.py
588
+ MVANet_Inference/README.org
589
  #+end_src
590
 
591
  * List of files to remove