Artyom commited on
Commit
f8d6c27
1 Parent(s): 94f9590
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ SCBC/Input/IMG_20240215_213330.png filter=lfs diff=lfs merge=lfs -text
37
+ SCBC/Input/IMG_20240215_213619.png filter=lfs diff=lfs merge=lfs -text
38
+ SCBC/Input/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text
39
+ SCBC/Output/IMG_20240215_213330.png filter=lfs diff=lfs merge=lfs -text
40
+ SCBC/Output/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text
SCBC/CPNet_model.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+ import torch.utils.model_zoo as model_zoo
7
+ from torchvision import models
8
+ from torchvision import transforms
9
+ import cv2
10
+ import matplotlib.pyplot as plt
11
+ from PIL import Image
12
+ import numpy as np
13
+ import math
14
+ import time
15
+ import tqdm
16
+ import os
17
+ import argparse
18
+ import copy
19
+ import sys
20
+ import networks as N
21
+ from model_module import *
22
+ sys.path.insert(0, '.')
23
+ # from .common import *
24
+ sys.path.insert(0, '../utils/')
25
+
26
+
27
+ class LiteISPNet(nn.Module):
28
+ def __init__(self,):
29
+ super(LiteISPNet, self).__init__()
30
+
31
+ ch_1 = 64
32
+ ch_2 = 128
33
+ ch_3 = 128
34
+ n_blocks = 4
35
+
36
+
37
+ self.head = N.seq(
38
+ N.conv(3, ch_1, mode='C')
39
+ ) # shape: (N, ch_1, H/2, W/2)
40
+
41
+ self.down1 = N.seq(
42
+ N.conv(ch_1, ch_1, mode='C'),
43
+ N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
44
+ N.conv(ch_1, ch_1, mode='C'),
45
+ N.DWTForward(ch_1)
46
+ ) # shape: (N, ch_1*4, H/4, W/4)
47
+
48
+ self.down2 = N.seq(
49
+ N.conv(ch_1*4, ch_1, mode='C'),
50
+ N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
51
+ N.DWTForward(ch_1)
52
+ ) # shape: (N, ch_1*4, H/8, W/8)
53
+
54
+ self.down3 = N.seq(
55
+ N.conv(ch_1*4, ch_2, mode='C'),
56
+ N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
57
+ N.DWTForward(ch_2)
58
+ ) # shape: (N, ch_2*4, H/16, W/16)
59
+
60
+ self.middle = N.seq(
61
+ N.conv(ch_2*4, ch_3, mode='C'),
62
+ N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
63
+ N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
64
+ N.conv(ch_3, ch_2*4, mode='C')
65
+ ) # shape: (N, ch_2*4, H/16, W/16)
66
+
67
+ self.up3 = N.seq(
68
+ N.DWTInverse(ch_2*4),
69
+ N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
70
+ N.conv(ch_2, ch_1*4, mode='C')
71
+ ) # shape: (N, ch_1*4, H/8, W/8)
72
+
73
+ self.up2 = N.seq(
74
+ N.DWTInverse(ch_1*4),
75
+ N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
76
+ N.conv(ch_1, ch_1*4, mode='C')
77
+ ) # shape: (N, ch_1*4, H/4, W/4)
78
+
79
+ self.up1 = N.seq(
80
+ N.DWTInverse(ch_1*4),
81
+ N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
82
+ N.conv(ch_1, ch_1, mode='C')
83
+ ) # shape: (N, ch_1, H/2, W/2)
84
+
85
+ self.tail = N.seq(
86
+ #N.conv(ch_1, ch_1*4, mode='C'),
87
+ #nn.PixelShuffle(upscale_factor=2),
88
+ N.conv(ch_1, 3, mode='C')
89
+ ) # shape: (N, 3, H, W)
90
+
91
+ def forward(self, raw):
92
+ # input = raw
93
+ input = torch.pow(raw, 1/2.2)
94
+
95
+ h = self.head(input)
96
+ h_coord = h
97
+
98
+ d1 = self.down1(h_coord)
99
+ d2 = self.down2(d1)
100
+ d3 = self.down3(d2)
101
+ m = self.middle(d3) + d3
102
+ u3 = self.up3(m) + d2
103
+ u2 = self.up2(u3) + d1
104
+ u1 = self.up1(u2) + h
105
+ out = self.tail(u1)
106
+
107
+ return out
108
+
109
+
110
+ class LiteAWBISPNet(nn.Module):
111
+ def __init__(self,):
112
+ super(LiteAWBISPNet, self).__init__()
113
+
114
+ ch_1 = 64
115
+ ch_2 = 128
116
+ ch_3 = 128
117
+ n_blocks = 4
118
+
119
+
120
+ self.head = N.seq(
121
+ N.conv(3, ch_1, mode='C')
122
+ ) # shape: (N, ch_1, H/2, W/2)
123
+
124
+ self.down1 = N.seq(
125
+ N.conv(ch_1, ch_1, mode='C'),
126
+ N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
127
+ N.conv(ch_1, ch_1, mode='C'),
128
+ N.DWTForward(ch_1)
129
+ ) # shape: (N, ch_1*4, H/4, W/4)
130
+
131
+ self.down2 = N.seq(
132
+ N.conv(ch_1*4, ch_1, mode='C'),
133
+ N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
134
+ N.DWTForward(ch_1)
135
+ ) # shape: (N, ch_1*4, H/8, W/8)
136
+
137
+ self.down3 = N.seq(
138
+ N.conv(ch_1*4, ch_2, mode='C'),
139
+ N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
140
+ N.DWTForward(ch_2)
141
+ ) # shape: (N, ch_2*4, H/16, W/16)
142
+
143
+ self.middle = N.seq(
144
+ N.conv(ch_2*4, ch_3, mode='C'),
145
+ N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
146
+ N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
147
+ N.conv(ch_3, ch_2*4, mode='C')
148
+ ) # shape: (N, ch_2*4, H/16, W/16)
149
+
150
+ self.up3 = N.seq(
151
+ N.DWTInverse(ch_2*4),
152
+ N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
153
+ N.conv(ch_2, ch_1*4, mode='C')
154
+ ) # shape: (N, ch_1*4, H/8, W/8)
155
+
156
+ self.up2 = N.seq(
157
+ N.DWTInverse(ch_1*4),
158
+ N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
159
+ N.conv(ch_1, ch_1*4, mode='C')
160
+ ) # shape: (N, ch_1*4, H/4, W/4)
161
+
162
+ self.up1 = N.seq(
163
+ N.DWTInverse(ch_1*4),
164
+ N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
165
+ N.conv(ch_1, ch_1, mode='C')
166
+ ) # shape: (N, ch_1, H/2, W/2)
167
+
168
+ self.tail = N.seq(
169
+ #N.conv(ch_1, ch_1*4, mode='C'),
170
+ #nn.PixelShuffle(upscale_factor=2),
171
+ N.conv(ch_1, 3, mode='C')
172
+ ) # shape: (N, 3, H, W)
173
+
174
+ def forward(self, raw):
175
+ # input = raw
176
+
177
+ input = raw
178
+ h = self.head(input)
179
+ h_coord = h
180
+
181
+ d1 = self.down1(h_coord)
182
+ d2 = self.down2(d1)
183
+ d3 = self.down3(d2)
184
+ m = self.middle(d3) + d3
185
+ u3 = self.up3(m) + d2
186
+ u2 = self.up2(u3) + d1
187
+ u1 = self.up1(u2) + h
188
+ out = self.tail(u1)
189
+
190
+ return out
191
+
192
+
193
+ # Alignment Encoder
194
+ class A_Encoder(nn.Module):
195
+ def __init__(self):
196
+ super(A_Encoder, self).__init__()
197
+ self.conv12 = Conv2d(3, 64, kernel_size=5, stride=2, padding=2, activation=nn.ReLU()) # 2
198
+ self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
199
+ self.conv23 = Conv2d(64, 128, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 4
200
+ self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
201
+ self.conv34 = Conv2d(128, 256, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 8
202
+ self.conv4a = Conv2d(256, 256, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 8
203
+ self.conv4b = Conv2d(256, 256, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 8
204
+ init_He(self)
205
+ self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
206
+ self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
207
+
208
+ def forward(self, in_f):
209
+ f = (in_f - self.mean) / self.std
210
+ x = f
211
+ x = F.upsample(x, size=(224, 224), mode='bilinear', align_corners=False)
212
+ x = self.conv12(x)
213
+ x = self.conv2(x)
214
+ x = self.conv23(x)
215
+ x = self.conv3(x)
216
+ x = self.conv34(x)
217
+ x = self.conv4a(x)
218
+ x = self.conv4b(x)
219
+ return x
220
+
221
+ # Alignment Regressor
222
+ class A_Regressor(nn.Module):
223
+ def __init__(self):
224
+ super(A_Regressor, self).__init__()
225
+ self.conv45 = Conv2d(512, 512, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 16
226
+ self.conv5a = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 16
227
+ self.conv5b = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 16
228
+ self.conv56 = Conv2d(512, 512, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 32
229
+ self.conv6a = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 32
230
+ self.conv6b = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 32
231
+ init_He(self)
232
+
233
+ self.fc = nn.Linear(512, 6)
234
+ self.fc.weight.data.zero_()
235
+ self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float32))
236
+
237
+ def forward(self, feat1, feat2):
238
+ x = torch.cat([feat1, feat2], dim=1)
239
+ x = self.conv45(x)
240
+ x = self.conv5a(x)
241
+ x = self.conv5b(x)
242
+ x = self.conv56(x)
243
+ x = self.conv5a(x)
244
+ x = self.conv5b(x)
245
+
246
+ x = F.avg_pool2d(x, x.shape[2])
247
+ x = x.view(-1, x.shape[1])
248
+
249
+ theta = self.fc(x)
250
+ theta = theta.view(-1, 2, 3)
251
+
252
+ return theta
253
+
254
+ # Encoder (Copy network)
255
+ class Encoder(nn.Module):
256
+ def __init__(self):
257
+ super(Encoder, self).__init__()
258
+ self.conv12 = Conv2d(4, 64, kernel_size=5, stride=2, padding=2, activation=nn.ReLU()) # 2
259
+ self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
260
+ self.conv23 = Conv2d(64, 128, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 4
261
+ self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
262
+ self.value3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None) # 4
263
+ init_He(self)
264
+ self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
265
+ self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
266
+
267
+ def forward(self, in_f, in_v):
268
+ f = (in_f - self.mean) / self.std
269
+ x = torch.cat([f, in_v], dim=1)
270
+ x = self.conv12(x)
271
+ x = self.conv2(x)
272
+ x = self.conv23(x)
273
+ x = self.conv3(x)
274
+ v = self.value3(x)
275
+ return v
276
+
277
+ # Decoder (Paste network)
278
+ class Decoder(nn.Module):
279
+ def __init__(self):
280
+ super(Decoder, self).__init__()
281
+ self.conv4 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU())
282
+ self.conv5_1 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU())
283
+ self.conv5_2 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU())
284
+
285
+ # dilated convolution blocks
286
+ self.convA4_1 = Conv2d(257, 257, kernel_size=3, stride=1, padding=2, D=2, activation=nn.ReLU())
287
+ self.convA4_2 = Conv2d(257, 257, kernel_size=3, stride=1, padding=4, D=4, activation=nn.ReLU())
288
+ self.convA4_3 = Conv2d(257, 257, kernel_size=3, stride=1, padding=8, D=8, activation=nn.ReLU())
289
+ self.convA4_4 = Conv2d(257, 257, kernel_size=3, stride=1, padding=16, D=16,activation=nn.ReLU())
290
+
291
+ self.conv3c = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
292
+ self.conv3b = Conv2d(257, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
293
+ self.conv3a = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
294
+ self.conv32 = Conv2d(128, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
295
+ self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
296
+ self.conv21 = Conv2d(64, 3, kernel_size=5, stride=1, padding=2, activation=None) # 1
297
+
298
+ self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
299
+ self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
300
+
301
+ def forward(self, x):
302
+ x = self.conv4(x)
303
+ x = self.conv5_1(x)
304
+ x = self.conv5_2(x)
305
+
306
+ x = self.convA4_1(x)
307
+ x = self.convA4_2(x)
308
+ x = self.convA4_3(x)
309
+ x = self.convA4_4(x)
310
+
311
+ x = self.conv3c(x)
312
+ x = self.conv3b(x)
313
+ x = self.conv3a(x)
314
+ x = F.upsample(x, scale_factor=2, mode='nearest') # 2
315
+ x = self.conv32(x)
316
+ x = self.conv2(x)
317
+ x = F.upsample(x, scale_factor=2, mode='nearest') # 2
318
+ x = self.conv21(x)
319
+
320
+ p = (x *self.std) + self.mean
321
+ return p
322
+
323
+
324
+ # Context Matching Module
325
+ class CM_Module(nn.Module):
326
+ def __init__(self):
327
+ super(CM_Module, self).__init__()
328
+
329
+ def masked_softmax(self, vec, mask, dim):
330
+ masked_vec = vec * mask.float()
331
+ max_vec = torch.max(masked_vec, dim=dim, keepdim=True)[0]
332
+ exps = torch.exp(masked_vec-max_vec)
333
+ masked_exps = exps * mask.float()
334
+ masked_sums = masked_exps.sum(dim, keepdim=True)
335
+ zeros = (masked_sums <1e-4)
336
+ masked_sums += zeros.float()
337
+ return masked_exps/masked_sums
338
+
339
+ def forward(self, values, tvmap, rvmaps):
340
+
341
+ B, C, T, H, W = values.size()
342
+ # t_feat: target feature
343
+ t_feat = values[:, :, 0]
344
+ # r_feats: refetence features
345
+ r_feats = values[:, :, 1:]
346
+
347
+ B, Cv, T, H, W = r_feats.size()
348
+ # vmap: visibility map
349
+ # tvmap: target visibility map
350
+ # rvmap: reference visibility map
351
+ # gs: cosine similarity
352
+ # c_m: c_match
353
+ gs_,vmap_ = [], []
354
+ tvmap_t = (F.upsample(tvmap, size=(H, W), mode='bilinear', align_corners=False)>0.5).float()
355
+ for r in range(T):
356
+ rvmap_t = (F.upsample(rvmaps[:,:,r], size=(H, W), mode='bilinear', align_corners=False)>0.5).float()
357
+ # vmap: visibility map
358
+ vmap = tvmap_t * rvmap_t
359
+ gs = (vmap * t_feat * r_feats[:,:,r]).sum(-1).sum(-1).sum(-1)
360
+ #valid sum
361
+ v_sum = vmap[:,0].sum(-1).sum(-1)
362
+ zeros = (v_sum <1e-4)
363
+ gs[zeros] = 0
364
+ v_sum += zeros.float()
365
+ gs = gs / v_sum / C
366
+ gs = torch.ones(t_feat.shape).float().cuda() * gs.view(B,1,1,1)
367
+ gs_.append(gs)
368
+ vmap_.append(rvmap_t)
369
+
370
+ gss = torch.stack(gs_, dim=2)
371
+ vmaps = torch.stack(vmap_, dim=2)
372
+
373
+ #weighted pixelwise masked softmax
374
+ c_match = self.masked_softmax(gss, vmaps, dim=2)
375
+ c_out = torch.sum(r_feats * c_match, dim=2)
376
+
377
+ # c_mask
378
+ c_mask = (c_match * vmaps)
379
+ c_mask = torch.sum(c_mask,2)
380
+ c_mask = 1. - (torch.mean(c_mask, 1, keepdim=True))
381
+
382
+ return torch.cat([t_feat, c_out, c_mask], dim=1), c_mask
383
+
384
+
385
+ class GCMModel(nn.Module):
386
+ def __init__(self):
387
+ super(GCMModel, self).__init__()
388
+ self.ch_1 = 16
389
+ self.ch_2 = 32
390
+ guide_input_channels = 3
391
+ align_input_channels = 3
392
+ self.gcm_coord = None
393
+
394
+ if not self.gcm_coord:
395
+ guide_input_channels = 3
396
+ align_input_channels = 3
397
+
398
+ self.guide_net = N.seq(
399
+ N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'),
400
+ N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'),
401
+ nn.AdaptiveAvgPool2d(1),
402
+ N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C')
403
+ )
404
+
405
+ self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR')
406
+
407
+ self.align_base = N.seq(
408
+ N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCRCRCR')
409
+ )
410
+ self.align_tail = N.seq(
411
+ N.conv(self.ch_2, 3, 1, padding=0, mode='C')
412
+ )
413
+
414
+ def forward(self, demosaic_raw):
415
+ demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2)
416
+ guide_input = demosaic_raw
417
+ base_input =demosaic_raw
418
+ guide = self.guide_net(guide_input)
419
+ out = self.align_head(base_input)
420
+ out = guide * out + out
421
+ out = self.align_base(out)
422
+ out = self.align_tail(out)+demosaic_raw
423
+
424
+ return out
425
+
426
+ class Fusion(nn.Module):
427
+ def __init__(self):
428
+ super(Fusion, self).__init__()
429
+ self.ch_1 = 16
430
+ self.ch_2 = 32
431
+ guide_input_channels = 9
432
+ align_input_channels = 9
433
+ self.gcm_coord = None
434
+
435
+ if not self.gcm_coord:
436
+ guide_input_channels = 9
437
+ align_input_channels = 9
438
+
439
+ self.guide_net = N.seq(
440
+ N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'),
441
+ N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'),
442
+ nn.AdaptiveAvgPool2d(1),
443
+ N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C')
444
+ )
445
+
446
+ self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR')
447
+
448
+ self.align_base = N.seq(
449
+ N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR')
450
+ )
451
+ self.align_tail = N.seq(
452
+ N.conv(self.ch_2, 3, 1, padding=0, mode='C')
453
+ )
454
+
455
+ def forward(self, demosaic_raw):
456
+ #demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2)
457
+ guide_input = demosaic_raw
458
+ base_input =demosaic_raw
459
+ guide = self.guide_net(guide_input)
460
+ out = self.align_head(base_input)
461
+ out = guide * out + out
462
+ out = self.align_base(out)
463
+ out = self.align_tail(out)
464
+
465
+ return out
466
+
467
+
468
+
469
+
470
+ class CPNet(nn.Module):
471
+ def __init__(self, mode='Train'):
472
+ super(CPNet, self).__init__()
473
+ self.A_Encoder = A_Encoder() # Align
474
+ self.A_Regressor = A_Regressor() # output: alignment network
475
+ self.GCMModel = GCMModel()
476
+ self.Encoder = Encoder() # Merge
477
+ self.CM_Module = CM_Module()
478
+
479
+ self.Decoder = Decoder()
480
+
481
+ self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
482
+ self.register_buffer('mean3d', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1,1))
483
+
484
+
485
+ def encoding(self, frames, holes):
486
+
487
+ batch_size, _, num_frames, height, width = frames.size()
488
+ # padding
489
+ (frames, holes), pad = pad_divide_by([frames, holes], 8, (frames.size()[3], frames.size()[4]))
490
+
491
+ feat_ = []
492
+ for t in range(num_frames):
493
+ feat = self.A_Encoder(frames[:,:,t], holes[:,:,t])
494
+ feat_.append(feat)
495
+ feats = torch.stack(feat_, dim=2)
496
+ return feats
497
+
498
+ def inpainting(self, rfeats, rframes, rholes, frame, hole, gt):
499
+
500
+ batch_size, _, height, width = frame.size() # B C H W
501
+ num_r = rfeats.size()[2] # # of reference frames
502
+
503
+ # padding
504
+ (rframes, rholes, frame, hole, gt), pad = pad_divide_by([rframes, rholes, frame, hole, gt], 8, (height, width))
505
+
506
+ # Target embedding
507
+ tfeat = self.A_Encoder(frame, hole)
508
+
509
+ # c_feat: Encoder(Copy Network) features
510
+ c_feat_ = [self.Encoder(frame, hole)]
511
+ L_align = torch.zeros_like(frame)
512
+
513
+ # aligned_r: aligned reference frames
514
+ aligned_r_ = []
515
+
516
+ # rvmap: aligned reference frames valid maps
517
+ rvmap_ = []
518
+
519
+ for r in range(num_r):
520
+ theta_rt = self.A_Regressor(tfeat, rfeats[:,:,r])
521
+ grid_rt = F.affine_grid(theta_rt, frame.size())
522
+
523
+ # aligned_r: aligned reference frame
524
+ # reference frame affine transformation
525
+ aligned_r = F.grid_sample(rframes[:,:,r], grid_rt)
526
+
527
+ # aligned_v: aligned reference visiblity map
528
+ # reference mask affine transformation
529
+ aligned_v = F.grid_sample(1-rholes[:,:,r], grid_rt)
530
+ aligned_v = (aligned_v>0.5).float()
531
+
532
+ aligned_r_.append(aligned_r)
533
+
534
+ #intersection of target and reference valid map
535
+ trvmap = (1-hole) * aligned_v
536
+ # compare the aligned frame - target frame
537
+
538
+ c_feat_.append(self.Encoder(aligned_r, aligned_v))
539
+
540
+ rvmap_.append(aligned_v)
541
+
542
+ aligned_rs = torch.stack(aligned_r_, 2)
543
+
544
+ c_feats =torch.stack(c_feat_, dim=2)
545
+ rvmaps = torch.stack(rvmap_, dim=2)
546
+
547
+ # p_in: paste network input(target features + c_out + c_mask)
548
+ p_in, c_mask = self.CM_Module(c_feats, 1-hole, rvmaps)
549
+
550
+ pred = self.Decoder(p_in)
551
+
552
+ _, _, _, H, W = aligned_rs.shape
553
+ c_mask = (F.upsample(c_mask, size=(H, W), mode='bilinear', align_corners=False)).detach()
554
+
555
+ comp = pred * (hole) + gt * (1.-hole)
556
+
557
+
558
+ if pad[2]+pad[3] > 0:
559
+ comp = comp[:,:,pad[2]:-pad[3],:]
560
+
561
+ if pad[0]+pad[1] > 0:
562
+ comp = comp[:,:,:,pad[0]:-pad[1]]
563
+
564
+ comp = torch.clamp(comp, 0, 1)
565
+
566
+ return comp
567
+
568
+ def forward(self, Source, Target):
569
+
570
+ feat_target =self.A_Encoder(Target)
571
+ feat_source = self.A_Encoder(Source)
572
+
573
+ theta = self.A_Regressor(feat_target,feat_source)
574
+ grid_rt = F.affine_grid(theta, Target.size())
575
+ aligned = F.grid_sample(Source, grid_rt)
576
+ mask = torch.ones_like(Source)
577
+ mask = F.grid_sample(mask,grid_rt)
578
+
579
+ return aligned,mask
580
+
581
+
582
+ class AC(nn.Module):
583
+ def __init__(self):
584
+ super(AC, self).__init__()
585
+ self.ch_1 = 32
586
+ self.ch_2 = 64
587
+ guide_input_channels = 8
588
+ align_input_channels = 5
589
+ self.gcm_coord = None
590
+
591
+ if not self.gcm_coord:
592
+ guide_input_channels = 6
593
+ align_input_channels = 3
594
+
595
+ self.guide_net = N.seq(
596
+ N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'),
597
+ N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'),
598
+ nn.AdaptiveAvgPool2d(1),
599
+ N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C')
600
+ )
601
+
602
+ self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR')
603
+
604
+ self.align_base = N.seq(
605
+ N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR')
606
+ )
607
+ self.align_tail = N.seq(
608
+ N.conv(self.ch_2, 3, 1, padding=0, mode='C')
609
+ )
610
+
611
+ def forward(self, demosaic_raw, dslr, coord=None):
612
+ demosaic_raw = demosaic_raw+0.01*torch.ones_like(demosaic_raw )
613
+ demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2)
614
+ demosaic_raw = demosaic_raw/2
615
+ if self.gcm_coord:
616
+ guide_input = torch.cat((demosaic_raw, dslr, coord), 1)
617
+ base_input = torch.cat((demosaic_raw, coord), 1)
618
+ else:
619
+ guide_input = torch.cat((demosaic_raw, dslr), 1)
620
+ base_input = demosaic_raw
621
+
622
+ guide = self.guide_net(guide_input)
623
+
624
+ out = self.align_head(base_input)
625
+ out = guide * out + out
626
+ out = self.align_base(out)
627
+ out = self.align_tail(out) +demosaic_raw
628
+
629
+ return out
SCBC/Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ From python:3.8
2
+
3
+ COPY . /SCBC
4
+ WORKDIR /SCBC
5
+
6
+ ARG DEBIAN_FRONTEND=noninteractive
7
+ ENV TZ=Asia/Shanghai
8
+
9
+ RUN apt-get update && apt-get install -y \
10
+ libpng-dev libjpeg-dev \
11
+ libopencv-dev ffmpeg \
12
+ libgl1-mesa-glx
13
+
14
+ RUN python -m pip install --no-cache -r requirements.txt
15
+
16
+ CMD ["./run.sh"]
SCBC/Input/IMG_20240215_213330.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "black_level": [
3
+ 256,
4
+ 256,
5
+ 256,
6
+ 256
7
+ ],
8
+ "white_level": 4095,
9
+ "noise_profile": [
10
+ 0.001180699005,
11
+ 6.3947934705e-06
12
+ ],
13
+ "cfa_pattern": [
14
+ 0,
15
+ 1,
16
+ 1,
17
+ 2
18
+ ],
19
+ "orientation": "Horizontal (normal)",
20
+ "as_shot_neutral": [
21
+ 0.4234199302,
22
+ 1.0,
23
+ 0.2275
24
+ ]
25
+ }
SCBC/Input/IMG_20240215_213330.png ADDED

