aravindhv10 commited on
Commit
fb73232
·
1 Parent(s): 224fe79

Added AEMatter ComfyUI node

Browse files
.gitignore CHANGED
@@ -19,3 +19,10 @@ pretrain_model/
19
  **/__pycache__
20
  /rm.txt
21
  /waste.txt
 
 
 
 
 
 
 
 
19
  **/__pycache__
20
  /rm.txt
21
  /waste.txt
22
+ ComfyUI_AEMatter/AEMatter.execute.py
23
+ ComfyUI_AEMatter/__pycache__/__init__.cpython-310.pyc
24
+ ComfyUI_AEMatter/AEMatter.run.sh
25
+ ComfyUI_AEMatter/AEMatter.class.py
26
+ ComfyUI_AEMatter/AEMatter.import.py
27
+ ComfyUI_AEMatter/AEMatter.function.py
28
+ ComfyUI_AEMatter/AEMatter.unify.sh
ComfyUI_AEMatter/AEMatter.py ADDED
@@ -0,0 +1,1248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import cv2
3
+ import math
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import wget
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import init
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+
15
+ from collections import OrderedDict
16
+ from einops import rearrange, repeat
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+
19
+ import folder_paths
20
+ from folder_paths import models_dir
21
+
22
+
23
+ #!/usr/bin/python3
24
+ def mkdir_safe(out_path):
25
+ if type(out_path) == str:
26
+ if len(out_path) > 0:
27
+ if not os.path.exists(out_path):
28
+ os.mkdir(out_path)
29
+
30
+
31
+ def get_model_path():
32
+ import folder_paths
33
+ from folder_paths import models_dir
34
+
35
+ path_file_model = models_dir
36
+ mkdir_safe(out_path=path_file_model)
37
+
38
+ path_file_model = os.path.join(path_file_model, 'AEMatter')
39
+ mkdir_safe(out_path=path_file_model)
40
+
41
+ path_file_model = os.path.join(path_file_model, 'AEM_RWA.ckpt')
42
+
43
+ return path_file_model
44
+
45
+
46
+ def download_model(path):
47
+ if not os.path.exists(path):
48
+ wget.download(
49
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/AEMatter/AEM_RWA.ckpt?download=true',
50
+ out=path)
51
+
52
+
53
+ def from_torch_image(image):
54
+ image = image.cpu().numpy() * 255.0
55
+ image = np.clip(image, 0, 255).astype(np.uint8)
56
+ return image
57
+
58
+
59
+ def to_torch_image(image):
60
+ image = image.astype(dtype=np.float32)
61
+ image /= 255.0
62
+ image = torch.from_numpy(image)
63
+ return image
64
+
65
+
66
+ def window_partition(x, window_size):
67
+ """
68
+ Args:
69
+ x: (B, H, W, C)
70
+ window_size (int): window size
71
+ Returns:
72
+ windows: (num_windows*B, window_size, window_size, C)
73
+ """
74
+ B, H, W, C = x.shape
75
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
76
+ C)
77
+ windows = x.permute(0, 1, 3, 2, 4,
78
+ 5).contiguous().view(-1, window_size, window_size, C)
79
+ return windows
80
+
81
+
82
+ def window_reverse(windows, window_size, H, W):
83
+ """
84
+ Args:
85
+ windows: (num_windows*B, window_size, window_size, C)
86
+ window_size (int): Window size
87
+ H (int): Height of image
88
+ W (int): Width of image
89
+ Returns:
90
+ x: (B, H, W, C)
91
+ """
92
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
93
+ x = windows.view(B, H // window_size, W // window_size, window_size,
94
+ window_size, -1)
95
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
96
+ return x
97
+
98
+
99
+ def get_AEMatter_model(path_model_checkpoint):
100
+
101
+ download_model(path=path_model_checkpoint)
102
+
103
+ matmodel = AEMatter()
104
+ matmodel.load_state_dict(
105
+ torch.load(path_model_checkpoint, map_location='cpu')['model'])
106
+
107
+ matmodel = matmodel.cuda()
108
+ matmodel.eval()
109
+
110
+ return matmodel
111
+
112
+
113
+ def do_infer(rawimg, trimap, matmodel):
114
+ trimap_nonp = trimap.copy()
115
+ h, w, c = rawimg.shape
116
+ nonph, nonpw, _ = rawimg.shape
117
+ newh = (((h - 1) // 32) + 1) * 32
118
+ neww = (((w - 1) // 32) + 1) * 32
119
+ padh = newh - h
120
+ padh1 = int(padh / 2)
121
+ padh2 = padh - padh1
122
+ padw = neww - w
123
+ padw1 = int(padw / 2)
124
+ padw2 = padw - padw1
125
+
126
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
127
+ cv2.BORDER_REFLECT)
128
+
129
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
130
+ cv2.BORDER_REFLECT)
131
+
132
+ h_pad, w_pad, _ = rawimg_pad.shape
133
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
134
+ tritemp[:, :, 0] = (trimap_pad == 0)
135
+ tritemp[:, :, 1] = (trimap_pad == 128)
136
+ tritemp[:, :, 2] = (trimap_pad == 255)
137
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
138
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
139
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
140
+ img = np.array(img, np.float32)
141
+ img = img / 255.
142
+ img = torch.from_numpy(img).cuda()
143
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
144
+ with torch.no_grad():
145
+ pred = matmodel(img, tritempimgs)
146
+ pred = pred.detach().cpu().numpy()[0]
147
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
148
+ preda = pred[
149
+ 0:1,
150
+ ] * 255
151
+ preda = np.transpose(preda, (1, 2, 0))
152
+ preda = preda * (trimap_nonp[:, :, None]
153
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
154
+ preda = np.array(preda, np.uint8)
155
+ return preda
156
+
157
+
158
+ def main():
159
+ ptrimap = '/home/asd/Desktop/demo/retriever_trimap.png'
160
+ pimgs = '/home/asd/Desktop/demo/retriever_rgb.png'
161
+ p_outs = 'alpha.png'
162
+
163
+ matmodel = get_AEMatter_model(
164
+ path_model_checkpoint='/home/asd/Desktop/AEM_RWA.ckpt')
165
+
166
+ # matmodel = AEMatter()
167
+ # matmodel.load_state_dict(
168
+ # torch.load('/home/asd/Desktop/AEM_RWA.ckpt',
169
+ # map_location='cpu')['model'])
170
+
171
+ # matmodel = matmodel.cuda()
172
+ # matmodel.eval()
173
+
174
+ rawimg = pimgs
175
+ trimap = ptrimap
176
+ rawimg = cv2.imread(rawimg, cv2.IMREAD_COLOR)
177
+ trimap = cv2.imread(trimap, cv2.IMREAD_GRAYSCALE)
178
+ trimap_nonp = trimap.copy()
179
+ h, w, c = rawimg.shape
180
+ nonph, nonpw, _ = rawimg.shape
181
+ newh = (((h - 1) // 32) + 1) * 32
182
+ neww = (((w - 1) // 32) + 1) * 32
183
+ padh = newh - h
184
+ padh1 = int(padh / 2)
185
+ padh2 = padh - padh1
186
+ padw = neww - w
187
+ padw1 = int(padw / 2)
188
+ padw2 = padw - padw1
189
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
190
+ cv2.BORDER_REFLECT)
191
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
192
+ cv2.BORDER_REFLECT)
193
+ h_pad, w_pad, _ = rawimg_pad.shape
194
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
195
+ tritemp[:, :, 0] = (trimap_pad == 0)
196
+ tritemp[:, :, 1] = (trimap_pad == 128)
197
+ tritemp[:, :, 2] = (trimap_pad == 255)
198
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
199
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
200
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
201
+ img = np.array(img, np.float32)
202
+ img = img / 255.
203
+ img = torch.from_numpy(img).cuda()
204
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
205
+ with torch.no_grad():
206
+ pred = matmodel(img, tritempimgs)
207
+ pred = pred.detach().cpu().numpy()[0]
208
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
209
+ preda = pred[
210
+ 0:1,
211
+ ] * 255
212
+ preda = np.transpose(preda, (1, 2, 0))
213
+ preda = preda * (trimap_nonp[:, :, None]
214
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
215
+ preda = np.array(preda, np.uint8)
216
+ cv2.imwrite(p_outs, preda)
217
+
218
+
219
+ #!/usr/bin/python3
220
+ class WindowAttention(nn.Module):
221
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
222
+ It supports both of shifted and non-shifted window.
223
+ Args:
224
+ dim (int): Number of input channels.
225
+ window_size (tuple[int]): The height and width of the window.
226
+ num_heads (int): Number of attention heads.
227
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
228
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
229
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
230
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
231
+ """
232
+
233
+ def __init__(self,
234
+ dim,
235
+ window_size,
236
+ num_heads,
237
+ qkv_bias=True,
238
+ qk_scale=None,
239
+ attn_drop=0.,
240
+ proj_drop=0.):
241
+
242
+ super().__init__()
243
+ self.dim = dim
244
+ self.window_size = window_size # Wh, Ww
245
+ self.num_heads = num_heads
246
+ head_dim = dim // num_heads
247
+ self.scale = qk_scale or head_dim**-0.5
248
+
249
+ # define a parameter table of relative position bias
250
+ self.relative_position_bias_table = nn.Parameter(
251
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
252
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
253
+
254
+ # get pair-wise relative position index for each token inside the window
255
+ coords_h = torch.arange(self.window_size[0])
256
+ coords_w = torch.arange(self.window_size[1])
257
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
258
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
259
+ relative_coords = coords_flatten[:, :,
260
+ None] - coords_flatten[:,
261
+ None, :] # 2, Wh*Ww, Wh*Ww
262
+ relative_coords = relative_coords.permute(
263
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
264
+ relative_coords[:, :,
265
+ 0] += self.window_size[0] - 1 # shift to start from 0
266
+ relative_coords[:, :, 1] += self.window_size[1] - 1
267
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
268
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
269
+ self.register_buffer("relative_position_index",
270
+ relative_position_index)
271
+
272
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
273
+ self.attn_drop = nn.Dropout(attn_drop)
274
+ self.proj = nn.Linear(dim, dim)
275
+ self.proj_drop = nn.Dropout(proj_drop)
276
+
277
+ trunc_normal_(self.relative_position_bias_table, std=.02)
278
+ self.softmax = nn.Softmax(dim=-1)
279
+
280
+ def forward(self, x, mask=None):
281
+ """ Forward function.
282
+ Args:
283
+ x: input features with shape of (num_windows*B, N, C)
284
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
285
+ """
286
+ B_, N, C = x.shape
287
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
288
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
289
+ q, k, v = qkv[0], qkv[1], qkv[
290
+ 2] # make torchscript happy (cannot use tensor as tuple)
291
+
292
+ q = q * self.scale
293
+ attn = (q @ k.transpose(-2, -1))
294
+
295
+ relative_position_bias = self.relative_position_bias_table[
296
+ self.relative_position_index.view(-1)].view(
297
+ self.window_size[0] * self.window_size[1],
298
+ self.window_size[0] * self.window_size[1],
299
+ -1) # Wh*Ww,Wh*Ww,nH
300
+ relative_position_bias = relative_position_bias.permute(
301
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
302
+ attn = attn + relative_position_bias.unsqueeze(0)
303
+
304
+ if mask is not None:
305
+ nW = mask.shape[0]
306
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
307
+ N) + mask.unsqueeze(1).unsqueeze(0)
308
+ attn = attn.view(-1, self.num_heads, N, N)
309
+ attn = self.softmax(attn)
310
+ else:
311
+ attn = self.softmax(attn)
312
+
313
+ attn = self.attn_drop(attn)
314
+
315
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
316
+ x = self.proj(x)
317
+ x = self.proj_drop(x)
318
+ return x
319
+
320
+
321
+ class SwinTransformerBlock(nn.Module):
322
+ """ Swin Transformer Block.
323
+ Args:
324
+ dim (int): Number of input channels.
325
+ num_heads (int): Number of attention heads.
326
+ window_size (int): Window size.
327
+ shift_size (int): Shift size for SW-MSA.
328
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
329
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
330
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
331
+ drop (float, optional): Dropout rate. Default: 0.0
332
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
333
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
334
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
335
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
336
+ """
337
+
338
+ def __init__(self,
339
+ dim,
340
+ num_heads,
341
+ window_size=7,
342
+ shift_size=0,
343
+ mlp_ratio=4.,
344
+ qkv_bias=True,
345
+ qk_scale=None,
346
+ drop=0.,
347
+ attn_drop=0.,
348
+ drop_path=0.,
349
+ act_layer=nn.GELU,
350
+ norm_layer=nn.LayerNorm):
351
+ super().__init__()
352
+ self.dim = dim
353
+ self.num_heads = num_heads
354
+ self.window_size = window_size
355
+ self.shift_size = shift_size
356
+ self.mlp_ratio = mlp_ratio
357
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
358
+
359
+ self.norm1 = norm_layer(dim)
360
+ self.attn = WindowAttention(dim,
361
+ window_size=to_2tuple(self.window_size),
362
+ num_heads=num_heads,
363
+ qkv_bias=qkv_bias,
364
+ qk_scale=qk_scale,
365
+ attn_drop=attn_drop,
366
+ proj_drop=drop)
367
+
368
+ self.drop_path = DropPath(
369
+ drop_path) if drop_path > 0. else nn.Identity()
370
+ self.norm2 = norm_layer(dim)
371
+ mlp_hidden_dim = int(dim * mlp_ratio)
372
+ self.mlp = Mlp(in_features=dim,
373
+ hidden_features=mlp_hidden_dim,
374
+ act_layer=act_layer,
375
+ drop=drop)
376
+
377
+ self.H = None
378
+ self.W = None
379
+
380
+ def forward(self, x, mask_matrix):
381
+ """ Forward function.
382
+ Args:
383
+ x: Input feature, tensor size (B, H*W, C).
384
+ H, W: Spatial resolution of the input feature.
385
+ mask_matrix: Attention mask for cyclic shift.
386
+ """
387
+ B, L, C = x.shape
388
+ H, W = self.H, self.W
389
+ assert L == H * W, "input feature has wrong size"
390
+
391
+ shortcut = x
392
+ x = self.norm1(x)
393
+ x = x.view(B, H, W, C)
394
+
395
+ # pad feature maps to multiples of window size
396
+ pad_l = pad_t = 0
397
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
398
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
399
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
400
+ _, Hp, Wp, _ = x.shape
401
+
402
+ # cyclic shift
403
+ if self.shift_size > 0:
404
+ shifted_x = torch.roll(x,
405
+ shifts=(-self.shift_size, -self.shift_size),
406
+ dims=(1, 2))
407
+ attn_mask = mask_matrix
408
+ else:
409
+ shifted_x = x
410
+ attn_mask = None
411
+
412
+ # partition windows
413
+ x_windows = window_partition(
414
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
415
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
416
+ C) # nW*B, window_size*window_size, C
417
+
418
+ # W-MSA/SW-MSA
419
+ attn_windows = self.attn(
420
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
421
+
422
+ # merge windows
423
+ attn_windows = attn_windows.view(-1, self.window_size,
424
+ self.window_size, C)
425
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
426
+ Wp) # B H' W' C
427
+
428
+ # reverse cyclic shift
429
+ if self.shift_size > 0:
430
+ x = torch.roll(shifted_x,
431
+ shifts=(self.shift_size, self.shift_size),
432
+ dims=(1, 2))
433
+ else:
434
+ x = shifted_x
435
+
436
+ if pad_r > 0 or pad_b > 0:
437
+ x = x[:, :H, :W, :].contiguous()
438
+
439
+ x = x.view(B, H * W, C)
440
+
441
+ # FFN
442
+ x = shortcut + self.drop_path(x)
443
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
444
+
445
+ return x
446
+
447
+
448
+ class PatchMerging(nn.Module):
449
+ """ Patch Merging Layer
450
+ Args:
451
+ dim (int): Number of input channels.
452
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
453
+ """
454
+
455
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
456
+ super().__init__()
457
+ self.dim = dim
458
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
459
+ self.norm = norm_layer(4 * dim)
460
+
461
+ def forward(self, x, H, W):
462
+ """ Forward function.
463
+ Args:
464
+ x: Input feature, tensor size (B, H*W, C).
465
+ H, W: Spatial resolution of the input feature.
466
+ """
467
+ B, L, C = x.shape
468
+ assert L == H * W, "input feature has wrong size"
469
+
470
+ x = x.view(B, H, W, C)
471
+
472
+ # padding
473
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
474
+ if pad_input:
475
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
476
+
477
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
478
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
479
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
480
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
481
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
482
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
483
+
484
+ x = self.norm(x)
485
+ x = self.reduction(x)
486
+
487
+ return x
488
+
489
+
490
+ class BasicLayer(nn.Module):
491
+ """ A basic Swin Transformer layer for one stage.
492
+ Args:
493
+ dim (int): Number of feature channels
494
+ depth (int): Depths of this stage.
495
+ num_heads (int): Number of attention head.
496
+ window_size (int): Local window size. Default: 7.
497
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
498
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
499
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
500
+ drop (float, optional): Dropout rate. Default: 0.0
501
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
502
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
503
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
504
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
505
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
506
+ """
507
+
508
+ def __init__(self,
509
+ dim,
510
+ depth,
511
+ num_heads,
512
+ window_size=7,
513
+ mlp_ratio=4.,
514
+ qkv_bias=True,
515
+ qk_scale=None,
516
+ drop=0.,
517
+ attn_drop=0.,
518
+ drop_path=0.,
519
+ norm_layer=nn.LayerNorm,
520
+ downsample=None,
521
+ use_checkpoint=False):
522
+
523
+ super().__init__()
524
+ self.window_size = window_size
525
+ self.shift_size = window_size // 2
526
+ self.depth = depth
527
+ self.use_checkpoint = use_checkpoint
528
+
529
+ # build blocks
530
+ self.blocks = nn.ModuleList([
531
+ SwinTransformerBlock(dim=dim,
532
+ num_heads=num_heads,
533
+ window_size=window_size,
534
+ shift_size=0 if
535
+ (i % 2 == 0) else window_size // 2,
536
+ mlp_ratio=mlp_ratio,
537
+ qkv_bias=qkv_bias,
538
+ qk_scale=qk_scale,
539
+ drop=drop,
540
+ attn_drop=attn_drop,
541
+ drop_path=drop_path[i] if isinstance(
542
+ drop_path, list) else drop_path,
543
+ norm_layer=norm_layer) for i in range(depth)
544
+ ])
545
+
546
+ # patch merging layer
547
+ if downsample is not None:
548
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
549
+ else:
550
+ self.downsample = None
551
+
552
+ def forward(self, x, H, W):
553
+ """ Forward function.
554
+ Args:
555
+ x: Input feature, tensor size (B, H*W, C).
556
+ H, W: Spatial resolution of the input feature.
557
+ """
558
+ # print(x.shape,H,W)
559
+ # calculate attention mask for SW-MSA
560
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
561
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
562
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
563
+ h_slices = (slice(0, -self.window_size),
564
+ slice(-self.window_size,
565
+ -self.shift_size), slice(-self.shift_size, None))
566
+ w_slices = (slice(0, -self.window_size),
567
+ slice(-self.window_size,
568
+ -self.shift_size), slice(-self.shift_size, None))
569
+ cnt = 0
570
+ for h in h_slices:
571
+ for w in w_slices:
572
+ img_mask[:, h, w, :] = cnt
573
+ cnt += 1
574
+
575
+ mask_windows = window_partition(
576
+ img_mask, self.window_size) # nW, window_size, window_size, 1
577
+
578
+ mask_windows = mask_windows.view(-1,
579
+ self.window_size * self.window_size)
580
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
581
+ 2) # nW, ww window_size*window_size
582
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
583
+ float(-100.0)).masked_fill(
584
+ attn_mask == 0, float(0.0))
585
+
586
+ for blk in self.blocks:
587
+ blk.H, blk.W = H, W
588
+ if self.use_checkpoint:
589
+ x = checkpoint.checkpoint(blk, x, attn_mask)
590
+ else:
591
+ x = blk(x, attn_mask)
592
+
593
+ if self.downsample is not None:
594
+ x_down = self.downsample(x, H, W)
595
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
596
+ return x, H, W, x_down, Wh, Ww
597
+ else:
598
+ return x, H, W, x, H, W
599
+
600
+
601
+ class PatchEmbed(nn.Module):
602
+ """ Image to Patch Embedding
603
+ Args:
604
+ patch_size (int): Patch token size. Default: 4.
605
+ in_chans (int): Number of input image channels. Default: 3.
606
+ embed_dim (int): Number of linear projection output channels. Default: 96.
607
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
608
+ """
609
+
610
+ def __init__(self,
611
+ patch_size=4,
612
+ in_chans=3,
613
+ embed_dim=96,
614
+ norm_layer=None):
615
+
616
+ super().__init__()
617
+ patch_size = to_2tuple(patch_size)
618
+ self.patch_size = patch_size
619
+
620
+ self.in_chans = in_chans
621
+ self.embed_dim = embed_dim
622
+
623
+ self.proj = nn.Conv2d(in_chans,
624
+ embed_dim,
625
+ kernel_size=patch_size,
626
+ stride=patch_size)
627
+ if norm_layer is not None:
628
+ self.norm = norm_layer(embed_dim)
629
+ else:
630
+ self.norm = None
631
+
632
+ def forward(self, x):
633
+ """Forward function."""
634
+ # padding
635
+ _, _, H, W = x.size()
636
+ if W % self.patch_size[1] != 0:
637
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
638
+ if H % self.patch_size[0] != 0:
639
+ x = F.pad(x,
640
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
641
+
642
+ x = self.proj(x) # B C Wh Ww
643
+ if self.norm is not None:
644
+ Wh, Ww = x.size(2), x.size(3)
645
+ x = x.flatten(2).transpose(1, 2)
646
+ x = self.norm(x)
647
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
648
+
649
+ return x
650
+
651
+
652
+ class SwinTransformer(nn.Module):
653
+ """ Swin Transformer backbone.
654
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
655
+ https://arxiv.org/pdf/2103.14030
656
+ Args:
657
+ pretrain_img_size (int): Input image size for training the pretrained model,
658
+ used in absolute postion embedding. Default 224.
659
+ patch_size (int | tuple(int)): Patch size. Default: 4.
660
+ in_chans (int): Number of input image channels. Default: 3.
661
+ embed_dim (int): Number of linear projection output channels. Default: 96.
662
+ depths (tuple[int]): Depths of each Swin Transformer stage.
663
+ num_heads (tuple[int]): Number of attention head of each stage.
664
+ window_size (int): Window size. Default: 7.
665
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
666
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
667
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
668
+ drop_rate (float): Dropout rate.
669
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
670
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
671
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
672
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
673
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
674
+ out_indices (Sequence[int]): Output from which stages.
675
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
676
+ -1 means not freezing any parameters.
677
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
678
+ """
679
+
680
+ def __init__(self,
681
+ pretrain_img_size=224,
682
+ patch_size=4,
683
+ in_chans=3,
684
+ embed_dim=96,
685
+ depths=[2, 2, 6, 2],
686
+ num_heads=[3, 6, 12, 24],
687
+ window_size=7,
688
+ mlp_ratio=4.,
689
+ qkv_bias=True,
690
+ qk_scale=None,
691
+ drop_rate=0.,
692
+ attn_drop_rate=0.,
693
+ drop_path_rate=0.2,
694
+ norm_layer=nn.LayerNorm,
695
+ ape=False,
696
+ patch_norm=True,
697
+ out_indices=(0, 1, 2, 3),
698
+ frozen_stages=-1,
699
+ use_checkpoint=False):
700
+
701
+ super().__init__()
702
+
703
+ self.pretrain_img_size = pretrain_img_size
704
+ self.num_layers = len(depths)
705
+ self.embed_dim = embed_dim
706
+ self.ape = ape
707
+ self.patch_norm = patch_norm
708
+ self.out_indices = out_indices
709
+ self.frozen_stages = frozen_stages
710
+
711
+ # split image into non-overlapping patches
712
+ self.patch_embed = PatchEmbed(
713
+ patch_size=patch_size,
714
+ in_chans=in_chans,
715
+ embed_dim=embed_dim,
716
+ norm_layer=norm_layer if self.patch_norm else None)
717
+
718
+ # absolute position embedding
719
+ if self.ape:
720
+ pretrain_img_size = to_2tuple(pretrain_img_size)
721
+ patch_size = to_2tuple(patch_size)
722
+ patches_resolution = [
723
+ pretrain_img_size[0] // patch_size[0],
724
+ pretrain_img_size[1] // patch_size[1]
725
+ ]
726
+
727
+ self.absolute_pos_embed = nn.Parameter(
728
+ torch.zeros(1, embed_dim, patches_resolution[0],
729
+ patches_resolution[1]))
730
+ trunc_normal_(self.absolute_pos_embed, std=.02)
731
+
732
+ self.pos_drop = nn.Dropout(p=drop_rate)
733
+
734
+ # stochastic depth
735
+ dpr = [
736
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
737
+ ] # stochastic depth decay rule
738
+
739
+ # build layers
740
+ self.layers = nn.ModuleList()
741
+ for i_layer in range(self.num_layers):
742
+ layer = BasicLayer(
743
+ dim=int(embed_dim * 2**i_layer),
744
+ depth=depths[i_layer],
745
+ num_heads=num_heads[i_layer],
746
+ window_size=window_size,
747
+ mlp_ratio=mlp_ratio,
748
+ qkv_bias=qkv_bias,
749
+ qk_scale=qk_scale,
750
+ drop=drop_rate,
751
+ attn_drop=attn_drop_rate,
752
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
753
+ norm_layer=norm_layer,
754
+ downsample=PatchMerging if
755
+ (i_layer < self.num_layers - 1) else None,
756
+ use_checkpoint=use_checkpoint)
757
+ self.layers.append(layer)
758
+
759
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
760
+ self.num_features = num_features
761
+
762
+ # add a norm layer for each output
763
+ for i_layer in out_indices:
764
+ layer = norm_layer(num_features[i_layer])
765
+ layer_name = f'norm{i_layer}'
766
+ self.add_module(layer_name, layer)
767
+
768
+ self._freeze_stages()
769
+
770
+ def _freeze_stages(self):
771
+ if self.frozen_stages >= 0:
772
+ self.patch_embed.eval()
773
+ for param in self.patch_embed.parameters():
774
+ param.requires_grad = False
775
+
776
+ if self.frozen_stages >= 1 and self.ape:
777
+ self.absolute_pos_embed.requires_grad = False
778
+
779
+ if self.frozen_stages >= 2:
780
+ self.pos_drop.eval()
781
+ for i in range(0, self.frozen_stages - 1):
782
+ m = self.layers[i]
783
+ m.eval()
784
+ for param in m.parameters():
785
+ param.requires_grad = False
786
+
787
+ def init_weights(self, pretrained=None):
788
+ """Initialize the weights in backbone.
789
+ Args:
790
+ pretrained (str, optional): Path to pre-trained weights.
791
+ Defaults to None.
792
+ """
793
+
794
+ def forward(self, x):
795
+ """Forward function."""
796
+ x = self.patch_embed(x)
797
+
798
+ Wh, Ww = x.size(2), x.size(3)
799
+ if self.ape:
800
+ # interpolate the position embedding to the corresponding size
801
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
802
+ size=(Wh, Ww),
803
+ mode='bicubic')
804
+ x = (x + absolute_pos_embed).flatten(2).transpose(1,
805
+ 2) # B Wh*Ww C
806
+ else:
807
+ x = x.flatten(2).transpose(1, 2)
808
+ x = self.pos_drop(x)
809
+
810
+ outs = []
811
+ for i in range(self.num_layers):
812
+ layer = self.layers[i]
813
+
814
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
815
+
816
+ if i in self.out_indices:
817
+ norm_layer = getattr(self, f'norm{i}')
818
+ x_out = norm_layer(x_out)
819
+
820
+ out = x_out.view(-1, H, W,
821
+ self.num_features[i]).permute(0, 3, 1,
822
+ 2).contiguous()
823
+ outs.append(out)
824
+
825
+ return tuple(outs)
826
+
827
+ def train(self, mode=True):
828
+ """Convert the model into training mode while keep layers freezed."""
829
+ super(SwinTransformer, self).train(mode)
830
+ self._freeze_stages()
831
+
832
+
833
+ class Mlp(nn.Module):
834
+ """ Multilayer perceptron."""
835
+
836
+ def __init__(self,
837
+ in_features,
838
+ hidden_features=None,
839
+ out_features=None,
840
+ act_layer=nn.GELU,
841
+ drop=0.):
842
+ super().__init__()
843
+ out_features = out_features or in_features
844
+ hidden_features = hidden_features or in_features
845
+ self.fc1 = nn.Linear(in_features, hidden_features)
846
+ self.act = act_layer()
847
+ self.fc2 = nn.Linear(hidden_features, out_features)
848
+ self.drop = nn.Dropout(drop)
849
+
850
+ def forward(self, x):
851
+ x = self.fc1(x)
852
+ x = self.act(x)
853
+ x = self.drop(x)
854
+ x = self.fc2(x)
855
+ x = self.drop(x)
856
+ return x
857
+
858
+
859
+ class ResBlock(nn.Module):
860
+
861
+ def __init__(self, inc, midc):
862
+ super(ResBlock, self).__init__()
863
+ self.conv1 = nn.Conv2d(inc,
864
+ midc,
865
+ kernel_size=1,
866
+ stride=1,
867
+ padding=0,
868
+ bias=True)
869
+ self.gn1 = nn.GroupNorm(16, midc)
870
+ self.conv2 = nn.Conv2d(midc,
871
+ midc,
872
+ kernel_size=3,
873
+ stride=1,
874
+ padding=1,
875
+ bias=True)
876
+ self.gn2 = nn.GroupNorm(16, midc)
877
+ self.conv3 = nn.Conv2d(midc,
878
+ inc,
879
+ kernel_size=1,
880
+ stride=1,
881
+ padding=0,
882
+ bias=True)
883
+ self.relu = nn.LeakyReLU(0.1)
884
+
885
+ def forward(self, x):
886
+ x_ = x
887
+ x = self.conv1(x)
888
+ x = self.gn1(x)
889
+ x = self.relu(x)
890
+ x = self.conv2(x)
891
+ x = self.gn2(x)
892
+ x = self.relu(x)
893
+ x = self.conv3(x)
894
+ x = x + x_
895
+ x = self.relu(x)
896
+ return x
897
+
898
+
899
+ class AEALblock(nn.Module):
900
+
901
+ def __init__(self,
902
+ d_model,
903
+ nhead,
904
+ dim_feedforward=512,
905
+ dropout=0.0,
906
+ layer_norm_eps=1e-5,
907
+ batch_first=True,
908
+ norm_first=False,
909
+ width=5):
910
+ super(AEALblock, self).__init__()
911
+ self.self_attn2 = nn.MultiheadAttention(d_model // 2,
912
+ nhead // 2,
913
+ dropout=dropout,
914
+ batch_first=batch_first)
915
+ self.self_attn1 = nn.MultiheadAttention(d_model // 2,
916
+ nhead // 2,
917
+ dropout=dropout,
918
+ batch_first=batch_first)
919
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
920
+ self.dropout = nn.Dropout(dropout)
921
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
922
+ self.norm_first = norm_first
923
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
924
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
925
+ self.dropout1 = nn.Dropout(dropout)
926
+ self.dropout2 = nn.Dropout(dropout)
927
+ self.activation = nn.ReLU()
928
+ self.width = width
929
+ self.trans = nn.Sequential(
930
+ nn.Conv2d(d_model + 512, d_model // 2, 1, 1, 0),
931
+ ResBlock(d_model // 2, d_model // 4),
932
+ nn.Conv2d(d_model // 2, d_model, 1, 1, 0))
933
+ self.gamma = nn.Parameter(torch.zeros(1))
934
+
935
+ def forward(
936
+ self,
937
+ src,
938
+ feats,
939
+ ):
940
+ src = self.gamma * self.trans(torch.cat([src, feats], 1)) + src
941
+ b, c, h, w = src.shape
942
+ x1 = src[:, 0:c // 2]
943
+ x1_ = rearrange(x1, 'b c (h1 h2) w -> b c h1 h2 w', h2=self.width)
944
+ x1_ = rearrange(x1_, 'b c h1 h2 w -> (b h1) (h2 w) c')
945
+ x2 = src[:, c // 2:]
946
+ x2_ = rearrange(x2, 'b c h (w1 w2) -> b c h w1 w2', w2=self.width)
947
+ x2_ = rearrange(x2_, 'b c h w1 w2 -> (b w1) (h w2) c')
948
+ x = rearrange(src, 'b c h w-> b (h w) c')
949
+ x = self.norm1(x + self._sa_block(x1_, x2_, h, w))
950
+ x = self.norm2(x + self._ff_block(x))
951
+ x = rearrange(x, 'b (h w) c->b c h w', h=h, w=w)
952
+ return x
953
+
954
+ def _sa_block(self, x1, x2, h, w):
955
+ x1 = self.self_attn1(x1,
956
+ x1,
957
+ x1,
958
+ attn_mask=None,
959
+ key_padding_mask=None,
960
+ need_weights=False)[0]
961
+
962
+ x2 = self.self_attn2(x2,
963
+ x2,
964
+ x2,
965
+ attn_mask=None,
966
+ key_padding_mask=None,
967
+ need_weights=False)[0]
968
+
969
+ x1 = rearrange(x1,
970
+ '(b h1) (h2 w) c-> b (h1 h2 w) c',
971
+ h2=self.width,
972
+ h1=h // self.width)
973
+ x2 = rearrange(x2,
974
+ ' (b w1) (h w2) c-> b (h w1 w2) c',
975
+ w2=self.width,
976
+ w1=w // self.width)
977
+ x = torch.cat([x1, x2], dim=2)
978
+ return self.dropout1(x)
979
+
980
+ def _ff_block(self, x):
981
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
982
+ return self.dropout2(x)
983
+
984
+
985
+ class AEMatter(nn.Module):
986
+
987
+ def __init__(self):
988
+ super(AEMatter, self).__init__()
989
+ trans = SwinTransformer(pretrain_img_size=224,
990
+ embed_dim=96,
991
+ depths=[2, 2, 6, 2],
992
+ num_heads=[3, 6, 12, 24],
993
+ window_size=7,
994
+ ape=False,
995
+ drop_path_rate=0.2,
996
+ patch_norm=True,
997
+ use_checkpoint=False)
998
+
999
+ # trans.load_state_dict(torch.load(
1000
+ # '/home/asd/Desktop/swin_tiny_patch4_window7_224.pth',
1001
+ # map_location="cpu")["model"],
1002
+ # strict=False)
1003
+
1004
+ trans.patch_embed.proj = nn.Conv2d(64, 96, 3, 2, 1)
1005
+
1006
+ self.start_conv0 = nn.Sequential(nn.Conv2d(6, 48, 3, 1, 1),
1007
+ nn.PReLU(48))
1008
+
1009
+ self.start_conv = nn.Sequential(nn.Conv2d(48, 64, 3, 2,
1010
+ 1), nn.PReLU(64),
1011
+ nn.Conv2d(64, 64, 3, 1, 1),
1012
+ nn.PReLU(64))
1013
+
1014
+ self.trans = trans
1015
+ self.conv1 = nn.Sequential(
1016
+ nn.Conv2d(in_channels=640 + 768,
1017
+ out_channels=256,
1018
+ kernel_size=1,
1019
+ stride=1,
1020
+ padding=0,
1021
+ bias=True))
1022
+ self.conv2 = nn.Sequential(
1023
+ nn.Conv2d(in_channels=256 + 384,
1024
+ out_channels=256,
1025
+ kernel_size=1,
1026
+ stride=1,
1027
+ padding=0,
1028
+ bias=True), )
1029
+ self.conv3 = nn.Sequential(
1030
+ nn.Conv2d(in_channels=256 + 192,
1031
+ out_channels=192,
1032
+ kernel_size=1,
1033
+ stride=1,
1034
+ padding=0,
1035
+ bias=True), )
1036
+ self.conv4 = nn.Sequential(
1037
+ nn.Conv2d(in_channels=192 + 96,
1038
+ out_channels=128,
1039
+ kernel_size=1,
1040
+ stride=1,
1041
+ padding=0,
1042
+ bias=True), )
1043
+ self.ctran0 = BasicLayer(256, 3, 8, 7, drop_path=0.09)
1044
+ self.ctran1 = BasicLayer(256, 3, 8, 7, drop_path=0.07)
1045
+ self.ctran2 = BasicLayer(192, 3, 6, 7, drop_path=0.05)
1046
+ self.ctran3 = BasicLayer(128, 3, 4, 7, drop_path=0.03)
1047
+ self.conv5 = nn.Sequential(
1048
+ nn.Conv2d(in_channels=192,
1049
+ out_channels=64,
1050
+ kernel_size=3,
1051
+ stride=1,
1052
+ padding=1,
1053
+ bias=True), nn.PReLU(64),
1054
+ nn.Conv2d(in_channels=64,
1055
+ out_channels=64,
1056
+ kernel_size=3,
1057
+ stride=1,
1058
+ padding=1,
1059
+ bias=True), nn.PReLU(64),
1060
+ nn.Conv2d(in_channels=64,
1061
+ out_channels=48,
1062
+ kernel_size=3,
1063
+ stride=1,
1064
+ padding=1,
1065
+ bias=True), nn.PReLU(48))
1066
+ self.convo = nn.Sequential(
1067
+ nn.Conv2d(in_channels=48 + 48 + 6,
1068
+ out_channels=32,
1069
+ kernel_size=3,
1070
+ stride=1,
1071
+ padding=1,
1072
+ bias=True), nn.PReLU(32),
1073
+ nn.Conv2d(in_channels=32,
1074
+ out_channels=32,
1075
+ kernel_size=3,
1076
+ stride=1,
1077
+ padding=1,
1078
+ bias=True), nn.PReLU(32),
1079
+ nn.Conv2d(in_channels=32,
1080
+ out_channels=1,
1081
+ kernel_size=3,
1082
+ stride=1,
1083
+ padding=1,
1084
+ bias=True))
1085
+ self.up = nn.Upsample(scale_factor=2,
1086
+ mode='bilinear',
1087
+ align_corners=False)
1088
+ self.upn = nn.Upsample(scale_factor=2, mode='nearest')
1089
+ self.apptrans = nn.Sequential(
1090
+ nn.Conv2d(256 + 384, 256, 1, 1, bias=True), ResBlock(256, 128),
1091
+ ResBlock(256, 128), nn.Conv2d(256, 512, 2, 2, bias=True),
1092
+ ResBlock(512, 128))
1093
+ self.emb = nn.Sequential(nn.Conv2d(768, 640, 1, 1, 0),
1094
+ ResBlock(640, 160))
1095
+ self.embdp = nn.Sequential(nn.Conv2d(640, 640, 1, 1, 0))
1096
+ self.h2l = nn.Conv2d(768, 256, 1, 1, 0)
1097
+ self.width = 5
1098
+ self.trans1 = AEALblock(d_model=640,
1099
+ nhead=20,
1100
+ dim_feedforward=2048,
1101
+ dropout=0.2,
1102
+ width=self.width)
1103
+ self.trans2 = AEALblock(d_model=640,
1104
+ nhead=20,
1105
+ dim_feedforward=2048,
1106
+ dropout=0.2,
1107
+ width=self.width)
1108
+ self.trans3 = AEALblock(d_model=640,
1109
+ nhead=20,
1110
+ dim_feedforward=2048,
1111
+ dropout=0.2,
1112
+ width=self.width)
1113
+
1114
+ def aeal(self, x, sem):
1115
+ xe = self.emb(x)
1116
+ x_ = xe
1117
+ x_ = self.embdp(x_)
1118
+ b, c, h1, w1 = x_.shape
1119
+ bnew_ph = int(np.ceil(h1 / self.width) * self.width) - h1
1120
+ bnew_pw = int(np.ceil(w1 / self.width) * self.width) - w1
1121
+ newph1 = bnew_ph // 2
1122
+ newph2 = bnew_ph - newph1
1123
+ newpw1 = bnew_pw // 2
1124
+ newpw2 = bnew_pw - newpw1
1125
+ x_ = F.pad(x_, (newpw1, newpw2, newph1, newph2))
1126
+ sem = F.pad(sem, (newpw1, newpw2, newph1, newph2))
1127
+ x_ = self.trans1(x_, sem)
1128
+ x_ = self.trans2(x_, sem)
1129
+ x_ = self.trans3(x_, sem)
1130
+ x_ = x_[:, :, newph1:h1 + newph1, newpw1:w1 + newpw1]
1131
+ return x_
1132
+
1133
+ def forward(self, x, y):
1134
+ inputs = torch.cat((x, y), 1)
1135
+ x = self.start_conv0(inputs)
1136
+ x_ = self.start_conv(x)
1137
+ x1, x2, x3, x4 = self.trans(x_)
1138
+ x4h = self.h2l(x4)
1139
+ x3s = self.apptrans(torch.cat([x3, self.upn(x4h)], 1))
1140
+ x4_ = self.aeal(x4, x3s)
1141
+ x4 = torch.cat((x4, x4_), 1)
1142
+ X4 = self.conv1(x4)
1143
+ wh, ww = X4.shape[2], X4.shape[3]
1144
+ X4 = rearrange(X4, 'b c h w -> b (h w) c')
1145
+ X4, _, _, _, _, _ = self.ctran0(X4, wh, ww)
1146
+ X4 = rearrange(X4, 'b (h w) c -> b c h w', h=wh, w=ww)
1147
+ X3 = self.up(X4)
1148
+ X3 = torch.cat((x3, X3), 1)
1149
+ X3 = self.conv2(X3)
1150
+ wh, ww = X3.shape[2], X3.shape[3]
1151
+ X3 = rearrange(X3, 'b c h w -> b (h w) c')
1152
+ X3, _, _, _, _, _ = self.ctran1(X3, wh, ww)
1153
+ X3 = rearrange(X3, 'b (h w) c -> b c h w', h=wh, w=ww)
1154
+ X2 = self.up(X3)
1155
+ X2 = torch.cat((x2, X2), 1)
1156
+ X2 = self.conv3(X2)
1157
+ wh, ww = X2.shape[2], X2.shape[3]
1158
+ X2 = rearrange(X2, 'b c h w -> b (h w) c')
1159
+ X2, _, _, _, _, _ = self.ctran2(X2, wh, ww)
1160
+ X2 = rearrange(X2, 'b (h w) c -> b c h w', h=wh, w=ww)
1161
+ X1 = self.up(X2)
1162
+ X1 = torch.cat((x1, X1), 1)
1163
+ X1 = self.conv4(X1)
1164
+ wh, ww = X1.shape[2], X1.shape[3]
1165
+ X1 = rearrange(X1, 'b c h w -> b (h w) c')
1166
+ X1, _, _, _, _, _ = self.ctran3(X1, wh, ww)
1167
+ X1 = rearrange(X1, 'b (h w) c -> b c h w', h=wh, w=ww)
1168
+ X0 = self.up(X1)
1169
+ X0 = torch.cat((x_, X0), 1)
1170
+ X0 = self.conv5(X0)
1171
+ X = self.up(X0)
1172
+ X = torch.cat((inputs, x, X), 1)
1173
+ alpha = self.convo(X)
1174
+ alpha = torch.clamp(alpha, min=0, max=1)
1175
+ return alpha
1176
+
1177
+
1178
+ class load_AEMatter_Model:
1179
+
1180
+ def __init__(self):
1181
+ pass
1182
+
1183
+ @classmethod
1184
+ def INPUT_TYPES(s):
1185
+ return {
1186
+ "required": {},
1187
+ }
1188
+
1189
+ RETURN_TYPES = ("AEMatter_Model", )
1190
+ FUNCTION = "test"
1191
+ CATEGORY = "AEMatter"
1192
+
1193
+ def test(self):
1194
+ return (get_AEMatter_model(get_model_path()), )
1195
+
1196
+
1197
+ class run_AEMatter_inference:
1198
+
1199
+ def __init__(self):
1200
+ pass
1201
+
1202
+ @classmethod
1203
+ def INPUT_TYPES(s):
1204
+ return {
1205
+ "required": {
1206
+ "image": ("IMAGE", ),
1207
+ "trimap": ("MASK", ),
1208
+ "AEMatter_Model": ("AEMatter_Model", ),
1209
+ },
1210
+ }
1211
+
1212
+ RETURN_TYPES = ("MASK", )
1213
+ FUNCTION = "test"
1214
+ CATEGORY = "AEMatter"
1215
+
1216
+ def test(
1217
+ self,
1218
+ image,
1219
+ trimap,
1220
+ AEMatter_Model,
1221
+ ):
1222
+
1223
+ ret = []
1224
+ batch_size = image.shape[0]
1225
+
1226
+ for i in range(batch_size):
1227
+ tmp_i = from_torch_image(image[i])
1228
+ tmp_m = from_torch_image(trimap[i])
1229
+ tmp = do_infer(tmp_i, tmp_m, AEMatter_Model)
1230
+ ret.append(tmp)
1231
+
1232
+ ret = to_torch_image(np.array(ret))
1233
+ ret = ret.squeeze(-1)
1234
+ print(ret.shape)
1235
+
1236
+ return ret
1237
+
1238
+
1239
+ #!/usr/bin/python3
1240
+ NODE_CLASS_MAPPINGS = {
1241
+ 'load_AEMatter_Model': load_AEMatter_Model,
1242
+ 'run_AEMatter_inference': run_AEMatter_inference,
1243
+ }
1244
+
1245
+ NODE_DISPLAY_NAME_MAPPINGS = {
1246
+ 'load_AEMatter_Model': 'load_AEMatter_Model',
1247
+ 'run_AEMatter_inference': 'run_AEMatter_inference',
1248
+ }
ComfyUI_AEMatter/README.org ADDED
@@ -0,0 +1,1357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * COMMENT SAMPLE
2
+
3
+ ** AEMatter.import.py
4
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.import.py
5
+ #+end_src
6
+
7
+ ** AEMatter.function.py
8
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
9
+ #+end_src
10
+
11
+ ** AEMatter.class.py
12
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
13
+ #+end_src
14
+
15
+ ** AEMatter.execute.py
16
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.execute.py
17
+ #+end_src
18
+
19
+ ** AEMatter.unify.sh
20
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.unify.sh
21
+ #+end_src
22
+
23
+ ** AEMatter.run.sh
24
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.run.sh
25
+ #+end_src
26
+
27
+ * Code for AEMatter inference
28
+
29
+ ** AEMatter.import.py
30
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.import.py
31
+ import cv2
32
+ import math
33
+ import numpy as np
34
+ import os
35
+ import random
36
+ import wget
37
+
38
+ import torch
39
+ import torch.nn as nn
40
+ from torch.nn import init
41
+ import torch.nn.functional as F
42
+ import torch.utils.checkpoint as checkpoint
43
+
44
+ from collections import OrderedDict
45
+ from einops import rearrange, repeat
46
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
47
+
48
+ import folder_paths
49
+ from folder_paths import models_dir
50
+ #+end_src
51
+
52
+ ** Functions to prepare directory structure and download models
53
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
54
+ def mkdir_safe(out_path):
55
+ if type(out_path) == str:
56
+ if len(out_path) > 0:
57
+ if not os.path.exists(out_path):
58
+ os.mkdir(out_path)
59
+
60
+
61
+ def get_model_path():
62
+ import folder_paths
63
+ from folder_paths import models_dir
64
+
65
+ path_file_model = models_dir
66
+ mkdir_safe(out_path=path_file_model)
67
+
68
+ path_file_model = os.path.join(path_file_model, 'AEMatter')
69
+ mkdir_safe(out_path=path_file_model)
70
+
71
+ path_file_model = os.path.join(path_file_model, 'AEM_RWA.ckpt')
72
+
73
+ return path_file_model
74
+
75
+
76
+ def download_model(path):
77
+ if not os.path.exists(path):
78
+ wget.download(
79
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/AEMatter/AEM_RWA.ckpt?download=true',
80
+ out=path)
81
+
82
+
83
+ def from_torch_image(image):
84
+ image = image.cpu().numpy() * 255.0
85
+ image = np.clip(image, 0, 255).astype(np.uint8)
86
+ return image
87
+
88
+
89
+ def to_torch_image(image):
90
+ image = image.astype(dtype=np.float32)
91
+ image /= 255.0
92
+ image = torch.from_numpy(image)
93
+ return image
94
+ #+end_src
95
+
96
+ ** AEMatter.function.py
97
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
98
+ def window_partition(x, window_size):
99
+ """
100
+ Args:
101
+ x: (B, H, W, C)
102
+ window_size (int): window size
103
+ Returns:
104
+ windows: (num_windows*B, window_size, window_size, C)
105
+ """
106
+ B, H, W, C = x.shape
107
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
108
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
109
+ return windows
110
+ #+end_src
111
+
112
+ ** AEMatter.function.py
113
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
114
+ def window_reverse(windows, window_size, H, W):
115
+ """
116
+ Args:
117
+ windows: (num_windows*B, window_size, window_size, C)
118
+ window_size (int): Window size
119
+ H (int): Height of image
120
+ W (int): Width of image
121
+ Returns:
122
+ x: (B, H, W, C)
123
+ """
124
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
125
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
126
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
127
+ return x
128
+ #+end_src
129
+
130
+ ** AEMatter.class.py
131
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
132
+ class WindowAttention(nn.Module):
133
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
134
+ It supports both of shifted and non-shifted window.
135
+ Args:
136
+ dim (int): Number of input channels.
137
+ window_size (tuple[int]): The height and width of the window.
138
+ num_heads (int): Number of attention heads.
139
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
140
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
141
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
142
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
143
+ """
144
+
145
+ def __init__(self,
146
+ dim,
147
+ window_size,
148
+ num_heads,
149
+ qkv_bias=True,
150
+ qk_scale=None,
151
+ attn_drop=0.,
152
+ proj_drop=0.):
153
+
154
+ super().__init__()
155
+ self.dim = dim
156
+ self.window_size = window_size # Wh, Ww
157
+ self.num_heads = num_heads
158
+ head_dim = dim // num_heads
159
+ self.scale = qk_scale or head_dim**-0.5
160
+
161
+ # define a parameter table of relative position bias
162
+ self.relative_position_bias_table = nn.Parameter(
163
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
164
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
165
+
166
+ # get pair-wise relative position index for each token inside the window
167
+ coords_h = torch.arange(self.window_size[0])
168
+ coords_w = torch.arange(self.window_size[1])
169
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
170
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
171
+ relative_coords = coords_flatten[:, :,
172
+ None] - coords_flatten[:,
173
+ None, :] # 2, Wh*Ww, Wh*Ww
174
+ relative_coords = relative_coords.permute(
175
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
176
+ relative_coords[:, :,
177
+ 0] += self.window_size[0] - 1 # shift to start from 0
178
+ relative_coords[:, :, 1] += self.window_size[1] - 1
179
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
180
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
181
+ self.register_buffer("relative_position_index",
182
+ relative_position_index)
183
+
184
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
185
+ self.attn_drop = nn.Dropout(attn_drop)
186
+ self.proj = nn.Linear(dim, dim)
187
+ self.proj_drop = nn.Dropout(proj_drop)
188
+
189
+ trunc_normal_(self.relative_position_bias_table, std=.02)
190
+ self.softmax = nn.Softmax(dim=-1)
191
+
192
+ def forward(self, x, mask=None):
193
+ """ Forward function.
194
+ Args:
195
+ x: input features with shape of (num_windows*B, N, C)
196
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
197
+ """
198
+ B_, N, C = x.shape
199
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
200
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
201
+ q, k, v = qkv[0], qkv[1], qkv[
202
+ 2] # make torchscript happy (cannot use tensor as tuple)
203
+
204
+ q = q * self.scale
205
+ attn = (q @ k.transpose(-2, -1))
206
+
207
+ relative_position_bias = self.relative_position_bias_table[
208
+ self.relative_position_index.view(-1)].view(
209
+ self.window_size[0] * self.window_size[1],
210
+ self.window_size[0] * self.window_size[1],
211
+ -1) # Wh*Ww,Wh*Ww,nH
212
+ relative_position_bias = relative_position_bias.permute(
213
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
214
+ attn = attn + relative_position_bias.unsqueeze(0)
215
+
216
+ if mask is not None:
217
+ nW = mask.shape[0]
218
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
219
+ N) + mask.unsqueeze(1).unsqueeze(0)
220
+ attn = attn.view(-1, self.num_heads, N, N)
221
+ attn = self.softmax(attn)
222
+ else:
223
+ attn = self.softmax(attn)
224
+
225
+ attn = self.attn_drop(attn)
226
+
227
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
228
+ x = self.proj(x)
229
+ x = self.proj_drop(x)
230
+ return x
231
+ #+end_src
232
+
233
+ ** AEMatter.class.py
234
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
235
+ class SwinTransformerBlock(nn.Module):
236
+ """ Swin Transformer Block.
237
+ Args:
238
+ dim (int): Number of input channels.
239
+ num_heads (int): Number of attention heads.
240
+ window_size (int): Window size.
241
+ shift_size (int): Shift size for SW-MSA.
242
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
243
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
244
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
245
+ drop (float, optional): Dropout rate. Default: 0.0
246
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
247
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
248
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
249
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
250
+ """
251
+
252
+ def __init__(self,
253
+ dim,
254
+ num_heads,
255
+ window_size=7,
256
+ shift_size=0,
257
+ mlp_ratio=4.,
258
+ qkv_bias=True,
259
+ qk_scale=None,
260
+ drop=0.,
261
+ attn_drop=0.,
262
+ drop_path=0.,
263
+ act_layer=nn.GELU,
264
+ norm_layer=nn.LayerNorm):
265
+ super().__init__()
266
+ self.dim = dim
267
+ self.num_heads = num_heads
268
+ self.window_size = window_size
269
+ self.shift_size = shift_size
270
+ self.mlp_ratio = mlp_ratio
271
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
272
+
273
+ self.norm1 = norm_layer(dim)
274
+ self.attn = WindowAttention(dim,
275
+ window_size=to_2tuple(self.window_size),
276
+ num_heads=num_heads,
277
+ qkv_bias=qkv_bias,
278
+ qk_scale=qk_scale,
279
+ attn_drop=attn_drop,
280
+ proj_drop=drop)
281
+
282
+ self.drop_path = DropPath(
283
+ drop_path) if drop_path > 0. else nn.Identity()
284
+ self.norm2 = norm_layer(dim)
285
+ mlp_hidden_dim = int(dim * mlp_ratio)
286
+ self.mlp = Mlp(in_features=dim,
287
+ hidden_features=mlp_hidden_dim,
288
+ act_layer=act_layer,
289
+ drop=drop)
290
+
291
+ self.H = None
292
+ self.W = None
293
+
294
+ def forward(self, x, mask_matrix):
295
+ """ Forward function.
296
+ Args:
297
+ x: Input feature, tensor size (B, H*W, C).
298
+ H, W: Spatial resolution of the input feature.
299
+ mask_matrix: Attention mask for cyclic shift.
300
+ """
301
+ B, L, C = x.shape
302
+ H, W = self.H, self.W
303
+ assert L == H * W, "input feature has wrong size"
304
+
305
+ shortcut = x
306
+ x = self.norm1(x)
307
+ x = x.view(B, H, W, C)
308
+
309
+ # pad feature maps to multiples of window size
310
+ pad_l = pad_t = 0
311
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
312
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
313
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
314
+ _, Hp, Wp, _ = x.shape
315
+
316
+ # cyclic shift
317
+ if self.shift_size > 0:
318
+ shifted_x = torch.roll(x,
319
+ shifts=(-self.shift_size, -self.shift_size),
320
+ dims=(1, 2))
321
+ attn_mask = mask_matrix
322
+ else:
323
+ shifted_x = x
324
+ attn_mask = None
325
+
326
+ # partition windows
327
+ x_windows = window_partition(
328
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
329
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
330
+ C) # nW*B, window_size*window_size, C
331
+
332
+ # W-MSA/SW-MSA
333
+ attn_windows = self.attn(
334
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
335
+
336
+ # merge windows
337
+ attn_windows = attn_windows.view(-1, self.window_size,
338
+ self.window_size, C)
339
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
340
+ Wp) # B H' W' C
341
+
342
+ # reverse cyclic shift
343
+ if self.shift_size > 0:
344
+ x = torch.roll(shifted_x,
345
+ shifts=(self.shift_size, self.shift_size),
346
+ dims=(1, 2))
347
+ else:
348
+ x = shifted_x
349
+
350
+ if pad_r > 0 or pad_b > 0:
351
+ x = x[:, :H, :W, :].contiguous()
352
+
353
+ x = x.view(B, H * W, C)
354
+
355
+ # FFN
356
+ x = shortcut + self.drop_path(x)
357
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
358
+
359
+ return x
360
+ #+end_src
361
+
362
+ ** AEMatter.class.py
363
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
364
+ class PatchMerging(nn.Module):
365
+ """ Patch Merging Layer
366
+ Args:
367
+ dim (int): Number of input channels.
368
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
369
+ """
370
+
371
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
372
+ super().__init__()
373
+ self.dim = dim
374
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
375
+ self.norm = norm_layer(4 * dim)
376
+
377
+ def forward(self, x, H, W):
378
+ """ Forward function.
379
+ Args:
380
+ x: Input feature, tensor size (B, H*W, C).
381
+ H, W: Spatial resolution of the input feature.
382
+ """
383
+ B, L, C = x.shape
384
+ assert L == H * W, "input feature has wrong size"
385
+
386
+ x = x.view(B, H, W, C)
387
+
388
+ # padding
389
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
390
+ if pad_input:
391
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
392
+
393
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
394
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
395
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
396
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
397
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
398
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
399
+
400
+ x = self.norm(x)
401
+ x = self.reduction(x)
402
+
403
+ return x
404
+ #+end_src
405
+
406
+
407
+ ** AEMatter.class.py
408
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
409
+ class BasicLayer(nn.Module):
410
+ """ A basic Swin Transformer layer for one stage.
411
+ Args:
412
+ dim (int): Number of feature channels
413
+ depth (int): Depths of this stage.
414
+ num_heads (int): Number of attention head.
415
+ window_size (int): Local window size. Default: 7.
416
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
417
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
418
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
419
+ drop (float, optional): Dropout rate. Default: 0.0
420
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
421
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
422
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
423
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
424
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
425
+ """
426
+
427
+ def __init__(self,
428
+ dim,
429
+ depth,
430
+ num_heads,
431
+ window_size=7,
432
+ mlp_ratio=4.,
433
+ qkv_bias=True,
434
+ qk_scale=None,
435
+ drop=0.,
436
+ attn_drop=0.,
437
+ drop_path=0.,
438
+ norm_layer=nn.LayerNorm,
439
+ downsample=None,
440
+ use_checkpoint=False):
441
+
442
+ super().__init__()
443
+ self.window_size = window_size
444
+ self.shift_size = window_size // 2
445
+ self.depth = depth
446
+ self.use_checkpoint = use_checkpoint
447
+
448
+ # build blocks
449
+ self.blocks = nn.ModuleList([
450
+ SwinTransformerBlock(dim=dim,
451
+ num_heads=num_heads,
452
+ window_size=window_size,
453
+ shift_size=0 if
454
+ (i % 2 == 0) else window_size // 2,
455
+ mlp_ratio=mlp_ratio,
456
+ qkv_bias=qkv_bias,
457
+ qk_scale=qk_scale,
458
+ drop=drop,
459
+ attn_drop=attn_drop,
460
+ drop_path=drop_path[i] if isinstance(
461
+ drop_path, list) else drop_path,
462
+ norm_layer=norm_layer) for i in range(depth)
463
+ ])
464
+
465
+ # patch merging layer
466
+ if downsample is not None:
467
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
468
+ else:
469
+ self.downsample = None
470
+
471
+ def forward(self, x, H, W):
472
+ """ Forward function.
473
+ Args:
474
+ x: Input feature, tensor size (B, H*W, C).
475
+ H, W: Spatial resolution of the input feature.
476
+ """
477
+ # print(x.shape,H,W)
478
+ # calculate attention mask for SW-MSA
479
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
480
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
481
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
482
+ h_slices = (slice(0, -self.window_size),
483
+ slice(-self.window_size,
484
+ -self.shift_size), slice(-self.shift_size, None))
485
+ w_slices = (slice(0, -self.window_size),
486
+ slice(-self.window_size,
487
+ -self.shift_size), slice(-self.shift_size, None))
488
+ cnt = 0
489
+ for h in h_slices:
490
+ for w in w_slices:
491
+ img_mask[:, h, w, :] = cnt
492
+ cnt += 1
493
+
494
+ mask_windows = window_partition(
495
+ img_mask, self.window_size) # nW, window_size, window_size, 1
496
+
497
+ mask_windows = mask_windows.view(-1,
498
+ self.window_size * self.window_size)
499
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
500
+ 2) # nW, ww window_size*window_size
501
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
502
+ float(-100.0)).masked_fill(
503
+ attn_mask == 0, float(0.0))
504
+
505
+ for blk in self.blocks:
506
+ blk.H, blk.W = H, W
507
+ if self.use_checkpoint:
508
+ x = checkpoint.checkpoint(blk, x, attn_mask)
509
+ else:
510
+ x = blk(x, attn_mask)
511
+
512
+ if self.downsample is not None:
513
+ x_down = self.downsample(x, H, W)
514
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
515
+ return x, H, W, x_down, Wh, Ww
516
+ else:
517
+ return x, H, W, x, H, W
518
+ #+end_src
519
+
520
+ ** AEMatter.class.py
521
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
522
+ class PatchEmbed(nn.Module):
523
+ """ Image to Patch Embedding
524
+ Args:
525
+ patch_size (int): Patch token size. Default: 4.
526
+ in_chans (int): Number of input image channels. Default: 3.
527
+ embed_dim (int): Number of linear projection output channels. Default: 96.
528
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
529
+ """
530
+
531
+ def __init__(self,
532
+ patch_size=4,
533
+ in_chans=3,
534
+ embed_dim=96,
535
+ norm_layer=None):
536
+
537
+ super().__init__()
538
+ patch_size = to_2tuple(patch_size)
539
+ self.patch_size = patch_size
540
+
541
+ self.in_chans = in_chans
542
+ self.embed_dim = embed_dim
543
+
544
+ self.proj = nn.Conv2d(in_chans,
545
+ embed_dim,
546
+ kernel_size=patch_size,
547
+ stride=patch_size)
548
+ if norm_layer is not None:
549
+ self.norm = norm_layer(embed_dim)
550
+ else:
551
+ self.norm = None
552
+
553
+ def forward(self, x):
554
+ """Forward function."""
555
+ # padding
556
+ _, _, H, W = x.size()
557
+ if W % self.patch_size[1] != 0:
558
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
559
+ if H % self.patch_size[0] != 0:
560
+ x = F.pad(x,
561
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
562
+
563
+ x = self.proj(x) # B C Wh Ww
564
+ if self.norm is not None:
565
+ Wh, Ww = x.size(2), x.size(3)
566
+ x = x.flatten(2).transpose(1, 2)
567
+ x = self.norm(x)
568
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
569
+
570
+ return x
571
+ #+end_src
572
+
573
+
574
+ ** AEMatter.class.py
575
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
576
+ class SwinTransformer(nn.Module):
577
+ """ Swin Transformer backbone.
578
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
579
+ https://arxiv.org/pdf/2103.14030
580
+ Args:
581
+ pretrain_img_size (int): Input image size for training the pretrained model,
582
+ used in absolute postion embedding. Default 224.
583
+ patch_size (int | tuple(int)): Patch size. Default: 4.
584
+ in_chans (int): Number of input image channels. Default: 3.
585
+ embed_dim (int): Number of linear projection output channels. Default: 96.
586
+ depths (tuple[int]): Depths of each Swin Transformer stage.
587
+ num_heads (tuple[int]): Number of attention head of each stage.
588
+ window_size (int): Window size. Default: 7.
589
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
590
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
591
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
592
+ drop_rate (float): Dropout rate.
593
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
594
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
595
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
596
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
597
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
598
+ out_indices (Sequence[int]): Output from which stages.
599
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
600
+ -1 means not freezing any parameters.
601
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
602
+ """
603
+
604
+ def __init__(self,
605
+ pretrain_img_size=224,
606
+ patch_size=4,
607
+ in_chans=3,
608
+ embed_dim=96,
609
+ depths=[2, 2, 6, 2],
610
+ num_heads=[3, 6, 12, 24],
611
+ window_size=7,
612
+ mlp_ratio=4.,
613
+ qkv_bias=True,
614
+ qk_scale=None,
615
+ drop_rate=0.,
616
+ attn_drop_rate=0.,
617
+ drop_path_rate=0.2,
618
+ norm_layer=nn.LayerNorm,
619
+ ape=False,
620
+ patch_norm=True,
621
+ out_indices=(0, 1, 2, 3),
622
+ frozen_stages=-1,
623
+ use_checkpoint=False):
624
+
625
+ super().__init__()
626
+
627
+ self.pretrain_img_size = pretrain_img_size
628
+ self.num_layers = len(depths)
629
+ self.embed_dim = embed_dim
630
+ self.ape = ape
631
+ self.patch_norm = patch_norm
632
+ self.out_indices = out_indices
633
+ self.frozen_stages = frozen_stages
634
+
635
+ # split image into non-overlapping patches
636
+ self.patch_embed = PatchEmbed(
637
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
638
+ norm_layer=norm_layer if self.patch_norm else None)
639
+
640
+ # absolute position embedding
641
+ if self.ape:
642
+ pretrain_img_size = to_2tuple(pretrain_img_size)
643
+ patch_size = to_2tuple(patch_size)
644
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
645
+
646
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
647
+ trunc_normal_(self.absolute_pos_embed, std=.02)
648
+
649
+ self.pos_drop = nn.Dropout(p=drop_rate)
650
+
651
+ # stochastic depth
652
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
653
+
654
+ # build layers
655
+ self.layers = nn.ModuleList()
656
+ for i_layer in range(self.num_layers):
657
+ layer = BasicLayer(
658
+ dim=int(embed_dim * 2 ** i_layer),
659
+ depth=depths[i_layer],
660
+ num_heads=num_heads[i_layer],
661
+ window_size=window_size,
662
+ mlp_ratio=mlp_ratio,
663
+ qkv_bias=qkv_bias,
664
+ qk_scale=qk_scale,
665
+ drop=drop_rate,
666
+ attn_drop=attn_drop_rate,
667
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
668
+ norm_layer=norm_layer,
669
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
670
+ use_checkpoint=use_checkpoint)
671
+ self.layers.append(layer)
672
+
673
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
674
+ self.num_features = num_features
675
+
676
+ # add a norm layer for each output
677
+ for i_layer in out_indices:
678
+ layer = norm_layer(num_features[i_layer])
679
+ layer_name = f'norm{i_layer}'
680
+ self.add_module(layer_name, layer)
681
+
682
+ self._freeze_stages()
683
+
684
+ def _freeze_stages(self):
685
+ if self.frozen_stages >= 0:
686
+ self.patch_embed.eval()
687
+ for param in self.patch_embed.parameters():
688
+ param.requires_grad = False
689
+
690
+ if self.frozen_stages >= 1 and self.ape:
691
+ self.absolute_pos_embed.requires_grad = False
692
+
693
+ if self.frozen_stages >= 2:
694
+ self.pos_drop.eval()
695
+ for i in range(0, self.frozen_stages - 1):
696
+ m = self.layers[i]
697
+ m.eval()
698
+ for param in m.parameters():
699
+ param.requires_grad = False
700
+
701
+ def init_weights(self, pretrained=None):
702
+ """Initialize the weights in backbone.
703
+ Args:
704
+ pretrained (str, optional): Path to pre-trained weights.
705
+ Defaults to None.
706
+ """
707
+
708
+
709
+ def forward(self, x):
710
+ """Forward function."""
711
+ x = self.patch_embed(x)
712
+
713
+ Wh, Ww = x.size(2), x.size(3)
714
+ if self.ape:
715
+ # interpolate the position embedding to the corresponding size
716
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
717
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
718
+ else:
719
+ x = x.flatten(2).transpose(1, 2)
720
+ x = self.pos_drop(x)
721
+
722
+ outs = []
723
+ for i in range(self.num_layers):
724
+ layer = self.layers[i]
725
+
726
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
727
+
728
+ if i in self.out_indices:
729
+ norm_layer = getattr(self, f'norm{i}')
730
+ x_out = norm_layer(x_out)
731
+
732
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
733
+ outs.append(out)
734
+
735
+ return tuple(outs)
736
+
737
+ def train(self, mode=True):
738
+ """Convert the model into training mode while keep layers freezed."""
739
+ super(SwinTransformer, self).train(mode)
740
+ self._freeze_stages()
741
+ #+end_src
742
+
743
+ ** AEMatter.class.py
744
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
745
+ class Mlp(nn.Module):
746
+ """ Multilayer perceptron."""
747
+
748
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
749
+ super().__init__()
750
+ out_features = out_features or in_features
751
+ hidden_features = hidden_features or in_features
752
+ self.fc1 = nn.Linear(in_features, hidden_features)
753
+ self.act = act_layer()
754
+ self.fc2 = nn.Linear(hidden_features, out_features)
755
+ self.drop = nn.Dropout(drop)
756
+
757
+ def forward(self, x):
758
+ x = self.fc1(x)
759
+ x = self.act(x)
760
+ x = self.drop(x)
761
+ x = self.fc2(x)
762
+ x = self.drop(x)
763
+ return x
764
+ #+end_src
765
+
766
+
767
+ ** AEMatter.class.py
768
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
769
+ class ResBlock(nn.Module):
770
+
771
+ def __init__(self, inc, midc):
772
+ super(ResBlock, self).__init__()
773
+ self.conv1 = nn.Conv2d(inc,
774
+ midc,
775
+ kernel_size=1,
776
+ stride=1,
777
+ padding=0,
778
+ bias=True)
779
+ self.gn1 = nn.GroupNorm(16, midc)
780
+ self.conv2 = nn.Conv2d(midc,
781
+ midc,
782
+ kernel_size=3,
783
+ stride=1,
784
+ padding=1,
785
+ bias=True)
786
+ self.gn2 = nn.GroupNorm(16, midc)
787
+ self.conv3 = nn.Conv2d(midc,
788
+ inc,
789
+ kernel_size=1,
790
+ stride=1,
791
+ padding=0,
792
+ bias=True)
793
+ self.relu = nn.LeakyReLU(0.1)
794
+
795
+ def forward(self, x):
796
+ x_ = x
797
+ x = self.conv1(x)
798
+ x = self.gn1(x)
799
+ x = self.relu(x)
800
+ x = self.conv2(x)
801
+ x = self.gn2(x)
802
+ x = self.relu(x)
803
+ x = self.conv3(x)
804
+ x = x + x_
805
+ x = self.relu(x)
806
+ return x
807
+ #+end_src
808
+
809
+ ** AEMatter.class.py
810
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
811
+ class AEALblock(nn.Module):
812
+
813
+ def __init__(self,
814
+ d_model,
815
+ nhead,
816
+ dim_feedforward=512,
817
+ dropout=0.0,
818
+ layer_norm_eps=1e-5,
819
+ batch_first=True,
820
+ norm_first=False,
821
+ width=5):
822
+ super(AEALblock, self).__init__()
823
+ self.self_attn2 = nn.MultiheadAttention(d_model // 2,
824
+ nhead // 2,
825
+ dropout=dropout,
826
+ batch_first=batch_first)
827
+ self.self_attn1 = nn.MultiheadAttention(d_model // 2,
828
+ nhead // 2,
829
+ dropout=dropout,
830
+ batch_first=batch_first)
831
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
832
+ self.dropout = nn.Dropout(dropout)
833
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
834
+ self.norm_first = norm_first
835
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
836
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
837
+ self.dropout1 = nn.Dropout(dropout)
838
+ self.dropout2 = nn.Dropout(dropout)
839
+ self.activation = nn.ReLU()
840
+ self.width = width
841
+ self.trans = nn.Sequential(
842
+ nn.Conv2d(d_model + 512, d_model // 2, 1, 1, 0),
843
+ ResBlock(d_model // 2, d_model // 4),
844
+ nn.Conv2d(d_model // 2, d_model, 1, 1, 0))
845
+ self.gamma = nn.Parameter(torch.zeros(1))
846
+
847
+ def forward(
848
+ self,
849
+ src,
850
+ feats,
851
+ ):
852
+ src = self.gamma * self.trans(torch.cat([src, feats], 1)) + src
853
+ b, c, h, w = src.shape
854
+ x1 = src[:, 0:c // 2]
855
+ x1_ = rearrange(x1, 'b c (h1 h2) w -> b c h1 h2 w', h2=self.width)
856
+ x1_ = rearrange(x1_, 'b c h1 h2 w -> (b h1) (h2 w) c')
857
+ x2 = src[:, c // 2:]
858
+ x2_ = rearrange(x2, 'b c h (w1 w2) -> b c h w1 w2', w2=self.width)
859
+ x2_ = rearrange(x2_, 'b c h w1 w2 -> (b w1) (h w2) c')
860
+ x = rearrange(src, 'b c h w-> b (h w) c')
861
+ x = self.norm1(x + self._sa_block(x1_, x2_, h, w))
862
+ x = self.norm2(x + self._ff_block(x))
863
+ x = rearrange(x, 'b (h w) c->b c h w', h=h, w=w)
864
+ return x
865
+
866
+ def _sa_block(self, x1, x2, h, w):
867
+ x1 = self.self_attn1(x1,
868
+ x1,
869
+ x1,
870
+ attn_mask=None,
871
+ key_padding_mask=None,
872
+ need_weights=False)[0]
873
+
874
+ x2 = self.self_attn2(x2,
875
+ x2,
876
+ x2,
877
+ attn_mask=None,
878
+ key_padding_mask=None,
879
+ need_weights=False)[0]
880
+
881
+ x1 = rearrange(x1,
882
+ '(b h1) (h2 w) c-> b (h1 h2 w) c',
883
+ h2=self.width,
884
+ h1=h // self.width)
885
+ x2 = rearrange(x2,
886
+ ' (b w1) (h w2) c-> b (h w1 w2) c',
887
+ w2=self.width,
888
+ w1=w // self.width)
889
+ x = torch.cat([x1, x2], dim=2)
890
+ return self.dropout1(x)
891
+
892
+ def _ff_block(self, x):
893
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
894
+ return self.dropout2(x)
895
+ #+end_src
896
+
897
+ ** AEMatter.class.py
898
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
899
+ class AEMatter(nn.Module):
900
+
901
+ def __init__(self):
902
+ super(AEMatter, self).__init__()
903
+ trans = SwinTransformer(pretrain_img_size=224,
904
+ embed_dim=96,
905
+ depths=[2, 2, 6, 2],
906
+ num_heads=[3, 6, 12, 24],
907
+ window_size=7,
908
+ ape=False,
909
+ drop_path_rate=0.2,
910
+ patch_norm=True,
911
+ use_checkpoint=False)
912
+
913
+ # trans.load_state_dict(torch.load(
914
+ # '/home/asd/Desktop/swin_tiny_patch4_window7_224.pth',
915
+ # map_location="cpu")["model"],
916
+ # strict=False)
917
+
918
+ trans.patch_embed.proj = nn.Conv2d(64, 96, 3, 2, 1)
919
+
920
+ self.start_conv0 = nn.Sequential(nn.Conv2d(6, 48, 3, 1, 1),
921
+ nn.PReLU(48))
922
+
923
+ self.start_conv = nn.Sequential(nn.Conv2d(48, 64, 3, 2,
924
+ 1), nn.PReLU(64),
925
+ nn.Conv2d(64, 64, 3, 1, 1),
926
+ nn.PReLU(64))
927
+
928
+ self.trans = trans
929
+ self.conv1 = nn.Sequential(
930
+ nn.Conv2d(in_channels=640 + 768,
931
+ out_channels=256,
932
+ kernel_size=1,
933
+ stride=1,
934
+ padding=0,
935
+ bias=True))
936
+ self.conv2 = nn.Sequential(
937
+ nn.Conv2d(in_channels=256 + 384,
938
+ out_channels=256,
939
+ kernel_size=1,
940
+ stride=1,
941
+ padding=0,
942
+ bias=True), )
943
+ self.conv3 = nn.Sequential(
944
+ nn.Conv2d(in_channels=256 + 192,
945
+ out_channels=192,
946
+ kernel_size=1,
947
+ stride=1,
948
+ padding=0,
949
+ bias=True), )
950
+ self.conv4 = nn.Sequential(
951
+ nn.Conv2d(in_channels=192 + 96,
952
+ out_channels=128,
953
+ kernel_size=1,
954
+ stride=1,
955
+ padding=0,
956
+ bias=True), )
957
+ self.ctran0 = BasicLayer(256, 3, 8, 7, drop_path=0.09)
958
+ self.ctran1 = BasicLayer(256, 3, 8, 7, drop_path=0.07)
959
+ self.ctran2 = BasicLayer(192, 3, 6, 7, drop_path=0.05)
960
+ self.ctran3 = BasicLayer(128, 3, 4, 7, drop_path=0.03)
961
+ self.conv5 = nn.Sequential(
962
+ nn.Conv2d(in_channels=192,
963
+ out_channels=64,
964
+ kernel_size=3,
965
+ stride=1,
966
+ padding=1,
967
+ bias=True), nn.PReLU(64),
968
+ nn.Conv2d(in_channels=64,
969
+ out_channels=64,
970
+ kernel_size=3,
971
+ stride=1,
972
+ padding=1,
973
+ bias=True), nn.PReLU(64),
974
+ nn.Conv2d(in_channels=64,
975
+ out_channels=48,
976
+ kernel_size=3,
977
+ stride=1,
978
+ padding=1,
979
+ bias=True), nn.PReLU(48))
980
+ self.convo = nn.Sequential(
981
+ nn.Conv2d(in_channels=48 + 48 + 6,
982
+ out_channels=32,
983
+ kernel_size=3,
984
+ stride=1,
985
+ padding=1,
986
+ bias=True), nn.PReLU(32),
987
+ nn.Conv2d(in_channels=32,
988
+ out_channels=32,
989
+ kernel_size=3,
990
+ stride=1,
991
+ padding=1,
992
+ bias=True), nn.PReLU(32),
993
+ nn.Conv2d(in_channels=32,
994
+ out_channels=1,
995
+ kernel_size=3,
996
+ stride=1,
997
+ padding=1,
998
+ bias=True))
999
+ self.up = nn.Upsample(scale_factor=2,
1000
+ mode='bilinear',
1001
+ align_corners=False)
1002
+ self.upn = nn.Upsample(scale_factor=2, mode='nearest')
1003
+ self.apptrans = nn.Sequential(
1004
+ nn.Conv2d(256 + 384, 256, 1, 1, bias=True), ResBlock(256, 128),
1005
+ ResBlock(256, 128), nn.Conv2d(256, 512, 2, 2, bias=True),
1006
+ ResBlock(512, 128))
1007
+ self.emb = nn.Sequential(nn.Conv2d(768, 640, 1, 1, 0),
1008
+ ResBlock(640, 160))
1009
+ self.embdp = nn.Sequential(nn.Conv2d(640, 640, 1, 1, 0))
1010
+ self.h2l = nn.Conv2d(768, 256, 1, 1, 0)
1011
+ self.width = 5
1012
+ self.trans1 = AEALblock(d_model=640,
1013
+ nhead=20,
1014
+ dim_feedforward=2048,
1015
+ dropout=0.2,
1016
+ width=self.width)
1017
+ self.trans2 = AEALblock(d_model=640,
1018
+ nhead=20,
1019
+ dim_feedforward=2048,
1020
+ dropout=0.2,
1021
+ width=self.width)
1022
+ self.trans3 = AEALblock(d_model=640,
1023
+ nhead=20,
1024
+ dim_feedforward=2048,
1025
+ dropout=0.2,
1026
+ width=self.width)
1027
+
1028
+ def aeal(self, x, sem):
1029
+ xe = self.emb(x)
1030
+ x_ = xe
1031
+ x_ = self.embdp(x_)
1032
+ b, c, h1, w1 = x_.shape
1033
+ bnew_ph = int(np.ceil(h1 / self.width) * self.width) - h1
1034
+ bnew_pw = int(np.ceil(w1 / self.width) * self.width) - w1
1035
+ newph1 = bnew_ph // 2
1036
+ newph2 = bnew_ph - newph1
1037
+ newpw1 = bnew_pw // 2
1038
+ newpw2 = bnew_pw - newpw1
1039
+ x_ = F.pad(x_, (newpw1, newpw2, newph1, newph2))
1040
+ sem = F.pad(sem, (newpw1, newpw2, newph1, newph2))
1041
+ x_ = self.trans1(x_, sem)
1042
+ x_ = self.trans2(x_, sem)
1043
+ x_ = self.trans3(x_, sem)
1044
+ x_ = x_[:, :, newph1:h1 + newph1, newpw1:w1 + newpw1]
1045
+ return x_
1046
+
1047
+ def forward(self, x, y):
1048
+ inputs = torch.cat((x, y), 1)
1049
+ x = self.start_conv0(inputs)
1050
+ x_ = self.start_conv(x)
1051
+ x1, x2, x3, x4 = self.trans(x_)
1052
+ x4h = self.h2l(x4)
1053
+ x3s = self.apptrans(torch.cat([x3, self.upn(x4h)], 1))
1054
+ x4_ = self.aeal(x4, x3s)
1055
+ x4 = torch.cat((x4, x4_), 1)
1056
+ X4 = self.conv1(x4)
1057
+ wh, ww = X4.shape[2], X4.shape[3]
1058
+ X4 = rearrange(X4, 'b c h w -> b (h w) c')
1059
+ X4, _, _, _, _, _ = self.ctran0(X4, wh, ww)
1060
+ X4 = rearrange(X4, 'b (h w) c -> b c h w', h=wh, w=ww)
1061
+ X3 = self.up(X4)
1062
+ X3 = torch.cat((x3, X3), 1)
1063
+ X3 = self.conv2(X3)
1064
+ wh, ww = X3.shape[2], X3.shape[3]
1065
+ X3 = rearrange(X3, 'b c h w -> b (h w) c')
1066
+ X3, _, _, _, _, _ = self.ctran1(X3, wh, ww)
1067
+ X3 = rearrange(X3, 'b (h w) c -> b c h w', h=wh, w=ww)
1068
+ X2 = self.up(X3)
1069
+ X2 = torch.cat((x2, X2), 1)
1070
+ X2 = self.conv3(X2)
1071
+ wh, ww = X2.shape[2], X2.shape[3]
1072
+ X2 = rearrange(X2, 'b c h w -> b (h w) c')
1073
+ X2, _, _, _, _, _ = self.ctran2(X2, wh, ww)
1074
+ X2 = rearrange(X2, 'b (h w) c -> b c h w', h=wh, w=ww)
1075
+ X1 = self.up(X2)
1076
+ X1 = torch.cat((x1, X1), 1)
1077
+ X1 = self.conv4(X1)
1078
+ wh, ww = X1.shape[2], X1.shape[3]
1079
+ X1 = rearrange(X1, 'b c h w -> b (h w) c')
1080
+ X1, _, _, _, _, _ = self.ctran3(X1, wh, ww)
1081
+ X1 = rearrange(X1, 'b (h w) c -> b c h w', h=wh, w=ww)
1082
+ X0 = self.up(X1)
1083
+ X0 = torch.cat((x_, X0), 1)
1084
+ X0 = self.conv5(X0)
1085
+ X = self.up(X0)
1086
+ X = torch.cat((inputs, x, X), 1)
1087
+ alpha = self.convo(X)
1088
+ alpha = torch.clamp(alpha, min=0, max=1)
1089
+ return alpha
1090
+ #+end_src
1091
+
1092
+ ** Function to load model
1093
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
1094
+ def get_AEMatter_model(path_model_checkpoint):
1095
+
1096
+ download_model(path=path_model_checkpoint)
1097
+
1098
+ matmodel = AEMatter()
1099
+ matmodel.load_state_dict(
1100
+ torch.load(path_model_checkpoint, map_location='cpu')['model'])
1101
+
1102
+ matmodel = matmodel.cuda()
1103
+ matmodel.eval()
1104
+
1105
+ return matmodel
1106
+ #+end_src
1107
+
1108
+ ** Function to do inference
1109
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
1110
+ def do_infer(rawimg, trimap, matmodel):
1111
+ trimap_nonp = trimap.copy()
1112
+ h, w, c = rawimg.shape
1113
+ nonph, nonpw, _ = rawimg.shape
1114
+ newh = (((h - 1) // 32) + 1) * 32
1115
+ neww = (((w - 1) // 32) + 1) * 32
1116
+ padh = newh - h
1117
+ padh1 = int(padh / 2)
1118
+ padh2 = padh - padh1
1119
+ padw = neww - w
1120
+ padw1 = int(padw / 2)
1121
+ padw2 = padw - padw1
1122
+
1123
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
1124
+ cv2.BORDER_REFLECT)
1125
+
1126
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
1127
+ cv2.BORDER_REFLECT)
1128
+
1129
+ h_pad, w_pad, _ = rawimg_pad.shape
1130
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
1131
+ tritemp[:, :, 0] = (trimap_pad == 0)
1132
+ tritemp[:, :, 1] = (trimap_pad == 128)
1133
+ tritemp[:, :, 2] = (trimap_pad == 255)
1134
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
1135
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
1136
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
1137
+ img = np.array(img, np.float32)
1138
+ img = img / 255.
1139
+ img = torch.from_numpy(img).cuda()
1140
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
1141
+ with torch.no_grad():
1142
+ pred = matmodel(img, tritempimgs)
1143
+ pred = pred.detach().cpu().numpy()[0]
1144
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
1145
+ preda = pred[
1146
+ 0:1,
1147
+ ] * 255
1148
+ preda = np.transpose(preda, (1, 2, 0))
1149
+ preda = preda * (trimap_nonp[:, :, None]
1150
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
1151
+ preda = np.array(preda, np.uint8)
1152
+ return preda
1153
+ #+end_src
1154
+
1155
+ ** Load ComfyUI AEMatter model
1156
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
1157
+ class load_AEMatter_Model:
1158
+
1159
+ def __init__(self):
1160
+ pass
1161
+
1162
+ @classmethod
1163
+ def INPUT_TYPES(s):
1164
+ return {
1165
+ "required": {},
1166
+ }
1167
+
1168
+ RETURN_TYPES = ("AEMatter_Model", )
1169
+ FUNCTION = "test"
1170
+ CATEGORY = "AEMatter"
1171
+
1172
+ def test(self):
1173
+ return (get_AEMatter_model(get_model_path()), )
1174
+
1175
+
1176
+ class run_AEMatter_inference:
1177
+
1178
+ def __init__(self):
1179
+ pass
1180
+
1181
+ @classmethod
1182
+ def INPUT_TYPES(s):
1183
+ return {
1184
+ "required": {
1185
+ "image": ("IMAGE", ),
1186
+ "trimap": ("MASK", ),
1187
+ "AEMatter_Model": ("AEMatter_Model", ),
1188
+ },
1189
+ }
1190
+
1191
+ RETURN_TYPES = ("MASK", )
1192
+ FUNCTION = "test"
1193
+ CATEGORY = "AEMatter"
1194
+
1195
+ def test(
1196
+ self,
1197
+ image,
1198
+ trimap,
1199
+ AEMatter_Model,
1200
+ ):
1201
+
1202
+ ret = []
1203
+ batch_size = image.shape[0]
1204
+
1205
+ for i in range(batch_size):
1206
+ tmp_i = from_torch_image(image[i])
1207
+ tmp_m = from_torch_image(trimap[i])
1208
+ tmp = do_infer(tmp_i, tmp_m, AEMatter_Model)
1209
+ ret.append(tmp)
1210
+
1211
+ ret = to_torch_image(np.array(ret))
1212
+ ret = ret.squeeze(-1)
1213
+ print(ret.shape)
1214
+
1215
+ return ret
1216
+ #+end_src
1217
+
1218
+ ** Main function
1219
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
1220
+ def main():
1221
+ ptrimap = '/home/asd/Desktop/demo/retriever_trimap.png'
1222
+ pimgs = '/home/asd/Desktop/demo/retriever_rgb.png'
1223
+ p_outs = 'alpha.png'
1224
+
1225
+ matmodel = get_AEMatter_model(
1226
+ path_model_checkpoint='/home/asd/Desktop/AEM_RWA.ckpt')
1227
+
1228
+ # matmodel = AEMatter()
1229
+ # matmodel.load_state_dict(
1230
+ # torch.load('/home/asd/Desktop/AEM_RWA.ckpt',
1231
+ # map_location='cpu')['model'])
1232
+
1233
+ # matmodel = matmodel.cuda()
1234
+ # matmodel.eval()
1235
+
1236
+ rawimg = pimgs
1237
+ trimap = ptrimap
1238
+ rawimg = cv2.imread(rawimg, cv2.IMREAD_COLOR)
1239
+ trimap = cv2.imread(trimap, cv2.IMREAD_GRAYSCALE)
1240
+ trimap_nonp = trimap.copy()
1241
+ h, w, c = rawimg.shape
1242
+ nonph, nonpw, _ = rawimg.shape
1243
+ newh = (((h - 1) // 32) + 1) * 32
1244
+ neww = (((w - 1) // 32) + 1) * 32
1245
+ padh = newh - h
1246
+ padh1 = int(padh / 2)
1247
+ padh2 = padh - padh1
1248
+ padw = neww - w
1249
+ padw1 = int(padw / 2)
1250
+ padw2 = padw - padw1
1251
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
1252
+ cv2.BORDER_REFLECT)
1253
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
1254
+ cv2.BORDER_REFLECT)
1255
+ h_pad, w_pad, _ = rawimg_pad.shape
1256
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
1257
+ tritemp[:, :, 0] = (trimap_pad == 0)
1258
+ tritemp[:, :, 1] = (trimap_pad == 128)
1259
+ tritemp[:, :, 2] = (trimap_pad == 255)
1260
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
1261
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
1262
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
1263
+ img = np.array(img, np.float32)
1264
+ img = img / 255.
1265
+ img = torch.from_numpy(img).cuda()
1266
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
1267
+ with torch.no_grad():
1268
+ pred = matmodel(img, tritempimgs)
1269
+ pred = pred.detach().cpu().numpy()[0]
1270
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
1271
+ preda = pred[
1272
+ 0:1,
1273
+ ] * 255
1274
+ preda = np.transpose(preda, (1, 2, 0))
1275
+ preda = preda * (trimap_nonp[:, :, None]
1276
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
1277
+ preda = np.array(preda, np.uint8)
1278
+ cv2.imwrite(p_outs, preda)
1279
+
1280
+ #+end_src
1281
+
1282
+ ** Comfyui Dictionary
1283
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.execute.py
1284
+ NODE_CLASS_MAPPINGS = {
1285
+ 'load_AEMatter_Model': load_AEMatter_Model,
1286
+ 'run_AEMatter_inference': run_AEMatter_inference,
1287
+ }
1288
+
1289
+ NODE_DISPLAY_NAME_MAPPINGS = {
1290
+ 'load_AEMatter_Model': 'load_AEMatter_Model',
1291
+ 'run_AEMatter_inference': 'run_AEMatter_inference',
1292
+ }
1293
+ #+end_src
1294
+
1295
+ ** COMMENT AEMatter.execute.py
1296
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.execute.py
1297
+ if __name__ == '__main__':
1298
+ # main()
1299
+
1300
+ rawimg = cv2.imread('/home/asd/Desktop/demo/retriever_rgb.png',
1301
+ cv2.IMREAD_COLOR)
1302
+
1303
+ trimap = cv2.imread('/home/asd/Desktop/demo/retriever_trimap.png',
1304
+ cv2.IMREAD_GRAYSCALE)
1305
+
1306
+ do_infer(rawimg, trimap,
1307
+ get_AEMatter_model('/home/asd/Desktop/AEM_RWA.ckpt'))
1308
+ #+end_src
1309
+
1310
+ ** AEMatter.unify.sh
1311
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.unify.sh
1312
+ . "${HOME}/dbnew.sh"
1313
+
1314
+ cat \
1315
+ 'AEMatter.import.py' \
1316
+ 'AEMatter.function.py' \
1317
+ 'AEMatter.class.py' \
1318
+ 'AEMatter.execute.py' \
1319
+ | expand | yapf3 \
1320
+ > 'AEMatter.py' \
1321
+ ;
1322
+
1323
+ cp 'AEMatter.py' '__init__.py'
1324
+ #+end_src
1325
+
1326
+ ** AEMatter.run.sh
1327
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.run.sh
1328
+ . "${HOME}/dbnew.sh"
1329
+ python3 './AEMatter.py'
1330
+ #+end_src
1331
+
1332
+ #+RESULTS:
1333
+
1334
+ * COMMENT WORK SPACE
1335
+
1336
+ ** ESHELL
1337
+ #+begin_src elisp
1338
+ (save-buffer)
1339
+ (org-babel-tangle)
1340
+ (shell-command "./AEMatter.unify.sh")
1341
+ #+end_src
1342
+
1343
+ #+RESULTS:
1344
+ : 0
1345
+
1346
+ ** SHELL
1347
+ #+begin_src sh :shebang #!/bin/sh :results output
1348
+ realpath .
1349
+ cd /home/asd/GITHUB/aravind-h-v/dreambooth_experiments/AEMatter
1350
+ #+end_src
1351
+
1352
+ #+RESULTS:
1353
+
1354
+ ** SHELL
1355
+ #+begin_src sh :shebang #!/bin/sh :results output
1356
+ ls
1357
+ #+end_src
ComfyUI_AEMatter/__init__.py ADDED
@@ -0,0 +1,1248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import cv2
3
+ import math
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import wget
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import init
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+
15
+ from collections import OrderedDict
16
+ from einops import rearrange, repeat
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+
19
+ import folder_paths
20
+ from folder_paths import models_dir
21
+
22
+
23
+ #!/usr/bin/python3
24
+ def mkdir_safe(out_path):
25
+ if type(out_path) == str:
26
+ if len(out_path) > 0:
27
+ if not os.path.exists(out_path):
28
+ os.mkdir(out_path)
29
+
30
+
31
+ def get_model_path():
32
+ import folder_paths
33
+ from folder_paths import models_dir
34
+
35
+ path_file_model = models_dir
36
+ mkdir_safe(out_path=path_file_model)
37
+
38
+ path_file_model = os.path.join(path_file_model, 'AEMatter')
39
+ mkdir_safe(out_path=path_file_model)
40
+
41
+ path_file_model = os.path.join(path_file_model, 'AEM_RWA.ckpt')
42
+
43
+ return path_file_model
44
+
45
+
46
+ def download_model(path):
47
+ if not os.path.exists(path):
48
+ wget.download(
49
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/AEMatter/AEM_RWA.ckpt?download=true',
50
+ out=path)
51
+
52
+
53
+ def from_torch_image(image):
54
+ image = image.cpu().numpy() * 255.0
55
+ image = np.clip(image, 0, 255).astype(np.uint8)
56
+ return image
57
+
58
+
59
+ def to_torch_image(image):
60
+ image = image.astype(dtype=np.float32)
61
+ image /= 255.0
62
+ image = torch.from_numpy(image)
63
+ return image
64
+
65
+
66
+ def window_partition(x, window_size):
67
+ """
68
+ Args:
69
+ x: (B, H, W, C)
70
+ window_size (int): window size
71
+ Returns:
72
+ windows: (num_windows*B, window_size, window_size, C)
73
+ """
74
+ B, H, W, C = x.shape
75
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
76
+ C)
77
+ windows = x.permute(0, 1, 3, 2, 4,
78
+ 5).contiguous().view(-1, window_size, window_size, C)
79
+ return windows
80
+
81
+
82
+ def window_reverse(windows, window_size, H, W):
83
+ """
84
+ Args:
85
+ windows: (num_windows*B, window_size, window_size, C)
86
+ window_size (int): Window size
87
+ H (int): Height of image
88
+ W (int): Width of image
89
+ Returns:
90
+ x: (B, H, W, C)
91
+ """
92
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
93
+ x = windows.view(B, H // window_size, W // window_size, window_size,
94
+ window_size, -1)
95
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
96
+ return x
97
+
98
+
99
+ def get_AEMatter_model(path_model_checkpoint):
100
+
101
+ download_model(path=path_model_checkpoint)
102
+
103
+ matmodel = AEMatter()
104
+ matmodel.load_state_dict(
105
+ torch.load(path_model_checkpoint, map_location='cpu')['model'])
106
+
107
+ matmodel = matmodel.cuda()
108
+ matmodel.eval()
109
+
110
+ return matmodel
111
+
112
+
113
+ def do_infer(rawimg, trimap, matmodel):
114
+ trimap_nonp = trimap.copy()
115
+ h, w, c = rawimg.shape
116
+ nonph, nonpw, _ = rawimg.shape
117
+ newh = (((h - 1) // 32) + 1) * 32
118
+ neww = (((w - 1) // 32) + 1) * 32
119
+ padh = newh - h
120
+ padh1 = int(padh / 2)
121
+ padh2 = padh - padh1
122
+ padw = neww - w
123
+ padw1 = int(padw / 2)
124
+ padw2 = padw - padw1
125
+
126
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
127
+ cv2.BORDER_REFLECT)
128
+
129
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
130
+ cv2.BORDER_REFLECT)
131
+
132
+ h_pad, w_pad, _ = rawimg_pad.shape
133
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
134
+ tritemp[:, :, 0] = (trimap_pad == 0)
135
+ tritemp[:, :, 1] = (trimap_pad == 128)
136
+ tritemp[:, :, 2] = (trimap_pad == 255)
137
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
138
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
139
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
140
+ img = np.array(img, np.float32)
141
+ img = img / 255.
142
+ img = torch.from_numpy(img).cuda()
143
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
144
+ with torch.no_grad():
145
+ pred = matmodel(img, tritempimgs)
146
+ pred = pred.detach().cpu().numpy()[0]
147
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
148
+ preda = pred[
149
+ 0:1,
150
+ ] * 255
151
+ preda = np.transpose(preda, (1, 2, 0))
152
+ preda = preda * (trimap_nonp[:, :, None]
153
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
154
+ preda = np.array(preda, np.uint8)
155
+ return preda
156
+
157
+
158
+ def main():
159
+ ptrimap = '/home/asd/Desktop/demo/retriever_trimap.png'
160
+ pimgs = '/home/asd/Desktop/demo/retriever_rgb.png'
161
+ p_outs = 'alpha.png'
162
+
163
+ matmodel = get_AEMatter_model(
164
+ path_model_checkpoint='/home/asd/Desktop/AEM_RWA.ckpt')
165
+
166
+ # matmodel = AEMatter()
167
+ # matmodel.load_state_dict(
168
+ # torch.load('/home/asd/Desktop/AEM_RWA.ckpt',
169
+ # map_location='cpu')['model'])
170
+
171
+ # matmodel = matmodel.cuda()
172
+ # matmodel.eval()
173
+
174
+ rawimg = pimgs
175
+ trimap = ptrimap
176
+ rawimg = cv2.imread(rawimg, cv2.IMREAD_COLOR)
177
+ trimap = cv2.imread(trimap, cv2.IMREAD_GRAYSCALE)
178
+ trimap_nonp = trimap.copy()
179
+ h, w, c = rawimg.shape
180
+ nonph, nonpw, _ = rawimg.shape
181
+ newh = (((h - 1) // 32) + 1) * 32
182
+ neww = (((w - 1) // 32) + 1) * 32
183
+ padh = newh - h
184
+ padh1 = int(padh / 2)
185
+ padh2 = padh - padh1
186
+ padw = neww - w
187
+ padw1 = int(padw / 2)
188
+ padw2 = padw - padw1
189
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
190
+ cv2.BORDER_REFLECT)
191
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
192
+ cv2.BORDER_REFLECT)
193
+ h_pad, w_pad, _ = rawimg_pad.shape
194
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
195
+ tritemp[:, :, 0] = (trimap_pad == 0)
196
+ tritemp[:, :, 1] = (trimap_pad == 128)
197
+ tritemp[:, :, 2] = (trimap_pad == 255)
198
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
199
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
200
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
201
+ img = np.array(img, np.float32)
202
+ img = img / 255.
203
+ img = torch.from_numpy(img).cuda()
204
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
205
+ with torch.no_grad():
206
+ pred = matmodel(img, tritempimgs)
207
+ pred = pred.detach().cpu().numpy()[0]
208
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
209
+ preda = pred[
210
+ 0:1,
211
+ ] * 255
212
+ preda = np.transpose(preda, (1, 2, 0))
213
+ preda = preda * (trimap_nonp[:, :, None]
214
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
215
+ preda = np.array(preda, np.uint8)
216
+ cv2.imwrite(p_outs, preda)
217
+
218
+
219
+ #!/usr/bin/python3
220
+ class WindowAttention(nn.Module):
221
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
222
+ It supports both of shifted and non-shifted window.
223
+ Args:
224
+ dim (int): Number of input channels.
225
+ window_size (tuple[int]): The height and width of the window.
226
+ num_heads (int): Number of attention heads.
227
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
228
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
229
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
230
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
231
+ """
232
+
233
+ def __init__(self,
234
+ dim,
235
+ window_size,
236
+ num_heads,
237
+ qkv_bias=True,
238
+ qk_scale=None,
239
+ attn_drop=0.,
240
+ proj_drop=0.):
241
+
242
+ super().__init__()
243
+ self.dim = dim
244
+ self.window_size = window_size # Wh, Ww
245
+ self.num_heads = num_heads
246
+ head_dim = dim // num_heads
247
+ self.scale = qk_scale or head_dim**-0.5
248
+
249
+ # define a parameter table of relative position bias
250
+ self.relative_position_bias_table = nn.Parameter(
251
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
252
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
253
+
254
+ # get pair-wise relative position index for each token inside the window
255
+ coords_h = torch.arange(self.window_size[0])
256
+ coords_w = torch.arange(self.window_size[1])
257
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
258
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
259
+ relative_coords = coords_flatten[:, :,
260
+ None] - coords_flatten[:,
261
+ None, :] # 2, Wh*Ww, Wh*Ww
262
+ relative_coords = relative_coords.permute(
263
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
264
+ relative_coords[:, :,
265
+ 0] += self.window_size[0] - 1 # shift to start from 0
266
+ relative_coords[:, :, 1] += self.window_size[1] - 1
267
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
268
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
269
+ self.register_buffer("relative_position_index",
270
+ relative_position_index)
271
+
272
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
273
+ self.attn_drop = nn.Dropout(attn_drop)
274
+ self.proj = nn.Linear(dim, dim)
275
+ self.proj_drop = nn.Dropout(proj_drop)
276
+
277
+ trunc_normal_(self.relative_position_bias_table, std=.02)
278
+ self.softmax = nn.Softmax(dim=-1)
279
+
280
+ def forward(self, x, mask=None):
281
+ """ Forward function.
282
+ Args:
283
+ x: input features with shape of (num_windows*B, N, C)
284
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
285
+ """
286
+ B_, N, C = x.shape
287
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
288
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
289
+ q, k, v = qkv[0], qkv[1], qkv[
290
+ 2] # make torchscript happy (cannot use tensor as tuple)
291
+
292
+ q = q * self.scale
293
+ attn = (q @ k.transpose(-2, -1))
294
+
295
+ relative_position_bias = self.relative_position_bias_table[
296
+ self.relative_position_index.view(-1)].view(
297
+ self.window_size[0] * self.window_size[1],
298
+ self.window_size[0] * self.window_size[1],
299
+ -1) # Wh*Ww,Wh*Ww,nH
300
+ relative_position_bias = relative_position_bias.permute(
301
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
302
+ attn = attn + relative_position_bias.unsqueeze(0)
303
+
304
+ if mask is not None:
305
+ nW = mask.shape[0]
306
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
307
+ N) + mask.unsqueeze(1).unsqueeze(0)
308
+ attn = attn.view(-1, self.num_heads, N, N)
309
+ attn = self.softmax(attn)
310
+ else:
311
+ attn = self.softmax(attn)
312
+
313
+ attn = self.attn_drop(attn)
314
+
315
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
316
+ x = self.proj(x)
317
+ x = self.proj_drop(x)
318
+ return x
319
+
320
+
321
+ class SwinTransformerBlock(nn.Module):
322
+ """ Swin Transformer Block.
323
+ Args:
324
+ dim (int): Number of input channels.
325
+ num_heads (int): Number of attention heads.
326
+ window_size (int): Window size.
327
+ shift_size (int): Shift size for SW-MSA.
328
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
329
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
330
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
331
+ drop (float, optional): Dropout rate. Default: 0.0
332
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
333
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
334
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
335
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
336
+ """
337
+
338
+ def __init__(self,
339
+ dim,
340
+ num_heads,
341
+ window_size=7,
342
+ shift_size=0,
343
+ mlp_ratio=4.,
344
+ qkv_bias=True,
345
+ qk_scale=None,
346
+ drop=0.,
347
+ attn_drop=0.,
348
+ drop_path=0.,
349
+ act_layer=nn.GELU,
350
+ norm_layer=nn.LayerNorm):
351
+ super().__init__()
352
+ self.dim = dim
353
+ self.num_heads = num_heads
354
+ self.window_size = window_size
355
+ self.shift_size = shift_size
356
+ self.mlp_ratio = mlp_ratio
357
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
358
+
359
+ self.norm1 = norm_layer(dim)
360
+ self.attn = WindowAttention(dim,
361
+ window_size=to_2tuple(self.window_size),
362
+ num_heads=num_heads,
363
+ qkv_bias=qkv_bias,
364
+ qk_scale=qk_scale,
365
+ attn_drop=attn_drop,
366
+ proj_drop=drop)
367
+
368
+ self.drop_path = DropPath(
369
+ drop_path) if drop_path > 0. else nn.Identity()
370
+ self.norm2 = norm_layer(dim)
371
+ mlp_hidden_dim = int(dim * mlp_ratio)
372
+ self.mlp = Mlp(in_features=dim,
373
+ hidden_features=mlp_hidden_dim,
374
+ act_layer=act_layer,
375
+ drop=drop)
376
+
377
+ self.H = None
378
+ self.W = None
379
+
380
+ def forward(self, x, mask_matrix):
381
+ """ Forward function.
382
+ Args:
383
+ x: Input feature, tensor size (B, H*W, C).
384
+ H, W: Spatial resolution of the input feature.
385
+ mask_matrix: Attention mask for cyclic shift.
386
+ """
387
+ B, L, C = x.shape
388
+ H, W = self.H, self.W
389
+ assert L == H * W, "input feature has wrong size"
390
+
391
+ shortcut = x
392
+ x = self.norm1(x)
393
+ x = x.view(B, H, W, C)
394
+
395
+ # pad feature maps to multiples of window size
396
+ pad_l = pad_t = 0
397
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
398
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
399
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
400
+ _, Hp, Wp, _ = x.shape
401
+
402
+ # cyclic shift
403
+ if self.shift_size > 0:
404
+ shifted_x = torch.roll(x,
405
+ shifts=(-self.shift_size, -self.shift_size),
406
+ dims=(1, 2))
407
+ attn_mask = mask_matrix
408
+ else:
409
+ shifted_x = x
410
+ attn_mask = None
411
+
412
+ # partition windows
413
+ x_windows = window_partition(
414
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
415
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
416
+ C) # nW*B, window_size*window_size, C
417
+
418
+ # W-MSA/SW-MSA
419
+ attn_windows = self.attn(
420
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
421
+
422
+ # merge windows
423
+ attn_windows = attn_windows.view(-1, self.window_size,
424
+ self.window_size, C)
425
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
426
+ Wp) # B H' W' C
427
+
428
+ # reverse cyclic shift
429
+ if self.shift_size > 0:
430
+ x = torch.roll(shifted_x,
431
+ shifts=(self.shift_size, self.shift_size),
432
+ dims=(1, 2))
433
+ else:
434
+ x = shifted_x
435
+
436
+ if pad_r > 0 or pad_b > 0:
437
+ x = x[:, :H, :W, :].contiguous()
438
+
439
+ x = x.view(B, H * W, C)
440
+
441
+ # FFN
442
+ x = shortcut + self.drop_path(x)
443
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
444
+
445
+ return x
446
+
447
+
448
+ class PatchMerging(nn.Module):
449
+ """ Patch Merging Layer
450
+ Args:
451
+ dim (int): Number of input channels.
452
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
453
+ """
454
+
455
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
456
+ super().__init__()
457
+ self.dim = dim
458
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
459
+ self.norm = norm_layer(4 * dim)
460
+
461
+ def forward(self, x, H, W):
462
+ """ Forward function.
463
+ Args:
464
+ x: Input feature, tensor size (B, H*W, C).
465
+ H, W: Spatial resolution of the input feature.
466
+ """
467
+ B, L, C = x.shape
468
+ assert L == H * W, "input feature has wrong size"
469
+
470
+ x = x.view(B, H, W, C)
471
+
472
+ # padding
473
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
474
+ if pad_input:
475
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
476
+
477
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
478
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
479
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
480
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
481
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
482
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
483
+
484
+ x = self.norm(x)
485
+ x = self.reduction(x)
486
+
487
+ return x
488
+
489
+
490
+ class BasicLayer(nn.Module):
491
+ """ A basic Swin Transformer layer for one stage.
492
+ Args:
493
+ dim (int): Number of feature channels
494
+ depth (int): Depths of this stage.
495
+ num_heads (int): Number of attention head.
496
+ window_size (int): Local window size. Default: 7.
497
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
498
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
499
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
500
+ drop (float, optional): Dropout rate. Default: 0.0
501
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
502
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
503
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
504
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
505
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
506
+ """
507
+
508
+ def __init__(self,
509
+ dim,
510
+ depth,
511
+ num_heads,
512
+ window_size=7,
513
+ mlp_ratio=4.,
514
+ qkv_bias=True,
515
+ qk_scale=None,
516
+ drop=0.,
517
+ attn_drop=0.,
518
+ drop_path=0.,
519
+ norm_layer=nn.LayerNorm,
520
+ downsample=None,
521
+ use_checkpoint=False):
522
+
523
+ super().__init__()
524
+ self.window_size = window_size
525
+ self.shift_size = window_size // 2
526
+ self.depth = depth
527
+ self.use_checkpoint = use_checkpoint
528
+
529
+ # build blocks
530
+ self.blocks = nn.ModuleList([
531
+ SwinTransformerBlock(dim=dim,
532
+ num_heads=num_heads,
533
+ window_size=window_size,
534
+ shift_size=0 if
535
+ (i % 2 == 0) else window_size // 2,
536
+ mlp_ratio=mlp_ratio,
537
+ qkv_bias=qkv_bias,
538
+ qk_scale=qk_scale,
539
+ drop=drop,
540
+ attn_drop=attn_drop,
541
+ drop_path=drop_path[i] if isinstance(
542
+ drop_path, list) else drop_path,
543
+ norm_layer=norm_layer) for i in range(depth)
544
+ ])
545
+
546
+ # patch merging layer
547
+ if downsample is not None:
548
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
549
+ else:
550
+ self.downsample = None
551
+
552
+ def forward(self, x, H, W):
553
+ """ Forward function.
554
+ Args:
555
+ x: Input feature, tensor size (B, H*W, C).
556
+ H, W: Spatial resolution of the input feature.
557
+ """
558
+ # print(x.shape,H,W)
559
+ # calculate attention mask for SW-MSA
560
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
561
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
562
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
563
+ h_slices = (slice(0, -self.window_size),
564
+ slice(-self.window_size,
565
+ -self.shift_size), slice(-self.shift_size, None))
566
+ w_slices = (slice(0, -self.window_size),
567
+ slice(-self.window_size,
568
+ -self.shift_size), slice(-self.shift_size, None))
569
+ cnt = 0
570
+ for h in h_slices:
571
+ for w in w_slices:
572
+ img_mask[:, h, w, :] = cnt
573
+ cnt += 1
574
+
575
+ mask_windows = window_partition(
576
+ img_mask, self.window_size) # nW, window_size, window_size, 1
577
+
578
+ mask_windows = mask_windows.view(-1,
579
+ self.window_size * self.window_size)
580
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
581
+ 2) # nW, ww window_size*window_size
582
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
583
+ float(-100.0)).masked_fill(
584
+ attn_mask == 0, float(0.0))
585
+
586
+ for blk in self.blocks:
587
+ blk.H, blk.W = H, W
588
+ if self.use_checkpoint:
589
+ x = checkpoint.checkpoint(blk, x, attn_mask)
590
+ else:
591
+ x = blk(x, attn_mask)
592
+
593
+ if self.downsample is not None:
594
+ x_down = self.downsample(x, H, W)
595
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
596
+ return x, H, W, x_down, Wh, Ww
597
+ else:
598
+ return x, H, W, x, H, W
599
+
600
+
601
+ class PatchEmbed(nn.Module):
602
+ """ Image to Patch Embedding
603
+ Args:
604
+ patch_size (int): Patch token size. Default: 4.
605
+ in_chans (int): Number of input image channels. Default: 3.
606
+ embed_dim (int): Number of linear projection output channels. Default: 96.
607
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
608
+ """
609
+
610
+ def __init__(self,
611
+ patch_size=4,
612
+ in_chans=3,
613
+ embed_dim=96,
614
+ norm_layer=None):
615
+
616
+ super().__init__()
617
+ patch_size = to_2tuple(patch_size)
618
+ self.patch_size = patch_size
619
+
620
+ self.in_chans = in_chans
621
+ self.embed_dim = embed_dim
622
+
623
+ self.proj = nn.Conv2d(in_chans,
624
+ embed_dim,
625
+ kernel_size=patch_size,
626
+ stride=patch_size)
627
+ if norm_layer is not None:
628
+ self.norm = norm_layer(embed_dim)
629
+ else:
630
+ self.norm = None
631
+
632
+ def forward(self, x):
633
+ """Forward function."""
634
+ # padding
635
+ _, _, H, W = x.size()
636
+ if W % self.patch_size[1] != 0:
637
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
638
+ if H % self.patch_size[0] != 0:
639
+ x = F.pad(x,
640
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
641
+
642
+ x = self.proj(x) # B C Wh Ww
643
+ if self.norm is not None:
644
+ Wh, Ww = x.size(2), x.size(3)
645
+ x = x.flatten(2).transpose(1, 2)
646
+ x = self.norm(x)
647
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
648
+
649
+ return x
650
+
651
+
652
+ class SwinTransformer(nn.Module):
653
+ """ Swin Transformer backbone.
654
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
655
+ https://arxiv.org/pdf/2103.14030
656
+ Args:
657
+ pretrain_img_size (int): Input image size for training the pretrained model,
658
+ used in absolute postion embedding. Default 224.
659
+ patch_size (int | tuple(int)): Patch size. Default: 4.
660
+ in_chans (int): Number of input image channels. Default: 3.
661
+ embed_dim (int): Number of linear projection output channels. Default: 96.
662
+ depths (tuple[int]): Depths of each Swin Transformer stage.
663
+ num_heads (tuple[int]): Number of attention head of each stage.
664
+ window_size (int): Window size. Default: 7.
665
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
666
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
667
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
668
+ drop_rate (float): Dropout rate.
669
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
670
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
671
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
672
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
673
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
674
+ out_indices (Sequence[int]): Output from which stages.
675
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
676
+ -1 means not freezing any parameters.
677
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
678
+ """
679
+
680
+ def __init__(self,
681
+ pretrain_img_size=224,
682
+ patch_size=4,
683
+ in_chans=3,
684
+ embed_dim=96,
685
+ depths=[2, 2, 6, 2],
686
+ num_heads=[3, 6, 12, 24],
687
+ window_size=7,
688
+ mlp_ratio=4.,
689
+ qkv_bias=True,
690
+ qk_scale=None,
691
+ drop_rate=0.,
692
+ attn_drop_rate=0.,
693
+ drop_path_rate=0.2,
694
+ norm_layer=nn.LayerNorm,
695
+ ape=False,
696
+ patch_norm=True,
697
+ out_indices=(0, 1, 2, 3),
698
+ frozen_stages=-1,
699
+ use_checkpoint=False):
700
+
701
+ super().__init__()
702
+
703
+ self.pretrain_img_size = pretrain_img_size
704
+ self.num_layers = len(depths)
705
+ self.embed_dim = embed_dim
706
+ self.ape = ape
707
+ self.patch_norm = patch_norm
708
+ self.out_indices = out_indices
709
+ self.frozen_stages = frozen_stages
710
+
711
+ # split image into non-overlapping patches
712
+ self.patch_embed = PatchEmbed(
713
+ patch_size=patch_size,
714
+ in_chans=in_chans,
715
+ embed_dim=embed_dim,
716
+ norm_layer=norm_layer if self.patch_norm else None)
717
+
718
+ # absolute position embedding
719
+ if self.ape:
720
+ pretrain_img_size = to_2tuple(pretrain_img_size)
721
+ patch_size = to_2tuple(patch_size)
722
+ patches_resolution = [
723
+ pretrain_img_size[0] // patch_size[0],
724
+ pretrain_img_size[1] // patch_size[1]
725
+ ]
726
+
727
+ self.absolute_pos_embed = nn.Parameter(
728
+ torch.zeros(1, embed_dim, patches_resolution[0],
729
+ patches_resolution[1]))
730
+ trunc_normal_(self.absolute_pos_embed, std=.02)
731
+
732
+ self.pos_drop = nn.Dropout(p=drop_rate)
733
+
734
+ # stochastic depth
735
+ dpr = [
736
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
737
+ ] # stochastic depth decay rule
738
+
739
+ # build layers
740
+ self.layers = nn.ModuleList()
741
+ for i_layer in range(self.num_layers):
742
+ layer = BasicLayer(
743
+ dim=int(embed_dim * 2**i_layer),
744
+ depth=depths[i_layer],
745
+ num_heads=num_heads[i_layer],
746
+ window_size=window_size,
747
+ mlp_ratio=mlp_ratio,
748
+ qkv_bias=qkv_bias,
749
+ qk_scale=qk_scale,
750
+ drop=drop_rate,
751
+ attn_drop=attn_drop_rate,
752
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
753
+ norm_layer=norm_layer,
754
+ downsample=PatchMerging if
755
+ (i_layer < self.num_layers - 1) else None,
756
+ use_checkpoint=use_checkpoint)
757
+ self.layers.append(layer)
758
+
759
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
760
+ self.num_features = num_features
761
+
762
+ # add a norm layer for each output
763
+ for i_layer in out_indices:
764
+ layer = norm_layer(num_features[i_layer])
765
+ layer_name = f'norm{i_layer}'
766
+ self.add_module(layer_name, layer)
767
+
768
+ self._freeze_stages()
769
+
770
+ def _freeze_stages(self):
771
+ if self.frozen_stages >= 0:
772
+ self.patch_embed.eval()
773
+ for param in self.patch_embed.parameters():
774
+ param.requires_grad = False
775
+
776
+ if self.frozen_stages >= 1 and self.ape:
777
+ self.absolute_pos_embed.requires_grad = False
778
+
779
+ if self.frozen_stages >= 2:
780
+ self.pos_drop.eval()
781
+ for i in range(0, self.frozen_stages - 1):
782
+ m = self.layers[i]
783
+ m.eval()
784
+ for param in m.parameters():
785
+ param.requires_grad = False
786
+
787
+ def init_weights(self, pretrained=None):
788
+ """Initialize the weights in backbone.
789
+ Args:
790
+ pretrained (str, optional): Path to pre-trained weights.
791
+ Defaults to None.
792
+ """
793
+
794
+ def forward(self, x):
795
+ """Forward function."""
796
+ x = self.patch_embed(x)
797
+
798
+ Wh, Ww = x.size(2), x.size(3)
799
+ if self.ape:
800
+ # interpolate the position embedding to the corresponding size
801
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
802
+ size=(Wh, Ww),
803
+ mode='bicubic')
804
+ x = (x + absolute_pos_embed).flatten(2).transpose(1,
805
+ 2) # B Wh*Ww C
806
+ else:
807
+ x = x.flatten(2).transpose(1, 2)
808
+ x = self.pos_drop(x)
809
+
810
+ outs = []
811
+ for i in range(self.num_layers):
812
+ layer = self.layers[i]
813
+
814
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
815
+
816
+ if i in self.out_indices:
817
+ norm_layer = getattr(self, f'norm{i}')
818
+ x_out = norm_layer(x_out)
819
+
820
+ out = x_out.view(-1, H, W,
821
+ self.num_features[i]).permute(0, 3, 1,
822
+ 2).contiguous()
823
+ outs.append(out)
824
+
825
+ return tuple(outs)
826
+
827
+ def train(self, mode=True):
828
+ """Convert the model into training mode while keep layers freezed."""
829
+ super(SwinTransformer, self).train(mode)
830
+ self._freeze_stages()
831
+
832
+
833
+ class Mlp(nn.Module):
834
+ """ Multilayer perceptron."""
835
+
836
+ def __init__(self,
837
+ in_features,
838
+ hidden_features=None,
839
+ out_features=None,
840
+ act_layer=nn.GELU,
841
+ drop=0.):
842
+ super().__init__()
843
+ out_features = out_features or in_features
844
+ hidden_features = hidden_features or in_features
845
+ self.fc1 = nn.Linear(in_features, hidden_features)
846
+ self.act = act_layer()
847
+ self.fc2 = nn.Linear(hidden_features, out_features)
848
+ self.drop = nn.Dropout(drop)
849
+
850
+ def forward(self, x):
851
+ x = self.fc1(x)
852
+ x = self.act(x)
853
+ x = self.drop(x)
854
+ x = self.fc2(x)
855
+ x = self.drop(x)
856
+ return x
857
+
858
+
859
+ class ResBlock(nn.Module):
860
+
861
+ def __init__(self, inc, midc):
862
+ super(ResBlock, self).__init__()
863
+ self.conv1 = nn.Conv2d(inc,
864
+ midc,
865
+ kernel_size=1,
866
+ stride=1,
867
+ padding=0,
868
+ bias=True)
869
+ self.gn1 = nn.GroupNorm(16, midc)
870
+ self.conv2 = nn.Conv2d(midc,
871
+ midc,
872
+ kernel_size=3,
873
+ stride=1,
874
+ padding=1,
875
+ bias=True)
876
+ self.gn2 = nn.GroupNorm(16, midc)
877
+ self.conv3 = nn.Conv2d(midc,
878
+ inc,
879
+ kernel_size=1,
880
+ stride=1,
881
+ padding=0,
882
+ bias=True)
883
+ self.relu = nn.LeakyReLU(0.1)
884
+
885
+ def forward(self, x):
886
+ x_ = x
887
+ x = self.conv1(x)
888
+ x = self.gn1(x)
889
+ x = self.relu(x)
890
+ x = self.conv2(x)
891
+ x = self.gn2(x)
892
+ x = self.relu(x)
893
+ x = self.conv3(x)
894
+ x = x + x_
895
+ x = self.relu(x)
896
+ return x
897
+
898
+
899
+ class AEALblock(nn.Module):
900
+
901
+ def __init__(self,
902
+ d_model,
903
+ nhead,
904
+ dim_feedforward=512,
905
+ dropout=0.0,
906
+ layer_norm_eps=1e-5,
907
+ batch_first=True,
908
+ norm_first=False,
909
+ width=5):
910
+ super(AEALblock, self).__init__()
911
+ self.self_attn2 = nn.MultiheadAttention(d_model // 2,
912
+ nhead // 2,
913
+ dropout=dropout,
914
+ batch_first=batch_first)
915
+ self.self_attn1 = nn.MultiheadAttention(d_model // 2,
916
+ nhead // 2,
917
+ dropout=dropout,
918
+ batch_first=batch_first)
919
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
920
+ self.dropout = nn.Dropout(dropout)
921
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
922
+ self.norm_first = norm_first
923
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
924
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
925
+ self.dropout1 = nn.Dropout(dropout)
926
+ self.dropout2 = nn.Dropout(dropout)
927
+ self.activation = nn.ReLU()
928
+ self.width = width
929
+ self.trans = nn.Sequential(
930
+ nn.Conv2d(d_model + 512, d_model // 2, 1, 1, 0),
931
+ ResBlock(d_model // 2, d_model // 4),
932
+ nn.Conv2d(d_model // 2, d_model, 1, 1, 0))
933
+ self.gamma = nn.Parameter(torch.zeros(1))
934
+
935
+ def forward(
936
+ self,
937
+ src,
938
+ feats,
939
+ ):
940
+ src = self.gamma * self.trans(torch.cat([src, feats], 1)) + src
941
+ b, c, h, w = src.shape
942
+ x1 = src[:, 0:c // 2]
943
+ x1_ = rearrange(x1, 'b c (h1 h2) w -> b c h1 h2 w', h2=self.width)
944
+ x1_ = rearrange(x1_, 'b c h1 h2 w -> (b h1) (h2 w) c')
945
+ x2 = src[:, c // 2:]
946
+ x2_ = rearrange(x2, 'b c h (w1 w2) -> b c h w1 w2', w2=self.width)
947
+ x2_ = rearrange(x2_, 'b c h w1 w2 -> (b w1) (h w2) c')
948
+ x = rearrange(src, 'b c h w-> b (h w) c')
949
+ x = self.norm1(x + self._sa_block(x1_, x2_, h, w))
950
+ x = self.norm2(x + self._ff_block(x))
951
+ x = rearrange(x, 'b (h w) c->b c h w', h=h, w=w)
952
+ return x
953
+
954
+ def _sa_block(self, x1, x2, h, w):
955
+ x1 = self.self_attn1(x1,
956
+ x1,
957
+ x1,
958
+ attn_mask=None,
959
+ key_padding_mask=None,
960
+ need_weights=False)[0]
961
+
962
+ x2 = self.self_attn2(x2,
963
+ x2,
964
+ x2,
965
+ attn_mask=None,
966
+ key_padding_mask=None,
967
+ need_weights=False)[0]
968
+
969
+ x1 = rearrange(x1,
970
+ '(b h1) (h2 w) c-> b (h1 h2 w) c',
971
+ h2=self.width,
972
+ h1=h // self.width)
973
+ x2 = rearrange(x2,
974
+ ' (b w1) (h w2) c-> b (h w1 w2) c',
975
+ w2=self.width,
976
+ w1=w // self.width)
977
+ x = torch.cat([x1, x2], dim=2)
978
+ return self.dropout1(x)
979
+
980
+ def _ff_block(self, x):
981
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
982
+ return self.dropout2(x)
983
+
984
+
985
+ class AEMatter(nn.Module):
986
+
987
+ def __init__(self):
988
+ super(AEMatter, self).__init__()
989
+ trans = SwinTransformer(pretrain_img_size=224,
990
+ embed_dim=96,
991
+ depths=[2, 2, 6, 2],
992
+ num_heads=[3, 6, 12, 24],
993
+ window_size=7,
994
+ ape=False,
995
+ drop_path_rate=0.2,
996
+ patch_norm=True,
997
+ use_checkpoint=False)
998
+
999
+ # trans.load_state_dict(torch.load(
1000
+ # '/home/asd/Desktop/swin_tiny_patch4_window7_224.pth',
1001
+ # map_location="cpu")["model"],
1002
+ # strict=False)
1003
+
1004
+ trans.patch_embed.proj = nn.Conv2d(64, 96, 3, 2, 1)
1005
+
1006
+ self.start_conv0 = nn.Sequential(nn.Conv2d(6, 48, 3, 1, 1),
1007
+ nn.PReLU(48))
1008
+
1009
+ self.start_conv = nn.Sequential(nn.Conv2d(48, 64, 3, 2,
1010
+ 1), nn.PReLU(64),
1011
+ nn.Conv2d(64, 64, 3, 1, 1),
1012
+ nn.PReLU(64))
1013
+
1014
+ self.trans = trans
1015
+ self.conv1 = nn.Sequential(
1016
+ nn.Conv2d(in_channels=640 + 768,
1017
+ out_channels=256,
1018
+ kernel_size=1,
1019
+ stride=1,
1020
+ padding=0,
1021
+ bias=True))
1022
+ self.conv2 = nn.Sequential(
1023
+ nn.Conv2d(in_channels=256 + 384,
1024
+ out_channels=256,
1025
+ kernel_size=1,
1026
+ stride=1,
1027
+ padding=0,
1028
+ bias=True), )
1029
+ self.conv3 = nn.Sequential(
1030
+ nn.Conv2d(in_channels=256 + 192,
1031
+ out_channels=192,
1032
+ kernel_size=1,
1033
+ stride=1,
1034
+ padding=0,
1035
+ bias=True), )
1036
+ self.conv4 = nn.Sequential(
1037
+ nn.Conv2d(in_channels=192 + 96,
1038
+ out_channels=128,
1039
+ kernel_size=1,
1040
+ stride=1,
1041
+ padding=0,
1042
+ bias=True), )
1043
+ self.ctran0 = BasicLayer(256, 3, 8, 7, drop_path=0.09)
1044
+ self.ctran1 = BasicLayer(256, 3, 8, 7, drop_path=0.07)
1045
+ self.ctran2 = BasicLayer(192, 3, 6, 7, drop_path=0.05)
1046
+ self.ctran3 = BasicLayer(128, 3, 4, 7, drop_path=0.03)
1047
+ self.conv5 = nn.Sequential(
1048
+ nn.Conv2d(in_channels=192,
1049
+ out_channels=64,
1050
+ kernel_size=3,
1051
+ stride=1,
1052
+ padding=1,
1053
+ bias=True), nn.PReLU(64),
1054
+ nn.Conv2d(in_channels=64,
1055
+ out_channels=64,
1056
+ kernel_size=3,
1057
+ stride=1,
1058
+ padding=1,
1059
+ bias=True), nn.PReLU(64),
1060
+ nn.Conv2d(in_channels=64,
1061
+ out_channels=48,
1062
+ kernel_size=3,
1063
+ stride=1,
1064
+ padding=1,
1065
+ bias=True), nn.PReLU(48))
1066
+ self.convo = nn.Sequential(
1067
+ nn.Conv2d(in_channels=48 + 48 + 6,
1068
+ out_channels=32,
1069
+ kernel_size=3,
1070
+ stride=1,
1071
+ padding=1,
1072
+ bias=True), nn.PReLU(32),
1073
+ nn.Conv2d(in_channels=32,
1074
+ out_channels=32,
1075
+ kernel_size=3,
1076
+ stride=1,
1077
+ padding=1,
1078
+ bias=True), nn.PReLU(32),
1079
+ nn.Conv2d(in_channels=32,
1080
+ out_channels=1,
1081
+ kernel_size=3,
1082
+ stride=1,
1083
+ padding=1,
1084
+ bias=True))
1085
+ self.up = nn.Upsample(scale_factor=2,
1086
+ mode='bilinear',
1087
+ align_corners=False)
1088
+ self.upn = nn.Upsample(scale_factor=2, mode='nearest')
1089
+ self.apptrans = nn.Sequential(
1090
+ nn.Conv2d(256 + 384, 256, 1, 1, bias=True), ResBlock(256, 128),
1091
+ ResBlock(256, 128), nn.Conv2d(256, 512, 2, 2, bias=True),
1092
+ ResBlock(512, 128))
1093
+ self.emb = nn.Sequential(nn.Conv2d(768, 640, 1, 1, 0),
1094
+ ResBlock(640, 160))
1095
+ self.embdp = nn.Sequential(nn.Conv2d(640, 640, 1, 1, 0))
1096
+ self.h2l = nn.Conv2d(768, 256, 1, 1, 0)
1097
+ self.width = 5
1098
+ self.trans1 = AEALblock(d_model=640,
1099
+ nhead=20,
1100
+ dim_feedforward=2048,
1101
+ dropout=0.2,
1102
+ width=self.width)
1103
+ self.trans2 = AEALblock(d_model=640,
1104
+ nhead=20,
1105
+ dim_feedforward=2048,
1106
+ dropout=0.2,
1107
+ width=self.width)
1108
+ self.trans3 = AEALblock(d_model=640,
1109
+ nhead=20,
1110
+ dim_feedforward=2048,
1111
+ dropout=0.2,
1112
+ width=self.width)
1113
+
1114
+ def aeal(self, x, sem):
1115
+ xe = self.emb(x)
1116
+ x_ = xe
1117
+ x_ = self.embdp(x_)
1118
+ b, c, h1, w1 = x_.shape
1119
+ bnew_ph = int(np.ceil(h1 / self.width) * self.width) - h1
1120
+ bnew_pw = int(np.ceil(w1 / self.width) * self.width) - w1
1121
+ newph1 = bnew_ph // 2
1122
+ newph2 = bnew_ph - newph1
1123
+ newpw1 = bnew_pw // 2
1124
+ newpw2 = bnew_pw - newpw1
1125
+ x_ = F.pad(x_, (newpw1, newpw2, newph1, newph2))
1126
+ sem = F.pad(sem, (newpw1, newpw2, newph1, newph2))
1127
+ x_ = self.trans1(x_, sem)
1128
+ x_ = self.trans2(x_, sem)
1129
+ x_ = self.trans3(x_, sem)
1130
+ x_ = x_[:, :, newph1:h1 + newph1, newpw1:w1 + newpw1]
1131
+ return x_
1132
+
1133
+ def forward(self, x, y):
1134
+ inputs = torch.cat((x, y), 1)
1135
+ x = self.start_conv0(inputs)
1136
+ x_ = self.start_conv(x)
1137
+ x1, x2, x3, x4 = self.trans(x_)
1138
+ x4h = self.h2l(x4)
1139
+ x3s = self.apptrans(torch.cat([x3, self.upn(x4h)], 1))
1140
+ x4_ = self.aeal(x4, x3s)
1141
+ x4 = torch.cat((x4, x4_), 1)
1142
+ X4 = self.conv1(x4)
1143
+ wh, ww = X4.shape[2], X4.shape[3]
1144
+ X4 = rearrange(X4, 'b c h w -> b (h w) c')
1145
+ X4, _, _, _, _, _ = self.ctran0(X4, wh, ww)
1146
+ X4 = rearrange(X4, 'b (h w) c -> b c h w', h=wh, w=ww)
1147
+ X3 = self.up(X4)
1148
+ X3 = torch.cat((x3, X3), 1)
1149
+ X3 = self.conv2(X3)
1150
+ wh, ww = X3.shape[2], X3.shape[3]
1151
+ X3 = rearrange(X3, 'b c h w -> b (h w) c')
1152
+ X3, _, _, _, _, _ = self.ctran1(X3, wh, ww)
1153
+ X3 = rearrange(X3, 'b (h w) c -> b c h w', h=wh, w=ww)
1154
+ X2 = self.up(X3)
1155
+ X2 = torch.cat((x2, X2), 1)
1156
+ X2 = self.conv3(X2)
1157
+ wh, ww = X2.shape[2], X2.shape[3]
1158
+ X2 = rearrange(X2, 'b c h w -> b (h w) c')
1159
+ X2, _, _, _, _, _ = self.ctran2(X2, wh, ww)
1160
+ X2 = rearrange(X2, 'b (h w) c -> b c h w', h=wh, w=ww)
1161
+ X1 = self.up(X2)
1162
+ X1 = torch.cat((x1, X1), 1)
1163
+ X1 = self.conv4(X1)
1164
+ wh, ww = X1.shape[2], X1.shape[3]
1165
+ X1 = rearrange(X1, 'b c h w -> b (h w) c')
1166
+ X1, _, _, _, _, _ = self.ctran3(X1, wh, ww)
1167
+ X1 = rearrange(X1, 'b (h w) c -> b c h w', h=wh, w=ww)
1168
+ X0 = self.up(X1)
1169
+ X0 = torch.cat((x_, X0), 1)
1170
+ X0 = self.conv5(X0)
1171
+ X = self.up(X0)
1172
+ X = torch.cat((inputs, x, X), 1)
1173
+ alpha = self.convo(X)
1174
+ alpha = torch.clamp(alpha, min=0, max=1)
1175
+ return alpha
1176
+
1177
+
1178
+ class load_AEMatter_Model:
1179
+
1180
+ def __init__(self):
1181
+ pass
1182
+
1183
+ @classmethod
1184
+ def INPUT_TYPES(s):
1185
+ return {
1186
+ "required": {},
1187
+ }
1188
+
1189
+ RETURN_TYPES = ("AEMatter_Model", )
1190
+ FUNCTION = "test"
1191
+ CATEGORY = "AEMatter"
1192
+
1193
+ def test(self):
1194
+ return (get_AEMatter_model(get_model_path()), )
1195
+
1196
+
1197
+ class run_AEMatter_inference:
1198
+
1199
+ def __init__(self):
1200
+ pass
1201
+
1202
+ @classmethod
1203
+ def INPUT_TYPES(s):
1204
+ return {
1205
+ "required": {
1206
+ "image": ("IMAGE", ),
1207
+ "trimap": ("MASK", ),
1208
+ "AEMatter_Model": ("AEMatter_Model", ),
1209
+ },
1210
+ }
1211
+
1212
+ RETURN_TYPES = ("MASK", )
1213
+ FUNCTION = "test"
1214
+ CATEGORY = "AEMatter"
1215
+
1216
+ def test(
1217
+ self,
1218
+ image,
1219
+ trimap,
1220
+ AEMatter_Model,
1221
+ ):
1222
+
1223
+ ret = []
1224
+ batch_size = image.shape[0]
1225
+
1226
+ for i in range(batch_size):
1227
+ tmp_i = from_torch_image(image[i])
1228
+ tmp_m = from_torch_image(trimap[i])
1229
+ tmp = do_infer(tmp_i, tmp_m, AEMatter_Model)
1230
+ ret.append(tmp)
1231
+
1232
+ ret = to_torch_image(np.array(ret))
1233
+ ret = ret.squeeze(-1)
1234
+ print(ret.shape)
1235
+
1236
+ return ret
1237
+
1238
+
1239
+ #!/usr/bin/python3
1240
+ NODE_CLASS_MAPPINGS = {
1241
+ 'load_AEMatter_Model': load_AEMatter_Model,
1242
+ 'run_AEMatter_inference': run_AEMatter_inference,
1243
+ }
1244
+
1245
+ NODE_DISPLAY_NAME_MAPPINGS = {
1246
+ 'load_AEMatter_Model': 'load_AEMatter_Model',
1247
+ 'run_AEMatter_inference': 'run_AEMatter_inference',
1248
+ }