Artyom
commited on
Commit
•
bd1c686
1
Parent(s):
6721043
dh-aisp
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- DH-AISP/1/__pycache__/awb.cpython-36.pyc +0 -0
- DH-AISP/1/awb.py +184 -0
- DH-AISP/1/daylight_isp_03_3_unet_sid_5/checkpoint +2 -0
- DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.data-00000-of-00001 +3 -0
- DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.index +0 -0
- DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.meta +3 -0
- DH-AISP/1/tensorflow2to1_3_unet_bining3_7.py +451 -0
- DH-AISP/2/__pycache__/model_convnext2_hdr.cpython-37.pyc +0 -0
- DH-AISP/2/__pycache__/myFFCResblock0.cpython-37.pyc +0 -0
- DH-AISP/2/__pycache__/test_dataset_for_testing.cpython-37.pyc +0 -0
- DH-AISP/2/focal_frequency_loss/__init__.py +3 -0
- DH-AISP/2/focal_frequency_loss/__pycache__/__init__.cpython-37.pyc +0 -0
- DH-AISP/2/focal_frequency_loss/__pycache__/focal_frequency_loss.cpython-37.pyc +0 -0
- DH-AISP/2/focal_frequency_loss/focal_frequency_loss.py +114 -0
- DH-AISP/2/model_convnext2_hdr.py +592 -0
- DH-AISP/2/myFFCResblock0.py +60 -0
- DH-AISP/2/perceptual.py +30 -0
- DH-AISP/2/pytorch_msssim/__init__.py +133 -0
- DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-36.pyc +0 -0
- DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-37.pyc +0 -0
- DH-AISP/2/result_low_light_hdr/checkpoint_gen.pth +3 -0
- DH-AISP/2/saicinpainting/__init__.py +0 -0
- DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-36.pyc +0 -0
- DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-37.pyc +0 -0
- DH-AISP/2/saicinpainting/__pycache__/utils.cpython-36.pyc +0 -0
- DH-AISP/2/saicinpainting/__pycache__/utils.cpython-37.pyc +0 -0
- DH-AISP/2/saicinpainting/evaluation/__init__.py +33 -0
- DH-AISP/2/saicinpainting/evaluation/data.py +168 -0
- DH-AISP/2/saicinpainting/evaluation/evaluator.py +220 -0
- DH-AISP/2/saicinpainting/evaluation/losses/__init__.py +0 -0
- DH-AISP/2/saicinpainting/evaluation/losses/base_loss.py +528 -0
- DH-AISP/2/saicinpainting/evaluation/losses/fid/__init__.py +0 -0
- DH-AISP/2/saicinpainting/evaluation/losses/fid/fid_score.py +328 -0
- DH-AISP/2/saicinpainting/evaluation/losses/fid/inception.py +323 -0
- DH-AISP/2/saicinpainting/evaluation/losses/lpips.py +891 -0
- DH-AISP/2/saicinpainting/evaluation/losses/ssim.py +74 -0
- DH-AISP/2/saicinpainting/evaluation/masks/README.md +27 -0
- DH-AISP/2/saicinpainting/evaluation/masks/__init__.py +0 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/README.md +25 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/__init__.py +0 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/countless2d.py +529 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/countless3d.py +356 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gcim.jpg +3 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png +0 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/images/segmentation.png +0 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/images/sparse.png +0 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png +0 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png +0 -0
- DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d.png +0 -0
.gitattributes
CHANGED
@@ -39,3 +39,6 @@ SCBC/Input/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text
|
|
39 |
SCBC/Output/IMG_20240215_213330.png filter=lfs diff=lfs merge=lfs -text
|
40 |
SCBC/Output/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text
|
41 |
PolyuColor/resources/average_shading.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
39 |
SCBC/Output/IMG_20240215_213330.png filter=lfs diff=lfs merge=lfs -text
|
40 |
SCBC/Output/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text
|
41 |
PolyuColor/resources/average_shading.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
43 |
+
DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.meta filter=lfs diff=lfs merge=lfs -text
|
44 |
+
DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gcim.jpg filter=lfs diff=lfs merge=lfs -text
|
DH-AISP/1/__pycache__/awb.cpython-36.pyc
ADDED
Binary file (3.82 kB). View file
|
|
DH-AISP/1/awb.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from glob import glob
|
5 |
+
|
6 |
+
|
7 |
+
def dynamic(rgb):
|
8 |
+
|
9 |
+
rgb = rgb[:-1, :-1, :] # 删去一行一列
|
10 |
+
h, w, _ = rgb.shape
|
11 |
+
col = 4
|
12 |
+
row = 3
|
13 |
+
h1 = h // row
|
14 |
+
w1 = w // col
|
15 |
+
|
16 |
+
r, g, b = cv2.split(rgb)
|
17 |
+
r_mask = r < 0.95
|
18 |
+
g_mask = g < 0.95
|
19 |
+
b_mask = b < 0.95
|
20 |
+
mask = r_mask * g_mask * b_mask
|
21 |
+
r *= mask
|
22 |
+
g *= mask
|
23 |
+
b *= mask
|
24 |
+
rgb = np.stack((r, g, b), axis=2)
|
25 |
+
|
26 |
+
y, cr, cb = cv2.split(cv2.cvtColor(rgb, cv2.COLOR_RGB2YCrCb))
|
27 |
+
cr -= 0.5
|
28 |
+
cb -= 0.5
|
29 |
+
|
30 |
+
mr, mb, dr, db = 0, 0, 0, 0
|
31 |
+
for r in range(row):
|
32 |
+
for c in range(col):
|
33 |
+
cr_1 = cr[r * h1:(r + 1) * h1, c * w1:(c + 1) * w1]
|
34 |
+
cb_1 = cb[r * h1:(r + 1) * h1, c * w1:(c + 1) * w1]
|
35 |
+
mr_1 = np.mean(cr_1)
|
36 |
+
mb_1 = np.mean(cb_1)
|
37 |
+
dr_1 = np.mean(np.abs(cr_1 - mr))
|
38 |
+
db_1 = np.mean(np.abs(cb_1 - mb))
|
39 |
+
|
40 |
+
mr += mr_1
|
41 |
+
mb += mb_1
|
42 |
+
dr += dr_1
|
43 |
+
db += db_1
|
44 |
+
|
45 |
+
mr /= col * row
|
46 |
+
mb /= col * row
|
47 |
+
dr /= col * row
|
48 |
+
db /= col * row
|
49 |
+
|
50 |
+
cb_mask = np.abs(cb - (mb + db * np.sign(mb))) < 1.5 * db
|
51 |
+
cr_mask = np.abs(cr - (1.5 * mr + dr * np.sign(mr))) < 1.5 * dr
|
52 |
+
|
53 |
+
mask = cb_mask * cr_mask
|
54 |
+
y_white = y * mask
|
55 |
+
|
56 |
+
hist_y = np.zeros(256, dtype=np.int)
|
57 |
+
y_white_uint8 = (y_white * 255).astype(np.int)
|
58 |
+
|
59 |
+
for v in range(255):
|
60 |
+
hist_y[v] = np.sum(y_white_uint8 == v)
|
61 |
+
|
62 |
+
thr_sum = 0.05 * np.sum(mask)
|
63 |
+
sum_v = 0
|
64 |
+
thr = 0
|
65 |
+
for v in range(255, -1, -1):
|
66 |
+
sum_v = sum_v + hist_y[v]
|
67 |
+
if sum_v > thr_sum:
|
68 |
+
thr = v
|
69 |
+
break
|
70 |
+
|
71 |
+
white_mask = y_white_uint8 > thr
|
72 |
+
cv2.imwrite(r'V:\Project\3_MEWDR\data\2nd_awb\t.png', (white_mask + 0) * 255)
|
73 |
+
|
74 |
+
r, g, b = cv2.split(rgb)
|
75 |
+
r_ave = np.sum(r[white_mask]) / np.sum(white_mask)
|
76 |
+
g_ave = np.sum(g[white_mask]) / np.sum(white_mask)
|
77 |
+
b_ave = np.sum(b[white_mask]) / np.sum(white_mask)
|
78 |
+
|
79 |
+
return 1 / r_ave, 1 / g_ave, 1 / b_ave
|
80 |
+
|
81 |
+
|
82 |
+
def perf_ref(rgb, eps):
|
83 |
+
h, w, _ = rgb.shape
|
84 |
+
|
85 |
+
r, g, b = cv2.split(rgb)
|
86 |
+
r_mask = r < 0.95
|
87 |
+
g_mask = g < 0.95
|
88 |
+
b_mask = b < 0.95
|
89 |
+
mask = r_mask * g_mask * b_mask
|
90 |
+
r *= mask
|
91 |
+
g *= mask
|
92 |
+
b *= mask
|
93 |
+
rgb = np.stack((r, g, b), axis=2)
|
94 |
+
rgb = np.clip(rgb * 255, 0, 255).astype(np.int)
|
95 |
+
|
96 |
+
hist_rgb = np.zeros(255 * 3, dtype=np.int)
|
97 |
+
rgb_sum = np.sum(rgb, axis=2)
|
98 |
+
|
99 |
+
for v in range(255 * 3):
|
100 |
+
hist_rgb[v] = np.sum(rgb_sum == v)
|
101 |
+
|
102 |
+
thr_sum = eps * h * w
|
103 |
+
sum_v = 0
|
104 |
+
thr = 0
|
105 |
+
for v in range(255 * 3 - 1, -1, -1):
|
106 |
+
sum_v = sum_v + hist_rgb[v]
|
107 |
+
if sum_v > thr_sum:
|
108 |
+
thr = v
|
109 |
+
break
|
110 |
+
|
111 |
+
thr_mask = rgb_sum > thr
|
112 |
+
r_ave = np.sum(r[thr_mask]) / np.sum(thr_mask)
|
113 |
+
g_ave = np.sum(g[thr_mask]) / np.sum(thr_mask)
|
114 |
+
b_ave = np.sum(b[thr_mask]) / np.sum(thr_mask)
|
115 |
+
|
116 |
+
# k = (r_ave + g_ave + b_ave) / 3.
|
117 |
+
# k = 255
|
118 |
+
|
119 |
+
# print(k)
|
120 |
+
|
121 |
+
# r = np.clip(r * k / r_ave, 0, 255)
|
122 |
+
# g = np.clip(g * k / g_ave, 0, 255)
|
123 |
+
# b = np.clip(b * k / b_ave, 0, 255)
|
124 |
+
|
125 |
+
return 1 / r_ave, 1 / g_ave, 1 / b_ave
|
126 |
+
|
127 |
+
|
128 |
+
def awb_v(in_image, bayer, eps):
|
129 |
+
|
130 |
+
assert bayer in ['GBRG', 'RGGB']
|
131 |
+
|
132 |
+
if bayer == 'GBRG':
|
133 |
+
g = in_image[0::2, 0::2] # [0,0]
|
134 |
+
b = in_image[0::2, 1::2] # [0,1]
|
135 |
+
r = in_image[1::2, 0::2] # [1,0]
|
136 |
+
else:
|
137 |
+
r = in_image[0::2, 0::2] # [0,0]
|
138 |
+
g = in_image[0::2, 1::2] # [0,1]
|
139 |
+
b = in_image[1::2, 1::2] # [1,1]
|
140 |
+
|
141 |
+
rgb = cv2.merge((r, g, b)) * 1
|
142 |
+
|
143 |
+
r_gain, g_gain, b_gain = perf_ref(rgb, eps)
|
144 |
+
|
145 |
+
return r_gain / g_gain, b_gain / g_gain
|
146 |
+
|
147 |
+
|
148 |
+
def main():
|
149 |
+
path = r'V:\Project\3_MEWDR\data\2nd_raw'
|
150 |
+
# out_path = r'V:\Project\3_MEWDR\data\2nd_awb'
|
151 |
+
|
152 |
+
files = glob(os.path.join(path, '*.png'))
|
153 |
+
|
154 |
+
for f in files:
|
155 |
+
img = cv2.imread(f, cv2.CV_16UC1)
|
156 |
+
img = (img.astype(np.float) - 2048) / (15400 - 2048) * 4
|
157 |
+
|
158 |
+
g = img[0::2, 0::2] # [0,0]
|
159 |
+
b = img[0::2, 1::2] # [0,1]
|
160 |
+
r = img[1::2, 0::2] # [1,0]
|
161 |
+
# g_ = img[1::2, 1::2]
|
162 |
+
|
163 |
+
rgb = cv2.merge((r, g, b))
|
164 |
+
|
165 |
+
# save_name = f.replace('.png', '_rgb.png').replace('2nd_raw', '2nd_awb')
|
166 |
+
|
167 |
+
r_gain, g_gain, b_gain = perf_ref(rgb, eps=0.1)
|
168 |
+
# r_gain, g_gain, b_gain = dynamic(rgb.astype(np.float32))
|
169 |
+
|
170 |
+
r *= r_gain / g_gain
|
171 |
+
b *= b_gain / g_gain
|
172 |
+
print(r_gain / g_gain, b_gain / g_gain)
|
173 |
+
|
174 |
+
out_rgb = np.clip(cv2.merge((r, g, b)) * 255, 0, 255)
|
175 |
+
|
176 |
+
save_name = f.replace('.png', '_awb4_dyn.png').replace('2nd_raw', '2nd_awb')
|
177 |
+
|
178 |
+
cv2.imwrite(save_name, cv2.cvtColor(out_rgb.astype(np.uint8), cv2.COLOR_RGB2BGR))
|
179 |
+
|
180 |
+
# break
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == '__main__':
|
184 |
+
main()
|
DH-AISP/1/daylight_isp_03_3_unet_sid_5/checkpoint
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
model_checkpoint_path: "model.ckpt"
|
2 |
+
all_model_checkpoint_paths: "model.ckpt"
|
DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6997bfa5624aba66e2497088cc8f379db63bac343a0a648e08f6a5840a48259f
|
3 |
+
size 175070404
|
DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.index
ADDED
Binary file (6.36 kB). View file
|
|
DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.meta
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79f84947bf3a5a9e851539308b85b43ecc6a8e93ed2c7ab9adb23f0fd6796286
|
3 |
+
size 124053471
|
DH-AISP/1/tensorflow2to1_3_unet_bining3_7.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# uniform content loss + adaptive threshold + per_class_input + recursive G
|
2 |
+
# improvement upon cqf37
|
3 |
+
from __future__ import division
|
4 |
+
import os
|
5 |
+
import tensorflow.compat.v1 as tf
|
6 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
7 |
+
import tf_slim as slim
|
8 |
+
import tensorflow as tf2
|
9 |
+
tf2.test.is_gpu_available()
|
10 |
+
import numpy as np
|
11 |
+
import glob
|
12 |
+
# import scipy.io as sio
|
13 |
+
import cv2
|
14 |
+
import json
|
15 |
+
from fractions import Fraction
|
16 |
+
import pdb
|
17 |
+
import sys
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
from awb import awb_v
|
21 |
+
|
22 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
23 |
+
input_dir = '../data/'
|
24 |
+
cha1 = 32
|
25 |
+
|
26 |
+
# get train IDs
|
27 |
+
train_fns = glob.glob(input_dir + '*.png')
|
28 |
+
train_ids = [os.path.basename(train_fn) for train_fn in train_fns]
|
29 |
+
|
30 |
+
result_dir = './mid/'
|
31 |
+
checkpoint_dir = './daylight_isp_03_3_unet_sid_5/'
|
32 |
+
|
33 |
+
if not os.path.exists(result_dir):
|
34 |
+
os.mkdir(result_dir)
|
35 |
+
|
36 |
+
#run python tensorflow2to1_1214_5202x3464_01_unetpp3.py ./data/ ./result/ ./daylight_isp_03/
|
37 |
+
|
38 |
+
# DEBUG = 0
|
39 |
+
# if DEBUG == 1:
|
40 |
+
# save_freq = 2
|
41 |
+
# test_ids = test_ids[0:5]
|
42 |
+
|
43 |
+
def json_read(fname, **kwargs):
|
44 |
+
with open(fname) as j:
|
45 |
+
data = json.load(j, **kwargs)
|
46 |
+
return data
|
47 |
+
|
48 |
+
def fraction_from_json(json_object):
|
49 |
+
if 'Fraction' in json_object:
|
50 |
+
return Fraction(*json_object['Fraction'])
|
51 |
+
return json_object
|
52 |
+
|
53 |
+
def fractions2floats(fractions):
|
54 |
+
floats = []
|
55 |
+
for fraction in fractions:
|
56 |
+
floats.append(float(fraction.numerator) / fraction.denominator)
|
57 |
+
return floats
|
58 |
+
|
59 |
+
def tv_loss(input_, output):
|
60 |
+
I = tf.image.rgb_to_grayscale(input_)
|
61 |
+
L = tf.log(I+0.0001)
|
62 |
+
dx = L[:, :-1, :-1, :] - L[:, :-1, 1:, :]
|
63 |
+
dy = L[:, :-1, :-1, :] - L[:, 1:, :-1, :]
|
64 |
+
|
65 |
+
alpha = tf.constant(1.2)
|
66 |
+
lamda = tf.constant(1.5)
|
67 |
+
dx = tf.divide(lamda, tf.pow(tf.abs(dx),alpha)+ tf.constant(0.0001))
|
68 |
+
dy = tf.divide(lamda, tf.pow(tf.abs(dy),alpha)+ tf.constant(0.0001))
|
69 |
+
shape = output.get_shape()
|
70 |
+
x_loss = dx *((output[:, :-1, :-1, :] - output[:, :-1, 1:, :])**2)
|
71 |
+
y_loss = dy *((output[:, :-1, :-1, :] - output[:, 1:, :-1, :])**2)
|
72 |
+
tvloss = tf.reduce_mean(x_loss + y_loss)/2.0
|
73 |
+
return tvloss
|
74 |
+
|
75 |
+
def lrelu(x):
|
76 |
+
return tf.maximum(x * 0.2, x)
|
77 |
+
|
78 |
+
|
79 |
+
def upsample_and_concat_3(x1, x2, output_channels, in_channels, name):
|
80 |
+
with tf.variable_scope(name):
|
81 |
+
x1 = slim.conv2d(x1, output_channels, [3, 3], rate=1, activation_fn=lrelu, scope='conv_2to1')
|
82 |
+
deconv = tf.image.resize_images(x1, [x1.shape[1] * 2, x1.shape[2] * 2])
|
83 |
+
deconv_output = tf.concat([deconv, x2], 3)
|
84 |
+
deconv_output.set_shape([None, None, None, output_channels * 2])
|
85 |
+
return deconv_output
|
86 |
+
|
87 |
+
|
88 |
+
def upsample_and_concat_h(x1, x2, output_channels, in_channels, name):
|
89 |
+
with tf.variable_scope(name):
|
90 |
+
#deconv = tf.image.resize_images(x1, [x1.shape[1].value*2, x1.shape[2].value*2])
|
91 |
+
pool_size = 2
|
92 |
+
deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
|
93 |
+
deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])
|
94 |
+
|
95 |
+
deconv_output = tf.concat([deconv, x2], 3)
|
96 |
+
deconv_output.set_shape([None, None, None, output_channels * 2])
|
97 |
+
|
98 |
+
return deconv_output
|
99 |
+
|
100 |
+
def upsample_and_concat_h_only(x1, output_channels, in_channels, name):
|
101 |
+
with tf.variable_scope(name):
|
102 |
+
x1 = tf.image.resize_images(x1, [x1.shape[1] * 2, x1.shape[2] * 2])
|
103 |
+
x1.set_shape([None, None, None, output_channels])
|
104 |
+
return x1
|
105 |
+
|
106 |
+
|
107 |
+
def conv_block(input, output_channels, name):
|
108 |
+
with tf.variable_scope(name):
|
109 |
+
conv = slim.conv2d(input, output_channels, [3, 3], activation_fn=lrelu, scope='conv1')
|
110 |
+
conv = slim.conv2d(conv, output_channels, [3, 3], activation_fn=lrelu, scope='conv2')
|
111 |
+
return conv
|
112 |
+
|
113 |
+
|
114 |
+
def conv_block_up(input, output_channels, name):
|
115 |
+
with tf.variable_scope(name):
|
116 |
+
conv = slim.conv2d(input, output_channels, [1, 1], scope='conv0')
|
117 |
+
conv = slim.conv2d(conv, output_channels, [3, 3], activation_fn=lrelu, scope='conv1')
|
118 |
+
conv = slim.conv2d(conv, output_channels, [3, 3], activation_fn=lrelu, scope='conv2')
|
119 |
+
|
120 |
+
return conv
|
121 |
+
|
122 |
+
|
123 |
+
def upsample_and_concat(x1, x2, output_channels, in_channels, p, name):
|
124 |
+
with tf.variable_scope(name):
|
125 |
+
pool_size = 2
|
126 |
+
|
127 |
+
deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
|
128 |
+
deconv_filter = tf.cast(deconv_filter, x1.dtype)
|
129 |
+
deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2[0]), strides=[1, pool_size, pool_size, 1])
|
130 |
+
# x2.append(deconv)
|
131 |
+
x2 = tf.concat(x2, axis=3)
|
132 |
+
deconv_output = tf.concat([x2, deconv], axis=3)
|
133 |
+
deconv_output.set_shape([None, None, None, output_channels * (p + 1)])
|
134 |
+
|
135 |
+
return deconv_output
|
136 |
+
|
137 |
+
|
138 |
+
def network(input):
|
139 |
+
with tf.variable_scope("generator_h"):
|
140 |
+
conv1_h = slim.conv2d(input, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_1')
|
141 |
+
conv1_h = slim.conv2d(conv1_h, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_2')
|
142 |
+
pool1_h = slim.max_pool2d(conv1_h, [2, 2], padding='SAME')
|
143 |
+
|
144 |
+
conv2_h = slim.conv2d(pool1_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_1')
|
145 |
+
conv2_h = slim.conv2d(conv2_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_2')
|
146 |
+
pool2_h = slim.max_pool2d(conv2_h, [2, 2], padding='SAME')
|
147 |
+
|
148 |
+
conv3_h = slim.conv2d(pool2_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_1')
|
149 |
+
conv3_h = slim.conv2d(conv3_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_2')
|
150 |
+
pool3_h = slim.max_pool2d(conv3_h, [2, 2], padding='SAME')
|
151 |
+
|
152 |
+
conv4_h = slim.conv2d(pool3_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_1')
|
153 |
+
conv4_h = slim.conv2d(conv4_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_2')
|
154 |
+
conv6_h = slim.conv2d(conv4_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_1')
|
155 |
+
conv6_h = slim.conv2d(conv6_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_2')
|
156 |
+
|
157 |
+
up7_h = upsample_and_concat_3(conv6_h, conv3_h, cha1*4,cha1*8, name='up7')
|
158 |
+
conv7_h = slim.conv2d(up7_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_1')
|
159 |
+
conv7_h = slim.conv2d(conv7_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_2')
|
160 |
+
|
161 |
+
up8_h = upsample_and_concat_3(conv7_h, conv2_h, cha1*2,cha1*4, name='up8')
|
162 |
+
conv8_h = slim.conv2d(up8_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_1')
|
163 |
+
conv8_h = slim.conv2d(conv8_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_2')
|
164 |
+
|
165 |
+
up9_h = upsample_and_concat_3(conv8_h, conv1_h, cha1,cha1*2, name='up9')
|
166 |
+
conv9_h = slim.conv2d(up9_h, cha1, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_1')
|
167 |
+
conv9_h = slim.conv2d(conv9_h, cha1, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_2')
|
168 |
+
|
169 |
+
up10_h = upsample_and_concat_h_only(conv9_h, cha1,cha1, name='up10')
|
170 |
+
conv10_h = slim.conv2d(up10_h, cha1, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv10_1')
|
171 |
+
out = slim.conv2d(conv10_h, 3, [3, 3], rate=1, activation_fn=None, scope='g_conv10_2')
|
172 |
+
return out
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
def fix_orientation(image, orientation):
|
177 |
+
# 1 = Horizontal(normal)
|
178 |
+
# 2 = Mirror horizontal
|
179 |
+
# 3 = Rotate 180
|
180 |
+
# 4 = Mirror vertical
|
181 |
+
# 5 = Mirror horizontal and rotate 270 CW
|
182 |
+
# 6 = Rotate 90 CW
|
183 |
+
# 7 = Mirror horizontal and rotate 90 CW
|
184 |
+
# 8 = Rotate 270 CW
|
185 |
+
|
186 |
+
if type(orientation) is list:
|
187 |
+
orientation = orientation[0]
|
188 |
+
|
189 |
+
if orientation == 'Horizontal (normal)':
|
190 |
+
pass
|
191 |
+
elif orientation == 'Mirror horizontal':
|
192 |
+
image = cv2.flip(image, 0)
|
193 |
+
elif orientation == 'Rotate 180':
|
194 |
+
image = cv2.rotate(image, cv2.ROTATE_180)
|
195 |
+
elif orientation == 'Mirror vertical':
|
196 |
+
image = cv2.flip(image, 1)
|
197 |
+
elif orientation == 'Mirror horizontal and rotate 270 CW':
|
198 |
+
image = cv2.flip(image, 0)
|
199 |
+
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
200 |
+
elif orientation == 'Rotate 90 CW':
|
201 |
+
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
|
202 |
+
elif orientation == 'Mirror horizontal and rotate 90 CW':
|
203 |
+
image = cv2.flip(image, 0)
|
204 |
+
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
|
205 |
+
elif orientation == 'Rotate 270 CW':
|
206 |
+
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
207 |
+
|
208 |
+
return image
|
209 |
+
|
210 |
+
class ExposureFusion(object):
|
211 |
+
def __init__(self, sequence, best_exposedness=0.5, sigma=0.2, eps=1e-12, exponents=(1.0, 1.0, 1.0), layers=11):
|
212 |
+
self.sequence = sequence # [N, H, W, 3], (0..1), float32
|
213 |
+
self.img_num = sequence.shape[0]
|
214 |
+
self.best_exposedness = best_exposedness
|
215 |
+
self.sigma = sigma
|
216 |
+
self.eps = eps
|
217 |
+
self.exponents = exponents
|
218 |
+
self.layers = layers
|
219 |
+
|
220 |
+
@staticmethod
|
221 |
+
def cal_contrast(src):
|
222 |
+
gray = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
|
223 |
+
laplace_kernel = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32)
|
224 |
+
contrast = cv2.filter2D(gray, -1, laplace_kernel, borderType=cv2.BORDER_REPLICATE)
|
225 |
+
return np.abs(contrast)
|
226 |
+
|
227 |
+
@staticmethod
|
228 |
+
def cal_saturation(src):
|
229 |
+
mean = np.mean(src, axis=-1)
|
230 |
+
channels = [(src[:, :, c] - mean)**2 for c in range(3)]
|
231 |
+
saturation = np.sqrt(np.mean(channels, axis=0))
|
232 |
+
return saturation
|
233 |
+
|
234 |
+
@staticmethod
|
235 |
+
def cal_exposedness(src, best_exposedness, sigma):
|
236 |
+
exposedness = [gauss_curve(src[:, :, c], best_exposedness, sigma) for c in range(3)]
|
237 |
+
exposedness = np.prod(exposedness, axis=0)
|
238 |
+
return exposedness
|
239 |
+
|
240 |
+
def cal_weight_map(self):
|
241 |
+
#pdb.set_trace()
|
242 |
+
weights = []
|
243 |
+
for idx in range(self.sequence.shape[0]):
|
244 |
+
contrast = self.cal_contrast(self.sequence[idx])
|
245 |
+
saturation = self.cal_saturation(self.sequence[idx])
|
246 |
+
exposedness = self.cal_exposedness(self.sequence[idx], self.best_exposedness, self.sigma)
|
247 |
+
weight = np.power(contrast, self.exponents[0]) * np.power(saturation, self.exponents[1]) * np.power(exposedness, self.exponents[2])
|
248 |
+
# Gauss Blur
|
249 |
+
# weight = cv2.GaussianBlur(weight, (21, 21), 2.1)
|
250 |
+
weights.append(weight)
|
251 |
+
#pdb.set_trace()
|
252 |
+
weights = np.stack(weights, 0) + self.eps
|
253 |
+
# normalize
|
254 |
+
weights = weights / np.expand_dims(np.sum(weights, axis=0), axis=0)
|
255 |
+
return weights
|
256 |
+
|
257 |
+
def naive_fusion(self):
|
258 |
+
weights = self.cal_weight_map() # [N, H, W]
|
259 |
+
weights = np.stack([weights, weights, weights], axis=-1) # [N, H, W, 3]
|
260 |
+
naive_fusion = np.sum(weights * self.sequence * 255, axis=0)
|
261 |
+
naive_fusion = np.clip(naive_fusion, 0, 255).astype(np.uint8)
|
262 |
+
return naive_fusion
|
263 |
+
|
264 |
+
def build_gaussian_pyramid(self, high_res):
|
265 |
+
#pdb.set_trace()
|
266 |
+
gaussian_pyramid = [high_res]
|
267 |
+
for idx in range(1, self.layers):
|
268 |
+
kernel1=np.array([[0.0039,0.0156,0.0234,0.0156,0.0039],[0.0156,0.0625,0.0938,0.0625,0.0156],[0.0234,0.0938,0.1406,0.0938,0.0234],[0.0156,0.0625,0.0938,0.0625,0.0156],[0.0039,0.0156,0.0234,0.0156,0.0039]],dtype='float32')
|
269 |
+
gaussian_pyramid.append(cv2.filter2D(gaussian_pyramid[-1], -1,kernel=kernel1)[::2, ::2])
|
270 |
+
#gaussian_pyramid.append(cv2.GaussianBlur(gaussian_pyramid[-1], (5, 5), 0.83)[::2, ::2])
|
271 |
+
return gaussian_pyramid
|
272 |
+
|
273 |
+
def build_laplace_pyramid(self, gaussian_pyramid):
|
274 |
+
laplace_pyramid = [gaussian_pyramid[-1]]
|
275 |
+
for idx in range(1, self.layers):
|
276 |
+
size = (gaussian_pyramid[self.layers - idx - 1].shape[1], gaussian_pyramid[self.layers - idx - 1].shape[0])
|
277 |
+
upsampled = cv2.resize(gaussian_pyramid[self.layers - idx], size, interpolation=cv2.INTER_LINEAR)
|
278 |
+
laplace_pyramid.append(gaussian_pyramid[self.layers - idx - 1] - upsampled)
|
279 |
+
laplace_pyramid.reverse()
|
280 |
+
return laplace_pyramid
|
281 |
+
|
282 |
+
def multi_resolution_fusion(self):
|
283 |
+
#pdb.set_trace()
|
284 |
+
weights = self.cal_weight_map() # [N, H, W]
|
285 |
+
weights = np.stack([weights, weights, weights], axis=-1) # [N, H, W, 3]
|
286 |
+
|
287 |
+
image_gaussian_pyramid = [self.build_gaussian_pyramid(self.sequence[i] * 255) for i in range(self.img_num)]
|
288 |
+
image_laplace_pyramid = [self.build_laplace_pyramid(image_gaussian_pyramid[i]) for i in range(self.img_num)]
|
289 |
+
weights_gaussian_pyramid = [self.build_gaussian_pyramid(weights[i]) for i in range(self.img_num)]
|
290 |
+
|
291 |
+
fused_laplace_pyramid = [np.sum([image_laplace_pyramid[n][l] *
|
292 |
+
weights_gaussian_pyramid[n][l] for n in range(self.img_num)], axis=0) for l in range(self.layers)]
|
293 |
+
|
294 |
+
result = fused_laplace_pyramid[-1]
|
295 |
+
for k in range(1, self.layers):
|
296 |
+
size = (fused_laplace_pyramid[self.layers - k - 1].shape[1], fused_laplace_pyramid[self.layers - k - 1].shape[0])
|
297 |
+
upsampled = cv2.resize(result, size, interpolation=cv2.INTER_LINEAR)
|
298 |
+
result = upsampled + fused_laplace_pyramid[self.layers - k - 1]
|
299 |
+
#pdb.set_trace()
|
300 |
+
#result = np.clip(result, 0, 255).astype(np.uint8)
|
301 |
+
|
302 |
+
|
303 |
+
return result
|
304 |
+
|
305 |
+
h_pre1, w_pre1 = 6144, 8192
|
306 |
+
pad_1 = 0
|
307 |
+
pad_2 = 0
|
308 |
+
h_exp1, w_exp1 = h_pre1 // 2, w_pre1 // 2
|
309 |
+
|
310 |
+
sess = tf.Session()
|
311 |
+
in_image = tf.placeholder(tf.float32, [None, h_exp1, w_exp1, 4])
|
312 |
+
|
313 |
+
in_image1 = tf.nn.avg_pool(in_image,ksize=[1,4,4,1],strides=[1,4,4,1],padding='SAME')
|
314 |
+
in_image2 = tf.nn.avg_pool(in_image,ksize=[1,8,8,1],strides=[1,8,8,1],padding='SAME')
|
315 |
+
|
316 |
+
out_image1 = network(in_image1)
|
317 |
+
out_image2 = network(in_image2, reuse=True)
|
318 |
+
|
319 |
+
t_vars = tf.trainable_variables()
|
320 |
+
for ele1 in t_vars:
|
321 |
+
print("variable: ", ele1)
|
322 |
+
|
323 |
+
saver = tf.train.Saver()
|
324 |
+
sess.run(tf.global_variables_initializer())
|
325 |
+
|
326 |
+
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
|
327 |
+
if ckpt:
|
328 |
+
print('loaded ' + ckpt.model_checkpoint_path)
|
329 |
+
saver.restore(sess, ckpt.model_checkpoint_path)
|
330 |
+
|
331 |
+
in_pic4 = np.zeros([h_exp1, w_exp1, 4])
|
332 |
+
for k in range(len(train_ids)):
|
333 |
+
|
334 |
+
print(k)
|
335 |
+
train_id = train_ids[k]
|
336 |
+
in_path = input_dir + train_id[:-4] + '.png'
|
337 |
+
#raw_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
338 |
+
raw_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
339 |
+
#meta = np.load(input_dir1 + train_id[:-4] + '.npy').astype(np.float32)
|
340 |
+
#meta = scipy.io.loadmat(input_dir2 + train_id[:-4] + '.mat')
|
341 |
+
metadata = json_read(in_path[:-4] + '.json', object_hook=fraction_from_json)
|
342 |
+
|
343 |
+
white_level = float(metadata['white_level'])
|
344 |
+
black_level = float(metadata['black_level'][0].numerator)
|
345 |
+
|
346 |
+
orientation = metadata['orientation']
|
347 |
+
|
348 |
+
in_pic2 = np.clip((raw_image - black_level) /(white_level-black_level),0,1)
|
349 |
+
|
350 |
+
mean = np.mean(np.mean(in_pic2))
|
351 |
+
var = np.var(in_pic2)
|
352 |
+
|
353 |
+
bining = 4
|
354 |
+
|
355 |
+
if (mean < 0.01):
|
356 |
+
ratio = 6
|
357 |
+
elif (mean < 0.02):
|
358 |
+
ratio = 4
|
359 |
+
elif (mean < 0.037):
|
360 |
+
ratio = 3
|
361 |
+
else:
|
362 |
+
ratio = 2
|
363 |
+
|
364 |
+
if (var > 0.015):
|
365 |
+
ratio = ratio + 1
|
366 |
+
|
367 |
+
noise_profile = float(metadata['noise_profile'][0]) * ratio
|
368 |
+
if (noise_profile > 0.02):
|
369 |
+
bining = 8
|
370 |
+
ratio = np.clip(ratio - 1,2,4)
|
371 |
+
|
372 |
+
#r_gain, b_gain = awb_v(in_pic2, bayer='RGGB', eps=1)
|
373 |
+
r_gain1 = 1./metadata['as_shot_neutral'][0]
|
374 |
+
b_gain1 = 1./metadata['as_shot_neutral'][2]
|
375 |
+
|
376 |
+
#in_pic3 = np.pad(in_pic2, ((top_pad, btm_pad), (left_pad, right_pad)), mode='reflect') # GBRG to RGGB + reflect padding
|
377 |
+
h_pre,w_pre = in_pic2.shape
|
378 |
+
|
379 |
+
if (metadata['cfa_pattern'][0].numerator == 2):
|
380 |
+
in_pic2[0:h_pre-1,0:w_pre-1] = in_pic2[1:h_pre,1:w_pre]
|
381 |
+
|
382 |
+
r_gain, b_gain = awb_v(in_pic2 * (ratio**2), bayer='RGGB', eps=1)
|
383 |
+
in_pic3 = in_pic2
|
384 |
+
|
385 |
+
in_pic4[0:h_pre//2, 0:w_pre//2, 0] = in_pic3[0::2, 0::2] * r_gain
|
386 |
+
in_pic4[0:h_pre//2, 0:w_pre//2, 1] = in_pic3[0::2, 1::2]
|
387 |
+
in_pic4[0:h_pre//2, 0:w_pre//2, 2] = in_pic3[1::2, 1::2] * b_gain
|
388 |
+
in_pic4[0:h_pre//2, 0:w_pre//2, 3] = in_pic3[1::2, 0::2]
|
389 |
+
|
390 |
+
im1=np.clip(in_pic4*1,0,1)
|
391 |
+
in_np1 = np.expand_dims(im1,axis = 0)
|
392 |
+
if (bining == 4):
|
393 |
+
out_np1 =sess.run(out_image1,feed_dict={in_image: in_np1})
|
394 |
+
else:
|
395 |
+
out_np1 =sess.run(out_image2,feed_dict={in_image: in_np1})
|
396 |
+
|
397 |
+
out_np2 = fix_orientation(out_np1[0,0:h_pre//bining,0:w_pre//bining,:], orientation)
|
398 |
+
h_pre2,w_pre2,cc = out_np2.shape
|
399 |
+
|
400 |
+
if h_pre2 > w_pre2:
|
401 |
+
out_np_1 = cv2.resize(out_np2, (768, 1024), cv2.INTER_CUBIC)
|
402 |
+
if w_pre2 > h_pre2:
|
403 |
+
out_np_1 = cv2.resize(out_np2, (1024, 768), cv2.INTER_CUBIC)
|
404 |
+
|
405 |
+
im1=np.clip(in_pic4*ratio,0,1)
|
406 |
+
in_np1 = np.expand_dims(im1,axis = 0)
|
407 |
+
if (bining == 4):
|
408 |
+
out_np1 =sess.run(out_image1,feed_dict={in_image: in_np1})
|
409 |
+
else:
|
410 |
+
out_np1 =sess.run(out_image2,feed_dict={in_image: in_np1})
|
411 |
+
|
412 |
+
out_np2 = fix_orientation(out_np1[0,0:h_pre//bining,0:w_pre//bining,:], orientation)
|
413 |
+
h_pre2,w_pre2,cc = out_np2.shape
|
414 |
+
|
415 |
+
if h_pre2 > w_pre2:
|
416 |
+
out_np_2 = cv2.resize(out_np2, (768, 1024), cv2.INTER_CUBIC)
|
417 |
+
if w_pre2 > h_pre2:
|
418 |
+
out_np_2 = cv2.resize(out_np2, (1024, 768), cv2.INTER_CUBIC)
|
419 |
+
|
420 |
+
|
421 |
+
im1=np.clip(in_pic4*(ratio**2),0,1)
|
422 |
+
in_np1 = np.expand_dims(im1,axis = 0)
|
423 |
+
|
424 |
+
if (bining == 4):
|
425 |
+
out_np1 =sess.run(out_image1,feed_dict={in_image: in_np1})
|
426 |
+
else:
|
427 |
+
out_np1 =sess.run(out_image2,feed_dict={in_image: in_np1})
|
428 |
+
|
429 |
+
out_np2 = fix_orientation(out_np1[0,0:h_pre//bining,0:w_pre//bining,:], orientation)
|
430 |
+
h_pre2,w_pre2,cc = out_np2.shape
|
431 |
+
|
432 |
+
if h_pre2 > w_pre2:
|
433 |
+
out_np_3 = cv2.resize(out_np2, (768, 1024), cv2.INTER_CUBIC)
|
434 |
+
if w_pre2 > h_pre2:
|
435 |
+
out_np_3 = cv2.resize(out_np2, (1024, 768), cv2.INTER_CUBIC)
|
436 |
+
|
437 |
+
#pdb.set_trace()
|
438 |
+
'''sequence = np.stack([out_np_1, out_np_2, out_np_3], axis=0)
|
439 |
+
#sequence0 = sequence[0]
|
440 |
+
mef = ExposureFusion(sequence.astype(np.float32))
|
441 |
+
multi_res_fusion = mef.multi_resolution_fusion()
|
442 |
+
#pdb.set_trace()
|
443 |
+
result = reprocessing(multi_res_fusion)'''
|
444 |
+
|
445 |
+
#out_crop = multi_res_fusion
|
446 |
+
|
447 |
+
#np.save(result_dir + train_id[0:-4] + '_gray_{:}.npy'.format(gain), out_crop)
|
448 |
+
cv2.imwrite(result_dir + train_id[0:-4] + '_1.png', out_np_1[:,:,::-1]*255)
|
449 |
+
cv2.imwrite(result_dir + train_id[0:-4] + '_2.png', out_np_2[:,:,::-1]*255)
|
450 |
+
cv2.imwrite(result_dir + train_id[0:-4] + '_3.png', out_np_3[:,:,::-1]*255)
|
451 |
+
|
DH-AISP/2/__pycache__/model_convnext2_hdr.cpython-37.pyc
ADDED
Binary file (18.5 kB). View file
|
|
DH-AISP/2/__pycache__/myFFCResblock0.cpython-37.pyc
ADDED
Binary file (1.55 kB). View file
|
|
DH-AISP/2/__pycache__/test_dataset_for_testing.cpython-37.pyc
ADDED
Binary file (1.99 kB). View file
|
|
DH-AISP/2/focal_frequency_loss/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .focal_frequency_loss import FocalFrequencyLoss
|
2 |
+
|
3 |
+
__all__ = ['FocalFrequencyLoss']
|
DH-AISP/2/focal_frequency_loss/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (263 Bytes). View file
|
|
DH-AISP/2/focal_frequency_loss/__pycache__/focal_frequency_loss.cpython-37.pyc
ADDED
Binary file (4.01 kB). View file
|
|
DH-AISP/2/focal_frequency_loss/focal_frequency_loss.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
# version adaptation for PyTorch > 1.7.1
|
5 |
+
IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.'))) > (1, 7, 1)
|
6 |
+
if IS_HIGH_VERSION:
|
7 |
+
import torch.fft
|
8 |
+
|
9 |
+
|
10 |
+
class FocalFrequencyLoss(nn.Module):
|
11 |
+
"""The torch.nn.Module class that implements focal frequency loss - a
|
12 |
+
frequency domain loss function for optimizing generative models.
|
13 |
+
|
14 |
+
Ref:
|
15 |
+
Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021.
|
16 |
+
<https://arxiv.org/pdf/2012.12821.pdf>
|
17 |
+
|
18 |
+
Args:
|
19 |
+
loss_weight (float): weight for focal frequency loss. Default: 1.0
|
20 |
+
alpha (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0
|
21 |
+
patch_factor (int): the factor to crop image patches for patch-based focal frequency loss. Default: 1
|
22 |
+
ave_spectrum (bool): whether to use minibatch average spectrum. Default: False
|
23 |
+
log_matrix (bool): whether to adjust the spectrum weight matrix by logarithm. Default: False
|
24 |
+
batch_matrix (bool): whether to calculate the spectrum weight matrix using batch-based statistics. Default: False
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False):
|
28 |
+
super(FocalFrequencyLoss, self).__init__()
|
29 |
+
self.loss_weight = loss_weight
|
30 |
+
self.alpha = alpha
|
31 |
+
self.patch_factor = patch_factor
|
32 |
+
self.ave_spectrum = ave_spectrum
|
33 |
+
self.log_matrix = log_matrix
|
34 |
+
self.batch_matrix = batch_matrix
|
35 |
+
|
36 |
+
def tensor2freq(self, x):
|
37 |
+
# crop image patches
|
38 |
+
patch_factor = self.patch_factor
|
39 |
+
_, _, h, w = x.shape
|
40 |
+
assert h % patch_factor == 0 and w % patch_factor == 0, (
|
41 |
+
'Patch factor should be divisible by image height and width')
|
42 |
+
patch_list = []
|
43 |
+
patch_h = h // patch_factor
|
44 |
+
patch_w = w // patch_factor
|
45 |
+
for i in range(patch_factor):
|
46 |
+
for j in range(patch_factor):
|
47 |
+
patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
|
48 |
+
|
49 |
+
# stack to patch tensor
|
50 |
+
y = torch.stack(patch_list, 1)
|
51 |
+
|
52 |
+
# perform 2D DFT (real-to-complex, orthonormalization)
|
53 |
+
if IS_HIGH_VERSION:
|
54 |
+
freq = torch.fft.fft2(y, norm='ortho')
|
55 |
+
freq = torch.stack([freq.real, freq.imag], -1)
|
56 |
+
else:
|
57 |
+
freq = torch.rfft(y, 2, onesided=False, normalized=True)
|
58 |
+
return freq
|
59 |
+
|
60 |
+
def loss_formulation(self, recon_freq, real_freq, matrix=None):
|
61 |
+
# spectrum weight matrix
|
62 |
+
if matrix is not None:
|
63 |
+
# if the matrix is predefined
|
64 |
+
weight_matrix = matrix.detach()
|
65 |
+
else:
|
66 |
+
# if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance
|
67 |
+
matrix_tmp = (recon_freq - real_freq) ** 2
|
68 |
+
matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
|
69 |
+
|
70 |
+
# whether to adjust the spectrum weight matrix by logarithm
|
71 |
+
if self.log_matrix:
|
72 |
+
matrix_tmp = torch.log(matrix_tmp + 1.0)
|
73 |
+
|
74 |
+
# whether to calculate the spectrum weight matrix using batch-based statistics
|
75 |
+
if self.batch_matrix:
|
76 |
+
matrix_tmp = matrix_tmp / matrix_tmp.max()
|
77 |
+
else:
|
78 |
+
matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
|
79 |
+
|
80 |
+
matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
|
81 |
+
matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
|
82 |
+
weight_matrix = matrix_tmp.clone().detach()
|
83 |
+
|
84 |
+
assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
|
85 |
+
'The values of spectrum weight matrix should be in the range [0, 1], '
|
86 |
+
'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
|
87 |
+
|
88 |
+
# frequency distance using (squared) Euclidean distance
|
89 |
+
tmp = (recon_freq - real_freq) ** 2
|
90 |
+
freq_distance = tmp[..., 0] + tmp[..., 1]
|
91 |
+
|
92 |
+
# dynamic spectrum weighting (Hadamard product)
|
93 |
+
loss = weight_matrix * freq_distance
|
94 |
+
return torch.mean(loss)
|
95 |
+
|
96 |
+
def forward(self, pred, target, matrix=None, **kwargs):
|
97 |
+
"""Forward function to calculate focal frequency loss.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor.
|
101 |
+
target (torch.Tensor): of shape (N, C, H, W). Target tensor.
|
102 |
+
matrix (torch.Tensor, optional): Element-wise spectrum weight matrix.
|
103 |
+
Default: None (If set to None: calculated online, dynamic).
|
104 |
+
"""
|
105 |
+
pred_freq = self.tensor2freq(pred)
|
106 |
+
target_freq = self.tensor2freq(target)
|
107 |
+
|
108 |
+
# whether to use minibatch average spectrum
|
109 |
+
if self.ave_spectrum:
|
110 |
+
pred_freq = torch.mean(pred_freq, 0, keepdim=True)
|
111 |
+
target_freq = torch.mean(target_freq, 0, keepdim=True)
|
112 |
+
|
113 |
+
# calculate focal frequency loss
|
114 |
+
return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight
|
DH-AISP/2/model_convnext2_hdr.py
ADDED
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from timm.models.layers import trunc_normal_, DropPath
|
5 |
+
from timm.models.registry import register_model
|
6 |
+
|
7 |
+
#import Convnext as PreConv
|
8 |
+
from myFFCResblock0 import myFFCResblock
|
9 |
+
|
10 |
+
|
11 |
+
# A ConvNet for the 2020s
|
12 |
+
# original implementation https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
|
13 |
+
# paper https://arxiv.org/pdf/2201.03545.pdf
|
14 |
+
|
15 |
+
class ConvNeXt0(nn.Module):
|
16 |
+
r""" ConvNeXt
|
17 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
18 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
19 |
+
Args:
|
20 |
+
in_chans (int): Number of input image channels. Default: 3
|
21 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
22 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
23 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
24 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
25 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
26 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
27 |
+
"""
|
28 |
+
def __init__(self, block, in_chans=3, num_classes=1000,
|
29 |
+
depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], drop_path_rate=0.,
|
30 |
+
layer_scale_init_value=1e-6, head_init_scale=1.,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
35 |
+
stem = nn.Sequential(
|
36 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
37 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
38 |
+
)
|
39 |
+
self.downsample_layers.append(stem)
|
40 |
+
for i in range(3):
|
41 |
+
downsample_layer = nn.Sequential(
|
42 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
43 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
44 |
+
)
|
45 |
+
self.downsample_layers.append(downsample_layer)
|
46 |
+
|
47 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
48 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
49 |
+
cur = 0
|
50 |
+
for i in range(4):
|
51 |
+
stage = nn.Sequential(
|
52 |
+
*[block(dim=dims[i], drop_path=dp_rates[cur + j],
|
53 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
54 |
+
)
|
55 |
+
self.stages.append(stage)
|
56 |
+
cur += depths[i]
|
57 |
+
|
58 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
59 |
+
self.head = nn.Linear(dims[-1], num_classes)
|
60 |
+
|
61 |
+
self.apply(self._init_weights)
|
62 |
+
self.head.weight.data.mul_(head_init_scale)
|
63 |
+
self.head.bias.data.mul_(head_init_scale)
|
64 |
+
|
65 |
+
def _init_weights(self, m):
|
66 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
67 |
+
trunc_normal_(m.weight, std=.02)
|
68 |
+
nn.init.constant_(m.bias, 0)
|
69 |
+
|
70 |
+
def forward_features(self, x):
|
71 |
+
for i in range(4):
|
72 |
+
x = self.downsample_layers[i](x)
|
73 |
+
x = self.stages[i](x)
|
74 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
x = self.forward_features(x)
|
78 |
+
x = self.head(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
def dwt_init(x):
|
89 |
+
x01 = x[:, :, 0::2, :] / 2 #x01.shape=[4,3,128,256]
|
90 |
+
x02 = x[:, :, 1::2, :] / 2 #x02.shape=[4,3,128,256]
|
91 |
+
x1 = x01[:, :, :, 0::2] #x1.shape=[4,3,128,128]
|
92 |
+
x2 = x02[:, :, :, 0::2] #x2.shape=[4,3,128,128]
|
93 |
+
x3 = x01[:, :, :, 1::2] #x3.shape=[4,3,128,128]
|
94 |
+
x4 = x02[:, :, :, 1::2] #x4.shape=[4,3,128,128]
|
95 |
+
x_LL = x1 + x2 + x3 + x4
|
96 |
+
x_HL = -x1 - x2 + x3 + x4
|
97 |
+
x_LH = -x1 + x2 - x3 + x4
|
98 |
+
x_HH = x1 - x2 - x3 + x4
|
99 |
+
return x_LL, torch.cat((x_HL, x_LH, x_HH), 1)
|
100 |
+
|
101 |
+
class DWT(nn.Module):
|
102 |
+
def __init__(self):
|
103 |
+
super(DWT, self).__init__()
|
104 |
+
self.requires_grad = False
|
105 |
+
def forward(self, x):
|
106 |
+
return dwt_init(x)
|
107 |
+
|
108 |
+
class DWT_transform(nn.Module):
|
109 |
+
def __init__(self, in_channels,out_channels):
|
110 |
+
super().__init__()
|
111 |
+
self.dwt = DWT()
|
112 |
+
self.conv1x1_low = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
113 |
+
self.conv1x1_high = nn.Conv2d(in_channels*3, out_channels, kernel_size=1, padding=0)
|
114 |
+
def forward(self, x):
|
115 |
+
dwt_low_frequency,dwt_high_frequency = self.dwt(x)
|
116 |
+
dwt_low_frequency = self.conv1x1_low(dwt_low_frequency)
|
117 |
+
dwt_high_frequency = self.conv1x1_high(dwt_high_frequency)
|
118 |
+
return dwt_low_frequency,dwt_high_frequency
|
119 |
+
|
120 |
+
def blockUNet(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False):
|
121 |
+
block = nn.Sequential()
|
122 |
+
if relu:
|
123 |
+
block.add_module('%s_relu' % name, nn.ReLU(inplace=True))
|
124 |
+
else:
|
125 |
+
block.add_module('%s_leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True))
|
126 |
+
if not transposed:
|
127 |
+
block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False))
|
128 |
+
else:
|
129 |
+
block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1))
|
130 |
+
block.add_module('%s_bili' % name, nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
|
131 |
+
if bn:
|
132 |
+
block.add_module('%s_bn' % name, nn.BatchNorm2d(out_c))
|
133 |
+
if dropout:
|
134 |
+
block.add_module('%s_dropout' % name, nn.Dropout2d(0.5, inplace=True))
|
135 |
+
return block
|
136 |
+
|
137 |
+
# DW-GAN: A Discrete Wavelet Transform GAN for NonHomogeneous Dehazing 2021
|
138 |
+
# original implementation https://github.com/liuh127/NTIRE-2021-Dehazing-DWGAN/blob/main/model.py
|
139 |
+
# paper https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/papers/Fu_DW-GAN_A_Discrete_Wavelet_Transform_GAN_for_NonHomogeneous_Dehazing_CVPRW_2021_paper.pdf
|
140 |
+
class dwt_ffc_UNet2(nn.Module):
|
141 |
+
def __init__(self,output_nc=3, nf=16):
|
142 |
+
super(dwt_ffc_UNet2, self).__init__()
|
143 |
+
layer_idx = 1
|
144 |
+
name = 'layer%d' % layer_idx
|
145 |
+
layer1 = nn.Sequential()
|
146 |
+
layer1.add_module(name, nn.Conv2d(16, nf-1, 4, 2, 1, bias=False))
|
147 |
+
layer_idx += 1
|
148 |
+
name = 'layer%d' % layer_idx
|
149 |
+
layer2 = blockUNet(nf, nf*2-2, name, transposed=False, bn=True, relu=False, dropout=False)
|
150 |
+
layer_idx += 1
|
151 |
+
name = 'layer%d' % layer_idx
|
152 |
+
layer3 = blockUNet(nf*2, nf*4-4, name, transposed=False, bn=True, relu=False, dropout=False)
|
153 |
+
layer_idx += 1
|
154 |
+
name = 'layer%d' % layer_idx
|
155 |
+
layer4 = blockUNet(nf*4, nf*8-8, name, transposed=False, bn=True, relu=False, dropout=False)
|
156 |
+
layer_idx += 1
|
157 |
+
name = 'layer%d' % layer_idx
|
158 |
+
layer5 = blockUNet(nf*8, nf*8-16, name, transposed=False, bn=True, relu=False, dropout=False)
|
159 |
+
layer_idx += 1
|
160 |
+
name = 'layer%d' % layer_idx
|
161 |
+
layer6 = blockUNet(nf*4, nf*4, name, transposed=False, bn=False, relu=False, dropout=False)
|
162 |
+
|
163 |
+
layer_idx -= 1
|
164 |
+
name = 'dlayer%d' % layer_idx
|
165 |
+
dlayer6 = blockUNet(nf * 4, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)
|
166 |
+
layer_idx -= 1
|
167 |
+
name = 'dlayer%d' % layer_idx
|
168 |
+
dlayer5 = blockUNet(nf * 16+16, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False)
|
169 |
+
layer_idx -= 1
|
170 |
+
name = 'dlayer%d' % layer_idx
|
171 |
+
dlayer4 = blockUNet(nf * 16+8, nf * 4, name, transposed=True, bn=True, relu=True, dropout=False)
|
172 |
+
layer_idx -= 1
|
173 |
+
name = 'dlayer%d' % layer_idx
|
174 |
+
dlayer3 = blockUNet(nf * 8+4, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)
|
175 |
+
layer_idx -= 1
|
176 |
+
name = 'dlayer%d' % layer_idx
|
177 |
+
dlayer2 = blockUNet(nf * 4+2, nf, name, transposed=True, bn=True, relu=True, dropout=False)
|
178 |
+
layer_idx -= 1
|
179 |
+
name = 'dlayer%d' % layer_idx
|
180 |
+
dlayer1 = blockUNet(nf * 2+1, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)
|
181 |
+
|
182 |
+
self.initial_conv=nn.Conv2d(9,16,3,padding=1)
|
183 |
+
self.bn1=nn.BatchNorm2d(16)
|
184 |
+
self.layer1 = layer1
|
185 |
+
self.DWT_down_0= DWT_transform(9,1)
|
186 |
+
self.layer2 = layer2
|
187 |
+
self.DWT_down_1 = DWT_transform(16, 2)
|
188 |
+
self.layer3 = layer3
|
189 |
+
self.DWT_down_2 = DWT_transform(32, 4)
|
190 |
+
self.layer4 = layer4
|
191 |
+
self.DWT_down_3 = DWT_transform(64, 8)
|
192 |
+
self.layer5 = layer5
|
193 |
+
self.DWT_down_4 = DWT_transform(128, 16)
|
194 |
+
self.layer6 = layer6
|
195 |
+
self.dlayer6 = dlayer6
|
196 |
+
self.dlayer5 = dlayer5
|
197 |
+
self.dlayer4 = dlayer4
|
198 |
+
self.dlayer3 = dlayer3
|
199 |
+
self.dlayer2 = dlayer2
|
200 |
+
self.dlayer1 = dlayer1
|
201 |
+
self.tail_conv1 = nn.Conv2d(48, 32, 3, padding=1, bias=True)
|
202 |
+
self.bn2=nn.BatchNorm2d(32)
|
203 |
+
self.tail_conv2 = nn.Conv2d(nf*2, output_nc, 3,padding=1, bias=True)
|
204 |
+
|
205 |
+
|
206 |
+
self.FFCResNet = myFFCResblock(input_nc=64, output_nc=64)
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
conv_start=self.initial_conv(x)
|
210 |
+
conv_start=self.bn1(conv_start)
|
211 |
+
conv_out1 = self.layer1(conv_start)
|
212 |
+
dwt_low_0,dwt_high_0=self.DWT_down_0(x)
|
213 |
+
out1=torch.cat([conv_out1, dwt_low_0], 1)
|
214 |
+
conv_out2 = self.layer2(out1)
|
215 |
+
dwt_low_1,dwt_high_1= self.DWT_down_1(out1)
|
216 |
+
out2 = torch.cat([conv_out2, dwt_low_1], 1)
|
217 |
+
conv_out3 = self.layer3(out2)
|
218 |
+
|
219 |
+
dwt_low_2,dwt_high_2 = self.DWT_down_2(out2)
|
220 |
+
out3 = torch.cat([conv_out3, dwt_low_2], 1)
|
221 |
+
|
222 |
+
# conv_out4 = self.layer4(out3)
|
223 |
+
# dwt_low_3,dwt_high_3 = self.DWT_down_3(out3)
|
224 |
+
# out4 = torch.cat([conv_out4, dwt_low_3], 1)
|
225 |
+
|
226 |
+
# conv_out5 = self.layer5(out4)
|
227 |
+
# dwt_low_4,dwt_high_4 = self.DWT_down_4(out4)
|
228 |
+
# out5 = torch.cat([conv_out5, dwt_low_4], 1)
|
229 |
+
|
230 |
+
# out6 = self.layer6(out5)
|
231 |
+
|
232 |
+
|
233 |
+
out3_ffc= self.FFCResNet(out3)
|
234 |
+
|
235 |
+
|
236 |
+
dout3 = self.dlayer6(out3_ffc)
|
237 |
+
|
238 |
+
|
239 |
+
Tout3_out2 = torch.cat([dout3, out2,dwt_high_1], 1)
|
240 |
+
Tout2 = self.dlayer2(Tout3_out2)
|
241 |
+
Tout2_out1 = torch.cat([Tout2, out1,dwt_high_0], 1)
|
242 |
+
Tout1 = self.dlayer1(Tout2_out1)
|
243 |
+
|
244 |
+
Tout1_outinit = torch.cat([Tout1, conv_start], 1)
|
245 |
+
tail1=self.tail_conv1(Tout1_outinit)
|
246 |
+
tail2=self.bn2(tail1)
|
247 |
+
dout1 = self.tail_conv2(tail2)
|
248 |
+
|
249 |
+
|
250 |
+
return dout1
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
class Block(nn.Module):
|
258 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
259 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
260 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
261 |
+
We use (2) as we find it slightly faster in PyTorch
|
262 |
+
|
263 |
+
Args:
|
264 |
+
dim (int): Number of input channels.
|
265 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
266 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
267 |
+
"""
|
268 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
269 |
+
super().__init__()
|
270 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
271 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
272 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
273 |
+
self.act = nn.GELU()
|
274 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
275 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
276 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
277 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
278 |
+
|
279 |
+
def forward(self, x):
|
280 |
+
input = x
|
281 |
+
x = self.dwconv(x)
|
282 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
283 |
+
x = self.norm(x)
|
284 |
+
x = self.pwconv1(x)
|
285 |
+
x = self.act(x)
|
286 |
+
x = self.pwconv2(x)
|
287 |
+
if self.gamma is not None:
|
288 |
+
x = self.gamma * x
|
289 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
290 |
+
|
291 |
+
x = input + self.drop_path(x)
|
292 |
+
return x
|
293 |
+
|
294 |
+
|
295 |
+
class ConvNeXt(nn.Module):
|
296 |
+
def __init__(self, block, in_chans=3, num_classes=1000,
|
297 |
+
depths=[3, 3, 27, 3], dims=[256, 512, 1024,2048], drop_path_rate=0.,
|
298 |
+
layer_scale_init_value=1e-6, head_init_scale=1.,
|
299 |
+
):
|
300 |
+
super().__init__()
|
301 |
+
|
302 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
303 |
+
stem = nn.Sequential(
|
304 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
305 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
306 |
+
)
|
307 |
+
self.downsample_layers.append(stem)
|
308 |
+
for i in range(3):
|
309 |
+
downsample_layer = nn.Sequential(
|
310 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
311 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
312 |
+
)
|
313 |
+
self.downsample_layers.append(downsample_layer)
|
314 |
+
|
315 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
316 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
317 |
+
cur = 0
|
318 |
+
for i in range(4):
|
319 |
+
stage = nn.Sequential(
|
320 |
+
*[block(dim=dims[i], drop_path=dp_rates[cur + j],
|
321 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
322 |
+
)
|
323 |
+
self.stages.append(stage)
|
324 |
+
cur += depths[i]
|
325 |
+
|
326 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
327 |
+
self.head = nn.Linear(dims[-1], num_classes)
|
328 |
+
|
329 |
+
self.head.weight.data.mul_(head_init_scale)
|
330 |
+
self.head.bias.data.mul_(head_init_scale)
|
331 |
+
|
332 |
+
|
333 |
+
def forward(self, x):
|
334 |
+
x_layer1 = self.downsample_layers[0](x)
|
335 |
+
x_layer1 = self.stages[0](x_layer1)
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
x_layer2 = self.downsample_layers[1](x_layer1)
|
340 |
+
x_layer2 = self.stages[1](x_layer2)
|
341 |
+
|
342 |
+
|
343 |
+
x_layer3 = self.downsample_layers[2](x_layer2)
|
344 |
+
out = self.stages[2](x_layer3)
|
345 |
+
|
346 |
+
|
347 |
+
return x_layer1, x_layer2, out
|
348 |
+
|
349 |
+
class LayerNorm(nn.Module):
|
350 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
351 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
352 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
353 |
+
with shape (batch_size, channels, height, width).
|
354 |
+
"""
|
355 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
356 |
+
super().__init__()
|
357 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
358 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
359 |
+
self.eps = eps
|
360 |
+
self.data_format = data_format
|
361 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
362 |
+
raise NotImplementedError
|
363 |
+
self.normalized_shape = (normalized_shape, )
|
364 |
+
|
365 |
+
def forward(self, x):
|
366 |
+
if self.data_format == "channels_last":
|
367 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
368 |
+
elif self.data_format == "channels_first":
|
369 |
+
u = x.mean(1, keepdim=True)
|
370 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
371 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
372 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
373 |
+
return x
|
374 |
+
|
375 |
+
|
376 |
+
class PALayer(nn.Module):
|
377 |
+
def __init__(self, channel):
|
378 |
+
super(PALayer, self).__init__()
|
379 |
+
self.pa = nn.Sequential(
|
380 |
+
nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
|
381 |
+
nn.ReLU(inplace=True),
|
382 |
+
nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
|
383 |
+
nn.Sigmoid()
|
384 |
+
)
|
385 |
+
def forward(self, x):
|
386 |
+
y = self.pa(x)
|
387 |
+
return x * y
|
388 |
+
|
389 |
+
class CALayer(nn.Module):
|
390 |
+
def __init__(self, channel):
|
391 |
+
super(CALayer, self).__init__()
|
392 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
393 |
+
self.ca = nn.Sequential(
|
394 |
+
nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
|
395 |
+
nn.ReLU(inplace=True),
|
396 |
+
nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
|
397 |
+
nn.Sigmoid()
|
398 |
+
)
|
399 |
+
def forward(self, x):
|
400 |
+
y = self.avg_pool(x)
|
401 |
+
y = self.ca(y)
|
402 |
+
return x * y
|
403 |
+
|
404 |
+
class CP_Attention_block(nn.Module):
|
405 |
+
def __init__(self, conv, dim, kernel_size):
|
406 |
+
super(CP_Attention_block, self).__init__()
|
407 |
+
self.conv1 = conv(dim, dim, kernel_size, bias=True)
|
408 |
+
self.act1 = nn.ReLU(inplace=True)
|
409 |
+
self.conv2 = conv(dim, dim, kernel_size, bias=True)
|
410 |
+
self.calayer = CALayer(dim)
|
411 |
+
self.palayer = PALayer(dim)
|
412 |
+
def forward(self, x):
|
413 |
+
res = self.act1(self.conv1(x))
|
414 |
+
res = res + x
|
415 |
+
res = self.conv2(res)
|
416 |
+
res = self.calayer(res)
|
417 |
+
res = self.palayer(res)
|
418 |
+
res += x
|
419 |
+
return res
|
420 |
+
|
421 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
422 |
+
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)
|
423 |
+
|
424 |
+
class knowledge_adaptation_convnext(nn.Module):
|
425 |
+
def __init__(self):
|
426 |
+
super(knowledge_adaptation_convnext, self).__init__()
|
427 |
+
self.encoder = ConvNeXt(Block, in_chans=9,num_classes=1000, depths=[3, 3, 27, 3], dims=[256, 512, 1024,2048], drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1.)
|
428 |
+
'''pretrained_model = ConvNeXt0(Block, in_chans=3,num_classes=1000, depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1.)
|
429 |
+
#pretrained_model=nn.DataParallel(pretrained_model)
|
430 |
+
checkpoint=torch.load('./weights/convnext_xlarge_22k_1k_384_ema.pth')
|
431 |
+
#for k,v in checkpoint["model"].items():
|
432 |
+
#print(k)
|
433 |
+
#url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth"
|
434 |
+
|
435 |
+
#checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cuda:0")
|
436 |
+
pretrained_model.load_state_dict(checkpoint["model"])
|
437 |
+
|
438 |
+
pretrained_dict = pretrained_model.state_dict()
|
439 |
+
model_dict = self.encoder.state_dict()
|
440 |
+
key_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
441 |
+
model_dict.update(key_dict)
|
442 |
+
self.encoder.load_state_dict(model_dict)'''
|
443 |
+
|
444 |
+
|
445 |
+
self.up_block= nn.PixelShuffle(2)
|
446 |
+
self.attention0 = CP_Attention_block(default_conv, 1024, 3)
|
447 |
+
self.attention1 = CP_Attention_block(default_conv, 256, 3)
|
448 |
+
self.attention2 = CP_Attention_block(default_conv, 192, 3)
|
449 |
+
self.attention3 = CP_Attention_block(default_conv, 112, 3)
|
450 |
+
self.attention4 = CP_Attention_block(default_conv, 28, 3)
|
451 |
+
self.conv_process_1 = nn.Conv2d(28, 28, kernel_size=3,padding=1)
|
452 |
+
self.conv_process_2 = nn.Conv2d(28, 28, kernel_size=3,padding=1)
|
453 |
+
self.tail = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(28, 3, kernel_size=7, padding=0), nn.Tanh())
|
454 |
+
def forward(self, input):
|
455 |
+
x_layer1, x_layer2, x_output = self.encoder(input)
|
456 |
+
|
457 |
+
x_mid = self.attention0(x_output) #[1024,24,24]
|
458 |
+
|
459 |
+
x = self.up_block(x_mid) #[256,48,48]
|
460 |
+
x = self.attention1(x)
|
461 |
+
|
462 |
+
x = torch.cat((x, x_layer2), 1) #[768,48,48]
|
463 |
+
|
464 |
+
x = self.up_block(x) #[192,96,96]
|
465 |
+
x = self.attention2(x)
|
466 |
+
x = torch.cat((x, x_layer1), 1) #[448,96,96]
|
467 |
+
x = self.up_block(x) #[112,192,192]
|
468 |
+
x = self.attention3(x)
|
469 |
+
|
470 |
+
x = self.up_block(x) #[28,384,384]
|
471 |
+
x = self.attention4(x)
|
472 |
+
|
473 |
+
x=self.conv_process_1(x)
|
474 |
+
out=self.conv_process_2(x)
|
475 |
+
return out
|
476 |
+
|
477 |
+
|
478 |
+
class fusion_net(nn.Module):
|
479 |
+
def __init__(self):
|
480 |
+
super(fusion_net, self).__init__()
|
481 |
+
self.dwt_branch=dwt_ffc_UNet2()
|
482 |
+
self.knowledge_adaptation_branch=knowledge_adaptation_convnext()
|
483 |
+
self.fusion = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(31, 3, kernel_size=7, padding=0), nn.Tanh())
|
484 |
+
def forward(self, input):
|
485 |
+
dwt_branch=self.dwt_branch(input)
|
486 |
+
knowledge_adaptation_branch=self.knowledge_adaptation_branch(input)
|
487 |
+
x = torch.cat([dwt_branch, knowledge_adaptation_branch], 1)
|
488 |
+
x = self.fusion(x)
|
489 |
+
return x
|
490 |
+
|
491 |
+
|
492 |
+
|
493 |
+
class Discriminator(nn.Module):
|
494 |
+
def __init__(self):
|
495 |
+
super(Discriminator, self).__init__()
|
496 |
+
self.net = nn.Sequential(
|
497 |
+
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
498 |
+
nn.LeakyReLU(0.2),
|
499 |
+
|
500 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
|
501 |
+
nn.BatchNorm2d(64),
|
502 |
+
nn.LeakyReLU(0.2),
|
503 |
+
|
504 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
505 |
+
nn.BatchNorm2d(128),
|
506 |
+
nn.LeakyReLU(0.2),
|
507 |
+
|
508 |
+
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
|
509 |
+
nn.BatchNorm2d(128),
|
510 |
+
nn.LeakyReLU(0.2),
|
511 |
+
|
512 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
513 |
+
nn.BatchNorm2d(256),
|
514 |
+
nn.LeakyReLU(0.2),
|
515 |
+
|
516 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
|
517 |
+
nn.BatchNorm2d(256),
|
518 |
+
nn.LeakyReLU(0.2),
|
519 |
+
|
520 |
+
nn.Conv2d(256, 512, kernel_size=3, padding=1),
|
521 |
+
nn.BatchNorm2d(512),
|
522 |
+
nn.LeakyReLU(0.2),
|
523 |
+
|
524 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
525 |
+
nn.BatchNorm2d(512),
|
526 |
+
nn.LeakyReLU(0.2),
|
527 |
+
|
528 |
+
nn.AdaptiveAvgPool2d(1),
|
529 |
+
nn.Conv2d(512, 1024, kernel_size=1),
|
530 |
+
nn.LeakyReLU(0.2),
|
531 |
+
nn.Conv2d(1024, 1, kernel_size=1)
|
532 |
+
)
|
533 |
+
|
534 |
+
def forward(self, x):
|
535 |
+
batch_size = x.size(0)
|
536 |
+
return torch.sigmoid(self.net(x).view(batch_size))
|
537 |
+
|
538 |
+
|
539 |
+
class Discriminator2(nn.Module):
|
540 |
+
def __init__(self):
|
541 |
+
super(Discriminator2, self).__init__()
|
542 |
+
self.net = nn.Sequential(
|
543 |
+
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
544 |
+
nn.LeakyReLU(0.2),
|
545 |
+
|
546 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
|
547 |
+
nn.BatchNorm2d(64),
|
548 |
+
nn.LeakyReLU(0.2),
|
549 |
+
|
550 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
551 |
+
nn.BatchNorm2d(128),
|
552 |
+
nn.LeakyReLU(0.2),
|
553 |
+
|
554 |
+
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
|
555 |
+
nn.BatchNorm2d(128),
|
556 |
+
nn.LeakyReLU(0.2),
|
557 |
+
|
558 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
559 |
+
nn.BatchNorm2d(256),
|
560 |
+
nn.LeakyReLU(0.2),
|
561 |
+
|
562 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
|
563 |
+
nn.BatchNorm2d(256),
|
564 |
+
nn.LeakyReLU(0.2),
|
565 |
+
|
566 |
+
nn.Conv2d(256, 512, kernel_size=3, padding=1),
|
567 |
+
nn.BatchNorm2d(512),
|
568 |
+
nn.LeakyReLU(0.2),
|
569 |
+
|
570 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
571 |
+
nn.BatchNorm2d(512),
|
572 |
+
nn.LeakyReLU(0.2),
|
573 |
+
|
574 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
575 |
+
nn.BatchNorm2d(512),
|
576 |
+
nn.LeakyReLU(0.2),
|
577 |
+
|
578 |
+
nn.Conv2d(512, 1, kernel_size=3, padding=1),
|
579 |
+
)
|
580 |
+
|
581 |
+
def forward(self, x):
|
582 |
+
return self.net(x)
|
583 |
+
|
584 |
+
if __name__ == '__main__':
|
585 |
+
|
586 |
+
device = torch.device("cuda:0")
|
587 |
+
|
588 |
+
# Create model
|
589 |
+
im = torch.rand(1, 3, 640, 640).to(device)
|
590 |
+
model_g = fusion_net().to(device)
|
591 |
+
|
592 |
+
out_data = model_g(im)
|
DH-AISP/2/myFFCResblock0.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
from saicinpainting.training.modules.ffc0 import FFCResnetBlock
|
9 |
+
from saicinpainting.training.modules.ffc0 import FFC_BN_ACT
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
class myFFCResblock(nn.Module):
|
15 |
+
def __init__(self, input_nc, output_nc, n_blocks=2, norm_layer=nn.BatchNorm2d, #128--->64
|
16 |
+
padding_type='reflect', activation_layer=nn.ReLU,
|
17 |
+
resnet_conv_kwargs={},
|
18 |
+
spatial_transform_layers=None, spatial_transform_kwargs={},
|
19 |
+
add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
|
20 |
+
assert (n_blocks >= 0)
|
21 |
+
|
22 |
+
super().__init__()
|
23 |
+
self.initial = FFC_BN_ACT(input_nc, input_nc, kernel_size=3, padding=1, dilation=1,
|
24 |
+
norm_layer=norm_layer, activation_layer=activation_layer,
|
25 |
+
padding_type=padding_type,
|
26 |
+
**resnet_conv_kwargs)
|
27 |
+
|
28 |
+
self.ffcresblock = FFCResnetBlock(input_nc, padding_type=padding_type, activation_layer=activation_layer,
|
29 |
+
norm_layer=norm_layer, **resnet_conv_kwargs)
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
self.final = FFC_BN_ACT(input_nc, output_nc, kernel_size=3, padding=1, dilation=1,
|
34 |
+
norm_layer=norm_layer,
|
35 |
+
activation_layer=activation_layer,
|
36 |
+
padding_type=padding_type,
|
37 |
+
**resnet_conv_kwargs)
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
|
47 |
+
x_l, x_g = self.initial(x)
|
48 |
+
|
49 |
+
x_l, x_g = self.ffcresblock(x_l, x_g)
|
50 |
+
x_l, x_g = self.ffcresblock(x_l, x_g)
|
51 |
+
|
52 |
+
out_ = torch.cat([x_l, x_g], 1)
|
53 |
+
|
54 |
+
x_lout, x_gout =self.final(out_)
|
55 |
+
|
56 |
+
out = torch.cat([x_lout, x_gout], 1)
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
|
DH-AISP/2/perceptual.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --- Imports --- #
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
# --- Perceptual loss network --- #
|
6 |
+
class LossNetwork(torch.nn.Module):
|
7 |
+
def __init__(self, vgg_model):
|
8 |
+
super(LossNetwork, self).__init__()
|
9 |
+
self.vgg_layers = vgg_model
|
10 |
+
self.layer_name_mapping = {
|
11 |
+
'3': "relu1_2",
|
12 |
+
'8': "relu2_2",
|
13 |
+
'15': "relu3_3"
|
14 |
+
}
|
15 |
+
|
16 |
+
def output_features(self, x):
|
17 |
+
output = {}
|
18 |
+
for name, module in self.vgg_layers._modules.items():
|
19 |
+
x = module(x)
|
20 |
+
if name in self.layer_name_mapping:
|
21 |
+
output[self.layer_name_mapping[name]] = x
|
22 |
+
return list(output.values())
|
23 |
+
|
24 |
+
def forward(self, dehaze, gt):
|
25 |
+
loss = []
|
26 |
+
dehaze_features = self.output_features(dehaze)
|
27 |
+
gt_features = self.output_features(gt)
|
28 |
+
for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
|
29 |
+
loss.append(F.mse_loss(dehaze_feature, gt_feature))
|
30 |
+
return sum(loss)/len(loss)
|
DH-AISP/2/pytorch_msssim/__init__.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from math import exp
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def gaussian(window_size, sigma):
|
8 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
9 |
+
return gauss/gauss.sum()
|
10 |
+
|
11 |
+
|
12 |
+
def create_window(window_size, channel=1):
|
13 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
14 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
15 |
+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
16 |
+
return window
|
17 |
+
|
18 |
+
|
19 |
+
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
|
20 |
+
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
21 |
+
if val_range is None:
|
22 |
+
if torch.max(img1) > 128:
|
23 |
+
max_val = 255
|
24 |
+
else:
|
25 |
+
max_val = 1
|
26 |
+
|
27 |
+
if torch.min(img1) < -0.5:
|
28 |
+
min_val = -1
|
29 |
+
else:
|
30 |
+
min_val = 0
|
31 |
+
L = max_val - min_val
|
32 |
+
else:
|
33 |
+
L = val_range
|
34 |
+
|
35 |
+
padd = 0
|
36 |
+
(_, channel, height, width) = img1.size()
|
37 |
+
if window is None:
|
38 |
+
real_size = min(window_size, height, width)
|
39 |
+
window = create_window(real_size, channel=channel).to(img1.device)
|
40 |
+
|
41 |
+
mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
42 |
+
mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
43 |
+
|
44 |
+
mu1_sq = mu1.pow(2)
|
45 |
+
mu2_sq = mu2.pow(2)
|
46 |
+
mu1_mu2 = mu1 * mu2
|
47 |
+
|
48 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
|
49 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
|
50 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
|
51 |
+
|
52 |
+
C1 = (0.01 * L) ** 2
|
53 |
+
C2 = (0.03 * L) ** 2
|
54 |
+
|
55 |
+
v1 = 2.0 * sigma12 + C2
|
56 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
57 |
+
cs = torch.mean(v1 / v2) # contrast sensitivity
|
58 |
+
|
59 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
60 |
+
|
61 |
+
if size_average:
|
62 |
+
ret = ssim_map.mean()
|
63 |
+
else:
|
64 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
65 |
+
|
66 |
+
if full:
|
67 |
+
return ret, cs
|
68 |
+
return ret
|
69 |
+
|
70 |
+
|
71 |
+
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
|
72 |
+
device = img1.device
|
73 |
+
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
|
74 |
+
levels = weights.size()[0]
|
75 |
+
mssim = []
|
76 |
+
mcs = []
|
77 |
+
for _ in range(levels):
|
78 |
+
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
|
79 |
+
mssim.append(sim)
|
80 |
+
mcs.append(cs)
|
81 |
+
|
82 |
+
img1 = F.avg_pool2d(img1, (2, 2))
|
83 |
+
img2 = F.avg_pool2d(img2, (2, 2))
|
84 |
+
|
85 |
+
mssim = torch.stack(mssim)
|
86 |
+
mcs = torch.stack(mcs)
|
87 |
+
|
88 |
+
# Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
|
89 |
+
if normalize:
|
90 |
+
mssim = (mssim + 1) / 2
|
91 |
+
mcs = (mcs + 1) / 2
|
92 |
+
|
93 |
+
pow1 = mcs ** weights
|
94 |
+
pow2 = mssim ** weights
|
95 |
+
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
|
96 |
+
output = torch.prod(pow1[:-1] * pow2[-1])
|
97 |
+
return output
|
98 |
+
|
99 |
+
|
100 |
+
# Classes to re-use window
|
101 |
+
class SSIM(torch.nn.Module):
|
102 |
+
def __init__(self, window_size=11, size_average=True, val_range=None):
|
103 |
+
super(SSIM, self).__init__()
|
104 |
+
self.window_size = window_size
|
105 |
+
self.size_average = size_average
|
106 |
+
self.val_range = val_range
|
107 |
+
|
108 |
+
# Assume 1 channel for SSIM
|
109 |
+
self.channel = 1
|
110 |
+
self.window = create_window(window_size)
|
111 |
+
|
112 |
+
def forward(self, img1, img2):
|
113 |
+
(_, channel, _, _) = img1.size()
|
114 |
+
|
115 |
+
if channel == self.channel and self.window.dtype == img1.dtype:
|
116 |
+
window = self.window
|
117 |
+
else:
|
118 |
+
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
|
119 |
+
self.window = window
|
120 |
+
self.channel = channel
|
121 |
+
|
122 |
+
return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
|
123 |
+
|
124 |
+
class MSSSIM(torch.nn.Module):
|
125 |
+
def __init__(self, window_size=11, size_average=True, channel=3):
|
126 |
+
super(MSSSIM, self).__init__()
|
127 |
+
self.window_size = window_size
|
128 |
+
self.size_average = size_average
|
129 |
+
self.channel = channel
|
130 |
+
|
131 |
+
def forward(self, img1, img2):
|
132 |
+
# TODO: store window between calls if possible
|
133 |
+
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
|
DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (3.9 kB). View file
|
|
DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (3.88 kB). View file
|
|
DH-AISP/2/result_low_light_hdr/checkpoint_gen.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e5952db983eb66b04c6a39348a0916164d9148ec99c4a3b8a77bf4e240657022
|
3 |
+
size 1491472482
|
DH-AISP/2/saicinpainting/__init__.py
ADDED
File without changes
|
DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (168 Bytes). View file
|
|
DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (155 Bytes). View file
|
|
DH-AISP/2/saicinpainting/__pycache__/utils.cpython-36.pyc
ADDED
Binary file (6.1 kB). View file
|
|
DH-AISP/2/saicinpainting/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (6.08 kB). View file
|
|
DH-AISP/2/saicinpainting/evaluation/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1
|
6 |
+
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
|
7 |
+
|
8 |
+
|
9 |
+
def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs):
|
10 |
+
logging.info(f'Make evaluator {kind}')
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
metrics = {}
|
13 |
+
if ssim:
|
14 |
+
metrics['ssim'] = SSIMScore()
|
15 |
+
if lpips:
|
16 |
+
metrics['lpips'] = LPIPSScore()
|
17 |
+
if fid:
|
18 |
+
metrics['fid'] = FIDScore().to(device)
|
19 |
+
|
20 |
+
if integral_kind is None:
|
21 |
+
integral_func = None
|
22 |
+
elif integral_kind == 'ssim_fid100_f1':
|
23 |
+
integral_func = ssim_fid100_f1
|
24 |
+
elif integral_kind == 'lpips_fid100_f1':
|
25 |
+
integral_func = lpips_fid100_f1
|
26 |
+
else:
|
27 |
+
raise ValueError(f'Unexpected integral_kind={integral_kind}')
|
28 |
+
|
29 |
+
if kind == 'default':
|
30 |
+
return InpaintingEvaluatorOnline(scores=metrics,
|
31 |
+
integral_func=integral_func,
|
32 |
+
integral_title=integral_kind,
|
33 |
+
**kwargs)
|
DH-AISP/2/saicinpainting/evaluation/data.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import PIL.Image as Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def load_image(fname, mode='RGB', return_orig=False):
|
13 |
+
img = np.array(Image.open(fname).convert(mode))
|
14 |
+
if img.ndim == 3:
|
15 |
+
img = np.transpose(img, (2, 0, 1))
|
16 |
+
out_img = img.astype('float32') / 255
|
17 |
+
if return_orig:
|
18 |
+
return out_img, img
|
19 |
+
else:
|
20 |
+
return out_img
|
21 |
+
|
22 |
+
|
23 |
+
def ceil_modulo(x, mod):
|
24 |
+
if x % mod == 0:
|
25 |
+
return x
|
26 |
+
return (x // mod + 1) * mod
|
27 |
+
|
28 |
+
|
29 |
+
def pad_img_to_modulo(img, mod):
|
30 |
+
channels, height, width = img.shape
|
31 |
+
out_height = ceil_modulo(height, mod)
|
32 |
+
out_width = ceil_modulo(width, mod)
|
33 |
+
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
|
34 |
+
|
35 |
+
|
36 |
+
def pad_tensor_to_modulo(img, mod):
|
37 |
+
batch_size, channels, height, width = img.shape
|
38 |
+
out_height = ceil_modulo(height, mod)
|
39 |
+
out_width = ceil_modulo(width, mod)
|
40 |
+
return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect')
|
41 |
+
|
42 |
+
|
43 |
+
def scale_image(img, factor, interpolation=cv2.INTER_AREA):
|
44 |
+
if img.shape[0] == 1:
|
45 |
+
img = img[0]
|
46 |
+
else:
|
47 |
+
img = np.transpose(img, (1, 2, 0))
|
48 |
+
|
49 |
+
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
|
50 |
+
|
51 |
+
if img.ndim == 2:
|
52 |
+
img = img[None, ...]
|
53 |
+
else:
|
54 |
+
img = np.transpose(img, (2, 0, 1))
|
55 |
+
return img
|
56 |
+
|
57 |
+
|
58 |
+
class InpaintingDataset(Dataset):
|
59 |
+
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
|
60 |
+
self.datadir = datadir
|
61 |
+
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, '**', '*mask*.png'), recursive=True)))
|
62 |
+
self.img_filenames = [fname.rsplit('_mask', 1)[0] + img_suffix for fname in self.mask_filenames]
|
63 |
+
self.pad_out_to_modulo = pad_out_to_modulo
|
64 |
+
self.scale_factor = scale_factor
|
65 |
+
|
66 |
+
def __len__(self):
|
67 |
+
return len(self.mask_filenames)
|
68 |
+
|
69 |
+
def __getitem__(self, i):
|
70 |
+
image = load_image(self.img_filenames[i], mode='RGB')
|
71 |
+
mask = load_image(self.mask_filenames[i], mode='L')
|
72 |
+
result = dict(image=image, mask=mask[None, ...])
|
73 |
+
|
74 |
+
if self.scale_factor is not None:
|
75 |
+
result['image'] = scale_image(result['image'], self.scale_factor)
|
76 |
+
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
|
77 |
+
|
78 |
+
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
|
79 |
+
result['unpad_to_size'] = result['image'].shape[1:]
|
80 |
+
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
|
81 |
+
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
|
82 |
+
|
83 |
+
return result
|
84 |
+
|
85 |
+
class OurInpaintingDataset(Dataset):
|
86 |
+
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
|
87 |
+
self.datadir = datadir
|
88 |
+
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, 'mask', '**', '*mask*.png'), recursive=True)))
|
89 |
+
self.img_filenames = [os.path.join(self.datadir, 'img', os.path.basename(fname.rsplit('-', 1)[0].rsplit('_', 1)[0]) + '.png') for fname in self.mask_filenames]
|
90 |
+
self.pad_out_to_modulo = pad_out_to_modulo
|
91 |
+
self.scale_factor = scale_factor
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.mask_filenames)
|
95 |
+
|
96 |
+
def __getitem__(self, i):
|
97 |
+
result = dict(image=load_image(self.img_filenames[i], mode='RGB'),
|
98 |
+
mask=load_image(self.mask_filenames[i], mode='L')[None, ...])
|
99 |
+
|
100 |
+
if self.scale_factor is not None:
|
101 |
+
result['image'] = scale_image(result['image'], self.scale_factor)
|
102 |
+
result['mask'] = scale_image(result['mask'], self.scale_factor)
|
103 |
+
|
104 |
+
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
|
105 |
+
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
|
106 |
+
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
|
107 |
+
|
108 |
+
return result
|
109 |
+
|
110 |
+
class PrecomputedInpaintingResultsDataset(InpaintingDataset):
|
111 |
+
def __init__(self, datadir, predictdir, inpainted_suffix='_inpainted.jpg', **kwargs):
|
112 |
+
super().__init__(datadir, **kwargs)
|
113 |
+
if not datadir.endswith('/'):
|
114 |
+
datadir += '/'
|
115 |
+
self.predictdir = predictdir
|
116 |
+
self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
|
117 |
+
for fname in self.mask_filenames]
|
118 |
+
|
119 |
+
def __getitem__(self, i):
|
120 |
+
result = super().__getitem__(i)
|
121 |
+
result['inpainted'] = load_image(self.pred_filenames[i])
|
122 |
+
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
|
123 |
+
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
|
124 |
+
return result
|
125 |
+
|
126 |
+
class OurPrecomputedInpaintingResultsDataset(OurInpaintingDataset):
|
127 |
+
def __init__(self, datadir, predictdir, inpainted_suffix="png", **kwargs):
|
128 |
+
super().__init__(datadir, **kwargs)
|
129 |
+
if not datadir.endswith('/'):
|
130 |
+
datadir += '/'
|
131 |
+
self.predictdir = predictdir
|
132 |
+
self.pred_filenames = [os.path.join(predictdir, os.path.basename(os.path.splitext(fname)[0]) + f'_inpainted.{inpainted_suffix}')
|
133 |
+
for fname in self.mask_filenames]
|
134 |
+
# self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
|
135 |
+
# for fname in self.mask_filenames]
|
136 |
+
|
137 |
+
def __getitem__(self, i):
|
138 |
+
result = super().__getitem__(i)
|
139 |
+
result['inpainted'] = self.file_loader(self.pred_filenames[i])
|
140 |
+
|
141 |
+
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
|
142 |
+
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
|
143 |
+
return result
|
144 |
+
|
145 |
+
class InpaintingEvalOnlineDataset(Dataset):
|
146 |
+
def __init__(self, indir, mask_generator, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None, **kwargs):
|
147 |
+
self.indir = indir
|
148 |
+
self.mask_generator = mask_generator
|
149 |
+
self.img_filenames = sorted(list(glob.glob(os.path.join(self.indir, '**', f'*{img_suffix}' ), recursive=True)))
|
150 |
+
self.pad_out_to_modulo = pad_out_to_modulo
|
151 |
+
self.scale_factor = scale_factor
|
152 |
+
|
153 |
+
def __len__(self):
|
154 |
+
return len(self.img_filenames)
|
155 |
+
|
156 |
+
def __getitem__(self, i):
|
157 |
+
img, raw_image = load_image(self.img_filenames[i], mode='RGB', return_orig=True)
|
158 |
+
mask = self.mask_generator(img, raw_image=raw_image)
|
159 |
+
result = dict(image=img, mask=mask)
|
160 |
+
|
161 |
+
if self.scale_factor is not None:
|
162 |
+
result['image'] = scale_image(result['image'], self.scale_factor)
|
163 |
+
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
|
164 |
+
|
165 |
+
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
|
166 |
+
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
|
167 |
+
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
|
168 |
+
return result
|
DH-AISP/2/saicinpainting/evaluation/evaluator.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import tqdm
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
|
11 |
+
from saicinpainting.evaluation.utils import move_to_device
|
12 |
+
|
13 |
+
LOGGER = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class InpaintingEvaluator():
|
17 |
+
def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda',
|
18 |
+
integral_func=None, integral_title=None, clamp_image_range=None):
|
19 |
+
"""
|
20 |
+
:param dataset: torch.utils.data.Dataset which contains images and masks
|
21 |
+
:param scores: dict {score_name: EvaluatorScore object}
|
22 |
+
:param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples
|
23 |
+
which are defined by share of area occluded by mask
|
24 |
+
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
|
25 |
+
:param batch_size: batch_size for the dataloader
|
26 |
+
:param device: device to use
|
27 |
+
"""
|
28 |
+
self.scores = scores
|
29 |
+
self.dataset = dataset
|
30 |
+
|
31 |
+
self.area_grouping = area_grouping
|
32 |
+
self.bins = bins
|
33 |
+
|
34 |
+
self.device = torch.device(device)
|
35 |
+
|
36 |
+
self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size)
|
37 |
+
|
38 |
+
self.integral_func = integral_func
|
39 |
+
self.integral_title = integral_title
|
40 |
+
self.clamp_image_range = clamp_image_range
|
41 |
+
|
42 |
+
def _get_bin_edges(self):
|
43 |
+
bin_edges = np.linspace(0, 1, self.bins + 1)
|
44 |
+
|
45 |
+
num_digits = max(0, math.ceil(math.log10(self.bins)) - 1)
|
46 |
+
interval_names = []
|
47 |
+
for idx_bin in range(self.bins):
|
48 |
+
start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \
|
49 |
+
round(100 * bin_edges[idx_bin + 1], num_digits)
|
50 |
+
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
|
51 |
+
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
|
52 |
+
interval_names.append("{0}-{1}%".format(start_percent, end_percent))
|
53 |
+
|
54 |
+
groups = []
|
55 |
+
for batch in self.dataloader:
|
56 |
+
mask = batch['mask']
|
57 |
+
batch_size = mask.shape[0]
|
58 |
+
area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1)
|
59 |
+
bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1
|
60 |
+
# corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element
|
61 |
+
bin_indices[bin_indices == self.bins] = self.bins - 1
|
62 |
+
groups.append(bin_indices)
|
63 |
+
groups = np.hstack(groups)
|
64 |
+
|
65 |
+
return groups, interval_names
|
66 |
+
|
67 |
+
def evaluate(self, model=None):
|
68 |
+
"""
|
69 |
+
:param model: callable with signature (image_batch, mask_batch); should return inpainted_batch
|
70 |
+
:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
|
71 |
+
name of the particular group arranged by area of mask (e.g. '10-20%')
|
72 |
+
and score statistics for the group as values.
|
73 |
+
"""
|
74 |
+
results = dict()
|
75 |
+
if self.area_grouping:
|
76 |
+
groups, interval_names = self._get_bin_edges()
|
77 |
+
else:
|
78 |
+
groups = None
|
79 |
+
|
80 |
+
for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'):
|
81 |
+
score.to(self.device)
|
82 |
+
with torch.no_grad():
|
83 |
+
score.reset()
|
84 |
+
for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False):
|
85 |
+
batch = move_to_device(batch, self.device)
|
86 |
+
image_batch, mask_batch = batch['image'], batch['mask']
|
87 |
+
if self.clamp_image_range is not None:
|
88 |
+
image_batch = torch.clamp(image_batch,
|
89 |
+
min=self.clamp_image_range[0],
|
90 |
+
max=self.clamp_image_range[1])
|
91 |
+
if model is None:
|
92 |
+
assert 'inpainted' in batch, \
|
93 |
+
'Model is None, so we expected precomputed inpainting results at key "inpainted"'
|
94 |
+
inpainted_batch = batch['inpainted']
|
95 |
+
else:
|
96 |
+
inpainted_batch = model(image_batch, mask_batch)
|
97 |
+
score(inpainted_batch, image_batch, mask_batch)
|
98 |
+
total_results, group_results = score.get_value(groups=groups)
|
99 |
+
|
100 |
+
results[(score_name, 'total')] = total_results
|
101 |
+
if groups is not None:
|
102 |
+
for group_index, group_values in group_results.items():
|
103 |
+
group_name = interval_names[group_index]
|
104 |
+
results[(score_name, group_name)] = group_values
|
105 |
+
|
106 |
+
if self.integral_func is not None:
|
107 |
+
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
|
108 |
+
|
109 |
+
return results
|
110 |
+
|
111 |
+
|
112 |
+
def ssim_fid100_f1(metrics, fid_scale=100):
|
113 |
+
ssim = metrics[('ssim', 'total')]['mean']
|
114 |
+
fid = metrics[('fid', 'total')]['mean']
|
115 |
+
fid_rel = max(0, fid_scale - fid) / fid_scale
|
116 |
+
f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
|
117 |
+
return f1
|
118 |
+
|
119 |
+
|
120 |
+
def lpips_fid100_f1(metrics, fid_scale=100):
|
121 |
+
neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better
|
122 |
+
fid = metrics[('fid', 'total')]['mean']
|
123 |
+
fid_rel = max(0, fid_scale - fid) / fid_scale
|
124 |
+
f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3)
|
125 |
+
return f1
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
class InpaintingEvaluatorOnline(nn.Module):
|
130 |
+
def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted',
|
131 |
+
integral_func=None, integral_title=None, clamp_image_range=None):
|
132 |
+
"""
|
133 |
+
:param scores: dict {score_name: EvaluatorScore object}
|
134 |
+
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
|
135 |
+
:param device: device to use
|
136 |
+
"""
|
137 |
+
super().__init__()
|
138 |
+
LOGGER.info(f'{type(self)} init called')
|
139 |
+
self.scores = nn.ModuleDict(scores)
|
140 |
+
self.image_key = image_key
|
141 |
+
self.inpainted_key = inpainted_key
|
142 |
+
self.bins_num = bins
|
143 |
+
self.bin_edges = np.linspace(0, 1, self.bins_num + 1)
|
144 |
+
|
145 |
+
num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1)
|
146 |
+
self.interval_names = []
|
147 |
+
for idx_bin in range(self.bins_num):
|
148 |
+
start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \
|
149 |
+
round(100 * self.bin_edges[idx_bin + 1], num_digits)
|
150 |
+
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
|
151 |
+
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
|
152 |
+
self.interval_names.append("{0}-{1}%".format(start_percent, end_percent))
|
153 |
+
|
154 |
+
self.groups = []
|
155 |
+
|
156 |
+
self.integral_func = integral_func
|
157 |
+
self.integral_title = integral_title
|
158 |
+
self.clamp_image_range = clamp_image_range
|
159 |
+
|
160 |
+
LOGGER.info(f'{type(self)} init done')
|
161 |
+
|
162 |
+
def _get_bins(self, mask_batch):
|
163 |
+
batch_size = mask_batch.shape[0]
|
164 |
+
area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy()
|
165 |
+
bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1)
|
166 |
+
return bin_indices
|
167 |
+
|
168 |
+
def forward(self, batch: Dict[str, torch.Tensor]):
|
169 |
+
"""
|
170 |
+
Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end
|
171 |
+
:param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key)
|
172 |
+
"""
|
173 |
+
result = {}
|
174 |
+
with torch.no_grad():
|
175 |
+
image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key]
|
176 |
+
if self.clamp_image_range is not None:
|
177 |
+
image_batch = torch.clamp(image_batch,
|
178 |
+
min=self.clamp_image_range[0],
|
179 |
+
max=self.clamp_image_range[1])
|
180 |
+
self.groups.extend(self._get_bins(mask_batch))
|
181 |
+
|
182 |
+
for score_name, score in self.scores.items():
|
183 |
+
result[score_name] = score(inpainted_batch, image_batch, mask_batch)
|
184 |
+
return result
|
185 |
+
|
186 |
+
def process_batch(self, batch: Dict[str, torch.Tensor]):
|
187 |
+
return self(batch)
|
188 |
+
|
189 |
+
def evaluation_end(self, states=None):
|
190 |
+
""":return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
|
191 |
+
name of the particular group arranged by area of mask (e.g. '10-20%')
|
192 |
+
and score statistics for the group as values.
|
193 |
+
"""
|
194 |
+
LOGGER.info(f'{type(self)}: evaluation_end called')
|
195 |
+
|
196 |
+
self.groups = np.array(self.groups)
|
197 |
+
|
198 |
+
results = {}
|
199 |
+
for score_name, score in self.scores.items():
|
200 |
+
LOGGER.info(f'Getting value of {score_name}')
|
201 |
+
cur_states = [s[score_name] for s in states] if states is not None else None
|
202 |
+
total_results, group_results = score.get_value(groups=self.groups, states=cur_states)
|
203 |
+
LOGGER.info(f'Getting value of {score_name} done')
|
204 |
+
results[(score_name, 'total')] = total_results
|
205 |
+
|
206 |
+
for group_index, group_values in group_results.items():
|
207 |
+
group_name = self.interval_names[group_index]
|
208 |
+
results[(score_name, group_name)] = group_values
|
209 |
+
|
210 |
+
if self.integral_func is not None:
|
211 |
+
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
|
212 |
+
|
213 |
+
LOGGER.info(f'{type(self)}: reset scores')
|
214 |
+
self.groups = []
|
215 |
+
for sc in self.scores.values():
|
216 |
+
sc.reset()
|
217 |
+
LOGGER.info(f'{type(self)}: reset scores done')
|
218 |
+
|
219 |
+
LOGGER.info(f'{type(self)}: evaluation_end done')
|
220 |
+
return results
|
DH-AISP/2/saicinpainting/evaluation/losses/__init__.py
ADDED
File without changes
|
DH-AISP/2/saicinpainting/evaluation/losses/base_loss.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from abc import abstractmethod, ABC
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import sklearn
|
6 |
+
import sklearn.svm
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
from scipy import linalg
|
12 |
+
|
13 |
+
from models.ade20k import SegmentationModule, NUM_CLASS, segm_options
|
14 |
+
from .fid.inception import InceptionV3
|
15 |
+
from .lpips import PerceptualLoss
|
16 |
+
from .ssim import SSIM
|
17 |
+
|
18 |
+
LOGGER = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
def get_groupings(groups):
|
22 |
+
"""
|
23 |
+
:param groups: group numbers for respective elements
|
24 |
+
:return: dict of kind {group_idx: indices of the corresponding group elements}
|
25 |
+
"""
|
26 |
+
label_groups, count_groups = np.unique(groups, return_counts=True)
|
27 |
+
|
28 |
+
indices = np.argsort(groups)
|
29 |
+
|
30 |
+
grouping = dict()
|
31 |
+
cur_start = 0
|
32 |
+
for label, count in zip(label_groups, count_groups):
|
33 |
+
cur_end = cur_start + count
|
34 |
+
cur_indices = indices[cur_start:cur_end]
|
35 |
+
grouping[label] = cur_indices
|
36 |
+
cur_start = cur_end
|
37 |
+
return grouping
|
38 |
+
|
39 |
+
|
40 |
+
class EvaluatorScore(nn.Module):
|
41 |
+
@abstractmethod
|
42 |
+
def forward(self, pred_batch, target_batch, mask):
|
43 |
+
pass
|
44 |
+
|
45 |
+
@abstractmethod
|
46 |
+
def get_value(self, groups=None, states=None):
|
47 |
+
pass
|
48 |
+
|
49 |
+
@abstractmethod
|
50 |
+
def reset(self):
|
51 |
+
pass
|
52 |
+
|
53 |
+
|
54 |
+
class PairwiseScore(EvaluatorScore, ABC):
|
55 |
+
def __init__(self):
|
56 |
+
super().__init__()
|
57 |
+
self.individual_values = None
|
58 |
+
|
59 |
+
def get_value(self, groups=None, states=None):
|
60 |
+
"""
|
61 |
+
:param groups:
|
62 |
+
:return:
|
63 |
+
total_results: dict of kind {'mean': score mean, 'std': score std}
|
64 |
+
group_results: None, if groups is None;
|
65 |
+
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
|
66 |
+
"""
|
67 |
+
individual_values = torch.cat(states, dim=-1).reshape(-1).cpu().numpy() if states is not None \
|
68 |
+
else self.individual_values
|
69 |
+
|
70 |
+
total_results = {
|
71 |
+
'mean': individual_values.mean(),
|
72 |
+
'std': individual_values.std()
|
73 |
+
}
|
74 |
+
|
75 |
+
if groups is None:
|
76 |
+
return total_results, None
|
77 |
+
|
78 |
+
group_results = dict()
|
79 |
+
grouping = get_groupings(groups)
|
80 |
+
for label, index in grouping.items():
|
81 |
+
group_scores = individual_values[index]
|
82 |
+
group_results[label] = {
|
83 |
+
'mean': group_scores.mean(),
|
84 |
+
'std': group_scores.std()
|
85 |
+
}
|
86 |
+
return total_results, group_results
|
87 |
+
|
88 |
+
def reset(self):
|
89 |
+
self.individual_values = []
|
90 |
+
|
91 |
+
|
92 |
+
class SSIMScore(PairwiseScore):
|
93 |
+
def __init__(self, window_size=11):
|
94 |
+
super().__init__()
|
95 |
+
self.score = SSIM(window_size=window_size, size_average=False).eval()
|
96 |
+
self.reset()
|
97 |
+
|
98 |
+
def forward(self, pred_batch, target_batch, mask=None):
|
99 |
+
batch_values = self.score(pred_batch, target_batch)
|
100 |
+
self.individual_values = np.hstack([
|
101 |
+
self.individual_values, batch_values.detach().cpu().numpy()
|
102 |
+
])
|
103 |
+
return batch_values
|
104 |
+
|
105 |
+
|
106 |
+
class LPIPSScore(PairwiseScore):
|
107 |
+
def __init__(self, model='net-lin', net='vgg', model_path=None, use_gpu=True):
|
108 |
+
super().__init__()
|
109 |
+
self.score = PerceptualLoss(model=model, net=net, model_path=model_path,
|
110 |
+
use_gpu=use_gpu, spatial=False).eval()
|
111 |
+
self.reset()
|
112 |
+
|
113 |
+
def forward(self, pred_batch, target_batch, mask=None):
|
114 |
+
batch_values = self.score(pred_batch, target_batch).flatten()
|
115 |
+
self.individual_values = np.hstack([
|
116 |
+
self.individual_values, batch_values.detach().cpu().numpy()
|
117 |
+
])
|
118 |
+
return batch_values
|
119 |
+
|
120 |
+
|
121 |
+
def fid_calculate_activation_statistics(act):
|
122 |
+
mu = np.mean(act, axis=0)
|
123 |
+
sigma = np.cov(act, rowvar=False)
|
124 |
+
return mu, sigma
|
125 |
+
|
126 |
+
|
127 |
+
def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6):
|
128 |
+
mu1, sigma1 = fid_calculate_activation_statistics(activations_pred)
|
129 |
+
mu2, sigma2 = fid_calculate_activation_statistics(activations_target)
|
130 |
+
|
131 |
+
diff = mu1 - mu2
|
132 |
+
|
133 |
+
# Product might be almost singular
|
134 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
135 |
+
if not np.isfinite(covmean).all():
|
136 |
+
msg = ('fid calculation produces singular product; '
|
137 |
+
'adding %s to diagonal of cov estimates') % eps
|
138 |
+
LOGGER.warning(msg)
|
139 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
140 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
141 |
+
|
142 |
+
# Numerical error might give slight imaginary component
|
143 |
+
if np.iscomplexobj(covmean):
|
144 |
+
# if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
145 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2):
|
146 |
+
m = np.max(np.abs(covmean.imag))
|
147 |
+
raise ValueError('Imaginary component {}'.format(m))
|
148 |
+
covmean = covmean.real
|
149 |
+
|
150 |
+
tr_covmean = np.trace(covmean)
|
151 |
+
|
152 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
153 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
154 |
+
|
155 |
+
|
156 |
+
class FIDScore(EvaluatorScore):
|
157 |
+
def __init__(self, dims=2048, eps=1e-6):
|
158 |
+
LOGGER.info("FIDscore init called")
|
159 |
+
super().__init__()
|
160 |
+
if getattr(FIDScore, '_MODEL', None) is None:
|
161 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
162 |
+
FIDScore._MODEL = InceptionV3([block_idx]).eval()
|
163 |
+
self.model = FIDScore._MODEL
|
164 |
+
self.eps = eps
|
165 |
+
self.reset()
|
166 |
+
LOGGER.info("FIDscore init done")
|
167 |
+
|
168 |
+
def forward(self, pred_batch, target_batch, mask=None):
|
169 |
+
activations_pred = self._get_activations(pred_batch)
|
170 |
+
activations_target = self._get_activations(target_batch)
|
171 |
+
|
172 |
+
self.activations_pred.append(activations_pred.detach().cpu())
|
173 |
+
self.activations_target.append(activations_target.detach().cpu())
|
174 |
+
|
175 |
+
return activations_pred, activations_target
|
176 |
+
|
177 |
+
def get_value(self, groups=None, states=None):
|
178 |
+
LOGGER.info("FIDscore get_value called")
|
179 |
+
activations_pred, activations_target = zip(*states) if states is not None \
|
180 |
+
else (self.activations_pred, self.activations_target)
|
181 |
+
activations_pred = torch.cat(activations_pred).cpu().numpy()
|
182 |
+
activations_target = torch.cat(activations_target).cpu().numpy()
|
183 |
+
|
184 |
+
total_distance = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps)
|
185 |
+
total_results = dict(mean=total_distance)
|
186 |
+
|
187 |
+
if groups is None:
|
188 |
+
group_results = None
|
189 |
+
else:
|
190 |
+
group_results = dict()
|
191 |
+
grouping = get_groupings(groups)
|
192 |
+
for label, index in grouping.items():
|
193 |
+
if len(index) > 1:
|
194 |
+
group_distance = calculate_frechet_distance(activations_pred[index], activations_target[index],
|
195 |
+
eps=self.eps)
|
196 |
+
group_results[label] = dict(mean=group_distance)
|
197 |
+
|
198 |
+
else:
|
199 |
+
group_results[label] = dict(mean=float('nan'))
|
200 |
+
|
201 |
+
self.reset()
|
202 |
+
|
203 |
+
LOGGER.info("FIDscore get_value done")
|
204 |
+
|
205 |
+
return total_results, group_results
|
206 |
+
|
207 |
+
def reset(self):
|
208 |
+
self.activations_pred = []
|
209 |
+
self.activations_target = []
|
210 |
+
|
211 |
+
def _get_activations(self, batch):
|
212 |
+
activations = self.model(batch)[0]
|
213 |
+
if activations.shape[2] != 1 or activations.shape[3] != 1:
|
214 |
+
assert False, \
|
215 |
+
'We should not have got here, because Inception always scales inputs to 299x299'
|
216 |
+
# activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1))
|
217 |
+
activations = activations.squeeze(-1).squeeze(-1)
|
218 |
+
return activations
|
219 |
+
|
220 |
+
|
221 |
+
class SegmentationAwareScore(EvaluatorScore):
|
222 |
+
def __init__(self, weights_path):
|
223 |
+
super().__init__()
|
224 |
+
self.segm_network = SegmentationModule(weights_path=weights_path, use_default_normalization=True).eval()
|
225 |
+
self.target_class_freq_by_image_total = []
|
226 |
+
self.target_class_freq_by_image_mask = []
|
227 |
+
self.pred_class_freq_by_image_mask = []
|
228 |
+
|
229 |
+
def forward(self, pred_batch, target_batch, mask):
|
230 |
+
pred_segm_flat = self.segm_network.predict(pred_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy()
|
231 |
+
target_segm_flat = self.segm_network.predict(target_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy()
|
232 |
+
mask_flat = (mask.view(mask.shape[0], -1) > 0.5).detach().cpu().numpy()
|
233 |
+
|
234 |
+
batch_target_class_freq_total = []
|
235 |
+
batch_target_class_freq_mask = []
|
236 |
+
batch_pred_class_freq_mask = []
|
237 |
+
|
238 |
+
for cur_pred_segm, cur_target_segm, cur_mask in zip(pred_segm_flat, target_segm_flat, mask_flat):
|
239 |
+
cur_target_class_freq_total = np.bincount(cur_target_segm, minlength=NUM_CLASS)[None, ...]
|
240 |
+
cur_target_class_freq_mask = np.bincount(cur_target_segm[cur_mask], minlength=NUM_CLASS)[None, ...]
|
241 |
+
cur_pred_class_freq_mask = np.bincount(cur_pred_segm[cur_mask], minlength=NUM_CLASS)[None, ...]
|
242 |
+
|
243 |
+
self.target_class_freq_by_image_total.append(cur_target_class_freq_total)
|
244 |
+
self.target_class_freq_by_image_mask.append(cur_target_class_freq_mask)
|
245 |
+
self.pred_class_freq_by_image_mask.append(cur_pred_class_freq_mask)
|
246 |
+
|
247 |
+
batch_target_class_freq_total.append(cur_target_class_freq_total)
|
248 |
+
batch_target_class_freq_mask.append(cur_target_class_freq_mask)
|
249 |
+
batch_pred_class_freq_mask.append(cur_pred_class_freq_mask)
|
250 |
+
|
251 |
+
batch_target_class_freq_total = np.concatenate(batch_target_class_freq_total, axis=0)
|
252 |
+
batch_target_class_freq_mask = np.concatenate(batch_target_class_freq_mask, axis=0)
|
253 |
+
batch_pred_class_freq_mask = np.concatenate(batch_pred_class_freq_mask, axis=0)
|
254 |
+
return batch_target_class_freq_total, batch_target_class_freq_mask, batch_pred_class_freq_mask
|
255 |
+
|
256 |
+
def reset(self):
|
257 |
+
super().reset()
|
258 |
+
self.target_class_freq_by_image_total = []
|
259 |
+
self.target_class_freq_by_image_mask = []
|
260 |
+
self.pred_class_freq_by_image_mask = []
|
261 |
+
|
262 |
+
|
263 |
+
def distribute_values_to_classes(target_class_freq_by_image_mask, values, idx2name):
|
264 |
+
assert target_class_freq_by_image_mask.ndim == 2 and target_class_freq_by_image_mask.shape[0] == values.shape[0]
|
265 |
+
total_class_freq = target_class_freq_by_image_mask.sum(0)
|
266 |
+
distr_values = (target_class_freq_by_image_mask * values[..., None]).sum(0)
|
267 |
+
result = distr_values / (total_class_freq + 1e-3)
|
268 |
+
return {idx2name[i]: val for i, val in enumerate(result) if total_class_freq[i] > 0}
|
269 |
+
|
270 |
+
|
271 |
+
def get_segmentation_idx2name():
|
272 |
+
return {i - 1: name for i, name in segm_options['classes'].set_index('Idx', drop=True)['Name'].to_dict().items()}
|
273 |
+
|
274 |
+
|
275 |
+
class SegmentationAwarePairwiseScore(SegmentationAwareScore):
|
276 |
+
def __init__(self, *args, **kwargs):
|
277 |
+
super().__init__(*args, **kwargs)
|
278 |
+
self.individual_values = []
|
279 |
+
self.segm_idx2name = get_segmentation_idx2name()
|
280 |
+
|
281 |
+
def forward(self, pred_batch, target_batch, mask):
|
282 |
+
cur_class_stats = super().forward(pred_batch, target_batch, mask)
|
283 |
+
score_values = self.calc_score(pred_batch, target_batch, mask)
|
284 |
+
self.individual_values.append(score_values)
|
285 |
+
return cur_class_stats + (score_values,)
|
286 |
+
|
287 |
+
@abstractmethod
|
288 |
+
def calc_score(self, pred_batch, target_batch, mask):
|
289 |
+
raise NotImplementedError()
|
290 |
+
|
291 |
+
def get_value(self, groups=None, states=None):
|
292 |
+
"""
|
293 |
+
:param groups:
|
294 |
+
:return:
|
295 |
+
total_results: dict of kind {'mean': score mean, 'std': score std}
|
296 |
+
group_results: None, if groups is None;
|
297 |
+
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
|
298 |
+
"""
|
299 |
+
if states is not None:
|
300 |
+
(target_class_freq_by_image_total,
|
301 |
+
target_class_freq_by_image_mask,
|
302 |
+
pred_class_freq_by_image_mask,
|
303 |
+
individual_values) = states
|
304 |
+
else:
|
305 |
+
target_class_freq_by_image_total = self.target_class_freq_by_image_total
|
306 |
+
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
|
307 |
+
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
|
308 |
+
individual_values = self.individual_values
|
309 |
+
|
310 |
+
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
|
311 |
+
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
|
312 |
+
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
|
313 |
+
individual_values = np.concatenate(individual_values, axis=0)
|
314 |
+
|
315 |
+
total_results = {
|
316 |
+
'mean': individual_values.mean(),
|
317 |
+
'std': individual_values.std(),
|
318 |
+
**distribute_values_to_classes(target_class_freq_by_image_mask, individual_values, self.segm_idx2name)
|
319 |
+
}
|
320 |
+
|
321 |
+
if groups is None:
|
322 |
+
return total_results, None
|
323 |
+
|
324 |
+
group_results = dict()
|
325 |
+
grouping = get_groupings(groups)
|
326 |
+
for label, index in grouping.items():
|
327 |
+
group_class_freq = target_class_freq_by_image_mask[index]
|
328 |
+
group_scores = individual_values[index]
|
329 |
+
group_results[label] = {
|
330 |
+
'mean': group_scores.mean(),
|
331 |
+
'std': group_scores.std(),
|
332 |
+
** distribute_values_to_classes(group_class_freq, group_scores, self.segm_idx2name)
|
333 |
+
}
|
334 |
+
return total_results, group_results
|
335 |
+
|
336 |
+
def reset(self):
|
337 |
+
super().reset()
|
338 |
+
self.individual_values = []
|
339 |
+
|
340 |
+
|
341 |
+
class SegmentationClassStats(SegmentationAwarePairwiseScore):
|
342 |
+
def calc_score(self, pred_batch, target_batch, mask):
|
343 |
+
return 0
|
344 |
+
|
345 |
+
def get_value(self, groups=None, states=None):
|
346 |
+
"""
|
347 |
+
:param groups:
|
348 |
+
:return:
|
349 |
+
total_results: dict of kind {'mean': score mean, 'std': score std}
|
350 |
+
group_results: None, if groups is None;
|
351 |
+
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
|
352 |
+
"""
|
353 |
+
if states is not None:
|
354 |
+
(target_class_freq_by_image_total,
|
355 |
+
target_class_freq_by_image_mask,
|
356 |
+
pred_class_freq_by_image_mask,
|
357 |
+
_) = states
|
358 |
+
else:
|
359 |
+
target_class_freq_by_image_total = self.target_class_freq_by_image_total
|
360 |
+
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
|
361 |
+
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
|
362 |
+
|
363 |
+
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
|
364 |
+
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
|
365 |
+
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
|
366 |
+
|
367 |
+
target_class_freq_by_image_total_marginal = target_class_freq_by_image_total.sum(0).astype('float32')
|
368 |
+
target_class_freq_by_image_total_marginal /= target_class_freq_by_image_total_marginal.sum()
|
369 |
+
|
370 |
+
target_class_freq_by_image_mask_marginal = target_class_freq_by_image_mask.sum(0).astype('float32')
|
371 |
+
target_class_freq_by_image_mask_marginal /= target_class_freq_by_image_mask_marginal.sum()
|
372 |
+
|
373 |
+
pred_class_freq_diff = (pred_class_freq_by_image_mask - target_class_freq_by_image_mask).sum(0) / (target_class_freq_by_image_mask.sum(0) + 1e-3)
|
374 |
+
|
375 |
+
total_results = dict()
|
376 |
+
total_results.update({f'total_freq/{self.segm_idx2name[i]}': v
|
377 |
+
for i, v in enumerate(target_class_freq_by_image_total_marginal)
|
378 |
+
if v > 0})
|
379 |
+
total_results.update({f'mask_freq/{self.segm_idx2name[i]}': v
|
380 |
+
for i, v in enumerate(target_class_freq_by_image_mask_marginal)
|
381 |
+
if v > 0})
|
382 |
+
total_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v
|
383 |
+
for i, v in enumerate(pred_class_freq_diff)
|
384 |
+
if target_class_freq_by_image_total_marginal[i] > 0})
|
385 |
+
|
386 |
+
if groups is None:
|
387 |
+
return total_results, None
|
388 |
+
|
389 |
+
group_results = dict()
|
390 |
+
grouping = get_groupings(groups)
|
391 |
+
for label, index in grouping.items():
|
392 |
+
group_target_class_freq_by_image_total = target_class_freq_by_image_total[index]
|
393 |
+
group_target_class_freq_by_image_mask = target_class_freq_by_image_mask[index]
|
394 |
+
group_pred_class_freq_by_image_mask = pred_class_freq_by_image_mask[index]
|
395 |
+
|
396 |
+
group_target_class_freq_by_image_total_marginal = group_target_class_freq_by_image_total.sum(0).astype('float32')
|
397 |
+
group_target_class_freq_by_image_total_marginal /= group_target_class_freq_by_image_total_marginal.sum()
|
398 |
+
|
399 |
+
group_target_class_freq_by_image_mask_marginal = group_target_class_freq_by_image_mask.sum(0).astype('float32')
|
400 |
+
group_target_class_freq_by_image_mask_marginal /= group_target_class_freq_by_image_mask_marginal.sum()
|
401 |
+
|
402 |
+
group_pred_class_freq_diff = (group_pred_class_freq_by_image_mask - group_target_class_freq_by_image_mask).sum(0) / (
|
403 |
+
group_target_class_freq_by_image_mask.sum(0) + 1e-3)
|
404 |
+
|
405 |
+
cur_group_results = dict()
|
406 |
+
cur_group_results.update({f'total_freq/{self.segm_idx2name[i]}': v
|
407 |
+
for i, v in enumerate(group_target_class_freq_by_image_total_marginal)
|
408 |
+
if v > 0})
|
409 |
+
cur_group_results.update({f'mask_freq/{self.segm_idx2name[i]}': v
|
410 |
+
for i, v in enumerate(group_target_class_freq_by_image_mask_marginal)
|
411 |
+
if v > 0})
|
412 |
+
cur_group_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v
|
413 |
+
for i, v in enumerate(group_pred_class_freq_diff)
|
414 |
+
if group_target_class_freq_by_image_total_marginal[i] > 0})
|
415 |
+
|
416 |
+
group_results[label] = cur_group_results
|
417 |
+
return total_results, group_results
|
418 |
+
|
419 |
+
|
420 |
+
class SegmentationAwareSSIM(SegmentationAwarePairwiseScore):
|
421 |
+
def __init__(self, *args, window_size=11, **kwargs):
|
422 |
+
super().__init__(*args, **kwargs)
|
423 |
+
self.score_impl = SSIM(window_size=window_size, size_average=False).eval()
|
424 |
+
|
425 |
+
def calc_score(self, pred_batch, target_batch, mask):
|
426 |
+
return self.score_impl(pred_batch, target_batch).detach().cpu().numpy()
|
427 |
+
|
428 |
+
|
429 |
+
class SegmentationAwareLPIPS(SegmentationAwarePairwiseScore):
|
430 |
+
def __init__(self, *args, model='net-lin', net='vgg', model_path=None, use_gpu=True, **kwargs):
|
431 |
+
super().__init__(*args, **kwargs)
|
432 |
+
self.score_impl = PerceptualLoss(model=model, net=net, model_path=model_path,
|
433 |
+
use_gpu=use_gpu, spatial=False).eval()
|
434 |
+
|
435 |
+
def calc_score(self, pred_batch, target_batch, mask):
|
436 |
+
return self.score_impl(pred_batch, target_batch).flatten().detach().cpu().numpy()
|
437 |
+
|
438 |
+
|
439 |
+
def calculade_fid_no_img(img_i, activations_pred, activations_target, eps=1e-6):
|
440 |
+
activations_pred = activations_pred.copy()
|
441 |
+
activations_pred[img_i] = activations_target[img_i]
|
442 |
+
return calculate_frechet_distance(activations_pred, activations_target, eps=eps)
|
443 |
+
|
444 |
+
|
445 |
+
class SegmentationAwareFID(SegmentationAwarePairwiseScore):
|
446 |
+
def __init__(self, *args, dims=2048, eps=1e-6, n_jobs=-1, **kwargs):
|
447 |
+
super().__init__(*args, **kwargs)
|
448 |
+
if getattr(FIDScore, '_MODEL', None) is None:
|
449 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
450 |
+
FIDScore._MODEL = InceptionV3([block_idx]).eval()
|
451 |
+
self.model = FIDScore._MODEL
|
452 |
+
self.eps = eps
|
453 |
+
self.n_jobs = n_jobs
|
454 |
+
|
455 |
+
def calc_score(self, pred_batch, target_batch, mask):
|
456 |
+
activations_pred = self._get_activations(pred_batch)
|
457 |
+
activations_target = self._get_activations(target_batch)
|
458 |
+
return activations_pred, activations_target
|
459 |
+
|
460 |
+
def get_value(self, groups=None, states=None):
|
461 |
+
"""
|
462 |
+
:param groups:
|
463 |
+
:return:
|
464 |
+
total_results: dict of kind {'mean': score mean, 'std': score std}
|
465 |
+
group_results: None, if groups is None;
|
466 |
+
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
|
467 |
+
"""
|
468 |
+
if states is not None:
|
469 |
+
(target_class_freq_by_image_total,
|
470 |
+
target_class_freq_by_image_mask,
|
471 |
+
pred_class_freq_by_image_mask,
|
472 |
+
activation_pairs) = states
|
473 |
+
else:
|
474 |
+
target_class_freq_by_image_total = self.target_class_freq_by_image_total
|
475 |
+
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
|
476 |
+
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
|
477 |
+
activation_pairs = self.individual_values
|
478 |
+
|
479 |
+
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
|
480 |
+
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
|
481 |
+
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
|
482 |
+
activations_pred, activations_target = zip(*activation_pairs)
|
483 |
+
activations_pred = np.concatenate(activations_pred, axis=0)
|
484 |
+
activations_target = np.concatenate(activations_target, axis=0)
|
485 |
+
|
486 |
+
total_results = {
|
487 |
+
'mean': calculate_frechet_distance(activations_pred, activations_target, eps=self.eps),
|
488 |
+
'std': 0,
|
489 |
+
**self.distribute_fid_to_classes(target_class_freq_by_image_mask, activations_pred, activations_target)
|
490 |
+
}
|
491 |
+
|
492 |
+
if groups is None:
|
493 |
+
return total_results, None
|
494 |
+
|
495 |
+
group_results = dict()
|
496 |
+
grouping = get_groupings(groups)
|
497 |
+
for label, index in grouping.items():
|
498 |
+
if len(index) > 1:
|
499 |
+
group_activations_pred = activations_pred[index]
|
500 |
+
group_activations_target = activations_target[index]
|
501 |
+
group_class_freq = target_class_freq_by_image_mask[index]
|
502 |
+
group_results[label] = {
|
503 |
+
'mean': calculate_frechet_distance(group_activations_pred, group_activations_target, eps=self.eps),
|
504 |
+
'std': 0,
|
505 |
+
**self.distribute_fid_to_classes(group_class_freq,
|
506 |
+
group_activations_pred,
|
507 |
+
group_activations_target)
|
508 |
+
}
|
509 |
+
else:
|
510 |
+
group_results[label] = dict(mean=float('nan'), std=0)
|
511 |
+
return total_results, group_results
|
512 |
+
|
513 |
+
def distribute_fid_to_classes(self, class_freq, activations_pred, activations_target):
|
514 |
+
real_fid = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps)
|
515 |
+
|
516 |
+
fid_no_images = Parallel(n_jobs=self.n_jobs)(
|
517 |
+
delayed(calculade_fid_no_img)(img_i, activations_pred, activations_target, eps=self.eps)
|
518 |
+
for img_i in range(activations_pred.shape[0])
|
519 |
+
)
|
520 |
+
errors = real_fid - fid_no_images
|
521 |
+
return distribute_values_to_classes(class_freq, errors, self.segm_idx2name)
|
522 |
+
|
523 |
+
def _get_activations(self, batch):
|
524 |
+
activations = self.model(batch)[0]
|
525 |
+
if activations.shape[2] != 1 or activations.shape[3] != 1:
|
526 |
+
activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1))
|
527 |
+
activations = activations.squeeze(-1).squeeze(-1).detach().cpu().numpy()
|
528 |
+
return activations
|
DH-AISP/2/saicinpainting/evaluation/losses/fid/__init__.py
ADDED
File without changes
|
DH-AISP/2/saicinpainting/evaluation/losses/fid/fid_score.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
3 |
+
|
4 |
+
The FID metric calculates the distance between two distributions of images.
|
5 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
6 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
7 |
+
|
8 |
+
When run as a stand-alone program, it compares the distribution of
|
9 |
+
images that are stored as PNG/JPEG at a specified location with a
|
10 |
+
distribution given by summary statistics (in pickle format).
|
11 |
+
|
12 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
13 |
+
the pool_3 layer of the inception net for generated samples and real world
|
14 |
+
samples respectively.
|
15 |
+
|
16 |
+
See --help to see further details.
|
17 |
+
|
18 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
19 |
+
of Tensorflow
|
20 |
+
|
21 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
22 |
+
|
23 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
24 |
+
you may not use this file except in compliance with the License.
|
25 |
+
You may obtain a copy of the License at
|
26 |
+
|
27 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
28 |
+
|
29 |
+
Unless required by applicable law or agreed to in writing, software
|
30 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
31 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
32 |
+
See the License for the specific language governing permissions and
|
33 |
+
limitations under the License.
|
34 |
+
"""
|
35 |
+
import os
|
36 |
+
import pathlib
|
37 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
38 |
+
|
39 |
+
import numpy as np
|
40 |
+
import torch
|
41 |
+
# from scipy.misc import imread
|
42 |
+
from imageio import imread
|
43 |
+
from PIL import Image, JpegImagePlugin
|
44 |
+
from scipy import linalg
|
45 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
46 |
+
from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor
|
47 |
+
|
48 |
+
try:
|
49 |
+
from tqdm import tqdm
|
50 |
+
except ImportError:
|
51 |
+
# If not tqdm is not available, provide a mock version of it
|
52 |
+
def tqdm(x): return x
|
53 |
+
|
54 |
+
try:
|
55 |
+
from .inception import InceptionV3
|
56 |
+
except ModuleNotFoundError:
|
57 |
+
from inception import InceptionV3
|
58 |
+
|
59 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
60 |
+
parser.add_argument('path', type=str, nargs=2,
|
61 |
+
help=('Path to the generated images or '
|
62 |
+
'to .npz statistic files'))
|
63 |
+
parser.add_argument('--batch-size', type=int, default=50,
|
64 |
+
help='Batch size to use')
|
65 |
+
parser.add_argument('--dims', type=int, default=2048,
|
66 |
+
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
|
67 |
+
help=('Dimensionality of Inception features to use. '
|
68 |
+
'By default, uses pool3 features'))
|
69 |
+
parser.add_argument('-c', '--gpu', default='', type=str,
|
70 |
+
help='GPU to use (leave blank for CPU only)')
|
71 |
+
parser.add_argument('--resize', default=256)
|
72 |
+
|
73 |
+
transform = Compose([Resize(256), CenterCrop(256), ToTensor()])
|
74 |
+
|
75 |
+
|
76 |
+
def get_activations(files, model, batch_size=50, dims=2048,
|
77 |
+
cuda=False, verbose=False, keep_size=False):
|
78 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
79 |
+
|
80 |
+
Params:
|
81 |
+
-- files : List of image files paths
|
82 |
+
-- model : Instance of inception model
|
83 |
+
-- batch_size : Batch size of images for the model to process at once.
|
84 |
+
Make sure that the number of samples is a multiple of
|
85 |
+
the batch size, otherwise some samples are ignored. This
|
86 |
+
behavior is retained to match the original FID score
|
87 |
+
implementation.
|
88 |
+
-- dims : Dimensionality of features returned by Inception
|
89 |
+
-- cuda : If set to True, use GPU
|
90 |
+
-- verbose : If set to True and parameter out_step is given, the number
|
91 |
+
of calculated batches is reported.
|
92 |
+
Returns:
|
93 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
94 |
+
activations of the given tensor when feeding inception with the
|
95 |
+
query tensor.
|
96 |
+
"""
|
97 |
+
model.eval()
|
98 |
+
|
99 |
+
if len(files) % batch_size != 0:
|
100 |
+
print(('Warning: number of images is not a multiple of the '
|
101 |
+
'batch size. Some samples are going to be ignored.'))
|
102 |
+
if batch_size > len(files):
|
103 |
+
print(('Warning: batch size is bigger than the data size. '
|
104 |
+
'Setting batch size to data size'))
|
105 |
+
batch_size = len(files)
|
106 |
+
|
107 |
+
n_batches = len(files) // batch_size
|
108 |
+
n_used_imgs = n_batches * batch_size
|
109 |
+
|
110 |
+
pred_arr = np.empty((n_used_imgs, dims))
|
111 |
+
|
112 |
+
for i in tqdm(range(n_batches)):
|
113 |
+
if verbose:
|
114 |
+
print('\rPropagating batch %d/%d' % (i + 1, n_batches),
|
115 |
+
end='', flush=True)
|
116 |
+
start = i * batch_size
|
117 |
+
end = start + batch_size
|
118 |
+
|
119 |
+
# # Official code goes below
|
120 |
+
# images = np.array([imread(str(f)).astype(np.float32)
|
121 |
+
# for f in files[start:end]])
|
122 |
+
|
123 |
+
# # Reshape to (n_images, 3, height, width)
|
124 |
+
# images = images.transpose((0, 3, 1, 2))
|
125 |
+
# images /= 255
|
126 |
+
# batch = torch.from_numpy(images).type(torch.FloatTensor)
|
127 |
+
# #
|
128 |
+
|
129 |
+
t = transform if not keep_size else ToTensor()
|
130 |
+
|
131 |
+
if isinstance(files[0], pathlib.PosixPath):
|
132 |
+
images = [t(Image.open(str(f))) for f in files[start:end]]
|
133 |
+
|
134 |
+
elif isinstance(files[0], Image.Image):
|
135 |
+
images = [t(f) for f in files[start:end]]
|
136 |
+
|
137 |
+
else:
|
138 |
+
raise ValueError(f"Unknown data type for image: {type(files[0])}")
|
139 |
+
|
140 |
+
batch = torch.stack(images)
|
141 |
+
|
142 |
+
if cuda:
|
143 |
+
batch = batch.cuda()
|
144 |
+
|
145 |
+
pred = model(batch)[0]
|
146 |
+
|
147 |
+
# If model output is not scalar, apply global spatial average pooling.
|
148 |
+
# This happens if you choose a dimensionality not equal 2048.
|
149 |
+
if pred.shape[2] != 1 or pred.shape[3] != 1:
|
150 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
151 |
+
|
152 |
+
pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
|
153 |
+
|
154 |
+
if verbose:
|
155 |
+
print(' done')
|
156 |
+
|
157 |
+
return pred_arr
|
158 |
+
|
159 |
+
|
160 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
161 |
+
"""Numpy implementation of the Frechet Distance.
|
162 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
163 |
+
and X_2 ~ N(mu_2, C_2) is
|
164 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
165 |
+
|
166 |
+
Stable version by Dougal J. Sutherland.
|
167 |
+
|
168 |
+
Params:
|
169 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
170 |
+
inception net (like returned by the function 'get_predictions')
|
171 |
+
for generated samples.
|
172 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
173 |
+
representative data set.
|
174 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
175 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
176 |
+
representative data set.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
-- : The Frechet Distance.
|
180 |
+
"""
|
181 |
+
|
182 |
+
mu1 = np.atleast_1d(mu1)
|
183 |
+
mu2 = np.atleast_1d(mu2)
|
184 |
+
|
185 |
+
sigma1 = np.atleast_2d(sigma1)
|
186 |
+
sigma2 = np.atleast_2d(sigma2)
|
187 |
+
|
188 |
+
assert mu1.shape == mu2.shape, \
|
189 |
+
'Training and test mean vectors have different lengths'
|
190 |
+
assert sigma1.shape == sigma2.shape, \
|
191 |
+
'Training and test covariances have different dimensions'
|
192 |
+
|
193 |
+
diff = mu1 - mu2
|
194 |
+
|
195 |
+
# Product might be almost singular
|
196 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
197 |
+
if not np.isfinite(covmean).all():
|
198 |
+
msg = ('fid calculation produces singular product; '
|
199 |
+
'adding %s to diagonal of cov estimates') % eps
|
200 |
+
print(msg)
|
201 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
202 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
203 |
+
|
204 |
+
# Numerical error might give slight imaginary component
|
205 |
+
if np.iscomplexobj(covmean):
|
206 |
+
# if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
207 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2):
|
208 |
+
m = np.max(np.abs(covmean.imag))
|
209 |
+
raise ValueError('Imaginary component {}'.format(m))
|
210 |
+
covmean = covmean.real
|
211 |
+
|
212 |
+
tr_covmean = np.trace(covmean)
|
213 |
+
|
214 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
215 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
216 |
+
|
217 |
+
|
218 |
+
def calculate_activation_statistics(files, model, batch_size=50,
|
219 |
+
dims=2048, cuda=False, verbose=False, keep_size=False):
|
220 |
+
"""Calculation of the statistics used by the FID.
|
221 |
+
Params:
|
222 |
+
-- files : List of image files paths
|
223 |
+
-- model : Instance of inception model
|
224 |
+
-- batch_size : The images numpy array is split into batches with
|
225 |
+
batch size batch_size. A reasonable batch size
|
226 |
+
depends on the hardware.
|
227 |
+
-- dims : Dimensionality of features returned by Inception
|
228 |
+
-- cuda : If set to True, use GPU
|
229 |
+
-- verbose : If set to True and parameter out_step is given, the
|
230 |
+
number of calculated batches is reported.
|
231 |
+
Returns:
|
232 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
233 |
+
the inception model.
|
234 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
235 |
+
the inception model.
|
236 |
+
"""
|
237 |
+
act = get_activations(files, model, batch_size, dims, cuda, verbose, keep_size=keep_size)
|
238 |
+
mu = np.mean(act, axis=0)
|
239 |
+
sigma = np.cov(act, rowvar=False)
|
240 |
+
return mu, sigma
|
241 |
+
|
242 |
+
|
243 |
+
def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
|
244 |
+
if path.endswith('.npz'):
|
245 |
+
f = np.load(path)
|
246 |
+
m, s = f['mu'][:], f['sigma'][:]
|
247 |
+
f.close()
|
248 |
+
else:
|
249 |
+
path = pathlib.Path(path)
|
250 |
+
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
|
251 |
+
m, s = calculate_activation_statistics(files, model, batch_size,
|
252 |
+
dims, cuda)
|
253 |
+
|
254 |
+
return m, s
|
255 |
+
|
256 |
+
|
257 |
+
def _compute_statistics_of_images(images, model, batch_size, dims, cuda, keep_size=False):
|
258 |
+
if isinstance(images, list): # exact paths to files are provided
|
259 |
+
m, s = calculate_activation_statistics(images, model, batch_size,
|
260 |
+
dims, cuda, keep_size=keep_size)
|
261 |
+
|
262 |
+
return m, s
|
263 |
+
|
264 |
+
else:
|
265 |
+
raise ValueError
|
266 |
+
|
267 |
+
|
268 |
+
def calculate_fid_given_paths(paths, batch_size, cuda, dims):
|
269 |
+
"""Calculates the FID of two paths"""
|
270 |
+
for p in paths:
|
271 |
+
if not os.path.exists(p):
|
272 |
+
raise RuntimeError('Invalid path: %s' % p)
|
273 |
+
|
274 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
275 |
+
|
276 |
+
model = InceptionV3([block_idx])
|
277 |
+
if cuda:
|
278 |
+
model.cuda()
|
279 |
+
|
280 |
+
m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
|
281 |
+
dims, cuda)
|
282 |
+
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
|
283 |
+
dims, cuda)
|
284 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
285 |
+
|
286 |
+
return fid_value
|
287 |
+
|
288 |
+
|
289 |
+
def calculate_fid_given_images(images, batch_size, cuda, dims, use_globals=False, keep_size=False):
|
290 |
+
if use_globals:
|
291 |
+
global FID_MODEL # for multiprocessing
|
292 |
+
|
293 |
+
for imgs in images:
|
294 |
+
if isinstance(imgs, list) and isinstance(imgs[0], (Image.Image, JpegImagePlugin.JpegImageFile)):
|
295 |
+
pass
|
296 |
+
else:
|
297 |
+
raise RuntimeError('Invalid images')
|
298 |
+
|
299 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
300 |
+
|
301 |
+
if 'FID_MODEL' not in globals() or not use_globals:
|
302 |
+
model = InceptionV3([block_idx])
|
303 |
+
if cuda:
|
304 |
+
model.cuda()
|
305 |
+
|
306 |
+
if use_globals:
|
307 |
+
FID_MODEL = model
|
308 |
+
|
309 |
+
else:
|
310 |
+
model = FID_MODEL
|
311 |
+
|
312 |
+
m1, s1 = _compute_statistics_of_images(images[0], model, batch_size,
|
313 |
+
dims, cuda, keep_size=False)
|
314 |
+
m2, s2 = _compute_statistics_of_images(images[1], model, batch_size,
|
315 |
+
dims, cuda, keep_size=False)
|
316 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
317 |
+
return fid_value
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == '__main__':
|
321 |
+
args = parser.parse_args()
|
322 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
323 |
+
|
324 |
+
fid_value = calculate_fid_given_paths(args.path,
|
325 |
+
args.batch_size,
|
326 |
+
args.gpu != '',
|
327 |
+
args.dims)
|
328 |
+
print('FID: ', fid_value)
|
DH-AISP/2/saicinpainting/evaluation/losses/fid/inception.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torchvision import models
|
7 |
+
|
8 |
+
try:
|
9 |
+
from torchvision.models.utils import load_state_dict_from_url
|
10 |
+
except ImportError:
|
11 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
12 |
+
|
13 |
+
# Inception weights ported to Pytorch from
|
14 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
15 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
16 |
+
|
17 |
+
|
18 |
+
LOGGER = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
class InceptionV3(nn.Module):
|
22 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
23 |
+
|
24 |
+
# Index of default block of inception to return,
|
25 |
+
# corresponds to output of final average pooling
|
26 |
+
DEFAULT_BLOCK_INDEX = 3
|
27 |
+
|
28 |
+
# Maps feature dimensionality to their output blocks indices
|
29 |
+
BLOCK_INDEX_BY_DIM = {
|
30 |
+
64: 0, # First max pooling features
|
31 |
+
192: 1, # Second max pooling featurs
|
32 |
+
768: 2, # Pre-aux classifier features
|
33 |
+
2048: 3 # Final average pooling features
|
34 |
+
}
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
38 |
+
resize_input=True,
|
39 |
+
normalize_input=True,
|
40 |
+
requires_grad=False,
|
41 |
+
use_fid_inception=True):
|
42 |
+
"""Build pretrained InceptionV3
|
43 |
+
|
44 |
+
Parameters
|
45 |
+
----------
|
46 |
+
output_blocks : list of int
|
47 |
+
Indices of blocks to return features of. Possible values are:
|
48 |
+
- 0: corresponds to output of first max pooling
|
49 |
+
- 1: corresponds to output of second max pooling
|
50 |
+
- 2: corresponds to output which is fed to aux classifier
|
51 |
+
- 3: corresponds to output of final average pooling
|
52 |
+
resize_input : bool
|
53 |
+
If true, bilinearly resizes input to width and height 299 before
|
54 |
+
feeding input to model. As the network without fully connected
|
55 |
+
layers is fully convolutional, it should be able to handle inputs
|
56 |
+
of arbitrary size, so resizing might not be strictly needed
|
57 |
+
normalize_input : bool
|
58 |
+
If true, scales the input from range (0, 1) to the range the
|
59 |
+
pretrained Inception network expects, namely (-1, 1)
|
60 |
+
requires_grad : bool
|
61 |
+
If true, parameters of the model require gradients. Possibly useful
|
62 |
+
for finetuning the network
|
63 |
+
use_fid_inception : bool
|
64 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
65 |
+
FID implementation. If false, uses the pretrained Inception model
|
66 |
+
available in torchvision. The FID Inception model has different
|
67 |
+
weights and a slightly different structure from torchvision's
|
68 |
+
Inception model. If you want to compute FID scores, you are
|
69 |
+
strongly advised to set this parameter to true to get comparable
|
70 |
+
results.
|
71 |
+
"""
|
72 |
+
super(InceptionV3, self).__init__()
|
73 |
+
|
74 |
+
self.resize_input = resize_input
|
75 |
+
self.normalize_input = normalize_input
|
76 |
+
self.output_blocks = sorted(output_blocks)
|
77 |
+
self.last_needed_block = max(output_blocks)
|
78 |
+
|
79 |
+
assert self.last_needed_block <= 3, \
|
80 |
+
'Last possible output block index is 3'
|
81 |
+
|
82 |
+
self.blocks = nn.ModuleList()
|
83 |
+
|
84 |
+
if use_fid_inception:
|
85 |
+
inception = fid_inception_v3()
|
86 |
+
else:
|
87 |
+
inception = models.inception_v3(pretrained=True)
|
88 |
+
|
89 |
+
# Block 0: input to maxpool1
|
90 |
+
block0 = [
|
91 |
+
inception.Conv2d_1a_3x3,
|
92 |
+
inception.Conv2d_2a_3x3,
|
93 |
+
inception.Conv2d_2b_3x3,
|
94 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
95 |
+
]
|
96 |
+
self.blocks.append(nn.Sequential(*block0))
|
97 |
+
|
98 |
+
# Block 1: maxpool1 to maxpool2
|
99 |
+
if self.last_needed_block >= 1:
|
100 |
+
block1 = [
|
101 |
+
inception.Conv2d_3b_1x1,
|
102 |
+
inception.Conv2d_4a_3x3,
|
103 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
104 |
+
]
|
105 |
+
self.blocks.append(nn.Sequential(*block1))
|
106 |
+
|
107 |
+
# Block 2: maxpool2 to aux classifier
|
108 |
+
if self.last_needed_block >= 2:
|
109 |
+
block2 = [
|
110 |
+
inception.Mixed_5b,
|
111 |
+
inception.Mixed_5c,
|
112 |
+
inception.Mixed_5d,
|
113 |
+
inception.Mixed_6a,
|
114 |
+
inception.Mixed_6b,
|
115 |
+
inception.Mixed_6c,
|
116 |
+
inception.Mixed_6d,
|
117 |
+
inception.Mixed_6e,
|
118 |
+
]
|
119 |
+
self.blocks.append(nn.Sequential(*block2))
|
120 |
+
|
121 |
+
# Block 3: aux classifier to final avgpool
|
122 |
+
if self.last_needed_block >= 3:
|
123 |
+
block3 = [
|
124 |
+
inception.Mixed_7a,
|
125 |
+
inception.Mixed_7b,
|
126 |
+
inception.Mixed_7c,
|
127 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
128 |
+
]
|
129 |
+
self.blocks.append(nn.Sequential(*block3))
|
130 |
+
|
131 |
+
for param in self.parameters():
|
132 |
+
param.requires_grad = requires_grad
|
133 |
+
|
134 |
+
def forward(self, inp):
|
135 |
+
"""Get Inception feature maps
|
136 |
+
|
137 |
+
Parameters
|
138 |
+
----------
|
139 |
+
inp : torch.autograd.Variable
|
140 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
141 |
+
range (0, 1)
|
142 |
+
|
143 |
+
Returns
|
144 |
+
-------
|
145 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
146 |
+
block, sorted ascending by index
|
147 |
+
"""
|
148 |
+
outp = []
|
149 |
+
x = inp
|
150 |
+
|
151 |
+
if self.resize_input:
|
152 |
+
x = F.interpolate(x,
|
153 |
+
size=(299, 299),
|
154 |
+
mode='bilinear',
|
155 |
+
align_corners=False)
|
156 |
+
|
157 |
+
if self.normalize_input:
|
158 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
159 |
+
|
160 |
+
for idx, block in enumerate(self.blocks):
|
161 |
+
x = block(x)
|
162 |
+
if idx in self.output_blocks:
|
163 |
+
outp.append(x)
|
164 |
+
|
165 |
+
if idx == self.last_needed_block:
|
166 |
+
break
|
167 |
+
|
168 |
+
return outp
|
169 |
+
|
170 |
+
|
171 |
+
def fid_inception_v3():
|
172 |
+
"""Build pretrained Inception model for FID computation
|
173 |
+
|
174 |
+
The Inception model for FID computation uses a different set of weights
|
175 |
+
and has a slightly different structure than torchvision's Inception.
|
176 |
+
|
177 |
+
This method first constructs torchvision's Inception and then patches the
|
178 |
+
necessary parts that are different in the FID Inception model.
|
179 |
+
"""
|
180 |
+
LOGGER.info('fid_inception_v3 called')
|
181 |
+
inception = models.inception_v3(num_classes=1008,
|
182 |
+
aux_logits=False,
|
183 |
+
pretrained=False)
|
184 |
+
LOGGER.info('models.inception_v3 done')
|
185 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
186 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
187 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
188 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
189 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
190 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
191 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
192 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
193 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
194 |
+
|
195 |
+
LOGGER.info('fid_inception_v3 patching done')
|
196 |
+
|
197 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
198 |
+
LOGGER.info('fid_inception_v3 weights downloaded')
|
199 |
+
|
200 |
+
inception.load_state_dict(state_dict)
|
201 |
+
LOGGER.info('fid_inception_v3 weights loaded into model')
|
202 |
+
|
203 |
+
return inception
|
204 |
+
|
205 |
+
|
206 |
+
class FIDInceptionA(models.inception.InceptionA):
|
207 |
+
"""InceptionA block patched for FID computation"""
|
208 |
+
def __init__(self, in_channels, pool_features):
|
209 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
210 |
+
|
211 |
+
def forward(self, x):
|
212 |
+
branch1x1 = self.branch1x1(x)
|
213 |
+
|
214 |
+
branch5x5 = self.branch5x5_1(x)
|
215 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
216 |
+
|
217 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
218 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
219 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
220 |
+
|
221 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
222 |
+
# its average calculation
|
223 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
224 |
+
count_include_pad=False)
|
225 |
+
branch_pool = self.branch_pool(branch_pool)
|
226 |
+
|
227 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
228 |
+
return torch.cat(outputs, 1)
|
229 |
+
|
230 |
+
|
231 |
+
class FIDInceptionC(models.inception.InceptionC):
|
232 |
+
"""InceptionC block patched for FID computation"""
|
233 |
+
def __init__(self, in_channels, channels_7x7):
|
234 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
branch1x1 = self.branch1x1(x)
|
238 |
+
|
239 |
+
branch7x7 = self.branch7x7_1(x)
|
240 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
241 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
242 |
+
|
243 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
244 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
245 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
246 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
247 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
248 |
+
|
249 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
250 |
+
# its average calculation
|
251 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
252 |
+
count_include_pad=False)
|
253 |
+
branch_pool = self.branch_pool(branch_pool)
|
254 |
+
|
255 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
256 |
+
return torch.cat(outputs, 1)
|
257 |
+
|
258 |
+
|
259 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
260 |
+
"""First InceptionE block patched for FID computation"""
|
261 |
+
def __init__(self, in_channels):
|
262 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
branch1x1 = self.branch1x1(x)
|
266 |
+
|
267 |
+
branch3x3 = self.branch3x3_1(x)
|
268 |
+
branch3x3 = [
|
269 |
+
self.branch3x3_2a(branch3x3),
|
270 |
+
self.branch3x3_2b(branch3x3),
|
271 |
+
]
|
272 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
273 |
+
|
274 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
275 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
276 |
+
branch3x3dbl = [
|
277 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
278 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
279 |
+
]
|
280 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
281 |
+
|
282 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
283 |
+
# its average calculation
|
284 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
285 |
+
count_include_pad=False)
|
286 |
+
branch_pool = self.branch_pool(branch_pool)
|
287 |
+
|
288 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
289 |
+
return torch.cat(outputs, 1)
|
290 |
+
|
291 |
+
|
292 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
293 |
+
"""Second InceptionE block patched for FID computation"""
|
294 |
+
def __init__(self, in_channels):
|
295 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
296 |
+
|
297 |
+
def forward(self, x):
|
298 |
+
branch1x1 = self.branch1x1(x)
|
299 |
+
|
300 |
+
branch3x3 = self.branch3x3_1(x)
|
301 |
+
branch3x3 = [
|
302 |
+
self.branch3x3_2a(branch3x3),
|
303 |
+
self.branch3x3_2b(branch3x3),
|
304 |
+
]
|
305 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
306 |
+
|
307 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
308 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
309 |
+
branch3x3dbl = [
|
310 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
311 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
312 |
+
]
|
313 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
314 |
+
|
315 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
316 |
+
# pooling. This is likely an error in this specific Inception
|
317 |
+
# implementation, as other Inception models use average pooling here
|
318 |
+
# (which matches the description in the paper).
|
319 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
320 |
+
branch_pool = self.branch_pool(branch_pool)
|
321 |
+
|
322 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
323 |
+
return torch.cat(outputs, 1)
|
DH-AISP/2/saicinpainting/evaluation/losses/lpips.py
ADDED
@@ -0,0 +1,891 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
############################################################
|
2 |
+
# The contents below have been combined using files in the #
|
3 |
+
# following repository: #
|
4 |
+
# https://github.com/richzhang/PerceptualSimilarity #
|
5 |
+
############################################################
|
6 |
+
|
7 |
+
############################################################
|
8 |
+
# __init__.py #
|
9 |
+
############################################################
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from skimage.metrics import structural_similarity
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from saicinpainting.utils import get_shape
|
16 |
+
|
17 |
+
|
18 |
+
class PerceptualLoss(torch.nn.Module):
|
19 |
+
def __init__(self, model='net-lin', net='alex', colorspace='rgb', model_path=None, spatial=False, use_gpu=True):
|
20 |
+
# VGG using our perceptually-learned weights (LPIPS metric)
|
21 |
+
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
|
22 |
+
super(PerceptualLoss, self).__init__()
|
23 |
+
self.use_gpu = use_gpu
|
24 |
+
self.spatial = spatial
|
25 |
+
self.model = DistModel()
|
26 |
+
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace,
|
27 |
+
model_path=model_path, spatial=self.spatial)
|
28 |
+
|
29 |
+
def forward(self, pred, target, normalize=True):
|
30 |
+
"""
|
31 |
+
Pred and target are Variables.
|
32 |
+
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
|
33 |
+
If normalize is False, assumes the images are already between [-1,+1]
|
34 |
+
Inputs pred and target are Nx3xHxW
|
35 |
+
Output pytorch Variable N long
|
36 |
+
"""
|
37 |
+
|
38 |
+
if normalize:
|
39 |
+
target = 2 * target - 1
|
40 |
+
pred = 2 * pred - 1
|
41 |
+
|
42 |
+
return self.model(target, pred)
|
43 |
+
|
44 |
+
|
45 |
+
def normalize_tensor(in_feat, eps=1e-10):
|
46 |
+
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
|
47 |
+
return in_feat / (norm_factor + eps)
|
48 |
+
|
49 |
+
|
50 |
+
def l2(p0, p1, range=255.):
|
51 |
+
return .5 * np.mean((p0 / range - p1 / range) ** 2)
|
52 |
+
|
53 |
+
|
54 |
+
def psnr(p0, p1, peak=255.):
|
55 |
+
return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2))
|
56 |
+
|
57 |
+
|
58 |
+
def dssim(p0, p1, range=255.):
|
59 |
+
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
|
60 |
+
|
61 |
+
|
62 |
+
def rgb2lab(in_img, mean_cent=False):
|
63 |
+
from skimage import color
|
64 |
+
img_lab = color.rgb2lab(in_img)
|
65 |
+
if (mean_cent):
|
66 |
+
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
|
67 |
+
return img_lab
|
68 |
+
|
69 |
+
|
70 |
+
def tensor2np(tensor_obj):
|
71 |
+
# change dimension of a tensor object into a numpy array
|
72 |
+
return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
|
73 |
+
|
74 |
+
|
75 |
+
def np2tensor(np_obj):
|
76 |
+
# change dimenion of np array into tensor array
|
77 |
+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
78 |
+
|
79 |
+
|
80 |
+
def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
|
81 |
+
# image tensor to lab tensor
|
82 |
+
from skimage import color
|
83 |
+
|
84 |
+
img = tensor2im(image_tensor)
|
85 |
+
img_lab = color.rgb2lab(img)
|
86 |
+
if (mc_only):
|
87 |
+
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
|
88 |
+
if (to_norm and not mc_only):
|
89 |
+
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
|
90 |
+
img_lab = img_lab / 100.
|
91 |
+
|
92 |
+
return np2tensor(img_lab)
|
93 |
+
|
94 |
+
|
95 |
+
def tensorlab2tensor(lab_tensor, return_inbnd=False):
|
96 |
+
from skimage import color
|
97 |
+
import warnings
|
98 |
+
warnings.filterwarnings("ignore")
|
99 |
+
|
100 |
+
lab = tensor2np(lab_tensor) * 100.
|
101 |
+
lab[:, :, 0] = lab[:, :, 0] + 50
|
102 |
+
|
103 |
+
rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1)
|
104 |
+
if (return_inbnd):
|
105 |
+
# convert back to lab, see if we match
|
106 |
+
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
|
107 |
+
mask = 1. * np.isclose(lab_back, lab, atol=2.)
|
108 |
+
mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
|
109 |
+
return (im2tensor(rgb_back), mask)
|
110 |
+
else:
|
111 |
+
return im2tensor(rgb_back)
|
112 |
+
|
113 |
+
|
114 |
+
def rgb2lab(input):
|
115 |
+
from skimage import color
|
116 |
+
return color.rgb2lab(input / 255.)
|
117 |
+
|
118 |
+
|
119 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
|
120 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
121 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
122 |
+
return image_numpy.astype(imtype)
|
123 |
+
|
124 |
+
|
125 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
|
126 |
+
return torch.Tensor((image / factor - cent)
|
127 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
128 |
+
|
129 |
+
|
130 |
+
def tensor2vec(vector_tensor):
|
131 |
+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
|
132 |
+
|
133 |
+
|
134 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
135 |
+
""" ap = voc_ap(rec, prec, [use_07_metric])
|
136 |
+
Compute VOC AP given precision and recall.
|
137 |
+
If use_07_metric is true, uses the
|
138 |
+
VOC 07 11 point method (default:False).
|
139 |
+
"""
|
140 |
+
if use_07_metric:
|
141 |
+
# 11 point metric
|
142 |
+
ap = 0.
|
143 |
+
for t in np.arange(0., 1.1, 0.1):
|
144 |
+
if np.sum(rec >= t) == 0:
|
145 |
+
p = 0
|
146 |
+
else:
|
147 |
+
p = np.max(prec[rec >= t])
|
148 |
+
ap = ap + p / 11.
|
149 |
+
else:
|
150 |
+
# correct AP calculation
|
151 |
+
# first append sentinel values at the end
|
152 |
+
mrec = np.concatenate(([0.], rec, [1.]))
|
153 |
+
mpre = np.concatenate(([0.], prec, [0.]))
|
154 |
+
|
155 |
+
# compute the precision envelope
|
156 |
+
for i in range(mpre.size - 1, 0, -1):
|
157 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
158 |
+
|
159 |
+
# to calculate area under PR curve, look for points
|
160 |
+
# where X axis (recall) changes value
|
161 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
162 |
+
|
163 |
+
# and sum (\Delta recall) * prec
|
164 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
165 |
+
return ap
|
166 |
+
|
167 |
+
|
168 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
|
169 |
+
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
170 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
171 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
172 |
+
return image_numpy.astype(imtype)
|
173 |
+
|
174 |
+
|
175 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
|
176 |
+
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
177 |
+
return torch.Tensor((image / factor - cent)
|
178 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
179 |
+
|
180 |
+
|
181 |
+
############################################################
|
182 |
+
# base_model.py #
|
183 |
+
############################################################
|
184 |
+
|
185 |
+
|
186 |
+
class BaseModel(torch.nn.Module):
|
187 |
+
def __init__(self):
|
188 |
+
super().__init__()
|
189 |
+
|
190 |
+
def name(self):
|
191 |
+
return 'BaseModel'
|
192 |
+
|
193 |
+
def initialize(self, use_gpu=True):
|
194 |
+
self.use_gpu = use_gpu
|
195 |
+
|
196 |
+
def forward(self):
|
197 |
+
pass
|
198 |
+
|
199 |
+
def get_image_paths(self):
|
200 |
+
pass
|
201 |
+
|
202 |
+
def optimize_parameters(self):
|
203 |
+
pass
|
204 |
+
|
205 |
+
def get_current_visuals(self):
|
206 |
+
return self.input
|
207 |
+
|
208 |
+
def get_current_errors(self):
|
209 |
+
return {}
|
210 |
+
|
211 |
+
def save(self, label):
|
212 |
+
pass
|
213 |
+
|
214 |
+
# helper saving function that can be used by subclasses
|
215 |
+
def save_network(self, network, path, network_label, epoch_label):
|
216 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
217 |
+
save_path = os.path.join(path, save_filename)
|
218 |
+
torch.save(network.state_dict(), save_path)
|
219 |
+
|
220 |
+
# helper loading function that can be used by subclasses
|
221 |
+
def load_network(self, network, network_label, epoch_label):
|
222 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
223 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
224 |
+
print('Loading network from %s' % save_path)
|
225 |
+
network.load_state_dict(torch.load(save_path, map_location='cpu'))
|
226 |
+
|
227 |
+
def update_learning_rate():
|
228 |
+
pass
|
229 |
+
|
230 |
+
def get_image_paths(self):
|
231 |
+
return self.image_paths
|
232 |
+
|
233 |
+
def save_done(self, flag=False):
|
234 |
+
np.save(os.path.join(self.save_dir, 'done_flag'), flag)
|
235 |
+
np.savetxt(os.path.join(self.save_dir, 'done_flag'), [flag, ], fmt='%i')
|
236 |
+
|
237 |
+
|
238 |
+
############################################################
|
239 |
+
# dist_model.py #
|
240 |
+
############################################################
|
241 |
+
|
242 |
+
import os
|
243 |
+
from collections import OrderedDict
|
244 |
+
from scipy.ndimage import zoom
|
245 |
+
from tqdm import tqdm
|
246 |
+
|
247 |
+
|
248 |
+
class DistModel(BaseModel):
|
249 |
+
def name(self):
|
250 |
+
return self.model_name
|
251 |
+
|
252 |
+
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False,
|
253 |
+
model_path=None,
|
254 |
+
use_gpu=True, printNet=False, spatial=False,
|
255 |
+
is_train=False, lr=.0001, beta1=0.5, version='0.1'):
|
256 |
+
'''
|
257 |
+
INPUTS
|
258 |
+
model - ['net-lin'] for linearly calibrated network
|
259 |
+
['net'] for off-the-shelf network
|
260 |
+
['L2'] for L2 distance in Lab colorspace
|
261 |
+
['SSIM'] for ssim in RGB colorspace
|
262 |
+
net - ['squeeze','alex','vgg']
|
263 |
+
model_path - if None, will look in weights/[NET_NAME].pth
|
264 |
+
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
265 |
+
use_gpu - bool - whether or not to use a GPU
|
266 |
+
printNet - bool - whether or not to print network architecture out
|
267 |
+
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
268 |
+
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
|
269 |
+
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
|
270 |
+
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
|
271 |
+
is_train - bool - [True] for training mode
|
272 |
+
lr - float - initial learning rate
|
273 |
+
beta1 - float - initial momentum term for adam
|
274 |
+
version - 0.1 for latest, 0.0 was original (with a bug)
|
275 |
+
'''
|
276 |
+
BaseModel.initialize(self, use_gpu=use_gpu)
|
277 |
+
|
278 |
+
self.model = model
|
279 |
+
self.net = net
|
280 |
+
self.is_train = is_train
|
281 |
+
self.spatial = spatial
|
282 |
+
self.model_name = '%s [%s]' % (model, net)
|
283 |
+
|
284 |
+
if (self.model == 'net-lin'): # pretrained net + linear layer
|
285 |
+
self.net = PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
|
286 |
+
use_dropout=True, spatial=spatial, version=version, lpips=True)
|
287 |
+
kw = dict(map_location='cpu')
|
288 |
+
if (model_path is None):
|
289 |
+
import inspect
|
290 |
+
model_path = os.path.abspath(
|
291 |
+
os.path.join(os.path.dirname(__file__), '..', '..', '..', 'models', 'lpips_models', f'{net}.pth'))
|
292 |
+
|
293 |
+
if (not is_train):
|
294 |
+
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
|
295 |
+
|
296 |
+
elif (self.model == 'net'): # pretrained network
|
297 |
+
self.net = PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
|
298 |
+
elif (self.model in ['L2', 'l2']):
|
299 |
+
self.net = L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing
|
300 |
+
self.model_name = 'L2'
|
301 |
+
elif (self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
|
302 |
+
self.net = DSSIM(use_gpu=use_gpu, colorspace=colorspace)
|
303 |
+
self.model_name = 'SSIM'
|
304 |
+
else:
|
305 |
+
raise ValueError("Model [%s] not recognized." % self.model)
|
306 |
+
|
307 |
+
self.trainable_parameters = list(self.net.parameters())
|
308 |
+
|
309 |
+
if self.is_train: # training mode
|
310 |
+
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
311 |
+
self.rankLoss = BCERankingLoss()
|
312 |
+
self.trainable_parameters += list(self.rankLoss.net.parameters())
|
313 |
+
self.lr = lr
|
314 |
+
self.old_lr = lr
|
315 |
+
self.optimizer_net = torch.optim.Adam(self.trainable_parameters, lr=lr, betas=(beta1, 0.999))
|
316 |
+
else: # test mode
|
317 |
+
self.net.eval()
|
318 |
+
|
319 |
+
# if (use_gpu):
|
320 |
+
# self.net.to(gpu_ids[0])
|
321 |
+
# self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
322 |
+
# if (self.is_train):
|
323 |
+
# self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
|
324 |
+
|
325 |
+
if (printNet):
|
326 |
+
print('---------- Networks initialized -------------')
|
327 |
+
print_network(self.net)
|
328 |
+
print('-----------------------------------------------')
|
329 |
+
|
330 |
+
def forward(self, in0, in1, retPerLayer=False):
|
331 |
+
''' Function computes the distance between image patches in0 and in1
|
332 |
+
INPUTS
|
333 |
+
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
334 |
+
OUTPUT
|
335 |
+
computed distances between in0 and in1
|
336 |
+
'''
|
337 |
+
|
338 |
+
return self.net(in0, in1, retPerLayer=retPerLayer)
|
339 |
+
|
340 |
+
# ***** TRAINING FUNCTIONS *****
|
341 |
+
def optimize_parameters(self):
|
342 |
+
self.forward_train()
|
343 |
+
self.optimizer_net.zero_grad()
|
344 |
+
self.backward_train()
|
345 |
+
self.optimizer_net.step()
|
346 |
+
self.clamp_weights()
|
347 |
+
|
348 |
+
def clamp_weights(self):
|
349 |
+
for module in self.net.modules():
|
350 |
+
if (hasattr(module, 'weight') and module.kernel_size == (1, 1)):
|
351 |
+
module.weight.data = torch.clamp(module.weight.data, min=0)
|
352 |
+
|
353 |
+
def set_input(self, data):
|
354 |
+
self.input_ref = data['ref']
|
355 |
+
self.input_p0 = data['p0']
|
356 |
+
self.input_p1 = data['p1']
|
357 |
+
self.input_judge = data['judge']
|
358 |
+
|
359 |
+
# if (self.use_gpu):
|
360 |
+
# self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
361 |
+
# self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
362 |
+
# self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
363 |
+
# self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
364 |
+
|
365 |
+
# self.var_ref = Variable(self.input_ref, requires_grad=True)
|
366 |
+
# self.var_p0 = Variable(self.input_p0, requires_grad=True)
|
367 |
+
# self.var_p1 = Variable(self.input_p1, requires_grad=True)
|
368 |
+
|
369 |
+
def forward_train(self): # run forward pass
|
370 |
+
# print(self.net.module.scaling_layer.shift)
|
371 |
+
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
|
372 |
+
|
373 |
+
assert False, "We shoud've not get here when using LPIPS as a metric"
|
374 |
+
|
375 |
+
self.d0 = self(self.var_ref, self.var_p0)
|
376 |
+
self.d1 = self(self.var_ref, self.var_p1)
|
377 |
+
self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
|
378 |
+
|
379 |
+
self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
|
380 |
+
|
381 |
+
self.loss_total = self.rankLoss(self.d0, self.d1, self.var_judge * 2. - 1.)
|
382 |
+
|
383 |
+
return self.loss_total
|
384 |
+
|
385 |
+
def backward_train(self):
|
386 |
+
torch.mean(self.loss_total).backward()
|
387 |
+
|
388 |
+
def compute_accuracy(self, d0, d1, judge):
|
389 |
+
''' d0, d1 are Variables, judge is a Tensor '''
|
390 |
+
d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
|
391 |
+
judge_per = judge.cpu().numpy().flatten()
|
392 |
+
return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
|
393 |
+
|
394 |
+
def get_current_errors(self):
|
395 |
+
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
|
396 |
+
('acc_r', self.acc_r)])
|
397 |
+
|
398 |
+
for key in retDict.keys():
|
399 |
+
retDict[key] = np.mean(retDict[key])
|
400 |
+
|
401 |
+
return retDict
|
402 |
+
|
403 |
+
def get_current_visuals(self):
|
404 |
+
zoom_factor = 256 / self.var_ref.data.size()[2]
|
405 |
+
|
406 |
+
ref_img = tensor2im(self.var_ref.data)
|
407 |
+
p0_img = tensor2im(self.var_p0.data)
|
408 |
+
p1_img = tensor2im(self.var_p1.data)
|
409 |
+
|
410 |
+
ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
|
411 |
+
p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
|
412 |
+
p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
|
413 |
+
|
414 |
+
return OrderedDict([('ref', ref_img_vis),
|
415 |
+
('p0', p0_img_vis),
|
416 |
+
('p1', p1_img_vis)])
|
417 |
+
|
418 |
+
def save(self, path, label):
|
419 |
+
if (self.use_gpu):
|
420 |
+
self.save_network(self.net.module, path, '', label)
|
421 |
+
else:
|
422 |
+
self.save_network(self.net, path, '', label)
|
423 |
+
self.save_network(self.rankLoss.net, path, 'rank', label)
|
424 |
+
|
425 |
+
def update_learning_rate(self, nepoch_decay):
|
426 |
+
lrd = self.lr / nepoch_decay
|
427 |
+
lr = self.old_lr - lrd
|
428 |
+
|
429 |
+
for param_group in self.optimizer_net.param_groups:
|
430 |
+
param_group['lr'] = lr
|
431 |
+
|
432 |
+
print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
|
433 |
+
self.old_lr = lr
|
434 |
+
|
435 |
+
|
436 |
+
def score_2afc_dataset(data_loader, func, name=''):
|
437 |
+
''' Function computes Two Alternative Forced Choice (2AFC) score using
|
438 |
+
distance function 'func' in dataset 'data_loader'
|
439 |
+
INPUTS
|
440 |
+
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
441 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
442 |
+
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
443 |
+
OUTPUTS
|
444 |
+
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
445 |
+
[1] - dictionary with following elements
|
446 |
+
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
447 |
+
gts - N array in [0,1], preferred patch selected by human evaluators
|
448 |
+
(closer to "0" for left patch p0, "1" for right patch p1,
|
449 |
+
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
450 |
+
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
451 |
+
CONSTS
|
452 |
+
N - number of test triplets in data_loader
|
453 |
+
'''
|
454 |
+
|
455 |
+
d0s = []
|
456 |
+
d1s = []
|
457 |
+
gts = []
|
458 |
+
|
459 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
460 |
+
d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
|
461 |
+
d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
|
462 |
+
gts += data['judge'].cpu().numpy().flatten().tolist()
|
463 |
+
|
464 |
+
d0s = np.array(d0s)
|
465 |
+
d1s = np.array(d1s)
|
466 |
+
gts = np.array(gts)
|
467 |
+
scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
|
468 |
+
|
469 |
+
return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
|
470 |
+
|
471 |
+
|
472 |
+
def score_jnd_dataset(data_loader, func, name=''):
|
473 |
+
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
|
474 |
+
INPUTS
|
475 |
+
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
476 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
477 |
+
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
478 |
+
OUTPUTS
|
479 |
+
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
480 |
+
[1] - dictionary with following elements
|
481 |
+
ds - N array containing distances between two patches shown to human evaluator
|
482 |
+
sames - N array containing fraction of people who thought the two patches were identical
|
483 |
+
CONSTS
|
484 |
+
N - number of test triplets in data_loader
|
485 |
+
'''
|
486 |
+
|
487 |
+
ds = []
|
488 |
+
gts = []
|
489 |
+
|
490 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
491 |
+
ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
|
492 |
+
gts += data['same'].cpu().numpy().flatten().tolist()
|
493 |
+
|
494 |
+
sames = np.array(gts)
|
495 |
+
ds = np.array(ds)
|
496 |
+
|
497 |
+
sorted_inds = np.argsort(ds)
|
498 |
+
ds_sorted = ds[sorted_inds]
|
499 |
+
sames_sorted = sames[sorted_inds]
|
500 |
+
|
501 |
+
TPs = np.cumsum(sames_sorted)
|
502 |
+
FPs = np.cumsum(1 - sames_sorted)
|
503 |
+
FNs = np.sum(sames_sorted) - TPs
|
504 |
+
|
505 |
+
precs = TPs / (TPs + FPs)
|
506 |
+
recs = TPs / (TPs + FNs)
|
507 |
+
score = voc_ap(recs, precs)
|
508 |
+
|
509 |
+
return (score, dict(ds=ds, sames=sames))
|
510 |
+
|
511 |
+
|
512 |
+
############################################################
|
513 |
+
# networks_basic.py #
|
514 |
+
############################################################
|
515 |
+
|
516 |
+
import torch.nn as nn
|
517 |
+
from torch.autograd import Variable
|
518 |
+
import numpy as np
|
519 |
+
|
520 |
+
|
521 |
+
def spatial_average(in_tens, keepdim=True):
|
522 |
+
return in_tens.mean([2, 3], keepdim=keepdim)
|
523 |
+
|
524 |
+
|
525 |
+
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
|
526 |
+
in_H = in_tens.shape[2]
|
527 |
+
scale_factor = 1. * out_H / in_H
|
528 |
+
|
529 |
+
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
|
530 |
+
|
531 |
+
|
532 |
+
# Learned perceptual metric
|
533 |
+
class PNetLin(nn.Module):
|
534 |
+
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False,
|
535 |
+
version='0.1', lpips=True):
|
536 |
+
super(PNetLin, self).__init__()
|
537 |
+
|
538 |
+
self.pnet_type = pnet_type
|
539 |
+
self.pnet_tune = pnet_tune
|
540 |
+
self.pnet_rand = pnet_rand
|
541 |
+
self.spatial = spatial
|
542 |
+
self.lpips = lpips
|
543 |
+
self.version = version
|
544 |
+
self.scaling_layer = ScalingLayer()
|
545 |
+
|
546 |
+
if (self.pnet_type in ['vgg', 'vgg16']):
|
547 |
+
net_type = vgg16
|
548 |
+
self.chns = [64, 128, 256, 512, 512]
|
549 |
+
elif (self.pnet_type == 'alex'):
|
550 |
+
net_type = alexnet
|
551 |
+
self.chns = [64, 192, 384, 256, 256]
|
552 |
+
elif (self.pnet_type == 'squeeze'):
|
553 |
+
net_type = squeezenet
|
554 |
+
self.chns = [64, 128, 256, 384, 384, 512, 512]
|
555 |
+
self.L = len(self.chns)
|
556 |
+
|
557 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
558 |
+
|
559 |
+
if (lpips):
|
560 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
561 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
562 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
563 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
564 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
565 |
+
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
566 |
+
if (self.pnet_type == 'squeeze'): # 7 layers for squeezenet
|
567 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
568 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
569 |
+
self.lins += [self.lin5, self.lin6]
|
570 |
+
|
571 |
+
def forward(self, in0, in1, retPerLayer=False):
|
572 |
+
# v0.0 - original release had a bug, where input was not scaled
|
573 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else (
|
574 |
+
in0, in1)
|
575 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
576 |
+
feats0, feats1, diffs = {}, {}, {}
|
577 |
+
|
578 |
+
for kk in range(self.L):
|
579 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
580 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
581 |
+
|
582 |
+
if (self.lpips):
|
583 |
+
if (self.spatial):
|
584 |
+
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
|
585 |
+
else:
|
586 |
+
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
|
587 |
+
else:
|
588 |
+
if (self.spatial):
|
589 |
+
res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
|
590 |
+
else:
|
591 |
+
res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
|
592 |
+
|
593 |
+
val = res[0]
|
594 |
+
for l in range(1, self.L):
|
595 |
+
val += res[l]
|
596 |
+
|
597 |
+
if (retPerLayer):
|
598 |
+
return (val, res)
|
599 |
+
else:
|
600 |
+
return val
|
601 |
+
|
602 |
+
|
603 |
+
class ScalingLayer(nn.Module):
|
604 |
+
def __init__(self):
|
605 |
+
super(ScalingLayer, self).__init__()
|
606 |
+
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
607 |
+
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
608 |
+
|
609 |
+
def forward(self, inp):
|
610 |
+
return (inp - self.shift) / self.scale
|
611 |
+
|
612 |
+
|
613 |
+
class NetLinLayer(nn.Module):
|
614 |
+
''' A single linear layer which does a 1x1 conv '''
|
615 |
+
|
616 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
617 |
+
super(NetLinLayer, self).__init__()
|
618 |
+
|
619 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
620 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
621 |
+
self.model = nn.Sequential(*layers)
|
622 |
+
|
623 |
+
|
624 |
+
class Dist2LogitLayer(nn.Module):
|
625 |
+
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
626 |
+
|
627 |
+
def __init__(self, chn_mid=32, use_sigmoid=True):
|
628 |
+
super(Dist2LogitLayer, self).__init__()
|
629 |
+
|
630 |
+
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ]
|
631 |
+
layers += [nn.LeakyReLU(0.2, True), ]
|
632 |
+
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ]
|
633 |
+
layers += [nn.LeakyReLU(0.2, True), ]
|
634 |
+
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ]
|
635 |
+
if (use_sigmoid):
|
636 |
+
layers += [nn.Sigmoid(), ]
|
637 |
+
self.model = nn.Sequential(*layers)
|
638 |
+
|
639 |
+
def forward(self, d0, d1, eps=0.1):
|
640 |
+
return self.model(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1))
|
641 |
+
|
642 |
+
|
643 |
+
class BCERankingLoss(nn.Module):
|
644 |
+
def __init__(self, chn_mid=32):
|
645 |
+
super(BCERankingLoss, self).__init__()
|
646 |
+
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
647 |
+
# self.parameters = list(self.net.parameters())
|
648 |
+
self.loss = torch.nn.BCELoss()
|
649 |
+
|
650 |
+
def forward(self, d0, d1, judge):
|
651 |
+
per = (judge + 1.) / 2.
|
652 |
+
self.logit = self.net(d0, d1)
|
653 |
+
return self.loss(self.logit, per)
|
654 |
+
|
655 |
+
|
656 |
+
# L2, DSSIM metrics
|
657 |
+
class FakeNet(nn.Module):
|
658 |
+
def __init__(self, use_gpu=True, colorspace='Lab'):
|
659 |
+
super(FakeNet, self).__init__()
|
660 |
+
self.use_gpu = use_gpu
|
661 |
+
self.colorspace = colorspace
|
662 |
+
|
663 |
+
|
664 |
+
class L2(FakeNet):
|
665 |
+
|
666 |
+
def forward(self, in0, in1, retPerLayer=None):
|
667 |
+
assert (in0.size()[0] == 1) # currently only supports batchSize 1
|
668 |
+
|
669 |
+
if (self.colorspace == 'RGB'):
|
670 |
+
(N, C, X, Y) = in0.size()
|
671 |
+
value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y),
|
672 |
+
dim=3).view(N)
|
673 |
+
return value
|
674 |
+
elif (self.colorspace == 'Lab'):
|
675 |
+
value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
|
676 |
+
tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float')
|
677 |
+
ret_var = Variable(torch.Tensor((value,)))
|
678 |
+
# if (self.use_gpu):
|
679 |
+
# ret_var = ret_var.cuda()
|
680 |
+
return ret_var
|
681 |
+
|
682 |
+
|
683 |
+
class DSSIM(FakeNet):
|
684 |
+
|
685 |
+
def forward(self, in0, in1, retPerLayer=None):
|
686 |
+
assert (in0.size()[0] == 1) # currently only supports batchSize 1
|
687 |
+
|
688 |
+
if (self.colorspace == 'RGB'):
|
689 |
+
value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype('float')
|
690 |
+
elif (self.colorspace == 'Lab'):
|
691 |
+
value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
|
692 |
+
tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float')
|
693 |
+
ret_var = Variable(torch.Tensor((value,)))
|
694 |
+
# if (self.use_gpu):
|
695 |
+
# ret_var = ret_var.cuda()
|
696 |
+
return ret_var
|
697 |
+
|
698 |
+
|
699 |
+
def print_network(net):
|
700 |
+
num_params = 0
|
701 |
+
for param in net.parameters():
|
702 |
+
num_params += param.numel()
|
703 |
+
print('Network', net)
|
704 |
+
print('Total number of parameters: %d' % num_params)
|
705 |
+
|
706 |
+
|
707 |
+
############################################################
|
708 |
+
# pretrained_networks.py #
|
709 |
+
############################################################
|
710 |
+
|
711 |
+
from collections import namedtuple
|
712 |
+
import torch
|
713 |
+
from torchvision import models as tv
|
714 |
+
|
715 |
+
|
716 |
+
class squeezenet(torch.nn.Module):
|
717 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
718 |
+
super(squeezenet, self).__init__()
|
719 |
+
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
720 |
+
self.slice1 = torch.nn.Sequential()
|
721 |
+
self.slice2 = torch.nn.Sequential()
|
722 |
+
self.slice3 = torch.nn.Sequential()
|
723 |
+
self.slice4 = torch.nn.Sequential()
|
724 |
+
self.slice5 = torch.nn.Sequential()
|
725 |
+
self.slice6 = torch.nn.Sequential()
|
726 |
+
self.slice7 = torch.nn.Sequential()
|
727 |
+
self.N_slices = 7
|
728 |
+
for x in range(2):
|
729 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
730 |
+
for x in range(2, 5):
|
731 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
732 |
+
for x in range(5, 8):
|
733 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
734 |
+
for x in range(8, 10):
|
735 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
736 |
+
for x in range(10, 11):
|
737 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
738 |
+
for x in range(11, 12):
|
739 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
740 |
+
for x in range(12, 13):
|
741 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
742 |
+
if not requires_grad:
|
743 |
+
for param in self.parameters():
|
744 |
+
param.requires_grad = False
|
745 |
+
|
746 |
+
def forward(self, X):
|
747 |
+
h = self.slice1(X)
|
748 |
+
h_relu1 = h
|
749 |
+
h = self.slice2(h)
|
750 |
+
h_relu2 = h
|
751 |
+
h = self.slice3(h)
|
752 |
+
h_relu3 = h
|
753 |
+
h = self.slice4(h)
|
754 |
+
h_relu4 = h
|
755 |
+
h = self.slice5(h)
|
756 |
+
h_relu5 = h
|
757 |
+
h = self.slice6(h)
|
758 |
+
h_relu6 = h
|
759 |
+
h = self.slice7(h)
|
760 |
+
h_relu7 = h
|
761 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])
|
762 |
+
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
|
763 |
+
|
764 |
+
return out
|
765 |
+
|
766 |
+
|
767 |
+
class alexnet(torch.nn.Module):
|
768 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
769 |
+
super(alexnet, self).__init__()
|
770 |
+
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
771 |
+
self.slice1 = torch.nn.Sequential()
|
772 |
+
self.slice2 = torch.nn.Sequential()
|
773 |
+
self.slice3 = torch.nn.Sequential()
|
774 |
+
self.slice4 = torch.nn.Sequential()
|
775 |
+
self.slice5 = torch.nn.Sequential()
|
776 |
+
self.N_slices = 5
|
777 |
+
for x in range(2):
|
778 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
779 |
+
for x in range(2, 5):
|
780 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
781 |
+
for x in range(5, 8):
|
782 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
783 |
+
for x in range(8, 10):
|
784 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
785 |
+
for x in range(10, 12):
|
786 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
787 |
+
if not requires_grad:
|
788 |
+
for param in self.parameters():
|
789 |
+
param.requires_grad = False
|
790 |
+
|
791 |
+
def forward(self, X):
|
792 |
+
h = self.slice1(X)
|
793 |
+
h_relu1 = h
|
794 |
+
h = self.slice2(h)
|
795 |
+
h_relu2 = h
|
796 |
+
h = self.slice3(h)
|
797 |
+
h_relu3 = h
|
798 |
+
h = self.slice4(h)
|
799 |
+
h_relu4 = h
|
800 |
+
h = self.slice5(h)
|
801 |
+
h_relu5 = h
|
802 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
803 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
804 |
+
|
805 |
+
return out
|
806 |
+
|
807 |
+
|
808 |
+
class vgg16(torch.nn.Module):
|
809 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
810 |
+
super(vgg16, self).__init__()
|
811 |
+
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
812 |
+
self.slice1 = torch.nn.Sequential()
|
813 |
+
self.slice2 = torch.nn.Sequential()
|
814 |
+
self.slice3 = torch.nn.Sequential()
|
815 |
+
self.slice4 = torch.nn.Sequential()
|
816 |
+
self.slice5 = torch.nn.Sequential()
|
817 |
+
self.N_slices = 5
|
818 |
+
for x in range(4):
|
819 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
820 |
+
for x in range(4, 9):
|
821 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
822 |
+
for x in range(9, 16):
|
823 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
824 |
+
for x in range(16, 23):
|
825 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
826 |
+
for x in range(23, 30):
|
827 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
828 |
+
if not requires_grad:
|
829 |
+
for param in self.parameters():
|
830 |
+
param.requires_grad = False
|
831 |
+
|
832 |
+
def forward(self, X):
|
833 |
+
h = self.slice1(X)
|
834 |
+
h_relu1_2 = h
|
835 |
+
h = self.slice2(h)
|
836 |
+
h_relu2_2 = h
|
837 |
+
h = self.slice3(h)
|
838 |
+
h_relu3_3 = h
|
839 |
+
h = self.slice4(h)
|
840 |
+
h_relu4_3 = h
|
841 |
+
h = self.slice5(h)
|
842 |
+
h_relu5_3 = h
|
843 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
844 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
845 |
+
|
846 |
+
return out
|
847 |
+
|
848 |
+
|
849 |
+
class resnet(torch.nn.Module):
|
850 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
851 |
+
super(resnet, self).__init__()
|
852 |
+
if (num == 18):
|
853 |
+
self.net = tv.resnet18(pretrained=pretrained)
|
854 |
+
elif (num == 34):
|
855 |
+
self.net = tv.resnet34(pretrained=pretrained)
|
856 |
+
elif (num == 50):
|
857 |
+
self.net = tv.resnet50(pretrained=pretrained)
|
858 |
+
elif (num == 101):
|
859 |
+
self.net = tv.resnet101(pretrained=pretrained)
|
860 |
+
elif (num == 152):
|
861 |
+
self.net = tv.resnet152(pretrained=pretrained)
|
862 |
+
self.N_slices = 5
|
863 |
+
|
864 |
+
self.conv1 = self.net.conv1
|
865 |
+
self.bn1 = self.net.bn1
|
866 |
+
self.relu = self.net.relu
|
867 |
+
self.maxpool = self.net.maxpool
|
868 |
+
self.layer1 = self.net.layer1
|
869 |
+
self.layer2 = self.net.layer2
|
870 |
+
self.layer3 = self.net.layer3
|
871 |
+
self.layer4 = self.net.layer4
|
872 |
+
|
873 |
+
def forward(self, X):
|
874 |
+
h = self.conv1(X)
|
875 |
+
h = self.bn1(h)
|
876 |
+
h = self.relu(h)
|
877 |
+
h_relu1 = h
|
878 |
+
h = self.maxpool(h)
|
879 |
+
h = self.layer1(h)
|
880 |
+
h_conv2 = h
|
881 |
+
h = self.layer2(h)
|
882 |
+
h_conv3 = h
|
883 |
+
h = self.layer3(h)
|
884 |
+
h_conv4 = h
|
885 |
+
h = self.layer4(h)
|
886 |
+
h_conv5 = h
|
887 |
+
|
888 |
+
outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
|
889 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
890 |
+
|
891 |
+
return out
|
DH-AISP/2/saicinpainting/evaluation/losses/ssim.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class SSIM(torch.nn.Module):
|
7 |
+
"""SSIM. Modified from:
|
8 |
+
https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, window_size=11, size_average=True):
|
12 |
+
super().__init__()
|
13 |
+
self.window_size = window_size
|
14 |
+
self.size_average = size_average
|
15 |
+
self.channel = 1
|
16 |
+
self.register_buffer('window', self._create_window(window_size, self.channel))
|
17 |
+
|
18 |
+
def forward(self, img1, img2):
|
19 |
+
assert len(img1.shape) == 4
|
20 |
+
|
21 |
+
channel = img1.size()[1]
|
22 |
+
|
23 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
24 |
+
window = self.window
|
25 |
+
else:
|
26 |
+
window = self._create_window(self.window_size, channel)
|
27 |
+
|
28 |
+
# window = window.to(img1.get_device())
|
29 |
+
window = window.type_as(img1)
|
30 |
+
|
31 |
+
self.window = window
|
32 |
+
self.channel = channel
|
33 |
+
|
34 |
+
return self._ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
35 |
+
|
36 |
+
def _gaussian(self, window_size, sigma):
|
37 |
+
gauss = torch.Tensor([
|
38 |
+
np.exp(-(x - (window_size // 2)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)
|
39 |
+
])
|
40 |
+
return gauss / gauss.sum()
|
41 |
+
|
42 |
+
def _create_window(self, window_size, channel):
|
43 |
+
_1D_window = self._gaussian(window_size, 1.5).unsqueeze(1)
|
44 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
45 |
+
return _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
46 |
+
|
47 |
+
def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
|
48 |
+
mu1 = F.conv2d(img1, window, padding=(window_size // 2), groups=channel)
|
49 |
+
mu2 = F.conv2d(img2, window, padding=(window_size // 2), groups=channel)
|
50 |
+
|
51 |
+
mu1_sq = mu1.pow(2)
|
52 |
+
mu2_sq = mu2.pow(2)
|
53 |
+
mu1_mu2 = mu1 * mu2
|
54 |
+
|
55 |
+
sigma1_sq = F.conv2d(
|
56 |
+
img1 * img1, window, padding=(window_size // 2), groups=channel) - mu1_sq
|
57 |
+
sigma2_sq = F.conv2d(
|
58 |
+
img2 * img2, window, padding=(window_size // 2), groups=channel) - mu2_sq
|
59 |
+
sigma12 = F.conv2d(
|
60 |
+
img1 * img2, window, padding=(window_size // 2), groups=channel) - mu1_mu2
|
61 |
+
|
62 |
+
C1 = 0.01 ** 2
|
63 |
+
C2 = 0.03 ** 2
|
64 |
+
|
65 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
|
66 |
+
((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
67 |
+
|
68 |
+
if size_average:
|
69 |
+
return ssim_map.mean()
|
70 |
+
|
71 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
72 |
+
|
73 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
74 |
+
return
|
DH-AISP/2/saicinpainting/evaluation/masks/README.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Current algorithm
|
2 |
+
|
3 |
+
## Choice of mask objects
|
4 |
+
|
5 |
+
For identification of the objects which are suitable for mask obtaining, panoptic segmentation model
|
6 |
+
from [detectron2](https://github.com/facebookresearch/detectron2) trained on COCO. Categories of the detected instances
|
7 |
+
belong either to "stuff" or "things" types. We consider that instances of objects should have category belong
|
8 |
+
to "things". Besides, we set upper bound on area which is taken by the object — we consider that too big
|
9 |
+
area indicates either of the instance being a background or a main object which should not be removed.
|
10 |
+
|
11 |
+
## Choice of position for mask
|
12 |
+
|
13 |
+
We consider that input image has size 2^n x 2^m. We downsample it using
|
14 |
+
[COUNTLESS](https://github.com/william-silversmith/countless) algorithm so the width is equal to
|
15 |
+
64 = 2^8 = 2^{downsample_levels}.
|
16 |
+
|
17 |
+
### Augmentation
|
18 |
+
|
19 |
+
There are several parameters for augmentation:
|
20 |
+
- Scaling factor. We limit scaling to the case when a mask after scaling with pivot point in its center fits inside the
|
21 |
+
image completely.
|
22 |
+
-
|
23 |
+
|
24 |
+
### Shift
|
25 |
+
|
26 |
+
|
27 |
+
## Select
|
DH-AISP/2/saicinpainting/evaluation/masks/__init__.py
ADDED
File without changes
|
DH-AISP/2/saicinpainting/evaluation/masks/countless/README.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[![Build Status](https://travis-ci.org/william-silversmith/countless.svg?branch=master)](https://travis-ci.org/william-silversmith/countless)
|
2 |
+
|
3 |
+
Python COUNTLESS Downsampling
|
4 |
+
=============================
|
5 |
+
|
6 |
+
To install:
|
7 |
+
|
8 |
+
`pip install -r requirements.txt`
|
9 |
+
|
10 |
+
To test:
|
11 |
+
|
12 |
+
`python test.py`
|
13 |
+
|
14 |
+
To benchmark countless2d:
|
15 |
+
|
16 |
+
`python python/countless2d.py python/images/gray_segmentation.png`
|
17 |
+
|
18 |
+
To benchmark countless3d:
|
19 |
+
|
20 |
+
`python python/countless3d.py`
|
21 |
+
|
22 |
+
Adjust N and the list of algorithms inside each script to modify the run parameters.
|
23 |
+
|
24 |
+
|
25 |
+
Python3 is slightly faster than Python2.
|
DH-AISP/2/saicinpainting/evaluation/masks/countless/__init__.py
ADDED
File without changes
|
DH-AISP/2/saicinpainting/evaluation/masks/countless/countless2d.py
ADDED
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division
|
2 |
+
|
3 |
+
"""
|
4 |
+
COUNTLESS performance test in Python.
|
5 |
+
|
6 |
+
python countless2d.py ./images/NAMEOFIMAGE
|
7 |
+
"""
|
8 |
+
|
9 |
+
import six
|
10 |
+
from six.moves import range
|
11 |
+
from collections import defaultdict
|
12 |
+
from functools import reduce
|
13 |
+
import operator
|
14 |
+
import io
|
15 |
+
import os
|
16 |
+
from PIL import Image
|
17 |
+
import math
|
18 |
+
import numpy as np
|
19 |
+
import random
|
20 |
+
import sys
|
21 |
+
import time
|
22 |
+
from tqdm import tqdm
|
23 |
+
from scipy import ndimage
|
24 |
+
|
25 |
+
def simplest_countless(data):
|
26 |
+
"""
|
27 |
+
Vectorized implementation of downsampling a 2D
|
28 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
29 |
+
|
30 |
+
data is a 2D numpy array with even dimensions.
|
31 |
+
"""
|
32 |
+
sections = []
|
33 |
+
|
34 |
+
# This loop splits the 2D array apart into four arrays that are
|
35 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
36 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
37 |
+
factor = (2,2)
|
38 |
+
for offset in np.ndindex(factor):
|
39 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
40 |
+
sections.append(part)
|
41 |
+
|
42 |
+
a, b, c, d = sections
|
43 |
+
|
44 |
+
ab = a * (a == b) # PICK(A,B)
|
45 |
+
ac = a * (a == c) # PICK(A,C)
|
46 |
+
bc = b * (b == c) # PICK(B,C)
|
47 |
+
|
48 |
+
a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed
|
49 |
+
|
50 |
+
return a + (a == 0) * d # AB || AC || BC || D
|
51 |
+
|
52 |
+
def quick_countless(data):
|
53 |
+
"""
|
54 |
+
Vectorized implementation of downsampling a 2D
|
55 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
56 |
+
|
57 |
+
data is a 2D numpy array with even dimensions.
|
58 |
+
"""
|
59 |
+
sections = []
|
60 |
+
|
61 |
+
# This loop splits the 2D array apart into four arrays that are
|
62 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
63 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
64 |
+
factor = (2,2)
|
65 |
+
for offset in np.ndindex(factor):
|
66 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
67 |
+
sections.append(part)
|
68 |
+
|
69 |
+
a, b, c, d = sections
|
70 |
+
|
71 |
+
ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
|
72 |
+
bc = b * (b == c) # PICK(B,C)
|
73 |
+
|
74 |
+
a = ab_ac | bc # (PICK(A,B) || PICK(A,C)) or PICK(B,C)
|
75 |
+
return a + (a == 0) * d # AB || AC || BC || D
|
76 |
+
|
77 |
+
def quickest_countless(data):
|
78 |
+
"""
|
79 |
+
Vectorized implementation of downsampling a 2D
|
80 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
81 |
+
|
82 |
+
data is a 2D numpy array with even dimensions.
|
83 |
+
"""
|
84 |
+
sections = []
|
85 |
+
|
86 |
+
# This loop splits the 2D array apart into four arrays that are
|
87 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
88 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
89 |
+
factor = (2,2)
|
90 |
+
for offset in np.ndindex(factor):
|
91 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
92 |
+
sections.append(part)
|
93 |
+
|
94 |
+
a, b, c, d = sections
|
95 |
+
|
96 |
+
ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
|
97 |
+
ab_ac |= b * (b == c) # PICK(B,C)
|
98 |
+
return ab_ac + (ab_ac == 0) * d # AB || AC || BC || D
|
99 |
+
|
100 |
+
def quick_countless_xor(data):
|
101 |
+
"""
|
102 |
+
Vectorized implementation of downsampling a 2D
|
103 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
104 |
+
|
105 |
+
data is a 2D numpy array with even dimensions.
|
106 |
+
"""
|
107 |
+
sections = []
|
108 |
+
|
109 |
+
# This loop splits the 2D array apart into four arrays that are
|
110 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
111 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
112 |
+
factor = (2,2)
|
113 |
+
for offset in np.ndindex(factor):
|
114 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
115 |
+
sections.append(part)
|
116 |
+
|
117 |
+
a, b, c, d = sections
|
118 |
+
|
119 |
+
ab = a ^ (a ^ b) # a or b
|
120 |
+
ab += (ab != a) * ((ab ^ (ab ^ c)) - b) # b or c
|
121 |
+
ab += (ab == c) * ((ab ^ (ab ^ d)) - c) # c or d
|
122 |
+
return ab
|
123 |
+
|
124 |
+
def stippled_countless(data):
|
125 |
+
"""
|
126 |
+
Vectorized implementation of downsampling a 2D
|
127 |
+
image by 2 on each side using the COUNTLESS algorithm
|
128 |
+
that treats zero as "background" and inflates lone
|
129 |
+
pixels.
|
130 |
+
|
131 |
+
data is a 2D numpy array with even dimensions.
|
132 |
+
"""
|
133 |
+
sections = []
|
134 |
+
|
135 |
+
# This loop splits the 2D array apart into four arrays that are
|
136 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
137 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
138 |
+
factor = (2,2)
|
139 |
+
for offset in np.ndindex(factor):
|
140 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
141 |
+
sections.append(part)
|
142 |
+
|
143 |
+
a, b, c, d = sections
|
144 |
+
|
145 |
+
ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
|
146 |
+
ab_ac |= b * (b == c) # PICK(B,C)
|
147 |
+
|
148 |
+
nonzero = a + (a == 0) * (b + (b == 0) * c)
|
149 |
+
return ab_ac + (ab_ac == 0) * (d + (d == 0) * nonzero) # AB || AC || BC || D
|
150 |
+
|
151 |
+
def zero_corrected_countless(data):
|
152 |
+
"""
|
153 |
+
Vectorized implementation of downsampling a 2D
|
154 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
155 |
+
|
156 |
+
data is a 2D numpy array with even dimensions.
|
157 |
+
"""
|
158 |
+
# allows us to prevent losing 1/2 a bit of information
|
159 |
+
# at the top end by using a bigger type. Without this 255 is handled incorrectly.
|
160 |
+
data, upgraded = upgrade_type(data)
|
161 |
+
|
162 |
+
# offset from zero, raw countless doesn't handle 0 correctly
|
163 |
+
# we'll remove the extra 1 at the end.
|
164 |
+
data += 1
|
165 |
+
|
166 |
+
sections = []
|
167 |
+
|
168 |
+
# This loop splits the 2D array apart into four arrays that are
|
169 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
170 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
171 |
+
factor = (2,2)
|
172 |
+
for offset in np.ndindex(factor):
|
173 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
174 |
+
sections.append(part)
|
175 |
+
|
176 |
+
a, b, c, d = sections
|
177 |
+
|
178 |
+
ab = a * (a == b) # PICK(A,B)
|
179 |
+
ac = a * (a == c) # PICK(A,C)
|
180 |
+
bc = b * (b == c) # PICK(B,C)
|
181 |
+
|
182 |
+
a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed
|
183 |
+
|
184 |
+
result = a + (a == 0) * d - 1 # a or d - 1
|
185 |
+
|
186 |
+
if upgraded:
|
187 |
+
return downgrade_type(result)
|
188 |
+
|
189 |
+
# only need to reset data if we weren't upgraded
|
190 |
+
# b/c no copy was made in that case
|
191 |
+
data -= 1
|
192 |
+
|
193 |
+
return result
|
194 |
+
|
195 |
+
def countless_extreme(data):
|
196 |
+
nonzeros = np.count_nonzero(data)
|
197 |
+
# print("nonzeros", nonzeros)
|
198 |
+
|
199 |
+
N = reduce(operator.mul, data.shape)
|
200 |
+
|
201 |
+
if nonzeros == N:
|
202 |
+
print("quick")
|
203 |
+
return quick_countless(data)
|
204 |
+
elif np.count_nonzero(data + 1) == N:
|
205 |
+
print("quick")
|
206 |
+
# print("upper", nonzeros)
|
207 |
+
return quick_countless(data)
|
208 |
+
else:
|
209 |
+
return countless(data)
|
210 |
+
|
211 |
+
|
212 |
+
def countless(data):
|
213 |
+
"""
|
214 |
+
Vectorized implementation of downsampling a 2D
|
215 |
+
image by 2 on each side using the COUNTLESS algorithm.
|
216 |
+
|
217 |
+
data is a 2D numpy array with even dimensions.
|
218 |
+
"""
|
219 |
+
# allows us to prevent losing 1/2 a bit of information
|
220 |
+
# at the top end by using a bigger type. Without this 255 is handled incorrectly.
|
221 |
+
data, upgraded = upgrade_type(data)
|
222 |
+
|
223 |
+
# offset from zero, raw countless doesn't handle 0 correctly
|
224 |
+
# we'll remove the extra 1 at the end.
|
225 |
+
data += 1
|
226 |
+
|
227 |
+
sections = []
|
228 |
+
|
229 |
+
# This loop splits the 2D array apart into four arrays that are
|
230 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
231 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
232 |
+
factor = (2,2)
|
233 |
+
for offset in np.ndindex(factor):
|
234 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
235 |
+
sections.append(part)
|
236 |
+
|
237 |
+
a, b, c, d = sections
|
238 |
+
|
239 |
+
ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
|
240 |
+
ab_ac |= b * (b == c) # PICK(B,C)
|
241 |
+
result = ab_ac + (ab_ac == 0) * d - 1 # (matches or d) - 1
|
242 |
+
|
243 |
+
if upgraded:
|
244 |
+
return downgrade_type(result)
|
245 |
+
|
246 |
+
# only need to reset data if we weren't upgraded
|
247 |
+
# b/c no copy was made in that case
|
248 |
+
data -= 1
|
249 |
+
|
250 |
+
return result
|
251 |
+
|
252 |
+
def upgrade_type(arr):
|
253 |
+
dtype = arr.dtype
|
254 |
+
|
255 |
+
if dtype == np.uint8:
|
256 |
+
return arr.astype(np.uint16), True
|
257 |
+
elif dtype == np.uint16:
|
258 |
+
return arr.astype(np.uint32), True
|
259 |
+
elif dtype == np.uint32:
|
260 |
+
return arr.astype(np.uint64), True
|
261 |
+
|
262 |
+
return arr, False
|
263 |
+
|
264 |
+
def downgrade_type(arr):
|
265 |
+
dtype = arr.dtype
|
266 |
+
|
267 |
+
if dtype == np.uint64:
|
268 |
+
return arr.astype(np.uint32)
|
269 |
+
elif dtype == np.uint32:
|
270 |
+
return arr.astype(np.uint16)
|
271 |
+
elif dtype == np.uint16:
|
272 |
+
return arr.astype(np.uint8)
|
273 |
+
|
274 |
+
return arr
|
275 |
+
|
276 |
+
def odd_to_even(image):
|
277 |
+
"""
|
278 |
+
To facilitate 2x2 downsampling segmentation, change an odd sized image into an even sized one.
|
279 |
+
Works by mirroring the starting 1 pixel edge of the image on odd shaped sides.
|
280 |
+
|
281 |
+
e.g. turn a 3x3x5 image into a 4x4x5 (the x and y are what are getting downsampled)
|
282 |
+
|
283 |
+
For example: [ 3, 2, 4 ] => [ 3, 3, 2, 4 ] which is now easy to downsample.
|
284 |
+
|
285 |
+
"""
|
286 |
+
shape = np.array(image.shape)
|
287 |
+
|
288 |
+
offset = (shape % 2)[:2] # x,y offset
|
289 |
+
|
290 |
+
# detect if we're dealing with an even
|
291 |
+
# image. if so it's fine, just return.
|
292 |
+
if not np.any(offset):
|
293 |
+
return image
|
294 |
+
|
295 |
+
oddshape = image.shape[:2] + offset
|
296 |
+
oddshape = np.append(oddshape, shape[2:])
|
297 |
+
oddshape = oddshape.astype(int)
|
298 |
+
|
299 |
+
newimg = np.empty(shape=oddshape, dtype=image.dtype)
|
300 |
+
|
301 |
+
ox,oy = offset
|
302 |
+
sx,sy = oddshape
|
303 |
+
|
304 |
+
newimg[0,0] = image[0,0] # corner
|
305 |
+
newimg[ox:sx,0] = image[:,0] # x axis line
|
306 |
+
newimg[0,oy:sy] = image[0,:] # y axis line
|
307 |
+
|
308 |
+
return newimg
|
309 |
+
|
310 |
+
def counting(array):
|
311 |
+
factor = (2, 2, 1)
|
312 |
+
shape = array.shape
|
313 |
+
|
314 |
+
while len(shape) < 4:
|
315 |
+
array = np.expand_dims(array, axis=-1)
|
316 |
+
shape = array.shape
|
317 |
+
|
318 |
+
output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor))
|
319 |
+
output = np.zeros(output_shape, dtype=array.dtype)
|
320 |
+
|
321 |
+
for chan in range(0, shape[3]):
|
322 |
+
for z in range(0, shape[2]):
|
323 |
+
for x in range(0, shape[0], 2):
|
324 |
+
for y in range(0, shape[1], 2):
|
325 |
+
block = array[ x:x+2, y:y+2, z, chan ] # 2x2 block
|
326 |
+
|
327 |
+
hashtable = defaultdict(int)
|
328 |
+
for subx, suby in np.ndindex(block.shape[0], block.shape[1]):
|
329 |
+
hashtable[block[subx, suby]] += 1
|
330 |
+
|
331 |
+
best = (0, 0)
|
332 |
+
for segid, val in six.iteritems(hashtable):
|
333 |
+
if best[1] < val:
|
334 |
+
best = (segid, val)
|
335 |
+
|
336 |
+
output[ x // 2, y // 2, chan ] = best[0]
|
337 |
+
|
338 |
+
return output
|
339 |
+
|
340 |
+
def ndzoom(array):
|
341 |
+
if len(array.shape) == 3:
|
342 |
+
ratio = ( 1 / 2.0, 1 / 2.0, 1.0 )
|
343 |
+
else:
|
344 |
+
ratio = ( 1 / 2.0, 1 / 2.0)
|
345 |
+
return ndimage.interpolation.zoom(array, ratio, order=1)
|
346 |
+
|
347 |
+
def countless_if(array):
|
348 |
+
factor = (2, 2, 1)
|
349 |
+
shape = array.shape
|
350 |
+
|
351 |
+
if len(shape) < 3:
|
352 |
+
array = array[ :,:, np.newaxis ]
|
353 |
+
shape = array.shape
|
354 |
+
|
355 |
+
output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor))
|
356 |
+
output = np.zeros(output_shape, dtype=array.dtype)
|
357 |
+
|
358 |
+
for chan in range(0, shape[2]):
|
359 |
+
for x in range(0, shape[0], 2):
|
360 |
+
for y in range(0, shape[1], 2):
|
361 |
+
block = array[ x:x+2, y:y+2, chan ] # 2x2 block
|
362 |
+
|
363 |
+
if block[0,0] == block[1,0]:
|
364 |
+
pick = block[0,0]
|
365 |
+
elif block[0,0] == block[0,1]:
|
366 |
+
pick = block[0,0]
|
367 |
+
elif block[1,0] == block[0,1]:
|
368 |
+
pick = block[1,0]
|
369 |
+
else:
|
370 |
+
pick = block[1,1]
|
371 |
+
|
372 |
+
output[ x // 2, y // 2, chan ] = pick
|
373 |
+
|
374 |
+
return np.squeeze(output)
|
375 |
+
|
376 |
+
def downsample_with_averaging(array):
|
377 |
+
"""
|
378 |
+
Downsample x by factor using averaging.
|
379 |
+
|
380 |
+
@return: The downsampled array, of the same type as x.
|
381 |
+
"""
|
382 |
+
|
383 |
+
if len(array.shape) == 3:
|
384 |
+
factor = (2,2,1)
|
385 |
+
else:
|
386 |
+
factor = (2,2)
|
387 |
+
|
388 |
+
if np.array_equal(factor[:3], np.array([1,1,1])):
|
389 |
+
return array
|
390 |
+
|
391 |
+
output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor))
|
392 |
+
temp = np.zeros(output_shape, float)
|
393 |
+
counts = np.zeros(output_shape, np.int)
|
394 |
+
for offset in np.ndindex(factor):
|
395 |
+
part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
396 |
+
indexing_expr = tuple(np.s_[:s] for s in part.shape)
|
397 |
+
temp[indexing_expr] += part
|
398 |
+
counts[indexing_expr] += 1
|
399 |
+
return np.cast[array.dtype](temp / counts)
|
400 |
+
|
401 |
+
def downsample_with_max_pooling(array):
|
402 |
+
|
403 |
+
factor = (2,2)
|
404 |
+
|
405 |
+
if np.all(np.array(factor, int) == 1):
|
406 |
+
return array
|
407 |
+
|
408 |
+
sections = []
|
409 |
+
|
410 |
+
for offset in np.ndindex(factor):
|
411 |
+
part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
412 |
+
sections.append(part)
|
413 |
+
|
414 |
+
output = sections[0].copy()
|
415 |
+
|
416 |
+
for section in sections[1:]:
|
417 |
+
np.maximum(output, section, output)
|
418 |
+
|
419 |
+
return output
|
420 |
+
|
421 |
+
def striding(array):
|
422 |
+
"""Downsample x by factor using striding.
|
423 |
+
|
424 |
+
@return: The downsampled array, of the same type as x.
|
425 |
+
"""
|
426 |
+
factor = (2,2)
|
427 |
+
if np.all(np.array(factor, int) == 1):
|
428 |
+
return array
|
429 |
+
return array[tuple(np.s_[::f] for f in factor)]
|
430 |
+
|
431 |
+
def benchmark():
|
432 |
+
filename = sys.argv[1]
|
433 |
+
img = Image.open(filename)
|
434 |
+
data = np.array(img.getdata(), dtype=np.uint8)
|
435 |
+
|
436 |
+
if len(data.shape) == 1:
|
437 |
+
n_channels = 1
|
438 |
+
reshape = (img.height, img.width)
|
439 |
+
else:
|
440 |
+
n_channels = min(data.shape[1], 3)
|
441 |
+
data = data[:, :n_channels]
|
442 |
+
reshape = (img.height, img.width, n_channels)
|
443 |
+
|
444 |
+
data = data.reshape(reshape).astype(np.uint8)
|
445 |
+
|
446 |
+
methods = [
|
447 |
+
simplest_countless,
|
448 |
+
quick_countless,
|
449 |
+
quick_countless_xor,
|
450 |
+
quickest_countless,
|
451 |
+
stippled_countless,
|
452 |
+
zero_corrected_countless,
|
453 |
+
countless,
|
454 |
+
downsample_with_averaging,
|
455 |
+
downsample_with_max_pooling,
|
456 |
+
ndzoom,
|
457 |
+
striding,
|
458 |
+
# countless_if,
|
459 |
+
# counting,
|
460 |
+
]
|
461 |
+
|
462 |
+
formats = {
|
463 |
+
1: 'L',
|
464 |
+
3: 'RGB',
|
465 |
+
4: 'RGBA'
|
466 |
+
}
|
467 |
+
|
468 |
+
if not os.path.exists('./results'):
|
469 |
+
os.mkdir('./results')
|
470 |
+
|
471 |
+
N = 500
|
472 |
+
img_size = float(img.width * img.height) / 1024.0 / 1024.0
|
473 |
+
print("N = %d, %dx%d (%.2f MPx) %d chan, %s" % (N, img.width, img.height, img_size, n_channels, filename))
|
474 |
+
print("Algorithm\tMPx/sec\tMB/sec\tSec")
|
475 |
+
for fn in methods:
|
476 |
+
print(fn.__name__, end='')
|
477 |
+
sys.stdout.flush()
|
478 |
+
|
479 |
+
start = time.time()
|
480 |
+
# tqdm is here to show you what's going on the first time you run it.
|
481 |
+
# Feel free to remove it to get slightly more accurate timing results.
|
482 |
+
for _ in tqdm(range(N), desc=fn.__name__, disable=True):
|
483 |
+
result = fn(data)
|
484 |
+
end = time.time()
|
485 |
+
print("\r", end='')
|
486 |
+
|
487 |
+
total_time = (end - start)
|
488 |
+
mpx = N * img_size / total_time
|
489 |
+
mbytes = N * img_size * n_channels / total_time
|
490 |
+
# Output in tab separated format to enable copy-paste into excel/numbers
|
491 |
+
print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time))
|
492 |
+
outimg = Image.fromarray(np.squeeze(result), formats[n_channels])
|
493 |
+
outimg.save('./results/{}.png'.format(fn.__name__, "PNG"))
|
494 |
+
|
495 |
+
if __name__ == '__main__':
|
496 |
+
benchmark()
|
497 |
+
|
498 |
+
|
499 |
+
# Example results:
|
500 |
+
# N = 5, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png
|
501 |
+
# Function MPx/sec MB/sec Sec
|
502 |
+
# simplest_countless 752.855 752.855 0.01
|
503 |
+
# quick_countless 920.328 920.328 0.01
|
504 |
+
# zero_corrected_countless 534.143 534.143 0.01
|
505 |
+
# countless 644.247 644.247 0.01
|
506 |
+
# downsample_with_averaging 372.575 372.575 0.01
|
507 |
+
# downsample_with_max_pooling 974.060 974.060 0.01
|
508 |
+
# ndzoom 137.517 137.517 0.04
|
509 |
+
# striding 38550.588 38550.588 0.00
|
510 |
+
# countless_if 4.377 4.377 1.14
|
511 |
+
# counting 0.117 0.117 42.85
|
512 |
+
|
513 |
+
# Run without non-numpy implementations:
|
514 |
+
# N = 2000, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png
|
515 |
+
# Algorithm MPx/sec MB/sec Sec
|
516 |
+
# simplest_countless 800.522 800.522 2.50
|
517 |
+
# quick_countless 945.420 945.420 2.12
|
518 |
+
# quickest_countless 947.256 947.256 2.11
|
519 |
+
# stippled_countless 544.049 544.049 3.68
|
520 |
+
# zero_corrected_countless 575.310 575.310 3.48
|
521 |
+
# countless 646.684 646.684 3.09
|
522 |
+
# downsample_with_averaging 385.132 385.132 5.19
|
523 |
+
# downsample_with_max_poolin 988.361 988.361 2.02
|
524 |
+
# ndzoom 163.104 163.104 12.26
|
525 |
+
# striding 81589.340 81589.340 0.02
|
526 |
+
|
527 |
+
|
528 |
+
|
529 |
+
|
DH-AISP/2/saicinpainting/evaluation/masks/countless/countless3d.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from six.moves import range
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import io
|
5 |
+
import time
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
from collections import defaultdict
|
10 |
+
from copy import deepcopy
|
11 |
+
from itertools import combinations
|
12 |
+
from functools import reduce
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from memory_profiler import profile
|
16 |
+
|
17 |
+
def countless5(a,b,c,d,e):
|
18 |
+
"""First stage of generalizing from countless2d.
|
19 |
+
|
20 |
+
You have five slots: A, B, C, D, E
|
21 |
+
|
22 |
+
You can decide if something is the winner by first checking for
|
23 |
+
matches of three, then matches of two, then picking just one if
|
24 |
+
the other two tries fail. In countless2d, you just check for matches
|
25 |
+
of two and then pick one of them otherwise.
|
26 |
+
|
27 |
+
Unfortunately, you need to check ABC, ABD, ABE, BCD, BDE, & CDE.
|
28 |
+
Then you need to check AB, AC, AD, BC, BD
|
29 |
+
We skip checking E because if none of these match, we pick E. We can
|
30 |
+
skip checking AE, BE, CE, DE since if any of those match, E is our boy
|
31 |
+
so it's redundant.
|
32 |
+
|
33 |
+
So countless grows cominatorially in complexity.
|
34 |
+
"""
|
35 |
+
sections = [ a,b,c,d,e ]
|
36 |
+
|
37 |
+
p2 = lambda q,r: q * (q == r) # q if p == q else 0
|
38 |
+
p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) # q if q == r == s else 0
|
39 |
+
|
40 |
+
lor = lambda x,y: x + (x == 0) * y
|
41 |
+
|
42 |
+
results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) )
|
43 |
+
results3 = reduce(lor, results3)
|
44 |
+
|
45 |
+
results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) )
|
46 |
+
results2 = reduce(lor, results2)
|
47 |
+
|
48 |
+
return reduce(lor, (results3, results2, e))
|
49 |
+
|
50 |
+
def countless8(a,b,c,d,e,f,g,h):
|
51 |
+
"""Extend countless5 to countless8. Same deal, except we also
|
52 |
+
need to check for matches of length 4."""
|
53 |
+
sections = [ a, b, c, d, e, f, g, h ]
|
54 |
+
|
55 |
+
p2 = lambda q,r: q * (q == r)
|
56 |
+
p3 = lambda q,r,s: q * ( (q == r) & (r == s) )
|
57 |
+
p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) )
|
58 |
+
|
59 |
+
lor = lambda x,y: x + (x == 0) * y
|
60 |
+
|
61 |
+
results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) )
|
62 |
+
results4 = reduce(lor, results4)
|
63 |
+
|
64 |
+
results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) )
|
65 |
+
results3 = reduce(lor, results3)
|
66 |
+
|
67 |
+
# We can always use our shortcut of omitting the last element
|
68 |
+
# for N choose 2
|
69 |
+
results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) )
|
70 |
+
results2 = reduce(lor, results2)
|
71 |
+
|
72 |
+
return reduce(lor, [ results4, results3, results2, h ])
|
73 |
+
|
74 |
+
def dynamic_countless3d(data):
|
75 |
+
"""countless8 + dynamic programming. ~2x faster"""
|
76 |
+
sections = []
|
77 |
+
|
78 |
+
# shift zeros up one so they don't interfere with bitwise operators
|
79 |
+
# we'll shift down at the end
|
80 |
+
data += 1
|
81 |
+
|
82 |
+
# This loop splits the 2D array apart into four arrays that are
|
83 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
84 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
85 |
+
factor = (2,2,2)
|
86 |
+
for offset in np.ndindex(factor):
|
87 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
88 |
+
sections.append(part)
|
89 |
+
|
90 |
+
pick = lambda a,b: a * (a == b)
|
91 |
+
lor = lambda x,y: x + (x == 0) * y
|
92 |
+
|
93 |
+
subproblems2 = {}
|
94 |
+
|
95 |
+
results2 = None
|
96 |
+
for x,y in combinations(range(7), 2):
|
97 |
+
res = pick(sections[x], sections[y])
|
98 |
+
subproblems2[(x,y)] = res
|
99 |
+
if results2 is not None:
|
100 |
+
results2 += (results2 == 0) * res
|
101 |
+
else:
|
102 |
+
results2 = res
|
103 |
+
|
104 |
+
subproblems3 = {}
|
105 |
+
|
106 |
+
results3 = None
|
107 |
+
for x,y,z in combinations(range(8), 3):
|
108 |
+
res = pick(subproblems2[(x,y)], sections[z])
|
109 |
+
|
110 |
+
if z != 7:
|
111 |
+
subproblems3[(x,y,z)] = res
|
112 |
+
|
113 |
+
if results3 is not None:
|
114 |
+
results3 += (results3 == 0) * res
|
115 |
+
else:
|
116 |
+
results3 = res
|
117 |
+
|
118 |
+
results3 = reduce(lor, (results3, results2, sections[-1]))
|
119 |
+
|
120 |
+
# free memory
|
121 |
+
results2 = None
|
122 |
+
subproblems2 = None
|
123 |
+
res = None
|
124 |
+
|
125 |
+
results4 = ( pick(subproblems3[(x,y,z)], sections[w]) for x,y,z,w in combinations(range(8), 4) )
|
126 |
+
results4 = reduce(lor, results4)
|
127 |
+
subproblems3 = None # free memory
|
128 |
+
|
129 |
+
final_result = lor(results4, results3) - 1
|
130 |
+
data -= 1
|
131 |
+
return final_result
|
132 |
+
|
133 |
+
def countless3d(data):
|
134 |
+
"""Now write countless8 in such a way that it could be used
|
135 |
+
to process an image."""
|
136 |
+
sections = []
|
137 |
+
|
138 |
+
# shift zeros up one so they don't interfere with bitwise operators
|
139 |
+
# we'll shift down at the end
|
140 |
+
data += 1
|
141 |
+
|
142 |
+
# This loop splits the 2D array apart into four arrays that are
|
143 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
144 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
145 |
+
factor = (2,2,2)
|
146 |
+
for offset in np.ndindex(factor):
|
147 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
148 |
+
sections.append(part)
|
149 |
+
|
150 |
+
p2 = lambda q,r: q * (q == r)
|
151 |
+
p3 = lambda q,r,s: q * ( (q == r) & (r == s) )
|
152 |
+
p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) )
|
153 |
+
|
154 |
+
lor = lambda x,y: x + (x == 0) * y
|
155 |
+
|
156 |
+
results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) )
|
157 |
+
results4 = reduce(lor, results4)
|
158 |
+
|
159 |
+
results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) )
|
160 |
+
results3 = reduce(lor, results3)
|
161 |
+
|
162 |
+
results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) )
|
163 |
+
results2 = reduce(lor, results2)
|
164 |
+
|
165 |
+
final_result = reduce(lor, (results4, results3, results2, sections[-1])) - 1
|
166 |
+
data -= 1
|
167 |
+
return final_result
|
168 |
+
|
169 |
+
def countless_generalized(data, factor):
|
170 |
+
assert len(data.shape) == len(factor)
|
171 |
+
|
172 |
+
sections = []
|
173 |
+
|
174 |
+
mode_of = reduce(lambda x,y: x * y, factor)
|
175 |
+
majority = int(math.ceil(float(mode_of) / 2))
|
176 |
+
|
177 |
+
data += 1
|
178 |
+
|
179 |
+
# This loop splits the 2D array apart into four arrays that are
|
180 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
181 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
182 |
+
for offset in np.ndindex(factor):
|
183 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
184 |
+
sections.append(part)
|
185 |
+
|
186 |
+
def pick(elements):
|
187 |
+
eq = ( elements[i] == elements[i+1] for i in range(len(elements) - 1) )
|
188 |
+
anded = reduce(lambda p,q: p & q, eq)
|
189 |
+
return elements[0] * anded
|
190 |
+
|
191 |
+
def logical_or(x,y):
|
192 |
+
return x + (x == 0) * y
|
193 |
+
|
194 |
+
result = ( pick(combo) for combo in combinations(sections, majority) )
|
195 |
+
result = reduce(logical_or, result)
|
196 |
+
for i in range(majority - 1, 3-1, -1): # 3-1 b/c of exclusive bounds
|
197 |
+
partial_result = ( pick(combo) for combo in combinations(sections, i) )
|
198 |
+
partial_result = reduce(logical_or, partial_result)
|
199 |
+
result = logical_or(result, partial_result)
|
200 |
+
|
201 |
+
partial_result = ( pick(combo) for combo in combinations(sections[:-1], 2) )
|
202 |
+
partial_result = reduce(logical_or, partial_result)
|
203 |
+
result = logical_or(result, partial_result)
|
204 |
+
|
205 |
+
result = logical_or(result, sections[-1]) - 1
|
206 |
+
data -= 1
|
207 |
+
return result
|
208 |
+
|
209 |
+
def dynamic_countless_generalized(data, factor):
|
210 |
+
assert len(data.shape) == len(factor)
|
211 |
+
|
212 |
+
sections = []
|
213 |
+
|
214 |
+
mode_of = reduce(lambda x,y: x * y, factor)
|
215 |
+
majority = int(math.ceil(float(mode_of) / 2))
|
216 |
+
|
217 |
+
data += 1 # offset from zero
|
218 |
+
|
219 |
+
# This loop splits the 2D array apart into four arrays that are
|
220 |
+
# all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
|
221 |
+
# and (1,1) representing the A, B, C, and D positions from Figure 1.
|
222 |
+
for offset in np.ndindex(factor):
|
223 |
+
part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
224 |
+
sections.append(part)
|
225 |
+
|
226 |
+
pick = lambda a,b: a * (a == b)
|
227 |
+
lor = lambda x,y: x + (x == 0) * y # logical or
|
228 |
+
|
229 |
+
subproblems = [ {}, {} ]
|
230 |
+
results2 = None
|
231 |
+
for x,y in combinations(range(len(sections) - 1), 2):
|
232 |
+
res = pick(sections[x], sections[y])
|
233 |
+
subproblems[0][(x,y)] = res
|
234 |
+
if results2 is not None:
|
235 |
+
results2 = lor(results2, res)
|
236 |
+
else:
|
237 |
+
results2 = res
|
238 |
+
|
239 |
+
results = [ results2 ]
|
240 |
+
for r in range(3, majority+1):
|
241 |
+
r_results = None
|
242 |
+
for combo in combinations(range(len(sections)), r):
|
243 |
+
res = pick(subproblems[0][combo[:-1]], sections[combo[-1]])
|
244 |
+
|
245 |
+
if combo[-1] != len(sections) - 1:
|
246 |
+
subproblems[1][combo] = res
|
247 |
+
|
248 |
+
if r_results is not None:
|
249 |
+
r_results = lor(r_results, res)
|
250 |
+
else:
|
251 |
+
r_results = res
|
252 |
+
results.append(r_results)
|
253 |
+
subproblems[0] = subproblems[1]
|
254 |
+
subproblems[1] = {}
|
255 |
+
|
256 |
+
results.reverse()
|
257 |
+
final_result = lor(reduce(lor, results), sections[-1]) - 1
|
258 |
+
data -= 1
|
259 |
+
return final_result
|
260 |
+
|
261 |
+
def downsample_with_averaging(array):
|
262 |
+
"""
|
263 |
+
Downsample x by factor using averaging.
|
264 |
+
|
265 |
+
@return: The downsampled array, of the same type as x.
|
266 |
+
"""
|
267 |
+
factor = (2,2,2)
|
268 |
+
|
269 |
+
if np.array_equal(factor[:3], np.array([1,1,1])):
|
270 |
+
return array
|
271 |
+
|
272 |
+
output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor))
|
273 |
+
temp = np.zeros(output_shape, float)
|
274 |
+
counts = np.zeros(output_shape, np.int)
|
275 |
+
for offset in np.ndindex(factor):
|
276 |
+
part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
277 |
+
indexing_expr = tuple(np.s_[:s] for s in part.shape)
|
278 |
+
temp[indexing_expr] += part
|
279 |
+
counts[indexing_expr] += 1
|
280 |
+
return np.cast[array.dtype](temp / counts)
|
281 |
+
|
282 |
+
def downsample_with_max_pooling(array):
|
283 |
+
|
284 |
+
factor = (2,2,2)
|
285 |
+
|
286 |
+
sections = []
|
287 |
+
|
288 |
+
for offset in np.ndindex(factor):
|
289 |
+
part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
|
290 |
+
sections.append(part)
|
291 |
+
|
292 |
+
output = sections[0].copy()
|
293 |
+
|
294 |
+
for section in sections[1:]:
|
295 |
+
np.maximum(output, section, output)
|
296 |
+
|
297 |
+
return output
|
298 |
+
|
299 |
+
def striding(array):
|
300 |
+
"""Downsample x by factor using striding.
|
301 |
+
|
302 |
+
@return: The downsampled array, of the same type as x.
|
303 |
+
"""
|
304 |
+
factor = (2,2,2)
|
305 |
+
if np.all(np.array(factor, int) == 1):
|
306 |
+
return array
|
307 |
+
return array[tuple(np.s_[::f] for f in factor)]
|
308 |
+
|
309 |
+
def benchmark():
|
310 |
+
def countless3d_generalized(img):
|
311 |
+
return countless_generalized(img, (2,8,1))
|
312 |
+
def countless3d_dynamic_generalized(img):
|
313 |
+
return dynamic_countless_generalized(img, (8,8,1))
|
314 |
+
|
315 |
+
methods = [
|
316 |
+
# countless3d,
|
317 |
+
# dynamic_countless3d,
|
318 |
+
countless3d_generalized,
|
319 |
+
# countless3d_dynamic_generalized,
|
320 |
+
# striding,
|
321 |
+
# downsample_with_averaging,
|
322 |
+
# downsample_with_max_pooling
|
323 |
+
]
|
324 |
+
|
325 |
+
data = np.zeros(shape=(16**2, 16**2, 16**2), dtype=np.uint8) + 1
|
326 |
+
|
327 |
+
N = 5
|
328 |
+
|
329 |
+
print('Algorithm\tMPx\tMB/sec\tSec\tN=%d' % N)
|
330 |
+
|
331 |
+
for fn in methods:
|
332 |
+
start = time.time()
|
333 |
+
for _ in range(N):
|
334 |
+
result = fn(data)
|
335 |
+
end = time.time()
|
336 |
+
|
337 |
+
total_time = (end - start)
|
338 |
+
mpx = N * float(data.shape[0] * data.shape[1] * data.shape[2]) / total_time / 1024.0 / 1024.0
|
339 |
+
mbytes = mpx * np.dtype(data.dtype).itemsize
|
340 |
+
# Output in tab separated format to enable copy-paste into excel/numbers
|
341 |
+
print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time))
|
342 |
+
|
343 |
+
if __name__ == '__main__':
|
344 |
+
benchmark()
|
345 |
+
|
346 |
+
# Algorithm MPx MB/sec Sec N=5
|
347 |
+
# countless3d 10.564 10.564 60.58
|
348 |
+
# dynamic_countless3d 22.717 22.717 28.17
|
349 |
+
# countless3d_generalized 9.702 9.702 65.96
|
350 |
+
# countless3d_dynamic_generalized 22.720 22.720 28.17
|
351 |
+
# striding 253360.506 253360.506 0.00
|
352 |
+
# downsample_with_averaging 224.098 224.098 2.86
|
353 |
+
# downsample_with_max_pooling 690.474 690.474 0.93
|
354 |
+
|
355 |
+
|
356 |
+
|
DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gcim.jpg
ADDED
Git LFS Details
|
DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png
ADDED
DH-AISP/2/saicinpainting/evaluation/masks/countless/images/segmentation.png
ADDED
DH-AISP/2/saicinpainting/evaluation/masks/countless/images/sparse.png
ADDED
DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png
ADDED
DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png
ADDED
DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d.png
ADDED