Git LFS Details

  • SHA256: d8dee2e87044c00bbfbc570d19f6993753b66640710be4df8f8d1f8d536bbc28
  • Pointer size: 133 Bytes
  • Size of remote file: 48.5 MB
SCBC/Input/IMG_20240215_213619.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "black_level": [
3
+ 256,
4
+ 256,
5
+ 256,
6
+ 256
7
+ ],
8
+ "white_level": 4095,
9
+ "noise_profile": [
10
+ 0.000575730186,
11
+ 3.09754693248e-06
12
+ ],
13
+ "cfa_pattern": [
14
+ 0,
15
+ 1,
16
+ 1,
17
+ 2
18
+ ],
19
+ "orientation": "Horizontal (normal)",
20
+ "as_shot_neutral": [
21
+ 0.4354066986,
22
+ 1.0,
23
+ 0.2288348701
24
+ ]
25
+ }
SCBC/Input/IMG_20240215_213619.png ADDED

Git LFS Details

  • SHA256: 3da5e417f363a74e103a6c06dd08b175f2d7e2b4f2b0010ecb15b5fce8b59de0
  • Pointer size: 133 Bytes
  • Size of remote file: 36.8 MB
SCBC/Input/IMG_20240215_214449.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "black_level": [
3
+ 256,
4
+ 256,
5
+ 256,
6
+ 256
7
+ ],
8
+ "white_level": 4095,
9
+ "noise_profile": [
10
+ 0.002300534904,
11
+ 2.25042231834722e-05
12
+ ],
13
+ "cfa_pattern": [
14
+ 0,
15
+ 1,
16
+ 1,
17
+ 2
18
+ ],
19
+ "orientation": "Horizontal (normal)",
20
+ "as_shot_neutral": [
21
+ 0.4204851752,
22
+ 1.0,
23
+ 0.224368194
24
+ ]
25
+ }
SCBC/Input/IMG_20240215_214449.png ADDED

