Artyom
commited on
Commit
•
f8d6c27
1
Parent(s):
94f9590
scbc
Browse files- .gitattributes +5 -0
- SCBC/CPNet_model.py +629 -0
- SCBC/Dockerfile +16 -0
- SCBC/Input/IMG_20240215_213330.json +25 -0
- SCBC/Input/IMG_20240215_213330.png +3 -0
- SCBC/Input/IMG_20240215_213619.json +25 -0
- SCBC/Input/IMG_20240215_213619.png +3 -0
- SCBC/Input/IMG_20240215_214449.json +25 -0
- SCBC/Input/IMG_20240215_214449.png +3 -0
- SCBC/Output/IMG_20240215_213330.png +3 -0
- SCBC/Output/IMG_20240215_213619.png +0 -0
- SCBC/Output/IMG_20240215_214449.png +3 -0
- SCBC/Readme.txt +2 -0
- SCBC/SCBC_Solution.py +130 -0
- SCBC/Utiles.py +143 -0
- SCBC/__pycache__/CPNet_model.cpython-38.pyc +0 -0
- SCBC/__pycache__/Utiles.cpython-38.pyc +0 -0
- SCBC/__pycache__/datasets.cpython-38.pyc +0 -0
- SCBC/__pycache__/datasets_crop.cpython-38.pyc +0 -0
- SCBC/__pycache__/datasets_fine.cpython-38.pyc +0 -0
- SCBC/__pycache__/model_module.cpython-38.pyc +0 -0
- SCBC/__pycache__/models.cpython-38.pyc +0 -0
- SCBC/__pycache__/networks.cpython-38.pyc +0 -0
- SCBC/__pycache__/utils.cpython-38.pyc +0 -0
- SCBC/model_module.py +49 -0
- SCBC/model_zoo/CC2.pth +3 -0
- SCBC/model_zoo/dn_mwrcanet_raw_c1.pth +3 -0
- SCBC/models.py +92 -0
- SCBC/net/__pycache__/mwrcanet.cpython-38.pyc +0 -0
- SCBC/net/mwrcanet.py +167 -0
- SCBC/networks.py +294 -0
- SCBC/requirements.txt +13 -0
- SCBC/run.sh +1 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
SCBC/Input/IMG_20240215_213330.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
SCBC/Input/IMG_20240215_213619.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
SCBC/Input/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
SCBC/Output/IMG_20240215_213330.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
SCBC/Output/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text
|
SCBC/CPNet_model.py
ADDED
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn.init as init
|
6 |
+
import torch.utils.model_zoo as model_zoo
|
7 |
+
from torchvision import models
|
8 |
+
from torchvision import transforms
|
9 |
+
import cv2
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
import math
|
14 |
+
import time
|
15 |
+
import tqdm
|
16 |
+
import os
|
17 |
+
import argparse
|
18 |
+
import copy
|
19 |
+
import sys
|
20 |
+
import networks as N
|
21 |
+
from model_module import *
|
22 |
+
sys.path.insert(0, '.')
|
23 |
+
# from .common import *
|
24 |
+
sys.path.insert(0, '../utils/')
|
25 |
+
|
26 |
+
|
27 |
+
class LiteISPNet(nn.Module):
|
28 |
+
def __init__(self,):
|
29 |
+
super(LiteISPNet, self).__init__()
|
30 |
+
|
31 |
+
ch_1 = 64
|
32 |
+
ch_2 = 128
|
33 |
+
ch_3 = 128
|
34 |
+
n_blocks = 4
|
35 |
+
|
36 |
+
|
37 |
+
self.head = N.seq(
|
38 |
+
N.conv(3, ch_1, mode='C')
|
39 |
+
) # shape: (N, ch_1, H/2, W/2)
|
40 |
+
|
41 |
+
self.down1 = N.seq(
|
42 |
+
N.conv(ch_1, ch_1, mode='C'),
|
43 |
+
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
|
44 |
+
N.conv(ch_1, ch_1, mode='C'),
|
45 |
+
N.DWTForward(ch_1)
|
46 |
+
) # shape: (N, ch_1*4, H/4, W/4)
|
47 |
+
|
48 |
+
self.down2 = N.seq(
|
49 |
+
N.conv(ch_1*4, ch_1, mode='C'),
|
50 |
+
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
|
51 |
+
N.DWTForward(ch_1)
|
52 |
+
) # shape: (N, ch_1*4, H/8, W/8)
|
53 |
+
|
54 |
+
self.down3 = N.seq(
|
55 |
+
N.conv(ch_1*4, ch_2, mode='C'),
|
56 |
+
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
|
57 |
+
N.DWTForward(ch_2)
|
58 |
+
) # shape: (N, ch_2*4, H/16, W/16)
|
59 |
+
|
60 |
+
self.middle = N.seq(
|
61 |
+
N.conv(ch_2*4, ch_3, mode='C'),
|
62 |
+
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
|
63 |
+
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
|
64 |
+
N.conv(ch_3, ch_2*4, mode='C')
|
65 |
+
) # shape: (N, ch_2*4, H/16, W/16)
|
66 |
+
|
67 |
+
self.up3 = N.seq(
|
68 |
+
N.DWTInverse(ch_2*4),
|
69 |
+
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
|
70 |
+
N.conv(ch_2, ch_1*4, mode='C')
|
71 |
+
) # shape: (N, ch_1*4, H/8, W/8)
|
72 |
+
|
73 |
+
self.up2 = N.seq(
|
74 |
+
N.DWTInverse(ch_1*4),
|
75 |
+
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
|
76 |
+
N.conv(ch_1, ch_1*4, mode='C')
|
77 |
+
) # shape: (N, ch_1*4, H/4, W/4)
|
78 |
+
|
79 |
+
self.up1 = N.seq(
|
80 |
+
N.DWTInverse(ch_1*4),
|
81 |
+
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
|
82 |
+
N.conv(ch_1, ch_1, mode='C')
|
83 |
+
) # shape: (N, ch_1, H/2, W/2)
|
84 |
+
|
85 |
+
self.tail = N.seq(
|
86 |
+
#N.conv(ch_1, ch_1*4, mode='C'),
|
87 |
+
#nn.PixelShuffle(upscale_factor=2),
|
88 |
+
N.conv(ch_1, 3, mode='C')
|
89 |
+
) # shape: (N, 3, H, W)
|
90 |
+
|
91 |
+
def forward(self, raw):
|
92 |
+
# input = raw
|
93 |
+
input = torch.pow(raw, 1/2.2)
|
94 |
+
|
95 |
+
h = self.head(input)
|
96 |
+
h_coord = h
|
97 |
+
|
98 |
+
d1 = self.down1(h_coord)
|
99 |
+
d2 = self.down2(d1)
|
100 |
+
d3 = self.down3(d2)
|
101 |
+
m = self.middle(d3) + d3
|
102 |
+
u3 = self.up3(m) + d2
|
103 |
+
u2 = self.up2(u3) + d1
|
104 |
+
u1 = self.up1(u2) + h
|
105 |
+
out = self.tail(u1)
|
106 |
+
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
class LiteAWBISPNet(nn.Module):
|
111 |
+
def __init__(self,):
|
112 |
+
super(LiteAWBISPNet, self).__init__()
|
113 |
+
|
114 |
+
ch_1 = 64
|
115 |
+
ch_2 = 128
|
116 |
+
ch_3 = 128
|
117 |
+
n_blocks = 4
|
118 |
+
|
119 |
+
|
120 |
+
self.head = N.seq(
|
121 |
+
N.conv(3, ch_1, mode='C')
|
122 |
+
) # shape: (N, ch_1, H/2, W/2)
|
123 |
+
|
124 |
+
self.down1 = N.seq(
|
125 |
+
N.conv(ch_1, ch_1, mode='C'),
|
126 |
+
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
|
127 |
+
N.conv(ch_1, ch_1, mode='C'),
|
128 |
+
N.DWTForward(ch_1)
|
129 |
+
) # shape: (N, ch_1*4, H/4, W/4)
|
130 |
+
|
131 |
+
self.down2 = N.seq(
|
132 |
+
N.conv(ch_1*4, ch_1, mode='C'),
|
133 |
+
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
|
134 |
+
N.DWTForward(ch_1)
|
135 |
+
) # shape: (N, ch_1*4, H/8, W/8)
|
136 |
+
|
137 |
+
self.down3 = N.seq(
|
138 |
+
N.conv(ch_1*4, ch_2, mode='C'),
|
139 |
+
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
|
140 |
+
N.DWTForward(ch_2)
|
141 |
+
) # shape: (N, ch_2*4, H/16, W/16)
|
142 |
+
|
143 |
+
self.middle = N.seq(
|
144 |
+
N.conv(ch_2*4, ch_3, mode='C'),
|
145 |
+
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
|
146 |
+
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
|
147 |
+
N.conv(ch_3, ch_2*4, mode='C')
|
148 |
+
) # shape: (N, ch_2*4, H/16, W/16)
|
149 |
+
|
150 |
+
self.up3 = N.seq(
|
151 |
+
N.DWTInverse(ch_2*4),
|
152 |
+
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
|
153 |
+
N.conv(ch_2, ch_1*4, mode='C')
|
154 |
+
) # shape: (N, ch_1*4, H/8, W/8)
|
155 |
+
|
156 |
+
self.up2 = N.seq(
|
157 |
+
N.DWTInverse(ch_1*4),
|
158 |
+
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
|
159 |
+
N.conv(ch_1, ch_1*4, mode='C')
|
160 |
+
) # shape: (N, ch_1*4, H/4, W/4)
|
161 |
+
|
162 |
+
self.up1 = N.seq(
|
163 |
+
N.DWTInverse(ch_1*4),
|
164 |
+
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
|
165 |
+
N.conv(ch_1, ch_1, mode='C')
|
166 |
+
) # shape: (N, ch_1, H/2, W/2)
|
167 |
+
|
168 |
+
self.tail = N.seq(
|
169 |
+
#N.conv(ch_1, ch_1*4, mode='C'),
|
170 |
+
#nn.PixelShuffle(upscale_factor=2),
|
171 |
+
N.conv(ch_1, 3, mode='C')
|
172 |
+
) # shape: (N, 3, H, W)
|
173 |
+
|
174 |
+
def forward(self, raw):
|
175 |
+
# input = raw
|
176 |
+
|
177 |
+
input = raw
|
178 |
+
h = self.head(input)
|
179 |
+
h_coord = h
|
180 |
+
|
181 |
+
d1 = self.down1(h_coord)
|
182 |
+
d2 = self.down2(d1)
|
183 |
+
d3 = self.down3(d2)
|
184 |
+
m = self.middle(d3) + d3
|
185 |
+
u3 = self.up3(m) + d2
|
186 |
+
u2 = self.up2(u3) + d1
|
187 |
+
u1 = self.up1(u2) + h
|
188 |
+
out = self.tail(u1)
|
189 |
+
|
190 |
+
return out
|
191 |
+
|
192 |
+
|
193 |
+
# Alignment Encoder
|
194 |
+
class A_Encoder(nn.Module):
|
195 |
+
def __init__(self):
|
196 |
+
super(A_Encoder, self).__init__()
|
197 |
+
self.conv12 = Conv2d(3, 64, kernel_size=5, stride=2, padding=2, activation=nn.ReLU()) # 2
|
198 |
+
self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
|
199 |
+
self.conv23 = Conv2d(64, 128, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 4
|
200 |
+
self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
|
201 |
+
self.conv34 = Conv2d(128, 256, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 8
|
202 |
+
self.conv4a = Conv2d(256, 256, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 8
|
203 |
+
self.conv4b = Conv2d(256, 256, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 8
|
204 |
+
init_He(self)
|
205 |
+
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
206 |
+
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
|
207 |
+
|
208 |
+
def forward(self, in_f):
|
209 |
+
f = (in_f - self.mean) / self.std
|
210 |
+
x = f
|
211 |
+
x = F.upsample(x, size=(224, 224), mode='bilinear', align_corners=False)
|
212 |
+
x = self.conv12(x)
|
213 |
+
x = self.conv2(x)
|
214 |
+
x = self.conv23(x)
|
215 |
+
x = self.conv3(x)
|
216 |
+
x = self.conv34(x)
|
217 |
+
x = self.conv4a(x)
|
218 |
+
x = self.conv4b(x)
|
219 |
+
return x
|
220 |
+
|
221 |
+
# Alignment Regressor
|
222 |
+
class A_Regressor(nn.Module):
|
223 |
+
def __init__(self):
|
224 |
+
super(A_Regressor, self).__init__()
|
225 |
+
self.conv45 = Conv2d(512, 512, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 16
|
226 |
+
self.conv5a = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 16
|
227 |
+
self.conv5b = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 16
|
228 |
+
self.conv56 = Conv2d(512, 512, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 32
|
229 |
+
self.conv6a = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 32
|
230 |
+
self.conv6b = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 32
|
231 |
+
init_He(self)
|
232 |
+
|
233 |
+
self.fc = nn.Linear(512, 6)
|
234 |
+
self.fc.weight.data.zero_()
|
235 |
+
self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float32))
|
236 |
+
|
237 |
+
def forward(self, feat1, feat2):
|
238 |
+
x = torch.cat([feat1, feat2], dim=1)
|
239 |
+
x = self.conv45(x)
|
240 |
+
x = self.conv5a(x)
|
241 |
+
x = self.conv5b(x)
|
242 |
+
x = self.conv56(x)
|
243 |
+
x = self.conv5a(x)
|
244 |
+
x = self.conv5b(x)
|
245 |
+
|
246 |
+
x = F.avg_pool2d(x, x.shape[2])
|
247 |
+
x = x.view(-1, x.shape[1])
|
248 |
+
|
249 |
+
theta = self.fc(x)
|
250 |
+
theta = theta.view(-1, 2, 3)
|
251 |
+
|
252 |
+
return theta
|
253 |
+
|
254 |
+
# Encoder (Copy network)
|
255 |
+
class Encoder(nn.Module):
|
256 |
+
def __init__(self):
|
257 |
+
super(Encoder, self).__init__()
|
258 |
+
self.conv12 = Conv2d(4, 64, kernel_size=5, stride=2, padding=2, activation=nn.ReLU()) # 2
|
259 |
+
self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
|
260 |
+
self.conv23 = Conv2d(64, 128, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 4
|
261 |
+
self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
|
262 |
+
self.value3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None) # 4
|
263 |
+
init_He(self)
|
264 |
+
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
265 |
+
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
|
266 |
+
|
267 |
+
def forward(self, in_f, in_v):
|
268 |
+
f = (in_f - self.mean) / self.std
|
269 |
+
x = torch.cat([f, in_v], dim=1)
|
270 |
+
x = self.conv12(x)
|
271 |
+
x = self.conv2(x)
|
272 |
+
x = self.conv23(x)
|
273 |
+
x = self.conv3(x)
|
274 |
+
v = self.value3(x)
|
275 |
+
return v
|
276 |
+
|
277 |
+
# Decoder (Paste network)
|
278 |
+
class Decoder(nn.Module):
|
279 |
+
def __init__(self):
|
280 |
+
super(Decoder, self).__init__()
|
281 |
+
self.conv4 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU())
|
282 |
+
self.conv5_1 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU())
|
283 |
+
self.conv5_2 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU())
|
284 |
+
|
285 |
+
# dilated convolution blocks
|
286 |
+
self.convA4_1 = Conv2d(257, 257, kernel_size=3, stride=1, padding=2, D=2, activation=nn.ReLU())
|
287 |
+
self.convA4_2 = Conv2d(257, 257, kernel_size=3, stride=1, padding=4, D=4, activation=nn.ReLU())
|
288 |
+
self.convA4_3 = Conv2d(257, 257, kernel_size=3, stride=1, padding=8, D=8, activation=nn.ReLU())
|
289 |
+
self.convA4_4 = Conv2d(257, 257, kernel_size=3, stride=1, padding=16, D=16,activation=nn.ReLU())
|
290 |
+
|
291 |
+
self.conv3c = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
|
292 |
+
self.conv3b = Conv2d(257, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
|
293 |
+
self.conv3a = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
|
294 |
+
self.conv32 = Conv2d(128, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
|
295 |
+
self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
|
296 |
+
self.conv21 = Conv2d(64, 3, kernel_size=5, stride=1, padding=2, activation=None) # 1
|
297 |
+
|
298 |
+
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
299 |
+
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
x = self.conv4(x)
|
303 |
+
x = self.conv5_1(x)
|
304 |
+
x = self.conv5_2(x)
|
305 |
+
|
306 |
+
x = self.convA4_1(x)
|
307 |
+
x = self.convA4_2(x)
|
308 |
+
x = self.convA4_3(x)
|
309 |
+
x = self.convA4_4(x)
|
310 |
+
|
311 |
+
x = self.conv3c(x)
|
312 |
+
x = self.conv3b(x)
|
313 |
+
x = self.conv3a(x)
|
314 |
+
x = F.upsample(x, scale_factor=2, mode='nearest') # 2
|
315 |
+
x = self.conv32(x)
|
316 |
+
x = self.conv2(x)
|
317 |
+
x = F.upsample(x, scale_factor=2, mode='nearest') # 2
|
318 |
+
x = self.conv21(x)
|
319 |
+
|
320 |
+
p = (x *self.std) + self.mean
|
321 |
+
return p
|
322 |
+
|
323 |
+
|
324 |
+
# Context Matching Module
|
325 |
+
class CM_Module(nn.Module):
|
326 |
+
def __init__(self):
|
327 |
+
super(CM_Module, self).__init__()
|
328 |
+
|
329 |
+
def masked_softmax(self, vec, mask, dim):
|
330 |
+
masked_vec = vec * mask.float()
|
331 |
+
max_vec = torch.max(masked_vec, dim=dim, keepdim=True)[0]
|
332 |
+
exps = torch.exp(masked_vec-max_vec)
|
333 |
+
masked_exps = exps * mask.float()
|
334 |
+
masked_sums = masked_exps.sum(dim, keepdim=True)
|
335 |
+
zeros = (masked_sums <1e-4)
|
336 |
+
masked_sums += zeros.float()
|
337 |
+
return masked_exps/masked_sums
|
338 |
+
|
339 |
+
def forward(self, values, tvmap, rvmaps):
|
340 |
+
|
341 |
+
B, C, T, H, W = values.size()
|
342 |
+
# t_feat: target feature
|
343 |
+
t_feat = values[:, :, 0]
|
344 |
+
# r_feats: refetence features
|
345 |
+
r_feats = values[:, :, 1:]
|
346 |
+
|
347 |
+
B, Cv, T, H, W = r_feats.size()
|
348 |
+
# vmap: visibility map
|
349 |
+
# tvmap: target visibility map
|
350 |
+
# rvmap: reference visibility map
|
351 |
+
# gs: cosine similarity
|
352 |
+
# c_m: c_match
|
353 |
+
gs_,vmap_ = [], []
|
354 |
+
tvmap_t = (F.upsample(tvmap, size=(H, W), mode='bilinear', align_corners=False)>0.5).float()
|
355 |
+
for r in range(T):
|
356 |
+
rvmap_t = (F.upsample(rvmaps[:,:,r], size=(H, W), mode='bilinear', align_corners=False)>0.5).float()
|
357 |
+
# vmap: visibility map
|
358 |
+
vmap = tvmap_t * rvmap_t
|
359 |
+
gs = (vmap * t_feat * r_feats[:,:,r]).sum(-1).sum(-1).sum(-1)
|
360 |
+
#valid sum
|
361 |
+
v_sum = vmap[:,0].sum(-1).sum(-1)
|
362 |
+
zeros = (v_sum <1e-4)
|
363 |
+
gs[zeros] = 0
|
364 |
+
v_sum += zeros.float()
|
365 |
+
gs = gs / v_sum / C
|
366 |
+
gs = torch.ones(t_feat.shape).float().cuda() * gs.view(B,1,1,1)
|
367 |
+
gs_.append(gs)
|
368 |
+
vmap_.append(rvmap_t)
|
369 |
+
|
370 |
+
gss = torch.stack(gs_, dim=2)
|
371 |
+
vmaps = torch.stack(vmap_, dim=2)
|
372 |
+
|
373 |
+
#weighted pixelwise masked softmax
|
374 |
+
c_match = self.masked_softmax(gss, vmaps, dim=2)
|
375 |
+
c_out = torch.sum(r_feats * c_match, dim=2)
|
376 |
+
|
377 |
+
# c_mask
|
378 |
+
c_mask = (c_match * vmaps)
|
379 |
+
c_mask = torch.sum(c_mask,2)
|
380 |
+
c_mask = 1. - (torch.mean(c_mask, 1, keepdim=True))
|
381 |
+
|
382 |
+
return torch.cat([t_feat, c_out, c_mask], dim=1), c_mask
|
383 |
+
|
384 |
+
|
385 |
+
class GCMModel(nn.Module):
|
386 |
+
def __init__(self):
|
387 |
+
super(GCMModel, self).__init__()
|
388 |
+
self.ch_1 = 16
|
389 |
+
self.ch_2 = 32
|
390 |
+
guide_input_channels = 3
|
391 |
+
align_input_channels = 3
|
392 |
+
self.gcm_coord = None
|
393 |
+
|
394 |
+
if not self.gcm_coord:
|
395 |
+
guide_input_channels = 3
|
396 |
+
align_input_channels = 3
|
397 |
+
|
398 |
+
self.guide_net = N.seq(
|
399 |
+
N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'),
|
400 |
+
N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'),
|
401 |
+
nn.AdaptiveAvgPool2d(1),
|
402 |
+
N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C')
|
403 |
+
)
|
404 |
+
|
405 |
+
self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR')
|
406 |
+
|
407 |
+
self.align_base = N.seq(
|
408 |
+
N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCRCRCR')
|
409 |
+
)
|
410 |
+
self.align_tail = N.seq(
|
411 |
+
N.conv(self.ch_2, 3, 1, padding=0, mode='C')
|
412 |
+
)
|
413 |
+
|
414 |
+
def forward(self, demosaic_raw):
|
415 |
+
demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2)
|
416 |
+
guide_input = demosaic_raw
|
417 |
+
base_input =demosaic_raw
|
418 |
+
guide = self.guide_net(guide_input)
|
419 |
+
out = self.align_head(base_input)
|
420 |
+
out = guide * out + out
|
421 |
+
out = self.align_base(out)
|
422 |
+
out = self.align_tail(out)+demosaic_raw
|
423 |
+
|
424 |
+
return out
|
425 |
+
|
426 |
+
class Fusion(nn.Module):
|
427 |
+
def __init__(self):
|
428 |
+
super(Fusion, self).__init__()
|
429 |
+
self.ch_1 = 16
|
430 |
+
self.ch_2 = 32
|
431 |
+
guide_input_channels = 9
|
432 |
+
align_input_channels = 9
|
433 |
+
self.gcm_coord = None
|
434 |
+
|
435 |
+
if not self.gcm_coord:
|
436 |
+
guide_input_channels = 9
|
437 |
+
align_input_channels = 9
|
438 |
+
|
439 |
+
self.guide_net = N.seq(
|
440 |
+
N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'),
|
441 |
+
N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'),
|
442 |
+
nn.AdaptiveAvgPool2d(1),
|
443 |
+
N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C')
|
444 |
+
)
|
445 |
+
|
446 |
+
self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR')
|
447 |
+
|
448 |
+
self.align_base = N.seq(
|
449 |
+
N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR')
|
450 |
+
)
|
451 |
+
self.align_tail = N.seq(
|
452 |
+
N.conv(self.ch_2, 3, 1, padding=0, mode='C')
|
453 |
+
)
|
454 |
+
|
455 |
+
def forward(self, demosaic_raw):
|
456 |
+
#demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2)
|
457 |
+
guide_input = demosaic_raw
|
458 |
+
base_input =demosaic_raw
|
459 |
+
guide = self.guide_net(guide_input)
|
460 |
+
out = self.align_head(base_input)
|
461 |
+
out = guide * out + out
|
462 |
+
out = self.align_base(out)
|
463 |
+
out = self.align_tail(out)
|
464 |
+
|
465 |
+
return out
|
466 |
+
|
467 |
+
|
468 |
+
|
469 |
+
|
470 |
+
class CPNet(nn.Module):
|
471 |
+
def __init__(self, mode='Train'):
|
472 |
+
super(CPNet, self).__init__()
|
473 |
+
self.A_Encoder = A_Encoder() # Align
|
474 |
+
self.A_Regressor = A_Regressor() # output: alignment network
|
475 |
+
self.GCMModel = GCMModel()
|
476 |
+
self.Encoder = Encoder() # Merge
|
477 |
+
self.CM_Module = CM_Module()
|
478 |
+
|
479 |
+
self.Decoder = Decoder()
|
480 |
+
|
481 |
+
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
482 |
+
self.register_buffer('mean3d', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1,1))
|
483 |
+
|
484 |
+
|
485 |
+
def encoding(self, frames, holes):
|
486 |
+
|
487 |
+
batch_size, _, num_frames, height, width = frames.size()
|
488 |
+
# padding
|
489 |
+
(frames, holes), pad = pad_divide_by([frames, holes], 8, (frames.size()[3], frames.size()[4]))
|
490 |
+
|
491 |
+
feat_ = []
|
492 |
+
for t in range(num_frames):
|
493 |
+
feat = self.A_Encoder(frames[:,:,t], holes[:,:,t])
|
494 |
+
feat_.append(feat)
|
495 |
+
feats = torch.stack(feat_, dim=2)
|
496 |
+
return feats
|
497 |
+
|
498 |
+
def inpainting(self, rfeats, rframes, rholes, frame, hole, gt):
|
499 |
+
|
500 |
+
batch_size, _, height, width = frame.size() # B C H W
|
501 |
+
num_r = rfeats.size()[2] # # of reference frames
|
502 |
+
|
503 |
+
# padding
|
504 |
+
(rframes, rholes, frame, hole, gt), pad = pad_divide_by([rframes, rholes, frame, hole, gt], 8, (height, width))
|
505 |
+
|
506 |
+
# Target embedding
|
507 |
+
tfeat = self.A_Encoder(frame, hole)
|
508 |
+
|
509 |
+
# c_feat: Encoder(Copy Network) features
|
510 |
+
c_feat_ = [self.Encoder(frame, hole)]
|
511 |
+
L_align = torch.zeros_like(frame)
|
512 |
+
|
513 |
+
# aligned_r: aligned reference frames
|
514 |
+
aligned_r_ = []
|
515 |
+
|
516 |
+
# rvmap: aligned reference frames valid maps
|
517 |
+
rvmap_ = []
|
518 |
+
|
519 |
+
for r in range(num_r):
|
520 |
+
theta_rt = self.A_Regressor(tfeat, rfeats[:,:,r])
|
521 |
+
grid_rt = F.affine_grid(theta_rt, frame.size())
|
522 |
+
|
523 |
+
# aligned_r: aligned reference frame
|
524 |
+
# reference frame affine transformation
|
525 |
+
aligned_r = F.grid_sample(rframes[:,:,r], grid_rt)
|
526 |
+
|
527 |
+
# aligned_v: aligned reference visiblity map
|
528 |
+
# reference mask affine transformation
|
529 |
+
aligned_v = F.grid_sample(1-rholes[:,:,r], grid_rt)
|
530 |
+
aligned_v = (aligned_v>0.5).float()
|
531 |
+
|
532 |
+
aligned_r_.append(aligned_r)
|
533 |
+
|
534 |
+
#intersection of target and reference valid map
|
535 |
+
trvmap = (1-hole) * aligned_v
|
536 |
+
# compare the aligned frame - target frame
|
537 |
+
|
538 |
+
c_feat_.append(self.Encoder(aligned_r, aligned_v))
|
539 |
+
|
540 |
+
rvmap_.append(aligned_v)
|
541 |
+
|
542 |
+
aligned_rs = torch.stack(aligned_r_, 2)
|
543 |
+
|
544 |
+
c_feats =torch.stack(c_feat_, dim=2)
|
545 |
+
rvmaps = torch.stack(rvmap_, dim=2)
|
546 |
+
|
547 |
+
# p_in: paste network input(target features + c_out + c_mask)
|
548 |
+
p_in, c_mask = self.CM_Module(c_feats, 1-hole, rvmaps)
|
549 |
+
|
550 |
+
pred = self.Decoder(p_in)
|
551 |
+
|
552 |
+
_, _, _, H, W = aligned_rs.shape
|
553 |
+
c_mask = (F.upsample(c_mask, size=(H, W), mode='bilinear', align_corners=False)).detach()
|
554 |
+
|
555 |
+
comp = pred * (hole) + gt * (1.-hole)
|
556 |
+
|
557 |
+
|
558 |
+
if pad[2]+pad[3] > 0:
|
559 |
+
comp = comp[:,:,pad[2]:-pad[3],:]
|
560 |
+
|
561 |
+
if pad[0]+pad[1] > 0:
|
562 |
+
comp = comp[:,:,:,pad[0]:-pad[1]]
|
563 |
+
|
564 |
+
comp = torch.clamp(comp, 0, 1)
|
565 |
+
|
566 |
+
return comp
|
567 |
+
|
568 |
+
def forward(self, Source, Target):
|
569 |
+
|
570 |
+
feat_target =self.A_Encoder(Target)
|
571 |
+
feat_source = self.A_Encoder(Source)
|
572 |
+
|
573 |
+
theta = self.A_Regressor(feat_target,feat_source)
|
574 |
+
grid_rt = F.affine_grid(theta, Target.size())
|
575 |
+
aligned = F.grid_sample(Source, grid_rt)
|
576 |
+
mask = torch.ones_like(Source)
|
577 |
+
mask = F.grid_sample(mask,grid_rt)
|
578 |
+
|
579 |
+
return aligned,mask
|
580 |
+
|
581 |
+
|
582 |
+
class AC(nn.Module):
|
583 |
+
def __init__(self):
|
584 |
+
super(AC, self).__init__()
|
585 |
+
self.ch_1 = 32
|
586 |
+
self.ch_2 = 64
|
587 |
+
guide_input_channels = 8
|
588 |
+
align_input_channels = 5
|
589 |
+
self.gcm_coord = None
|
590 |
+
|
591 |
+
if not self.gcm_coord:
|
592 |
+
guide_input_channels = 6
|
593 |
+
align_input_channels = 3
|
594 |
+
|
595 |
+
self.guide_net = N.seq(
|
596 |
+
N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'),
|
597 |
+
N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'),
|
598 |
+
nn.AdaptiveAvgPool2d(1),
|
599 |
+
N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C')
|
600 |
+
)
|
601 |
+
|
602 |
+
self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR')
|
603 |
+
|
604 |
+
self.align_base = N.seq(
|
605 |
+
N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR')
|
606 |
+
)
|
607 |
+
self.align_tail = N.seq(
|
608 |
+
N.conv(self.ch_2, 3, 1, padding=0, mode='C')
|
609 |
+
)
|
610 |
+
|
611 |
+
def forward(self, demosaic_raw, dslr, coord=None):
|
612 |
+
demosaic_raw = demosaic_raw+0.01*torch.ones_like(demosaic_raw )
|
613 |
+
demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2)
|
614 |
+
demosaic_raw = demosaic_raw/2
|
615 |
+
if self.gcm_coord:
|
616 |
+
guide_input = torch.cat((demosaic_raw, dslr, coord), 1)
|
617 |
+
base_input = torch.cat((demosaic_raw, coord), 1)
|
618 |
+
else:
|
619 |
+
guide_input = torch.cat((demosaic_raw, dslr), 1)
|
620 |
+
base_input = demosaic_raw
|
621 |
+
|
622 |
+
guide = self.guide_net(guide_input)
|
623 |
+
|
624 |
+
out = self.align_head(base_input)
|
625 |
+
out = guide * out + out
|
626 |
+
out = self.align_base(out)
|
627 |
+
out = self.align_tail(out) +demosaic_raw
|
628 |
+
|
629 |
+
return out
|
SCBC/Dockerfile
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
From python:3.8
|
2 |
+
|
3 |
+
COPY . /SCBC
|
4 |
+
WORKDIR /SCBC
|
5 |
+
|
6 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
7 |
+
ENV TZ=Asia/Shanghai
|
8 |
+
|
9 |
+
RUN apt-get update && apt-get install -y \
|
10 |
+
libpng-dev libjpeg-dev \
|
11 |
+
libopencv-dev ffmpeg \
|
12 |
+
libgl1-mesa-glx
|
13 |
+
|
14 |
+
RUN python -m pip install --no-cache -r requirements.txt
|
15 |
+
|
16 |
+
CMD ["./run.sh"]
|
SCBC/Input/IMG_20240215_213330.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"black_level": [
|
3 |
+
256,
|
4 |
+
256,
|
5 |
+
256,
|
6 |
+
256
|
7 |
+
],
|
8 |
+
"white_level": 4095,
|
9 |
+
"noise_profile": [
|
10 |
+
0.001180699005,
|
11 |
+
6.3947934705e-06
|
12 |
+
],
|
13 |
+
"cfa_pattern": [
|
14 |
+
0,
|
15 |
+
1,
|
16 |
+
1,
|
17 |
+
2
|
18 |
+
],
|
19 |
+
"orientation": "Horizontal (normal)",
|
20 |
+
"as_shot_neutral": [
|
21 |
+
0.4234199302,
|
22 |
+
1.0,
|
23 |
+
0.2275
|
24 |
+
]
|
25 |
+
}
|
SCBC/Input/IMG_20240215_213330.png
ADDED
Git LFS Details
|
SCBC/Input/IMG_20240215_213619.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"black_level": [
|
3 |
+
256,
|
4 |
+
256,
|
5 |
+
256,
|
6 |
+
256
|
7 |
+
],
|
8 |
+
"white_level": 4095,
|
9 |
+
"noise_profile": [
|
10 |
+
0.000575730186,
|
11 |
+
3.09754693248e-06
|
12 |
+
],
|
13 |
+
"cfa_pattern": [
|
14 |
+
0,
|
15 |
+
1,
|
16 |
+
1,
|
17 |
+
2
|
18 |
+
],
|
19 |
+
"orientation": "Horizontal (normal)",
|
20 |
+
"as_shot_neutral": [
|
21 |
+
0.4354066986,
|
22 |
+
1.0,
|
23 |
+
0.2288348701
|
24 |
+
]
|
25 |
+
}
|
SCBC/Input/IMG_20240215_213619.png
ADDED
Git LFS Details
|
SCBC/Input/IMG_20240215_214449.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"black_level": [
|
3 |
+
256,
|
4 |
+
256,
|
5 |
+
256,
|
6 |
+
256
|
7 |
+
],
|
8 |
+
"white_level": 4095,
|
9 |
+
"noise_profile": [
|
10 |
+
0.002300534904,
|
11 |
+
2.25042231834722e-05
|
12 |
+
],
|
13 |
+
"cfa_pattern": [
|
14 |
+
0,
|
15 |
+
1,
|
16 |
+
1,
|
17 |
+
2
|
18 |
+
],
|
19 |
+
"orientation": "Horizontal (normal)",
|
20 |
+
"as_shot_neutral": [
|
21 |
+
0.4204851752,
|
22 |
+
1.0,
|
23 |
+
0.224368194
|
24 |
+
]
|
25 |
+
}
|
SCBC/Input/IMG_20240215_214449.png
ADDED
Git LFS Details
|
SCBC/Output/IMG_20240215_213330.png
ADDED
Git LFS Details
|
SCBC/Output/IMG_20240215_213619.png
ADDED
SCBC/Output/IMG_20240215_214449.png
ADDED
Git LFS Details
|
SCBC/Readme.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
> docker build -t scbc .
|
2 |
+
> docker run --gpus all -it --rm -v $PWD/:/SCBC scbc sh run.sh
|
SCBC/SCBC_Solution.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from CPNet_model import LiteAWBISPNet
|
7 |
+
import torchvision
|
8 |
+
import numpy as np
|
9 |
+
from Utiles import white_balance,apply_color_space_transform, transform_xyz_to_srgb, apply_gamma,fix_orientation,binning,Four2One,One2Four
|
10 |
+
import time
|
11 |
+
from net.mwrcanet import Net
|
12 |
+
import torch.nn as nn
|
13 |
+
from PIL import Image
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
#######Set Raw path###########
|
17 |
+
Rpath = './Input'
|
18 |
+
image_files = []
|
19 |
+
|
20 |
+
####### Temp ###############################
|
21 |
+
|
22 |
+
|
23 |
+
infer_times = []
|
24 |
+
|
25 |
+
|
26 |
+
#######Color Matrix from Baseline#############
|
27 |
+
color_matrix = [1.06835938, -0.29882812, -0.14257812,
|
28 |
+
-0.43164062, 1.35546875, 0.05078125,
|
29 |
+
-0.1015625, 0.24414062, 0.5859375]
|
30 |
+
|
31 |
+
|
32 |
+
#######Data Transfer###########################
|
33 |
+
transforms_ = [ transforms.ToTensor(),
|
34 |
+
transforms.Resize([768,1024])]
|
35 |
+
transform = transforms.Compose(transforms_)
|
36 |
+
|
37 |
+
transforms_ = [ transforms.ToTensor()]
|
38 |
+
transformo = transforms.Compose(transforms_)
|
39 |
+
|
40 |
+
########Load the pretrained refinement model####
|
41 |
+
model = LiteAWBISPNet()
|
42 |
+
model.cuda()
|
43 |
+
model.load_state_dict(torch.load('./model_zoo/CC2.pth') )
|
44 |
+
|
45 |
+
######load pretrianed Denoised model##############
|
46 |
+
last_ckpt = './model_zoo/dn_mwrcanet_raw_c1.pth'
|
47 |
+
dn_net = Net()
|
48 |
+
dn_model = nn.DataParallel(dn_net).cuda()
|
49 |
+
tmp_ckpt = torch.load(last_ckpt)
|
50 |
+
pretrained_dict = tmp_ckpt['state_dict']
|
51 |
+
model_dict=dn_model.state_dict()
|
52 |
+
pretrained_dict_update = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
53 |
+
assert(len(pretrained_dict)==len(pretrained_dict_update))
|
54 |
+
assert(len(pretrained_dict_update)==len(model_dict))
|
55 |
+
model_dict.update(pretrained_dict_update)
|
56 |
+
dn_model.load_state_dict(model_dict)
|
57 |
+
|
58 |
+
############################Start Processing!#########
|
59 |
+
|
60 |
+
for filename in os.listdir(Rpath):
|
61 |
+
|
62 |
+
if os.path.splitext(filename)[-1].lower() == ".png":
|
63 |
+
image_files.append(filename)
|
64 |
+
|
65 |
+
with torch.no_grad():
|
66 |
+
for fp in image_files:
|
67 |
+
|
68 |
+
fp = os.path.join(Rpath, fp)
|
69 |
+
mn = os.path.splitext(fp)[-2]
|
70 |
+
mf = str(mn) + '.json'
|
71 |
+
|
72 |
+
raw_image = cv2.imread(fp, -1)
|
73 |
+
with open(mf, 'r') as file:
|
74 |
+
data = json.load(file)
|
75 |
+
|
76 |
+
############Bleack & Whilte##########################
|
77 |
+
time_BL_S = time.time()
|
78 |
+
|
79 |
+
raw_image = (raw_image.astype(np.float32) - 256.)
|
80 |
+
raw_image = raw_image / (4095. - 256.)
|
81 |
+
raw_image = np.clip(raw_image, 0.0, 1.0)
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
############# Binning ############################
|
86 |
+
|
87 |
+
raw_image = binning(raw_image,data)
|
88 |
+
|
89 |
+
|
90 |
+
############# Down sample ###########################
|
91 |
+
|
92 |
+
|
93 |
+
raw_image = cv2.resize(raw_image, [1024,768])
|
94 |
+
|
95 |
+
|
96 |
+
############ Raw Denoise ##########################
|
97 |
+
|
98 |
+
Temp_I = Four2One(raw_image)
|
99 |
+
Temp_I = transformo(Temp_I).unsqueeze(0).cuda()
|
100 |
+
Temp_I = dn_model(Temp_I)
|
101 |
+
Temp_I = np.asarray(Temp_I.squeeze(0).squeeze(0).cpu())
|
102 |
+
raw_image = One2Four(Temp_I)
|
103 |
+
#raw_image = cv2.resize(raw_image, [1024,768])
|
104 |
+
|
105 |
+
#############White Balance, Color M, Vignet #########
|
106 |
+
|
107 |
+
raw_image = white_balance(raw_image, data['as_shot_neutral'])
|
108 |
+
raw_image = apply_color_space_transform(raw_image, color_matrix)
|
109 |
+
raw_image = transform_xyz_to_srgb(raw_image)
|
110 |
+
raw_image = apply_gamma(raw_image)
|
111 |
+
|
112 |
+
|
113 |
+
#############Refinement#############################
|
114 |
+
|
115 |
+
Source = transform(raw_image).unsqueeze(0).float().cuda()
|
116 |
+
Out = model(Source)
|
117 |
+
|
118 |
+
#################Saving#############################
|
119 |
+
|
120 |
+
Out = Out.clip(0,1)
|
121 |
+
OA = np.asarray(Out.squeeze(0).cpu()).transpose(1,2,0).astype(np.float32)
|
122 |
+
OA = OA*255.
|
123 |
+
OA = OA.astype(np.uint8)
|
124 |
+
OA = fix_orientation(OA,data["orientation"])
|
125 |
+
time_Save_F = time.time()
|
126 |
+
OA = cv2.cvtColor(OA, cv2.COLOR_RGB2BGR)
|
127 |
+
OA = cv2.imwrite('./Output/' + str(os.path.basename(fp)),OA)
|
128 |
+
|
129 |
+
infer_times.append(time_Save_F-time_BL_S)
|
130 |
+
print(f"Average inference time: {np.mean(infer_times)} seconds")
|
SCBC/Utiles.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from fractions import Fraction
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import exifread
|
6 |
+
from exifread.utils import Ratio
|
7 |
+
import struct
|
8 |
+
import json
|
9 |
+
import torch
|
10 |
+
import time
|
11 |
+
|
12 |
+
Temp = np.ones([1536,2048]).astype(np.float32)
|
13 |
+
Timg = np.ones([768,1024,3]).astype(np.float32)
|
14 |
+
|
15 |
+
def apply_gamma(x):
|
16 |
+
# return x ** (1.0 / 2.2)
|
17 |
+
x = x.copy()
|
18 |
+
idx = x <= 0.0031308
|
19 |
+
x[idx] *= 12.92
|
20 |
+
x[idx == False] = (x[idx == False] ** (1.0 / 2.4)) * 1.055 - 0.055
|
21 |
+
return x
|
22 |
+
|
23 |
+
def binning(img,data):
|
24 |
+
|
25 |
+
if data['cfa_pattern'] == [0,1,1,2]:
|
26 |
+
|
27 |
+
ch_R = img[0::2, 0::2]
|
28 |
+
ch_G = (img[1::2, 0::2]+img[0::2,1::2])/2
|
29 |
+
ch_B = img[1::2, 1::2]
|
30 |
+
out = np.dstack((ch_R, ch_G, ch_B))
|
31 |
+
|
32 |
+
if data['cfa_pattern'] == [2,1,1,0]:
|
33 |
+
|
34 |
+
ch_R = img[1::2, 1::2]
|
35 |
+
ch_G = (img[1::2, 0::2]+img[0::2,1::2])/2
|
36 |
+
ch_B = img[0::2, 0::2]
|
37 |
+
out = np.dstack((ch_R, ch_G, ch_B))
|
38 |
+
|
39 |
+
return out
|
40 |
+
|
41 |
+
def Four2One(img):
|
42 |
+
Temp[0::2,0::2] = img[:,:,0]
|
43 |
+
Temp[1::2,0::2] = img[:,:,1]
|
44 |
+
Temp[0::2,1::2] = img[:,:,1]
|
45 |
+
Temp[1::2,1::2] = img[:,:,2]
|
46 |
+
|
47 |
+
return Temp
|
48 |
+
|
49 |
+
def One2Four(Temp):
|
50 |
+
Timg[:,:,0] = Temp[0::2,0::2]
|
51 |
+
Timg[:,:,1] = (Temp[1::2,0::2]+Temp[0::2,1::2])/2
|
52 |
+
Timg[:,:,2] = Temp[1::2,1::2]
|
53 |
+
|
54 |
+
return Timg
|
55 |
+
|
56 |
+
|
57 |
+
def white_balance(demosaic_img, as_shot_neutral):
|
58 |
+
if type(as_shot_neutral[0]) is Ratio:
|
59 |
+
as_shot_neutral = ratios2floats(as_shot_neutral)
|
60 |
+
|
61 |
+
as_shot_neutral = np.asarray(as_shot_neutral)
|
62 |
+
# transform vector into matrix
|
63 |
+
if as_shot_neutral.shape == (3,):
|
64 |
+
as_shot_neutral = np.diag(1. / as_shot_neutral)
|
65 |
+
|
66 |
+
assert as_shot_neutral.shape == (3, 3)
|
67 |
+
|
68 |
+
white_balanced_image = np.dot(demosaic_img, as_shot_neutral.T)
|
69 |
+
white_balanced_image = np.clip(white_balanced_image, 0.0, 1.0)
|
70 |
+
|
71 |
+
return white_balanced_image
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
def apply_color_space_transform(demosaiced_image, color_matrix):
|
79 |
+
xyz2cam = np.reshape(np.asarray(color_matrix), (3, 3))
|
80 |
+
# normalize rows (needed?)
|
81 |
+
xyz2cam = xyz2cam / np.sum(xyz2cam, axis=1, keepdims=True)
|
82 |
+
# inverse
|
83 |
+
cam2xyz = np.linalg.inv(xyz2cam)
|
84 |
+
# simplified matrix multiplication
|
85 |
+
xyz_image = cam2xyz[np.newaxis, np.newaxis, :, :] * \
|
86 |
+
demosaiced_image[:, :, np.newaxis, :]
|
87 |
+
xyz_image = np.sum(xyz_image, axis=-1)
|
88 |
+
xyz_image = np.clip(xyz_image, 0.0, 1.0)
|
89 |
+
return xyz_image
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
def transform_xyz_to_srgb(xyz_image):
|
96 |
+
xyz2srgb = np.array([[3.2404542, -1.5371385, -0.4985314],
|
97 |
+
[-0.9692660, 1.8760108, 0.0415560],
|
98 |
+
[0.0556434, -0.2040259, 1.0572252]])
|
99 |
+
|
100 |
+
# normalize rows (needed?)
|
101 |
+
xyz2srgb = xyz2srgb / np.sum(xyz2srgb, axis=-1, keepdims=True)
|
102 |
+
|
103 |
+
srgb_image = xyz2srgb[np.newaxis, np.newaxis, :, :] * xyz_image[:, :, np.newaxis, :]
|
104 |
+
srgb_image = np.sum(srgb_image, axis=-1)
|
105 |
+
srgb_image = np.clip(srgb_image, 0.0, 1.0)
|
106 |
+
return srgb_image
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
def fix_orientation(image, orientation):
|
112 |
+
# 1 = Horizontal(normal)
|
113 |
+
# 2 = Mirror horizontal
|
114 |
+
# 3 = Rotate 180
|
115 |
+
# 4 = Mirror vertical
|
116 |
+
# 5 = Mirror horizontal and rotate 270 CW
|
117 |
+
# 6 = Rotate 90 CW
|
118 |
+
# 7 = Mirror horizontal and rotate 90 CW
|
119 |
+
# 8 = Rotate 270 CW
|
120 |
+
|
121 |
+
if type(orientation) is list:
|
122 |
+
orientation = orientation[0]
|
123 |
+
|
124 |
+
if orientation == "Horizontal(normal)":
|
125 |
+
pass
|
126 |
+
elif orientation == "Mirror horizonta":
|
127 |
+
image = cv2.flip(image, 0)
|
128 |
+
elif orientation == "Rotate 180":
|
129 |
+
image = cv2.rotate(image, cv2.ROTATE_180)
|
130 |
+
elif orientation == "Mirror vertical":
|
131 |
+
image = cv2.flip(image, 1)
|
132 |
+
elif orientation == "Mirror horizontal and rotate 270 CW":
|
133 |
+
image = cv2.flip(image, 0)
|
134 |
+
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
135 |
+
elif orientation == "Rotate 90 CW":
|
136 |
+
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
|
137 |
+
elif orientation == "Mirror horizontal and rotate 90 CW":
|
138 |
+
image = cv2.flip(image, 0)
|
139 |
+
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
|
140 |
+
elif orientation == "Rotate 270 CW":
|
141 |
+
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
142 |
+
|
143 |
+
return image
|
SCBC/__pycache__/CPNet_model.cpython-38.pyc
ADDED
Binary file (15.1 kB). View file
|
|
SCBC/__pycache__/Utiles.cpython-38.pyc
ADDED
Binary file (3.6 kB). View file
|
|
SCBC/__pycache__/datasets.cpython-38.pyc
ADDED
Binary file (1.94 kB). View file
|
|
SCBC/__pycache__/datasets_crop.cpython-38.pyc
ADDED
Binary file (2.1 kB). View file
|
|
SCBC/__pycache__/datasets_fine.cpython-38.pyc
ADDED
Binary file (1.94 kB). View file
|
|
SCBC/__pycache__/model_module.cpython-38.pyc
ADDED
Binary file (1.8 kB). View file
|
|
SCBC/__pycache__/models.cpython-38.pyc
ADDED
Binary file (2.69 kB). View file
|
|
SCBC/__pycache__/networks.cpython-38.pyc
ADDED
Binary file (8.62 kB). View file
|
|
SCBC/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (4.09 kB). View file
|
|
SCBC/model_module.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Variable
|
5 |
+
import sys
|
6 |
+
|
7 |
+
|
8 |
+
class Conv2d(nn.Module):
|
9 |
+
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, D=1, activation=nn.ReLU()):
|
10 |
+
super(Conv2d, self).__init__()
|
11 |
+
if activation:
|
12 |
+
self.conv = nn.Sequential(
|
13 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, dilation=D),
|
14 |
+
activation
|
15 |
+
)
|
16 |
+
else:
|
17 |
+
self.conv = nn.Sequential(
|
18 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, dilation=D)
|
19 |
+
)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = self.conv(x)
|
23 |
+
return x
|
24 |
+
|
25 |
+
def init_He(module):
|
26 |
+
for m in module.modules():
|
27 |
+
if isinstance(m, nn.Conv2d):
|
28 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
29 |
+
elif isinstance(m, nn.BatchNorm2d):
|
30 |
+
nn.init.constant_(m.weight, 1)
|
31 |
+
nn.init.constant_(m.bias, 0)
|
32 |
+
|
33 |
+
def pad_divide_by(in_list, d, in_size):
|
34 |
+
out_list = []
|
35 |
+
h, w = in_size
|
36 |
+
if h % d > 0:
|
37 |
+
new_h = h + d - h % d
|
38 |
+
else:
|
39 |
+
new_h = h
|
40 |
+
if w % d > 0:
|
41 |
+
new_w = w + d - w % d
|
42 |
+
else:
|
43 |
+
new_w = w
|
44 |
+
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
|
45 |
+
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
|
46 |
+
pad_array = (int(lw), int(uw), int(lh), int(uh))
|
47 |
+
for inp in in_list:
|
48 |
+
out_list.append(F.pad(inp, pad_array))
|
49 |
+
return out_list, pad_array
|
SCBC/model_zoo/CC2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:867b8163d95115d73911c0c994044089b65291130196be72a91a5633fc91a873
|
3 |
+
size 35619323
|
SCBC/model_zoo/dn_mwrcanet_raw_c1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b33267f07b484900a327da312cd25b015486453cda55174b95e691310c597d6c
|
3 |
+
size 109093370
|
SCBC/models.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
class ResidualBlock(nn.Module):
|
5 |
+
def __init__(self, in_features):
|
6 |
+
super(ResidualBlock, self).__init__()
|
7 |
+
|
8 |
+
conv_block = [ nn.ReflectionPad2d(1),
|
9 |
+
nn.Conv2d(in_features, in_features, 3),
|
10 |
+
nn.InstanceNorm2d(in_features),
|
11 |
+
nn.ReLU(inplace=True),
|
12 |
+
nn.ReflectionPad2d(1),
|
13 |
+
nn.Conv2d(in_features, in_features, 3),
|
14 |
+
nn.InstanceNorm2d(in_features) ]
|
15 |
+
|
16 |
+
self.conv_block = nn.Sequential(*conv_block)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return x + self.conv_block(x)
|
20 |
+
|
21 |
+
class Generator(nn.Module):
|
22 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
|
23 |
+
super(Generator, self).__init__()
|
24 |
+
|
25 |
+
# Initial convolution block
|
26 |
+
model = [ nn.ReflectionPad2d(3),
|
27 |
+
nn.Conv2d(input_nc, 64, 7),
|
28 |
+
nn.InstanceNorm2d(64),
|
29 |
+
nn.ReLU(inplace=True) ]
|
30 |
+
|
31 |
+
# Downsampling
|
32 |
+
in_features = 64
|
33 |
+
out_features = in_features*2
|
34 |
+
for _ in range(2):
|
35 |
+
model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
36 |
+
nn.InstanceNorm2d(out_features),
|
37 |
+
nn.ReLU(inplace=True) ]
|
38 |
+
in_features = out_features
|
39 |
+
out_features = in_features*2
|
40 |
+
|
41 |
+
# Residual blocks
|
42 |
+
for _ in range(n_residual_blocks):
|
43 |
+
model += [ResidualBlock(in_features)]
|
44 |
+
|
45 |
+
# Upsampling
|
46 |
+
out_features = in_features//2
|
47 |
+
for _ in range(2):
|
48 |
+
model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
49 |
+
nn.InstanceNorm2d(out_features),
|
50 |
+
nn.ReLU(inplace=True) ]
|
51 |
+
in_features = out_features
|
52 |
+
out_features = in_features//2
|
53 |
+
|
54 |
+
# Output layer
|
55 |
+
model += [ nn.ReflectionPad2d(3),
|
56 |
+
nn.Conv2d(64, output_nc, 7),
|
57 |
+
nn.Tanh() ]
|
58 |
+
|
59 |
+
self.model = nn.Sequential(*model)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
return self.model(x)
|
63 |
+
|
64 |
+
class Discriminator(nn.Module):
|
65 |
+
def __init__(self, input_nc):
|
66 |
+
super(Discriminator, self).__init__()
|
67 |
+
|
68 |
+
# A bunch of convolutions one after another
|
69 |
+
model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
|
70 |
+
nn.LeakyReLU(0.2, inplace=True) ]
|
71 |
+
|
72 |
+
model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1),
|
73 |
+
nn.InstanceNorm2d(128),
|
74 |
+
nn.LeakyReLU(0.2, inplace=True) ]
|
75 |
+
|
76 |
+
model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1),
|
77 |
+
nn.InstanceNorm2d(256),
|
78 |
+
nn.LeakyReLU(0.2, inplace=True) ]
|
79 |
+
|
80 |
+
model += [ nn.Conv2d(256, 512, 4, padding=1),
|
81 |
+
nn.InstanceNorm2d(512),
|
82 |
+
nn.LeakyReLU(0.2, inplace=True) ]
|
83 |
+
|
84 |
+
# FCN classification layer
|
85 |
+
model += [nn.Conv2d(512, 1, 4, padding=1)]
|
86 |
+
|
87 |
+
self.model = nn.Sequential(*model)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x = self.model(x)
|
91 |
+
# Average pooling and flatten
|
92 |
+
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
|
SCBC/net/__pycache__/mwrcanet.cpython-38.pyc
ADDED
Binary file (5.93 kB). View file
|
|
SCBC/net/mwrcanet.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Yue Cao (cscaoyue@gmail.com) (cscaoyue@hit.edu.cn)
|
3 |
+
# supervisor : Wangmeng Zuo (cswmzuo@gmail.com)
|
4 |
+
# github: https://github.com/happycaoyue
|
5 |
+
# personal link: happycaoyue.com
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import numpy as np
|
9 |
+
import torch.nn.init as init
|
10 |
+
import torch.nn.functional as F
|
11 |
+
class HITVPCTeam:
|
12 |
+
r"""
|
13 |
+
DWT and IDWT block written by: Yue Cao
|
14 |
+
"""
|
15 |
+
class CALayer(nn.Module):
|
16 |
+
def __init__(self, channel=64, reduction=16):
|
17 |
+
super(HITVPCTeam.CALayer, self).__init__()
|
18 |
+
|
19 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
20 |
+
self.conv_du = nn.Sequential(
|
21 |
+
nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True),
|
24 |
+
nn.Sigmoid()
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
y = self.avg_pool(x)
|
29 |
+
y = self.conv_du(y)
|
30 |
+
return x * y
|
31 |
+
|
32 |
+
# conv - prelu - conv - sum
|
33 |
+
class RB(nn.Module):
|
34 |
+
def __init__(self, filters):
|
35 |
+
super(HITVPCTeam.RB, self).__init__()
|
36 |
+
self.conv1 = nn.Conv2d(filters, filters, 3, 1, 1)
|
37 |
+
self.act = nn.PReLU()
|
38 |
+
self.conv2 = nn.Conv2d(filters, filters, 3, 1, 1)
|
39 |
+
self.cuca = HITVPCTeam.CALayer(channel=filters)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
c0 = x
|
43 |
+
x = self.conv1(x)
|
44 |
+
x = self.act(x)
|
45 |
+
x = self.conv2(x)
|
46 |
+
out = self.cuca(x)
|
47 |
+
return out + c0
|
48 |
+
|
49 |
+
class NRB(nn.Module):
|
50 |
+
def __init__(self, n, f):
|
51 |
+
super(HITVPCTeam.NRB, self).__init__()
|
52 |
+
nets = []
|
53 |
+
for i in range(n):
|
54 |
+
nets.append(HITVPCTeam.RB(f))
|
55 |
+
self.body = nn.Sequential(*nets)
|
56 |
+
self.tail = nn.Conv2d(f, f, 3, 1, 1)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
return x + self.tail(self.body(x))
|
60 |
+
|
61 |
+
class DWTForward(nn.Module):
|
62 |
+
def __init__(self):
|
63 |
+
super(HITVPCTeam.DWTForward, self).__init__()
|
64 |
+
ll = np.array([[0.5, 0.5], [0.5, 0.5]])
|
65 |
+
lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
|
66 |
+
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
|
67 |
+
hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
|
68 |
+
filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1],
|
69 |
+
hl[None,::-1,::-1], hh[None,::-1,::-1]],
|
70 |
+
axis=0)
|
71 |
+
self.weight = nn.Parameter(
|
72 |
+
torch.tensor(filts).to(torch.get_default_dtype()),
|
73 |
+
requires_grad=False)
|
74 |
+
def forward(self, x):
|
75 |
+
C = x.shape[1]
|
76 |
+
filters = torch.cat([self.weight,] * C, dim=0)
|
77 |
+
y = F.conv2d(x, filters, groups=C, stride=2)
|
78 |
+
return y
|
79 |
+
|
80 |
+
class DWTInverse(nn.Module):
|
81 |
+
def __init__(self):
|
82 |
+
super(HITVPCTeam.DWTInverse, self).__init__()
|
83 |
+
ll = np.array([[0.5, 0.5], [0.5, 0.5]])
|
84 |
+
lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
|
85 |
+
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
|
86 |
+
hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
|
87 |
+
filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1],
|
88 |
+
hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]],
|
89 |
+
axis=0)
|
90 |
+
self.weight = nn.Parameter(
|
91 |
+
torch.tensor(filts).to(torch.get_default_dtype()),
|
92 |
+
requires_grad=False)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
C = int(x.shape[1] / 4)
|
96 |
+
filters = torch.cat([self.weight, ] * C, dim=0)
|
97 |
+
y = F.conv_transpose2d(x, filters, groups=C, stride=2)
|
98 |
+
return y
|
99 |
+
|
100 |
+
|
101 |
+
class Net(nn.Module):
|
102 |
+
def __init__(self, channels=1, filters_level1=96, filters_level2=256//2, filters_level3=256//2, n_rb=4*5):
|
103 |
+
super(Net, self).__init__()
|
104 |
+
|
105 |
+
self.head = HITVPCTeam.DWTForward()
|
106 |
+
|
107 |
+
self.down1 = nn.Sequential(
|
108 |
+
nn.Conv2d(channels * 4, filters_level1, 3, 1, 1),
|
109 |
+
nn.PReLU(),
|
110 |
+
HITVPCTeam.NRB(n_rb, filters_level1))
|
111 |
+
|
112 |
+
# sum 1
|
113 |
+
# self.down1 = HITVPCTeam.NRB(n_rb, filters_level1),
|
114 |
+
|
115 |
+
# sum 2
|
116 |
+
self.down2 = nn.Sequential(
|
117 |
+
HITVPCTeam.DWTForward(),
|
118 |
+
nn.Conv2d(filters_level1 * 4, filters_level2, 3, 1, 1),
|
119 |
+
nn.PReLU(),
|
120 |
+
HITVPCTeam.NRB(n_rb, filters_level2))
|
121 |
+
|
122 |
+
self.down3 = nn.Sequential(
|
123 |
+
HITVPCTeam.DWTForward(),
|
124 |
+
nn.Conv2d(filters_level2 * 4, filters_level3, 3, 1, 1),
|
125 |
+
nn.PReLU())
|
126 |
+
|
127 |
+
self.middle = HITVPCTeam.NRB(n_rb, filters_level3)
|
128 |
+
|
129 |
+
self.up1 = nn.Sequential(
|
130 |
+
nn.Conv2d(filters_level3, filters_level2 * 4, 3, 1, 1),
|
131 |
+
nn.PReLU(),
|
132 |
+
HITVPCTeam.DWTInverse())
|
133 |
+
|
134 |
+
self.up2 = nn.Sequential(
|
135 |
+
HITVPCTeam.NRB(n_rb, filters_level2),
|
136 |
+
nn.Conv2d(filters_level2, filters_level1 * 4, 3, 1, 1),
|
137 |
+
nn.PReLU(),
|
138 |
+
HITVPCTeam.DWTInverse())
|
139 |
+
|
140 |
+
self.up3 = nn.Sequential(
|
141 |
+
HITVPCTeam.NRB(n_rb, filters_level1),
|
142 |
+
nn.Conv2d(filters_level1, channels * 4, 3, 1, 1))
|
143 |
+
|
144 |
+
self.tail = HITVPCTeam.DWTInverse()
|
145 |
+
|
146 |
+
def forward(self, inputs):
|
147 |
+
c0 = inputs
|
148 |
+
c1 = self.head(c0)
|
149 |
+
c2 = self.down1(c1)
|
150 |
+
c3 = self.down2(c2)
|
151 |
+
c4 = self.down3(c3)
|
152 |
+
m = self.middle(c4)
|
153 |
+
c5 = self.up1(m) + c3
|
154 |
+
c6 = self.up2(c5) + c2
|
155 |
+
c7 = self.up3(c6) + c1
|
156 |
+
return self.tail(c7)
|
157 |
+
|
158 |
+
def _initialize_weights(self):
|
159 |
+
for m in self.modules():
|
160 |
+
if isinstance(m, nn.Conv2d):
|
161 |
+
init.orthogonal_(m.weight)
|
162 |
+
print('init weight')
|
163 |
+
if m.bias is not None:
|
164 |
+
init.constant_(m.bias, 0)
|
165 |
+
elif isinstance(m, nn.BatchNorm2d):
|
166 |
+
init.constant_(m.weight, 1)
|
167 |
+
init.constant_(m.bias, 0)
|
SCBC/networks.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
from torch.optim import lr_scheduler
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
|
8 |
+
def get_scheduler(optimizer, opt):
|
9 |
+
if opt.lr_policy == 'linear':
|
10 |
+
def lambda_rule(epoch):
|
11 |
+
return 1 - max(0, epoch-opt.niter) / max(1, float(opt.niter_decay))
|
12 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
13 |
+
elif opt.lr_policy == 'step':
|
14 |
+
scheduler = lr_scheduler.StepLR(optimizer,
|
15 |
+
step_size=opt.lr_decay_iters,
|
16 |
+
gamma=0.5)
|
17 |
+
elif opt.lr_policy == 'plateau':
|
18 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
|
19 |
+
mode='min',
|
20 |
+
factor=0.2,
|
21 |
+
threshold=0.01,
|
22 |
+
patience=5)
|
23 |
+
elif opt.lr_policy == 'cosine':
|
24 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
|
25 |
+
T_max=opt.niter,
|
26 |
+
eta_min=0)
|
27 |
+
else:
|
28 |
+
return NotImplementedError('lr [%s] is not implemented', opt.lr_policy)
|
29 |
+
return scheduler
|
30 |
+
|
31 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
32 |
+
def init_func(m): # define the initialization function
|
33 |
+
classname = m.__class__.__name__
|
34 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 \
|
35 |
+
or classname.find('Linear') != -1):
|
36 |
+
if init_type == 'normal':
|
37 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
38 |
+
elif init_type == 'xavier':
|
39 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
40 |
+
elif init_type == 'kaiming':
|
41 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
42 |
+
elif init_type == 'orthogonal':
|
43 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
44 |
+
elif init_type == 'uniform':
|
45 |
+
init.uniform_(m.weight.data, b=init_gain)
|
46 |
+
else:
|
47 |
+
raise NotImplementedError('[%s] is not implemented' % init_type)
|
48 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
49 |
+
init.constant_(m.bias.data, 0.0)
|
50 |
+
elif classname.find('BatchNorm2d') != -1:
|
51 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
52 |
+
init.constant_(m.bias.data, 0.0)
|
53 |
+
|
54 |
+
print('initialize network with %s' % init_type)
|
55 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
56 |
+
|
57 |
+
def init_net(net, init_type='default', init_gain=0.02, gpu_ids=[]):
|
58 |
+
if len(gpu_ids) > 0:
|
59 |
+
assert(torch.cuda.is_available())
|
60 |
+
net.to(gpu_ids[0])
|
61 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
62 |
+
if init_type != 'default' and init_type is not None:
|
63 |
+
init_weights(net, init_type, init_gain=init_gain)
|
64 |
+
return net
|
65 |
+
|
66 |
+
|
67 |
+
'''
|
68 |
+
# ===================================
|
69 |
+
# Advanced nn.Sequential
|
70 |
+
# reform nn.Sequentials and nn.Modules
|
71 |
+
# to a single nn.Sequential
|
72 |
+
# ===================================
|
73 |
+
'''
|
74 |
+
|
75 |
+
def seq(*args):
|
76 |
+
if len(args) == 1:
|
77 |
+
args = args[0]
|
78 |
+
if isinstance(args, nn.Module):
|
79 |
+
return args
|
80 |
+
modules = OrderedDict()
|
81 |
+
if isinstance(args, OrderedDict):
|
82 |
+
for k, v in args.items():
|
83 |
+
modules[k] = seq(v)
|
84 |
+
return nn.Sequential(modules)
|
85 |
+
assert isinstance(args, (list, tuple))
|
86 |
+
return nn.Sequential(*[seq(i) for i in args])
|
87 |
+
|
88 |
+
'''
|
89 |
+
# ===================================
|
90 |
+
# Useful blocks
|
91 |
+
# --------------------------------
|
92 |
+
# conv (+ normaliation + relu)
|
93 |
+
# concat
|
94 |
+
# sum
|
95 |
+
# resblock (ResBlock)
|
96 |
+
# resdenseblock (ResidualDenseBlock_5C)
|
97 |
+
# resinresdenseblock (RRDB)
|
98 |
+
# ===================================
|
99 |
+
'''
|
100 |
+
|
101 |
+
# -------------------------------------------------------
|
102 |
+
# return nn.Sequantial of (Conv + BN + ReLU)
|
103 |
+
# -------------------------------------------------------
|
104 |
+
def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1,
|
105 |
+
output_padding=0, dilation=1, groups=1, bias=True,
|
106 |
+
padding_mode='zeros', mode='CBR'):
|
107 |
+
L = []
|
108 |
+
for t in mode:
|
109 |
+
if t == 'C':
|
110 |
+
L.append(nn.Conv2d(in_channels=in_channels,
|
111 |
+
out_channels=out_channels,
|
112 |
+
kernel_size=kernel_size,
|
113 |
+
stride=stride,
|
114 |
+
padding=padding,
|
115 |
+
dilation=dilation,
|
116 |
+
groups=groups,
|
117 |
+
bias=bias,
|
118 |
+
padding_mode=padding_mode))
|
119 |
+
elif t == 'X':
|
120 |
+
assert in_channels == out_channels
|
121 |
+
L.append(nn.Conv2d(in_channels=in_channels,
|
122 |
+
out_channels=out_channels,
|
123 |
+
kernel_size=kernel_size,
|
124 |
+
stride=stride,
|
125 |
+
padding=padding,
|
126 |
+
dilation=dilation,
|
127 |
+
groups=in_channels,
|
128 |
+
bias=bias,
|
129 |
+
padding_mode=padding_mode))
|
130 |
+
elif t == 'T':
|
131 |
+
L.append(nn.ConvTranspose2d(in_channels=in_channels,
|
132 |
+
out_channels=out_channels,
|
133 |
+
kernel_size=kernel_size,
|
134 |
+
stride=stride,
|
135 |
+
padding=padding,
|
136 |
+
output_padding=output_padding,
|
137 |
+
groups=groups,
|
138 |
+
bias=bias,
|
139 |
+
dilation=dilation,
|
140 |
+
padding_mode=padding_mode))
|
141 |
+
elif t == 'B':
|
142 |
+
L.append(nn.BatchNorm2d(out_channels))
|
143 |
+
elif t == 'I':
|
144 |
+
L.append(nn.InstanceNorm2d(out_channels, affine=True))
|
145 |
+
elif t == 'i':
|
146 |
+
L.append(nn.InstanceNorm2d(out_channels))
|
147 |
+
elif t == 'R':
|
148 |
+
L.append(nn.ReLU(inplace=True))
|
149 |
+
elif t == 'r':
|
150 |
+
L.append(nn.ReLU(inplace=False))
|
151 |
+
elif t == 'S':
|
152 |
+
L.append(nn.Sigmoid())
|
153 |
+
elif t == 'P':
|
154 |
+
L.append(nn.PReLU())
|
155 |
+
elif t == 'L':
|
156 |
+
L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True))
|
157 |
+
elif t == 'l':
|
158 |
+
L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False))
|
159 |
+
elif t == '2':
|
160 |
+
L.append(nn.PixelShuffle(upscale_factor=2))
|
161 |
+
elif t == '3':
|
162 |
+
L.append(nn.PixelShuffle(upscale_factor=3))
|
163 |
+
elif t == '4':
|
164 |
+
L.append(nn.PixelShuffle(upscale_factor=4))
|
165 |
+
elif t == 'U':
|
166 |
+
L.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
167 |
+
elif t == 'u':
|
168 |
+
L.append(nn.Upsample(scale_factor=3, mode='nearest'))
|
169 |
+
elif t == 'M':
|
170 |
+
L.append(nn.MaxPool2d(kernel_size=kernel_size,
|
171 |
+
stride=stride,
|
172 |
+
padding=0))
|
173 |
+
elif t == 'A':
|
174 |
+
L.append(nn.AvgPool2d(kernel_size=kernel_size,
|
175 |
+
stride=stride,
|
176 |
+
padding=0))
|
177 |
+
else:
|
178 |
+
raise NotImplementedError('Undefined type: '.format(t))
|
179 |
+
return seq(*L)
|
180 |
+
|
181 |
+
|
182 |
+
class DWTForward(nn.Conv2d):
|
183 |
+
def __init__(self, in_channels=64):
|
184 |
+
super(DWTForward, self).__init__(in_channels, in_channels*4, 2, 2,
|
185 |
+
groups=in_channels, bias=False)
|
186 |
+
weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]],
|
187 |
+
[[[0.5, 0.5], [-0.5, -0.5]]],
|
188 |
+
[[[0.5, -0.5], [ 0.5, -0.5]]],
|
189 |
+
[[[0.5, -0.5], [-0.5, 0.5]]]],
|
190 |
+
dtype=torch.get_default_dtype()
|
191 |
+
).repeat(in_channels, 1, 1, 1)# / 2
|
192 |
+
self.weight.data.copy_(weight)
|
193 |
+
self.requires_grad_(False)
|
194 |
+
|
195 |
+
|
196 |
+
class DWTInverse(nn.ConvTranspose2d):
|
197 |
+
def __init__(self, in_channels=64):
|
198 |
+
super(DWTInverse, self).__init__(in_channels, in_channels//4, 2, 2,
|
199 |
+
groups=in_channels//4, bias=False)
|
200 |
+
weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]],
|
201 |
+
[[[0.5, 0.5], [-0.5, -0.5]]],
|
202 |
+
[[[0.5, -0.5], [ 0.5, -0.5]]],
|
203 |
+
[[[0.5, -0.5], [-0.5, 0.5]]]],
|
204 |
+
dtype=torch.get_default_dtype()
|
205 |
+
).repeat(in_channels//4, 1, 1, 1)# * 2
|
206 |
+
self.weight.data.copy_(weight)
|
207 |
+
self.requires_grad_(False)
|
208 |
+
|
209 |
+
|
210 |
+
# -------------------------------------------------------
|
211 |
+
# Channel Attention (CA) Layer
|
212 |
+
# -------------------------------------------------------
|
213 |
+
class CALayer(nn.Module):
|
214 |
+
def __init__(self, channel=64, reduction=16):
|
215 |
+
super(CALayer, self).__init__()
|
216 |
+
|
217 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
218 |
+
self.conv_du = nn.Sequential(
|
219 |
+
nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True),
|
220 |
+
nn.ReLU(inplace=True),
|
221 |
+
nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True),
|
222 |
+
nn.Sigmoid()
|
223 |
+
)
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
y = self.avg_pool(x)
|
227 |
+
y = self.conv_du(y)
|
228 |
+
return x * y
|
229 |
+
|
230 |
+
|
231 |
+
# -------------------------------------------------------
|
232 |
+
# Res Block: x + conv(relu(conv(x)))
|
233 |
+
# -------------------------------------------------------
|
234 |
+
class ResBlock(nn.Module):
|
235 |
+
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1,
|
236 |
+
padding=1, bias=True, mode='CRC'):
|
237 |
+
super(ResBlock, self).__init__()
|
238 |
+
|
239 |
+
assert in_channels == out_channels
|
240 |
+
if mode[0] in ['R','L']:
|
241 |
+
mode = mode[0].lower() + mode[1:]
|
242 |
+
|
243 |
+
self.res = conv(in_channels, out_channels, kernel_size,
|
244 |
+
stride, padding=padding, bias=bias, mode=mode)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
res = self.res(x)
|
248 |
+
return x + res
|
249 |
+
|
250 |
+
|
251 |
+
# -------------------------------------------------------
|
252 |
+
# Residual Channel Attention Block (RCAB)
|
253 |
+
# -------------------------------------------------------
|
254 |
+
class RCABlock(nn.Module):
|
255 |
+
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1,
|
256 |
+
padding=1, bias=True, mode='CRC', reduction=16):
|
257 |
+
super(RCABlock, self).__init__()
|
258 |
+
assert in_channels == out_channels
|
259 |
+
if mode[0] in ['R','L']:
|
260 |
+
mode = mode[0].lower() + mode[1:]
|
261 |
+
|
262 |
+
self.res = conv(in_channels, out_channels, kernel_size,
|
263 |
+
stride, padding, bias=bias, mode=mode)
|
264 |
+
self.ca = CALayer(out_channels, reduction)
|
265 |
+
|
266 |
+
def forward(self, x):
|
267 |
+
res = self.res(x)
|
268 |
+
res = self.ca(res)
|
269 |
+
return res + x
|
270 |
+
|
271 |
+
|
272 |
+
# -------------------------------------------------------
|
273 |
+
# Residual Channel Attention Group (RG)
|
274 |
+
# -------------------------------------------------------
|
275 |
+
class RCAGroup(nn.Module):
|
276 |
+
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1,
|
277 |
+
padding=1, bias=True, mode='CRC', reduction=16, nb=12):
|
278 |
+
super(RCAGroup, self).__init__()
|
279 |
+
assert in_channels == out_channels
|
280 |
+
if mode[0] in ['R','L']:
|
281 |
+
mode = mode[0].lower() + mode[1:]
|
282 |
+
|
283 |
+
RG = [RCABlock(in_channels, out_channels, kernel_size, stride, padding,
|
284 |
+
bias, mode, reduction) for _ in range(nb)]
|
285 |
+
# RG = [ResBlock(in_channels, out_channels, kernel_size, stride, padding,
|
286 |
+
# bias, mode) for _ in range(nb)]
|
287 |
+
RG.append(conv(out_channels, out_channels, mode='C'))
|
288 |
+
|
289 |
+
self.rg = nn.Sequential(*RG)
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
res = self.rg(x)
|
293 |
+
return res + x
|
294 |
+
|
SCBC/requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python
|
2 |
+
scipy
|
3 |
+
numpy
|
4 |
+
torch
|
5 |
+
pandas
|
6 |
+
torchvision
|
7 |
+
Pillow
|
8 |
+
matplotlib
|
9 |
+
tqdm
|
10 |
+
imageio
|
11 |
+
seaborn
|
12 |
+
hdf5storage
|
13 |
+
exifread
|
SCBC/run.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python SCBC_Solution.py
|