Git LFS Details

  • SHA256: 09ae87c5b6996d3600f439c85283dadab8ba3d157ad0460ccf8064904f998ad7
  • Pointer size: 133 Bytes
  • Size of remote file: 60.2 MB
SCBC/Output/IMG_20240215_213330.png ADDED

Git LFS Details

  • SHA256: bd76e050c3548ce761e99badce8dc7073abb45e46d0730cd1e889aebf0ff27d2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.42 MB
SCBC/Output/IMG_20240215_213619.png ADDED
SCBC/Output/IMG_20240215_214449.png ADDED

Git LFS Details

  • SHA256: ee34b6063a519c1f9b7867bc7237ca6c17412fecf6658a4ee00522f908bf27d9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.42 MB
SCBC/Readme.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ > docker build -t scbc .
2
+ > docker run --gpus all -it --rm -v $PWD/:/SCBC scbc sh run.sh
SCBC/SCBC_Solution.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from CPNet_model import LiteAWBISPNet
7
+ import torchvision
8
+ import numpy as np
9
+ from Utiles import white_balance,apply_color_space_transform, transform_xyz_to_srgb, apply_gamma,fix_orientation,binning,Four2One,One2Four
10
+ import time
11
+ from net.mwrcanet import Net
12
+ import torch.nn as nn
13
+ from PIL import Image
14
+ import torch.nn.functional as F
15
+
16
+ #######Set Raw path###########
17
+ Rpath = './Input'
18
+ image_files = []
19
+
20
+ ####### Temp ###############################
21
+
22
+
23
+ infer_times = []
24
+
25
+
26
+ #######Color Matrix from Baseline#############
27
+ color_matrix = [1.06835938, -0.29882812, -0.14257812,
28
+ -0.43164062, 1.35546875, 0.05078125,
29
+ -0.1015625, 0.24414062, 0.5859375]
30
+
31
+
32
+ #######Data Transfer###########################
33
+ transforms_ = [ transforms.ToTensor(),
34
+ transforms.Resize([768,1024])]
35
+ transform = transforms.Compose(transforms_)
36
+
37
+ transforms_ = [ transforms.ToTensor()]
38
+ transformo = transforms.Compose(transforms_)
39
+
40
+ ########Load the pretrained refinement model####
41
+ model = LiteAWBISPNet()
42
+ model.cuda()
43
+ model.load_state_dict(torch.load('./model_zoo/CC2.pth') )
44
+
45
+ ######load pretrianed Denoised model##############
46
+ last_ckpt = './model_zoo/dn_mwrcanet_raw_c1.pth'
47
+ dn_net = Net()
48
+ dn_model = nn.DataParallel(dn_net).cuda()
49
+ tmp_ckpt = torch.load(last_ckpt)
50
+ pretrained_dict = tmp_ckpt['state_dict']
51
+ model_dict=dn_model.state_dict()
52
+ pretrained_dict_update = {k: v for k, v in pretrained_dict.items() if k in model_dict}
53
+ assert(len(pretrained_dict)==len(pretrained_dict_update))
54
+ assert(len(pretrained_dict_update)==len(model_dict))
55
+ model_dict.update(pretrained_dict_update)
56
+ dn_model.load_state_dict(model_dict)
57
+
58
+ ############################Start Processing!#########
59
+
60
+ for filename in os.listdir(Rpath):
61
+
62
+ if os.path.splitext(filename)[-1].lower() == ".png":
63
+ image_files.append(filename)
64
+
65
+ with torch.no_grad():
66
+ for fp in image_files:
67
+
68
+ fp = os.path.join(Rpath, fp)
69
+ mn = os.path.splitext(fp)[-2]
70
+ mf = str(mn) + '.json'
71
+
72
+ raw_image = cv2.imread(fp, -1)
73
+ with open(mf, 'r') as file:
74
+ data = json.load(file)
75
+
76
+ ############Bleack & Whilte##########################
77
+ time_BL_S = time.time()
78
+
79
+ raw_image = (raw_image.astype(np.float32) - 256.)
80
+ raw_image = raw_image / (4095. - 256.)
81
+ raw_image = np.clip(raw_image, 0.0, 1.0)
82
+
83
+
84
+
85
+ ############# Binning ############################
86
+
87
+ raw_image = binning(raw_image,data)
88
+
89
+
90
+ ############# Down sample ###########################
91
+
92
+
93
+ raw_image = cv2.resize(raw_image, [1024,768])
94
+
95
+
96
+ ############ Raw Denoise ##########################
97
+
98
+ Temp_I = Four2One(raw_image)
99
+ Temp_I = transformo(Temp_I).unsqueeze(0).cuda()
100
+ Temp_I = dn_model(Temp_I)
101
+ Temp_I = np.asarray(Temp_I.squeeze(0).squeeze(0).cpu())
102
+ raw_image = One2Four(Temp_I)
103
+ #raw_image = cv2.resize(raw_image, [1024,768])
104
+
105
+ #############White Balance, Color M, Vignet #########
106
+
107
+ raw_image = white_balance(raw_image, data['as_shot_neutral'])
108
+ raw_image = apply_color_space_transform(raw_image, color_matrix)
109
+ raw_image = transform_xyz_to_srgb(raw_image)
110
+ raw_image = apply_gamma(raw_image)
111
+
112
+
113
+ #############Refinement#############################
114
+
115
+ Source = transform(raw_image).unsqueeze(0).float().cuda()
116
+ Out = model(Source)
117
+
118
+ #################Saving#############################
119
+
120
+ Out = Out.clip(0,1)
121
+ OA = np.asarray(Out.squeeze(0).cpu()).transpose(1,2,0).astype(np.float32)
122
+ OA = OA*255.
123
+ OA = OA.astype(np.uint8)
124
+ OA = fix_orientation(OA,data["orientation"])
125
+ time_Save_F = time.time()
126
+ OA = cv2.cvtColor(OA, cv2.COLOR_RGB2BGR)
127
+ OA = cv2.imwrite('./Output/' + str(os.path.basename(fp)),OA)
128
+
129
+ infer_times.append(time_Save_F-time_BL_S)
130
+ print(f"Average inference time: {np.mean(infer_times)} seconds")
SCBC/Utiles.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from fractions import Fraction
3
+ import cv2
4
+ import numpy as np
5
+ import exifread
6
+ from exifread.utils import Ratio
7
+ import struct
8
+ import json
9
+ import torch
10
+ import time
11
+
12
+ Temp = np.ones([1536,2048]).astype(np.float32)
13
+ Timg = np.ones([768,1024,3]).astype(np.float32)
14
+
15
+ def apply_gamma(x):
16
+ # return x ** (1.0 / 2.2)
17
+ x = x.copy()
18
+ idx = x <= 0.0031308
19
+ x[idx] *= 12.92
20
+ x[idx == False] = (x[idx == False] ** (1.0 / 2.4)) * 1.055 - 0.055
21
+ return x
22
+
23
+ def binning(img,data):
24
+
25
+ if data['cfa_pattern'] == [0,1,1,2]:
26
+
27
+ ch_R = img[0::2, 0::2]
28
+ ch_G = (img[1::2, 0::2]+img[0::2,1::2])/2
29
+ ch_B = img[1::2, 1::2]
30
+ out = np.dstack((ch_R, ch_G, ch_B))
31
+
32
+ if data['cfa_pattern'] == [2,1,1,0]:
33
+
34
+ ch_R = img[1::2, 1::2]
35
+ ch_G = (img[1::2, 0::2]+img[0::2,1::2])/2
36
+ ch_B = img[0::2, 0::2]
37
+ out = np.dstack((ch_R, ch_G, ch_B))
38
+
39
+ return out
40
+
41
+ def Four2One(img):
42
+ Temp[0::2,0::2] = img[:,:,0]
43
+ Temp[1::2,0::2] = img[:,:,1]
44
+ Temp[0::2,1::2] = img[:,:,1]
45
+ Temp[1::2,1::2] = img[:,:,2]
46
+
47
+ return Temp
48
+
49
+ def One2Four(Temp):
50
+ Timg[:,:,0] = Temp[0::2,0::2]
51
+ Timg[:,:,1] = (Temp[1::2,0::2]+Temp[0::2,1::2])/2
52
+ Timg[:,:,2] = Temp[1::2,1::2]
53
+
54
+ return Timg
55
+
56
+
57
+ def white_balance(demosaic_img, as_shot_neutral):
58
+ if type(as_shot_neutral[0]) is Ratio:
59
+ as_shot_neutral = ratios2floats(as_shot_neutral)
60
+
61
+ as_shot_neutral = np.asarray(as_shot_neutral)
62
+ # transform vector into matrix
63
+ if as_shot_neutral.shape == (3,):
64
+ as_shot_neutral = np.diag(1. / as_shot_neutral)
65
+
66
+ assert as_shot_neutral.shape == (3, 3)
67
+
68
+ white_balanced_image = np.dot(demosaic_img, as_shot_neutral.T)
69
+ white_balanced_image = np.clip(white_balanced_image, 0.0, 1.0)
70
+
71
+ return white_balanced_image
72
+
73
+
74
+
75
+
76
+
77
+
78
+ def apply_color_space_transform(demosaiced_image, color_matrix):
79
+ xyz2cam = np.reshape(np.asarray(color_matrix), (3, 3))
80
+ # normalize rows (needed?)
81
+ xyz2cam = xyz2cam / np.sum(xyz2cam, axis=1, keepdims=True)
82
+ # inverse
83
+ cam2xyz = np.linalg.inv(xyz2cam)
84
+ # simplified matrix multiplication
85
+ xyz_image = cam2xyz[np.newaxis, np.newaxis, :, :] * \
86
+ demosaiced_image[:, :, np.newaxis, :]
87
+ xyz_image = np.sum(xyz_image, axis=-1)
88
+ xyz_image = np.clip(xyz_image, 0.0, 1.0)
89
+ return xyz_image
90
+
91
+
92
+
93
+
94
+
95
+ def transform_xyz_to_srgb(xyz_image):
96
+ xyz2srgb = np.array([[3.2404542, -1.5371385, -0.4985314],
97
+ [-0.9692660, 1.8760108, 0.0415560],
98
+ [0.0556434, -0.2040259, 1.0572252]])
99
+
100
+ # normalize rows (needed?)
101
+ xyz2srgb = xyz2srgb / np.sum(xyz2srgb, axis=-1, keepdims=True)
102
+
103
+ srgb_image = xyz2srgb[np.newaxis, np.newaxis, :, :] * xyz_image[:, :, np.newaxis, :]
104
+ srgb_image = np.sum(srgb_image, axis=-1)
105
+ srgb_image = np.clip(srgb_image, 0.0, 1.0)
106
+ return srgb_image
107
+
108
+
109
+
110
+
111
+ def fix_orientation(image, orientation):
112
+ # 1 = Horizontal(normal)
113
+ # 2 = Mirror horizontal
114
+ # 3 = Rotate 180
115
+ # 4 = Mirror vertical
116
+ # 5 = Mirror horizontal and rotate 270 CW
117
+ # 6 = Rotate 90 CW
118
+ # 7 = Mirror horizontal and rotate 90 CW
119
+ # 8 = Rotate 270 CW
120
+
121
+ if type(orientation) is list:
122
+ orientation = orientation[0]
123
+
124
+ if orientation == "Horizontal(normal)":
125
+ pass
126
+ elif orientation == "Mirror horizonta":
127
+ image = cv2.flip(image, 0)
128
+ elif orientation == "Rotate 180":
129
+ image = cv2.rotate(image, cv2.ROTATE_180)
130
+ elif orientation == "Mirror vertical":
131
+ image = cv2.flip(image, 1)
132
+ elif orientation == "Mirror horizontal and rotate 270 CW":
133
+ image = cv2.flip(image, 0)
134
+ image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
135
+ elif orientation == "Rotate 90 CW":
136
+ image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
137
+ elif orientation == "Mirror horizontal and rotate 90 CW":
138
+ image = cv2.flip(image, 0)
139
+ image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
140
+ elif orientation == "Rotate 270 CW":
141
+ image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
142
+
143
+ return image
SCBC/__pycache__/CPNet_model.cpython-38.pyc ADDED
Binary file (15.1 kB). View file
 
SCBC/__pycache__/Utiles.cpython-38.pyc ADDED
Binary file (3.6 kB). View file
 
SCBC/__pycache__/datasets.cpython-38.pyc ADDED
Binary file (1.94 kB). View file
 
SCBC/__pycache__/datasets_crop.cpython-38.pyc ADDED
Binary file (2.1 kB). View file
 
SCBC/__pycache__/datasets_fine.cpython-38.pyc ADDED
Binary file (1.94 kB). View file
 
SCBC/__pycache__/model_module.cpython-38.pyc ADDED
Binary file (1.8 kB). View file
 
SCBC/__pycache__/models.cpython-38.pyc ADDED
Binary file (2.69 kB). View file
 
SCBC/__pycache__/networks.cpython-38.pyc ADDED
Binary file (8.62 kB). View file
 
SCBC/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.09 kB). View file
 
SCBC/model_module.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ import sys
6
+
7
+
8
+ class Conv2d(nn.Module):
9
+ def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, D=1, activation=nn.ReLU()):
10
+ super(Conv2d, self).__init__()
11
+ if activation:
12
+ self.conv = nn.Sequential(
13
+ nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, dilation=D),
14
+ activation
15
+ )
16
+ else:
17
+ self.conv = nn.Sequential(
18
+ nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, dilation=D)
19
+ )
20
+
21
+ def forward(self, x):
22
+ x = self.conv(x)
23
+ return x
24
+
25
+ def init_He(module):
26
+ for m in module.modules():
27
+ if isinstance(m, nn.Conv2d):
28
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
29
+ elif isinstance(m, nn.BatchNorm2d):
30
+ nn.init.constant_(m.weight, 1)
31
+ nn.init.constant_(m.bias, 0)
32
+
33
+ def pad_divide_by(in_list, d, in_size):
34
+ out_list = []
35
+ h, w = in_size
36
+ if h % d > 0:
37
+ new_h = h + d - h % d
38
+ else:
39
+ new_h = h
40
+ if w % d > 0:
41
+ new_w = w + d - w % d
42
+ else:
43
+ new_w = w
44
+ lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
45
+ lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
46
+ pad_array = (int(lw), int(uw), int(lh), int(uh))
47
+ for inp in in_list:
48
+ out_list.append(F.pad(inp, pad_array))
49
+ return out_list, pad_array
SCBC/model_zoo/CC2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:867b8163d95115d73911c0c994044089b65291130196be72a91a5633fc91a873
3
+ size 35619323
SCBC/model_zoo/dn_mwrcanet_raw_c1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b33267f07b484900a327da312cd25b015486453cda55174b95e691310c597d6c
3
+ size 109093370
SCBC/models.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ class ResidualBlock(nn.Module):
5
+ def __init__(self, in_features):
6
+ super(ResidualBlock, self).__init__()
7
+
8
+ conv_block = [ nn.ReflectionPad2d(1),
9
+ nn.Conv2d(in_features, in_features, 3),
10
+ nn.InstanceNorm2d(in_features),
11
+ nn.ReLU(inplace=True),
12
+ nn.ReflectionPad2d(1),
13
+ nn.Conv2d(in_features, in_features, 3),
14
+ nn.InstanceNorm2d(in_features) ]
15
+
16
+ self.conv_block = nn.Sequential(*conv_block)
17
+
18
+ def forward(self, x):
19
+ return x + self.conv_block(x)
20
+
21
+ class Generator(nn.Module):
22
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9):
23
+ super(Generator, self).__init__()
24
+
25
+ # Initial convolution block
26
+ model = [ nn.ReflectionPad2d(3),
27
+ nn.Conv2d(input_nc, 64, 7),
28
+ nn.InstanceNorm2d(64),
29
+ nn.ReLU(inplace=True) ]
30
+
31
+ # Downsampling
32
+ in_features = 64
33
+ out_features = in_features*2
34
+ for _ in range(2):
35
+ model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
36
+ nn.InstanceNorm2d(out_features),
37
+ nn.ReLU(inplace=True) ]
38
+ in_features = out_features
39
+ out_features = in_features*2
40
+
41
+ # Residual blocks
42
+ for _ in range(n_residual_blocks):
43
+ model += [ResidualBlock(in_features)]
44
+
45
+ # Upsampling
46
+ out_features = in_features//2
47
+ for _ in range(2):
48
+ model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
49
+ nn.InstanceNorm2d(out_features),
50
+ nn.ReLU(inplace=True) ]
51
+ in_features = out_features
52
+ out_features = in_features//2
53
+
54
+ # Output layer
55
+ model += [ nn.ReflectionPad2d(3),
56
+ nn.Conv2d(64, output_nc, 7),
57
+ nn.Tanh() ]
58
+
59
+ self.model = nn.Sequential(*model)
60
+
61
+ def forward(self, x):
62
+ return self.model(x)
63
+
64
+ class Discriminator(nn.Module):
65
+ def __init__(self, input_nc):
66
+ super(Discriminator, self).__init__()
67
+
68
+ # A bunch of convolutions one after another
69
+ model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
70
+ nn.LeakyReLU(0.2, inplace=True) ]
71
+
72
+ model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1),
73
+ nn.InstanceNorm2d(128),
74
+ nn.LeakyReLU(0.2, inplace=True) ]
75
+
76
+ model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1),
77
+ nn.InstanceNorm2d(256),
78
+ nn.LeakyReLU(0.2, inplace=True) ]
79
+
80
+ model += [ nn.Conv2d(256, 512, 4, padding=1),
81
+ nn.InstanceNorm2d(512),
82
+ nn.LeakyReLU(0.2, inplace=True) ]
83
+
84
+ # FCN classification layer
85
+ model += [nn.Conv2d(512, 1, 4, padding=1)]
86
+
87
+ self.model = nn.Sequential(*model)
88
+
89
+ def forward(self, x):
90
+ x = self.model(x)
91
+ # Average pooling and flatten
92
+ return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
SCBC/net/__pycache__/mwrcanet.cpython-38.pyc ADDED
Binary file (5.93 kB). View file
 
SCBC/net/mwrcanet.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Yue Cao (cscaoyue@gmail.com) (cscaoyue@hit.edu.cn)
3
+ # supervisor : Wangmeng Zuo (cswmzuo@gmail.com)
4
+ # github: https://github.com/happycaoyue
5
+ # personal link: happycaoyue.com
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ import torch.nn.init as init
10
+ import torch.nn.functional as F
11
+ class HITVPCTeam:
12
+ r"""
13
+ DWT and IDWT block written by: Yue Cao
14
+ """
15
+ class CALayer(nn.Module):
16
+ def __init__(self, channel=64, reduction=16):
17
+ super(HITVPCTeam.CALayer, self).__init__()
18
+
19
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
20
+ self.conv_du = nn.Sequential(
21
+ nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True),
22
+ nn.ReLU(inplace=True),
23
+ nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True),
24
+ nn.Sigmoid()
25
+ )
26
+
27
+ def forward(self, x):
28
+ y = self.avg_pool(x)
29
+ y = self.conv_du(y)
30
+ return x * y
31
+
32
+ # conv - prelu - conv - sum
33
+ class RB(nn.Module):
34
+ def __init__(self, filters):
35
+ super(HITVPCTeam.RB, self).__init__()
36
+ self.conv1 = nn.Conv2d(filters, filters, 3, 1, 1)
37
+ self.act = nn.PReLU()
38
+ self.conv2 = nn.Conv2d(filters, filters, 3, 1, 1)
39
+ self.cuca = HITVPCTeam.CALayer(channel=filters)
40
+
41
+ def forward(self, x):
42
+ c0 = x
43
+ x = self.conv1(x)
44
+ x = self.act(x)
45
+ x = self.conv2(x)
46
+ out = self.cuca(x)
47
+ return out + c0
48
+
49
+ class NRB(nn.Module):
50
+ def __init__(self, n, f):
51
+ super(HITVPCTeam.NRB, self).__init__()
52
+ nets = []
53
+ for i in range(n):
54
+ nets.append(HITVPCTeam.RB(f))
55
+ self.body = nn.Sequential(*nets)
56
+ self.tail = nn.Conv2d(f, f, 3, 1, 1)
57
+
58
+ def forward(self, x):
59
+ return x + self.tail(self.body(x))
60
+
61
+ class DWTForward(nn.Module):
62
+ def __init__(self):
63
+ super(HITVPCTeam.DWTForward, self).__init__()
64
+ ll = np.array([[0.5, 0.5], [0.5, 0.5]])
65
+ lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
66
+ hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
67
+ hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
68
+ filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1],
69
+ hl[None,::-1,::-1], hh[None,::-1,::-1]],
70
+ axis=0)
71
+ self.weight = nn.Parameter(
72
+ torch.tensor(filts).to(torch.get_default_dtype()),
73
+ requires_grad=False)
74
+ def forward(self, x):
75
+ C = x.shape[1]
76
+ filters = torch.cat([self.weight,] * C, dim=0)
77
+ y = F.conv2d(x, filters, groups=C, stride=2)
78
+ return y
79
+
80
+ class DWTInverse(nn.Module):
81
+ def __init__(self):
82
+ super(HITVPCTeam.DWTInverse, self).__init__()
83
+ ll = np.array([[0.5, 0.5], [0.5, 0.5]])
84
+ lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
85
+ hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
86
+ hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
87
+ filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1],
88
+ hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]],
89
+ axis=0)
90
+ self.weight = nn.Parameter(
91
+ torch.tensor(filts).to(torch.get_default_dtype()),
92
+ requires_grad=False)
93
+
94
+ def forward(self, x):
95
+ C = int(x.shape[1] / 4)
96
+ filters = torch.cat([self.weight, ] * C, dim=0)
97
+ y = F.conv_transpose2d(x, filters, groups=C, stride=2)
98
+ return y
99
+
100
+
101
+ class Net(nn.Module):
102
+ def __init__(self, channels=1, filters_level1=96, filters_level2=256//2, filters_level3=256//2, n_rb=4*5):
103
+ super(Net, self).__init__()
104
+
105
+ self.head = HITVPCTeam.DWTForward()
106
+
107
+ self.down1 = nn.Sequential(
108
+ nn.Conv2d(channels * 4, filters_level1, 3, 1, 1),
109
+ nn.PReLU(),
110
+ HITVPCTeam.NRB(n_rb, filters_level1))
111
+
112
+ # sum 1
113
+ # self.down1 = HITVPCTeam.NRB(n_rb, filters_level1),
114
+
115
+ # sum 2
116
+ self.down2 = nn.Sequential(
117
+ HITVPCTeam.DWTForward(),
118
+ nn.Conv2d(filters_level1 * 4, filters_level2, 3, 1, 1),
119
+ nn.PReLU(),
120
+ HITVPCTeam.NRB(n_rb, filters_level2))
121
+
122
+ self.down3 = nn.Sequential(
123
+ HITVPCTeam.DWTForward(),
124
+ nn.Conv2d(filters_level2 * 4, filters_level3, 3, 1, 1),
125
+ nn.PReLU())
126
+
127
+ self.middle = HITVPCTeam.NRB(n_rb, filters_level3)
128
+
129
+ self.up1 = nn.Sequential(
130
+ nn.Conv2d(filters_level3, filters_level2 * 4, 3, 1, 1),
131
+ nn.PReLU(),
132
+ HITVPCTeam.DWTInverse())
133
+
134
+ self.up2 = nn.Sequential(
135
+ HITVPCTeam.NRB(n_rb, filters_level2),
136
+ nn.Conv2d(filters_level2, filters_level1 * 4, 3, 1, 1),
137
+ nn.PReLU(),
138
+ HITVPCTeam.DWTInverse())
139
+
140
+ self.up3 = nn.Sequential(
141
+ HITVPCTeam.NRB(n_rb, filters_level1),
142
+ nn.Conv2d(filters_level1, channels * 4, 3, 1, 1))
143
+
144
+ self.tail = HITVPCTeam.DWTInverse()
145
+
146
+ def forward(self, inputs):
147
+ c0 = inputs
148
+ c1 = self.head(c0)
149
+ c2 = self.down1(c1)
150
+ c3 = self.down2(c2)
151
+ c4 = self.down3(c3)
152
+ m = self.middle(c4)
153
+ c5 = self.up1(m) + c3
154
+ c6 = self.up2(c5) + c2
155
+ c7 = self.up3(c6) + c1
156
+ return self.tail(c7)
157
+
158
+ def _initialize_weights(self):
159
+ for m in self.modules():
160
+ if isinstance(m, nn.Conv2d):
161
+ init.orthogonal_(m.weight)
162
+ print('init weight')
163
+ if m.bias is not None:
164
+ init.constant_(m.bias, 0)
165
+ elif isinstance(m, nn.BatchNorm2d):
166
+ init.constant_(m.weight, 1)
167
+ init.constant_(m.bias, 0)
SCBC/networks.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ from torch.optim import lr_scheduler
5
+ from collections import OrderedDict
6
+
7
+
8
+ def get_scheduler(optimizer, opt):
9
+ if opt.lr_policy == 'linear':
10
+ def lambda_rule(epoch):
11
+ return 1 - max(0, epoch-opt.niter) / max(1, float(opt.niter_decay))
12
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
13
+ elif opt.lr_policy == 'step':
14
+ scheduler = lr_scheduler.StepLR(optimizer,
15
+ step_size=opt.lr_decay_iters,
16
+ gamma=0.5)
17
+ elif opt.lr_policy == 'plateau':
18
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
19
+ mode='min',
20
+ factor=0.2,
21
+ threshold=0.01,
22
+ patience=5)
23
+ elif opt.lr_policy == 'cosine':
24
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
25
+ T_max=opt.niter,
26
+ eta_min=0)
27
+ else:
28
+ return NotImplementedError('lr [%s] is not implemented', opt.lr_policy)
29
+ return scheduler
30
+
31
+ def init_weights(net, init_type='normal', init_gain=0.02):
32
+ def init_func(m): # define the initialization function
33
+ classname = m.__class__.__name__
34
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 \
35
+ or classname.find('Linear') != -1):
36
+ if init_type == 'normal':
37
+ init.normal_(m.weight.data, 0.0, init_gain)
38
+ elif init_type == 'xavier':
39
+ init.xavier_normal_(m.weight.data, gain=init_gain)
40
+ elif init_type == 'kaiming':
41
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
42
+ elif init_type == 'orthogonal':
43
+ init.orthogonal_(m.weight.data, gain=init_gain)
44
+ elif init_type == 'uniform':
45
+ init.uniform_(m.weight.data, b=init_gain)
46
+ else:
47
+ raise NotImplementedError('[%s] is not implemented' % init_type)
48
+ if hasattr(m, 'bias') and m.bias is not None:
49
+ init.constant_(m.bias.data, 0.0)
50
+ elif classname.find('BatchNorm2d') != -1:
51
+ init.normal_(m.weight.data, 1.0, init_gain)
52
+ init.constant_(m.bias.data, 0.0)
53
+
54
+ print('initialize network with %s' % init_type)
55
+ net.apply(init_func) # apply the initialization function <init_func>
56
+
57
+ def init_net(net, init_type='default', init_gain=0.02, gpu_ids=[]):
58
+ if len(gpu_ids) > 0:
59
+ assert(torch.cuda.is_available())
60
+ net.to(gpu_ids[0])
61
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
62
+ if init_type != 'default' and init_type is not None:
63
+ init_weights(net, init_type, init_gain=init_gain)
64
+ return net
65
+
66
+
67
+ '''
68
+ # ===================================
69
+ # Advanced nn.Sequential
70
+ # reform nn.Sequentials and nn.Modules
71
+ # to a single nn.Sequential
72
+ # ===================================
73
+ '''
74
+
75
+ def seq(*args):
76
+ if len(args) == 1:
77
+ args = args[0]
78
+ if isinstance(args, nn.Module):
79
+ return args
80
+ modules = OrderedDict()
81
+ if isinstance(args, OrderedDict):
82
+ for k, v in args.items():
83
+ modules[k] = seq(v)
84
+ return nn.Sequential(modules)
85
+ assert isinstance(args, (list, tuple))
86
+ return nn.Sequential(*[seq(i) for i in args])
87
+
88
+ '''
89
+ # ===================================
90
+ # Useful blocks
91
+ # --------------------------------
92
+ # conv (+ normaliation + relu)
93
+ # concat
94
+ # sum
95
+ # resblock (ResBlock)
96
+ # resdenseblock (ResidualDenseBlock_5C)
97
+ # resinresdenseblock (RRDB)
98
+ # ===================================
99
+ '''
100
+
101
+ # -------------------------------------------------------
102
+ # return nn.Sequantial of (Conv + BN + ReLU)
103
+ # -------------------------------------------------------
104
+ def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1,
105
+ output_padding=0, dilation=1, groups=1, bias=True,
106
+ padding_mode='zeros', mode='CBR'):
107
+ L = []
108
+ for t in mode:
109
+ if t == 'C':
110
+ L.append(nn.Conv2d(in_channels=in_channels,
111
+ out_channels=out_channels,
112
+ kernel_size=kernel_size,
113
+ stride=stride,
114
+ padding=padding,
115
+ dilation=dilation,
116
+ groups=groups,
117
+ bias=bias,
118
+ padding_mode=padding_mode))
119
+ elif t == 'X':
120
+ assert in_channels == out_channels
121
+ L.append(nn.Conv2d(in_channels=in_channels,
122
+ out_channels=out_channels,
123
+ kernel_size=kernel_size,
124
+ stride=stride,
125
+ padding=padding,
126
+ dilation=dilation,
127
+ groups=in_channels,
128
+ bias=bias,
129
+ padding_mode=padding_mode))
130
+ elif t == 'T':
131
+ L.append(nn.ConvTranspose2d(in_channels=in_channels,
132
+ out_channels=out_channels,
133
+ kernel_size=kernel_size,
134
+ stride=stride,
135
+ padding=padding,
136
+ output_padding=output_padding,
137
+ groups=groups,
138
+ bias=bias,
139
+ dilation=dilation,
140
+ padding_mode=padding_mode))
141
+ elif t == 'B':
142
+ L.append(nn.BatchNorm2d(out_channels))
143
+ elif t == 'I':
144
+ L.append(nn.InstanceNorm2d(out_channels, affine=True))
145
+ elif t == 'i':
146
+ L.append(nn.InstanceNorm2d(out_channels))
147
+ elif t == 'R':
148
+ L.append(nn.ReLU(inplace=True))
149
+ elif t == 'r':
150
+ L.append(nn.ReLU(inplace=False))
151
+ elif t == 'S':
152
+ L.append(nn.Sigmoid())
153
+ elif t == 'P':
154
+ L.append(nn.PReLU())
155
+ elif t == 'L':
156
+ L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True))
157
+ elif t == 'l':
158
+ L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False))
159
+ elif t == '2':
160
+ L.append(nn.PixelShuffle(upscale_factor=2))
161
+ elif t == '3':
162
+ L.append(nn.PixelShuffle(upscale_factor=3))
163
+ elif t == '4':
164
+ L.append(nn.PixelShuffle(upscale_factor=4))
165
+ elif t == 'U':
166
+ L.append(nn.Upsample(scale_factor=2, mode='nearest'))
167
+ elif t == 'u':
168
+ L.append(nn.Upsample(scale_factor=3, mode='nearest'))
169
+ elif t == 'M':
170
+ L.append(nn.MaxPool2d(kernel_size=kernel_size,
171
+ stride=stride,
172
+ padding=0))
173
+ elif t == 'A':
174
+ L.append(nn.AvgPool2d(kernel_size=kernel_size,
175
+ stride=stride,
176
+ padding=0))
177
+ else:
178
+ raise NotImplementedError('Undefined type: '.format(t))
179
+ return seq(*L)
180
+
181
+
182
+ class DWTForward(nn.Conv2d):
183
+ def __init__(self, in_channels=64):
184
+ super(DWTForward, self).__init__(in_channels, in_channels*4, 2, 2,
185
+ groups=in_channels, bias=False)
186
+ weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]],
187
+ [[[0.5, 0.5], [-0.5, -0.5]]],
188
+ [[[0.5, -0.5], [ 0.5, -0.5]]],
189
+ [[[0.5, -0.5], [-0.5, 0.5]]]],
190
+ dtype=torch.get_default_dtype()
191
+ ).repeat(in_channels, 1, 1, 1)# / 2
192
+ self.weight.data.copy_(weight)
193
+ self.requires_grad_(False)
194
+
195
+
196
+ class DWTInverse(nn.ConvTranspose2d):
197
+ def __init__(self, in_channels=64):
198
+ super(DWTInverse, self).__init__(in_channels, in_channels//4, 2, 2,
199
+ groups=in_channels//4, bias=False)
200
+ weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]],
201
+ [[[0.5, 0.5], [-0.5, -0.5]]],
202
+ [[[0.5, -0.5], [ 0.5, -0.5]]],
203
+ [[[0.5, -0.5], [-0.5, 0.5]]]],
204
+ dtype=torch.get_default_dtype()
205
+ ).repeat(in_channels//4, 1, 1, 1)# * 2
206
+ self.weight.data.copy_(weight)
207
+ self.requires_grad_(False)
208
+
209
+
210
+ # -------------------------------------------------------
211
+ # Channel Attention (CA) Layer
212
+ # -------------------------------------------------------
213
+ class CALayer(nn.Module):
214
+ def __init__(self, channel=64, reduction=16):
215
+ super(CALayer, self).__init__()
216
+
217
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
218
+ self.conv_du = nn.Sequential(
219
+ nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True),
220
+ nn.ReLU(inplace=True),
221
+ nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True),
222
+ nn.Sigmoid()
223
+ )
224
+
225
+ def forward(self, x):
226
+ y = self.avg_pool(x)
227
+ y = self.conv_du(y)
228
+ return x * y
229
+
230
+
231
+ # -------------------------------------------------------
232
+ # Res Block: x + conv(relu(conv(x)))
233
+ # -------------------------------------------------------
234
+ class ResBlock(nn.Module):
235
+ def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1,
236
+ padding=1, bias=True, mode='CRC'):
237
+ super(ResBlock, self).__init__()
238
+
239
+ assert in_channels == out_channels
240
+ if mode[0] in ['R','L']:
241
+ mode = mode[0].lower() + mode[1:]
242
+
243
+ self.res = conv(in_channels, out_channels, kernel_size,
244
+ stride, padding=padding, bias=bias, mode=mode)
245
+
246
+ def forward(self, x):
247
+ res = self.res(x)
248
+ return x + res
249
+
250
+
251
+ # -------------------------------------------------------
252
+ # Residual Channel Attention Block (RCAB)
253
+ # -------------------------------------------------------
254
+ class RCABlock(nn.Module):
255
+ def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1,
256
+ padding=1, bias=True, mode='CRC', reduction=16):
257
+ super(RCABlock, self).__init__()
258
+ assert in_channels == out_channels
259
+ if mode[0] in ['R','L']:
260
+ mode = mode[0].lower() + mode[1:]
261
+
262
+ self.res = conv(in_channels, out_channels, kernel_size,
263
+ stride, padding, bias=bias, mode=mode)
264
+ self.ca = CALayer(out_channels, reduction)
265
+
266
+ def forward(self, x):
267
+ res = self.res(x)
268
+ res = self.ca(res)
269
+ return res + x
270
+
271
+
272
+ # -------------------------------------------------------
273
+ # Residual Channel Attention Group (RG)
274
+ # -------------------------------------------------------
275
+ class RCAGroup(nn.Module):
276
+ def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1,
277
+ padding=1, bias=True, mode='CRC', reduction=16, nb=12):
278
+ super(RCAGroup, self).__init__()
279
+ assert in_channels == out_channels
280
+ if mode[0] in ['R','L']:
281
+ mode = mode[0].lower() + mode[1:]
282
+
283
+ RG = [RCABlock(in_channels, out_channels, kernel_size, stride, padding,
284
+ bias, mode, reduction) for _ in range(nb)]
285
+ # RG = [ResBlock(in_channels, out_channels, kernel_size, stride, padding,
286
+ # bias, mode) for _ in range(nb)]
287
+ RG.append(conv(out_channels, out_channels, mode='C'))
288
+
289
+ self.rg = nn.Sequential(*RG)
290
+
291
+ def forward(self, x):
292
+ res = self.rg(x)
293
+ return res + x
294
+
SCBC/requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ scipy
3
+ numpy
4
+ torch
5
+ pandas
6
+ torchvision
7
+ Pillow
8
+ matplotlib
9
+ tqdm
10
+ imageio
11
+ seaborn
12
+ hdf5storage
13
+ exifread
SCBC/run.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python SCBC_Solution.py