Artyom commited on
Commit
bd1c686
1 Parent(s): 6721043
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. DH-AISP/1/__pycache__/awb.cpython-36.pyc +0 -0
  3. DH-AISP/1/awb.py +184 -0
  4. DH-AISP/1/daylight_isp_03_3_unet_sid_5/checkpoint +2 -0
  5. DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.data-00000-of-00001 +3 -0
  6. DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.index +0 -0
  7. DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.meta +3 -0
  8. DH-AISP/1/tensorflow2to1_3_unet_bining3_7.py +451 -0
  9. DH-AISP/2/__pycache__/model_convnext2_hdr.cpython-37.pyc +0 -0
  10. DH-AISP/2/__pycache__/myFFCResblock0.cpython-37.pyc +0 -0
  11. DH-AISP/2/__pycache__/test_dataset_for_testing.cpython-37.pyc +0 -0
  12. DH-AISP/2/focal_frequency_loss/__init__.py +3 -0
  13. DH-AISP/2/focal_frequency_loss/__pycache__/__init__.cpython-37.pyc +0 -0
  14. DH-AISP/2/focal_frequency_loss/__pycache__/focal_frequency_loss.cpython-37.pyc +0 -0
  15. DH-AISP/2/focal_frequency_loss/focal_frequency_loss.py +114 -0
  16. DH-AISP/2/model_convnext2_hdr.py +592 -0
  17. DH-AISP/2/myFFCResblock0.py +60 -0
  18. DH-AISP/2/perceptual.py +30 -0
  19. DH-AISP/2/pytorch_msssim/__init__.py +133 -0
  20. DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-36.pyc +0 -0
  21. DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-37.pyc +0 -0
  22. DH-AISP/2/result_low_light_hdr/checkpoint_gen.pth +3 -0
  23. DH-AISP/2/saicinpainting/__init__.py +0 -0
  24. DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-36.pyc +0 -0
  25. DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-37.pyc +0 -0
  26. DH-AISP/2/saicinpainting/__pycache__/utils.cpython-36.pyc +0 -0
  27. DH-AISP/2/saicinpainting/__pycache__/utils.cpython-37.pyc +0 -0
  28. DH-AISP/2/saicinpainting/evaluation/__init__.py +33 -0
  29. DH-AISP/2/saicinpainting/evaluation/data.py +168 -0
  30. DH-AISP/2/saicinpainting/evaluation/evaluator.py +220 -0
  31. DH-AISP/2/saicinpainting/evaluation/losses/__init__.py +0 -0
  32. DH-AISP/2/saicinpainting/evaluation/losses/base_loss.py +528 -0
  33. DH-AISP/2/saicinpainting/evaluation/losses/fid/__init__.py +0 -0
  34. DH-AISP/2/saicinpainting/evaluation/losses/fid/fid_score.py +328 -0
  35. DH-AISP/2/saicinpainting/evaluation/losses/fid/inception.py +323 -0
  36. DH-AISP/2/saicinpainting/evaluation/losses/lpips.py +891 -0
  37. DH-AISP/2/saicinpainting/evaluation/losses/ssim.py +74 -0
  38. DH-AISP/2/saicinpainting/evaluation/masks/README.md +27 -0
  39. DH-AISP/2/saicinpainting/evaluation/masks/__init__.py +0 -0
  40. DH-AISP/2/saicinpainting/evaluation/masks/countless/README.md +25 -0
  41. DH-AISP/2/saicinpainting/evaluation/masks/countless/__init__.py +0 -0
  42. DH-AISP/2/saicinpainting/evaluation/masks/countless/countless2d.py +529 -0
  43. DH-AISP/2/saicinpainting/evaluation/masks/countless/countless3d.py +356 -0
  44. DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gcim.jpg +3 -0
  45. DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png +0 -0
  46. DH-AISP/2/saicinpainting/evaluation/masks/countless/images/segmentation.png +0 -0
  47. DH-AISP/2/saicinpainting/evaluation/masks/countless/images/sparse.png +0 -0
  48. DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png +0 -0
  49. DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png +0 -0
  50. DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d.png +0 -0
.gitattributes CHANGED
@@ -39,3 +39,6 @@ SCBC/Input/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text
39
  SCBC/Output/IMG_20240215_213330.png filter=lfs diff=lfs merge=lfs -text
40
  SCBC/Output/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text
41
  PolyuColor/resources/average_shading.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
39
  SCBC/Output/IMG_20240215_213330.png filter=lfs diff=lfs merge=lfs -text
40
  SCBC/Output/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text
41
  PolyuColor/resources/average_shading.png filter=lfs diff=lfs merge=lfs -text
42
+ DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
43
+ DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.meta filter=lfs diff=lfs merge=lfs -text
44
+ DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gcim.jpg filter=lfs diff=lfs merge=lfs -text
DH-AISP/1/__pycache__/awb.cpython-36.pyc ADDED
Binary file (3.82 kB). View file
 
DH-AISP/1/awb.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from glob import glob
5
+
6
+
7
+ def dynamic(rgb):
8
+
9
+ rgb = rgb[:-1, :-1, :] # 删去一行一列
10
+ h, w, _ = rgb.shape
11
+ col = 4
12
+ row = 3
13
+ h1 = h // row
14
+ w1 = w // col
15
+
16
+ r, g, b = cv2.split(rgb)
17
+ r_mask = r < 0.95
18
+ g_mask = g < 0.95
19
+ b_mask = b < 0.95
20
+ mask = r_mask * g_mask * b_mask
21
+ r *= mask
22
+ g *= mask
23
+ b *= mask
24
+ rgb = np.stack((r, g, b), axis=2)
25
+
26
+ y, cr, cb = cv2.split(cv2.cvtColor(rgb, cv2.COLOR_RGB2YCrCb))
27
+ cr -= 0.5
28
+ cb -= 0.5
29
+
30
+ mr, mb, dr, db = 0, 0, 0, 0
31
+ for r in range(row):
32
+ for c in range(col):
33
+ cr_1 = cr[r * h1:(r + 1) * h1, c * w1:(c + 1) * w1]
34
+ cb_1 = cb[r * h1:(r + 1) * h1, c * w1:(c + 1) * w1]
35
+ mr_1 = np.mean(cr_1)
36
+ mb_1 = np.mean(cb_1)
37
+ dr_1 = np.mean(np.abs(cr_1 - mr))
38
+ db_1 = np.mean(np.abs(cb_1 - mb))
39
+
40
+ mr += mr_1
41
+ mb += mb_1
42
+ dr += dr_1
43
+ db += db_1
44
+
45
+ mr /= col * row
46
+ mb /= col * row
47
+ dr /= col * row
48
+ db /= col * row
49
+
50
+ cb_mask = np.abs(cb - (mb + db * np.sign(mb))) < 1.5 * db
51
+ cr_mask = np.abs(cr - (1.5 * mr + dr * np.sign(mr))) < 1.5 * dr
52
+
53
+ mask = cb_mask * cr_mask
54
+ y_white = y * mask
55
+
56
+ hist_y = np.zeros(256, dtype=np.int)
57
+ y_white_uint8 = (y_white * 255).astype(np.int)
58
+
59
+ for v in range(255):
60
+ hist_y[v] = np.sum(y_white_uint8 == v)
61
+
62
+ thr_sum = 0.05 * np.sum(mask)
63
+ sum_v = 0
64
+ thr = 0
65
+ for v in range(255, -1, -1):
66
+ sum_v = sum_v + hist_y[v]
67
+ if sum_v > thr_sum:
68
+ thr = v
69
+ break
70
+
71
+ white_mask = y_white_uint8 > thr
72
+ cv2.imwrite(r'V:\Project\3_MEWDR\data\2nd_awb\t.png', (white_mask + 0) * 255)
73
+
74
+ r, g, b = cv2.split(rgb)
75
+ r_ave = np.sum(r[white_mask]) / np.sum(white_mask)
76
+ g_ave = np.sum(g[white_mask]) / np.sum(white_mask)
77
+ b_ave = np.sum(b[white_mask]) / np.sum(white_mask)
78
+
79
+ return 1 / r_ave, 1 / g_ave, 1 / b_ave
80
+
81
+
82
+ def perf_ref(rgb, eps):
83
+ h, w, _ = rgb.shape
84
+
85
+ r, g, b = cv2.split(rgb)
86
+ r_mask = r < 0.95
87
+ g_mask = g < 0.95
88
+ b_mask = b < 0.95
89
+ mask = r_mask * g_mask * b_mask
90
+ r *= mask
91
+ g *= mask
92
+ b *= mask
93
+ rgb = np.stack((r, g, b), axis=2)
94
+ rgb = np.clip(rgb * 255, 0, 255).astype(np.int)
95
+
96
+ hist_rgb = np.zeros(255 * 3, dtype=np.int)
97
+ rgb_sum = np.sum(rgb, axis=2)
98
+
99
+ for v in range(255 * 3):
100
+ hist_rgb[v] = np.sum(rgb_sum == v)
101
+
102
+ thr_sum = eps * h * w
103
+ sum_v = 0
104
+ thr = 0
105
+ for v in range(255 * 3 - 1, -1, -1):
106
+ sum_v = sum_v + hist_rgb[v]
107
+ if sum_v > thr_sum:
108
+ thr = v
109
+ break
110
+
111
+ thr_mask = rgb_sum > thr
112
+ r_ave = np.sum(r[thr_mask]) / np.sum(thr_mask)
113
+ g_ave = np.sum(g[thr_mask]) / np.sum(thr_mask)
114
+ b_ave = np.sum(b[thr_mask]) / np.sum(thr_mask)
115
+
116
+ # k = (r_ave + g_ave + b_ave) / 3.
117
+ # k = 255
118
+
119
+ # print(k)
120
+
121
+ # r = np.clip(r * k / r_ave, 0, 255)
122
+ # g = np.clip(g * k / g_ave, 0, 255)
123
+ # b = np.clip(b * k / b_ave, 0, 255)
124
+
125
+ return 1 / r_ave, 1 / g_ave, 1 / b_ave
126
+
127
+
128
+ def awb_v(in_image, bayer, eps):
129
+
130
+ assert bayer in ['GBRG', 'RGGB']
131
+
132
+ if bayer == 'GBRG':
133
+ g = in_image[0::2, 0::2] # [0,0]
134
+ b = in_image[0::2, 1::2] # [0,1]
135
+ r = in_image[1::2, 0::2] # [1,0]
136
+ else:
137
+ r = in_image[0::2, 0::2] # [0,0]
138
+ g = in_image[0::2, 1::2] # [0,1]
139
+ b = in_image[1::2, 1::2] # [1,1]
140
+
141
+ rgb = cv2.merge((r, g, b)) * 1
142
+
143
+ r_gain, g_gain, b_gain = perf_ref(rgb, eps)
144
+
145
+ return r_gain / g_gain, b_gain / g_gain
146
+
147
+
148
+ def main():
149
+ path = r'V:\Project\3_MEWDR\data\2nd_raw'
150
+ # out_path = r'V:\Project\3_MEWDR\data\2nd_awb'
151
+
152
+ files = glob(os.path.join(path, '*.png'))
153
+
154
+ for f in files:
155
+ img = cv2.imread(f, cv2.CV_16UC1)
156
+ img = (img.astype(np.float) - 2048) / (15400 - 2048) * 4
157
+
158
+ g = img[0::2, 0::2] # [0,0]
159
+ b = img[0::2, 1::2] # [0,1]
160
+ r = img[1::2, 0::2] # [1,0]
161
+ # g_ = img[1::2, 1::2]
162
+
163
+ rgb = cv2.merge((r, g, b))
164
+
165
+ # save_name = f.replace('.png', '_rgb.png').replace('2nd_raw', '2nd_awb')
166
+
167
+ r_gain, g_gain, b_gain = perf_ref(rgb, eps=0.1)
168
+ # r_gain, g_gain, b_gain = dynamic(rgb.astype(np.float32))
169
+
170
+ r *= r_gain / g_gain
171
+ b *= b_gain / g_gain
172
+ print(r_gain / g_gain, b_gain / g_gain)
173
+
174
+ out_rgb = np.clip(cv2.merge((r, g, b)) * 255, 0, 255)
175
+
176
+ save_name = f.replace('.png', '_awb4_dyn.png').replace('2nd_raw', '2nd_awb')
177
+
178
+ cv2.imwrite(save_name, cv2.cvtColor(out_rgb.astype(np.uint8), cv2.COLOR_RGB2BGR))
179
+
180
+ # break
181
+
182
+
183
+ if __name__ == '__main__':
184
+ main()
DH-AISP/1/daylight_isp_03_3_unet_sid_5/checkpoint ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ model_checkpoint_path: "model.ckpt"
2
+ all_model_checkpoint_paths: "model.ckpt"
DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6997bfa5624aba66e2497088cc8f379db63bac343a0a648e08f6a5840a48259f
3
+ size 175070404
DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.index ADDED
Binary file (6.36 kB). View file
 
DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.meta ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79f84947bf3a5a9e851539308b85b43ecc6a8e93ed2c7ab9adb23f0fd6796286
3
+ size 124053471
DH-AISP/1/tensorflow2to1_3_unet_bining3_7.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # uniform content loss + adaptive threshold + per_class_input + recursive G
2
+ # improvement upon cqf37
3
+ from __future__ import division
4
+ import os
5
+ import tensorflow.compat.v1 as tf
6
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
7
+ import tf_slim as slim
8
+ import tensorflow as tf2
9
+ tf2.test.is_gpu_available()
10
+ import numpy as np
11
+ import glob
12
+ # import scipy.io as sio
13
+ import cv2
14
+ import json
15
+ from fractions import Fraction
16
+ import pdb
17
+ import sys
18
+ import argparse
19
+
20
+ from awb import awb_v
21
+
22
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
23
+ input_dir = '../data/'
24
+ cha1 = 32
25
+
26
+ # get train IDs
27
+ train_fns = glob.glob(input_dir + '*.png')
28
+ train_ids = [os.path.basename(train_fn) for train_fn in train_fns]
29
+
30
+ result_dir = './mid/'
31
+ checkpoint_dir = './daylight_isp_03_3_unet_sid_5/'
32
+
33
+ if not os.path.exists(result_dir):
34
+ os.mkdir(result_dir)
35
+
36
+ #run python tensorflow2to1_1214_5202x3464_01_unetpp3.py ./data/ ./result/ ./daylight_isp_03/
37
+
38
+ # DEBUG = 0
39
+ # if DEBUG == 1:
40
+ # save_freq = 2
41
+ # test_ids = test_ids[0:5]
42
+
43
+ def json_read(fname, **kwargs):
44
+ with open(fname) as j:
45
+ data = json.load(j, **kwargs)
46
+ return data
47
+
48
+ def fraction_from_json(json_object):
49
+ if 'Fraction' in json_object:
50
+ return Fraction(*json_object['Fraction'])
51
+ return json_object
52
+
53
+ def fractions2floats(fractions):
54
+ floats = []
55
+ for fraction in fractions:
56
+ floats.append(float(fraction.numerator) / fraction.denominator)
57
+ return floats
58
+
59
+ def tv_loss(input_, output):
60
+ I = tf.image.rgb_to_grayscale(input_)
61
+ L = tf.log(I+0.0001)
62
+ dx = L[:, :-1, :-1, :] - L[:, :-1, 1:, :]
63
+ dy = L[:, :-1, :-1, :] - L[:, 1:, :-1, :]
64
+
65
+ alpha = tf.constant(1.2)
66
+ lamda = tf.constant(1.5)
67
+ dx = tf.divide(lamda, tf.pow(tf.abs(dx),alpha)+ tf.constant(0.0001))
68
+ dy = tf.divide(lamda, tf.pow(tf.abs(dy),alpha)+ tf.constant(0.0001))
69
+ shape = output.get_shape()
70
+ x_loss = dx *((output[:, :-1, :-1, :] - output[:, :-1, 1:, :])**2)
71
+ y_loss = dy *((output[:, :-1, :-1, :] - output[:, 1:, :-1, :])**2)
72
+ tvloss = tf.reduce_mean(x_loss + y_loss)/2.0
73
+ return tvloss
74
+
75
+ def lrelu(x):
76
+ return tf.maximum(x * 0.2, x)
77
+
78
+
79
+ def upsample_and_concat_3(x1, x2, output_channels, in_channels, name):
80
+ with tf.variable_scope(name):
81
+ x1 = slim.conv2d(x1, output_channels, [3, 3], rate=1, activation_fn=lrelu, scope='conv_2to1')
82
+ deconv = tf.image.resize_images(x1, [x1.shape[1] * 2, x1.shape[2] * 2])
83
+ deconv_output = tf.concat([deconv, x2], 3)
84
+ deconv_output.set_shape([None, None, None, output_channels * 2])
85
+ return deconv_output
86
+
87
+
88
+ def upsample_and_concat_h(x1, x2, output_channels, in_channels, name):
89
+ with tf.variable_scope(name):
90
+ #deconv = tf.image.resize_images(x1, [x1.shape[1].value*2, x1.shape[2].value*2])
91
+ pool_size = 2
92
+ deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
93
+ deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])
94
+
95
+ deconv_output = tf.concat([deconv, x2], 3)
96
+ deconv_output.set_shape([None, None, None, output_channels * 2])
97
+
98
+ return deconv_output
99
+
100
+ def upsample_and_concat_h_only(x1, output_channels, in_channels, name):
101
+ with tf.variable_scope(name):
102
+ x1 = tf.image.resize_images(x1, [x1.shape[1] * 2, x1.shape[2] * 2])
103
+ x1.set_shape([None, None, None, output_channels])
104
+ return x1
105
+
106
+
107
+ def conv_block(input, output_channels, name):
108
+ with tf.variable_scope(name):
109
+ conv = slim.conv2d(input, output_channels, [3, 3], activation_fn=lrelu, scope='conv1')
110
+ conv = slim.conv2d(conv, output_channels, [3, 3], activation_fn=lrelu, scope='conv2')
111
+ return conv
112
+
113
+
114
+ def conv_block_up(input, output_channels, name):
115
+ with tf.variable_scope(name):
116
+ conv = slim.conv2d(input, output_channels, [1, 1], scope='conv0')
117
+ conv = slim.conv2d(conv, output_channels, [3, 3], activation_fn=lrelu, scope='conv1')
118
+ conv = slim.conv2d(conv, output_channels, [3, 3], activation_fn=lrelu, scope='conv2')
119
+
120
+ return conv
121
+
122
+
123
+ def upsample_and_concat(x1, x2, output_channels, in_channels, p, name):
124
+ with tf.variable_scope(name):
125
+ pool_size = 2
126
+
127
+ deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
128
+ deconv_filter = tf.cast(deconv_filter, x1.dtype)
129
+ deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2[0]), strides=[1, pool_size, pool_size, 1])
130
+ # x2.append(deconv)
131
+ x2 = tf.concat(x2, axis=3)
132
+ deconv_output = tf.concat([x2, deconv], axis=3)
133
+ deconv_output.set_shape([None, None, None, output_channels * (p + 1)])
134
+
135
+ return deconv_output
136
+
137
+
138
+ def network(input):
139
+ with tf.variable_scope("generator_h"):
140
+ conv1_h = slim.conv2d(input, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_1')
141
+ conv1_h = slim.conv2d(conv1_h, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_2')
142
+ pool1_h = slim.max_pool2d(conv1_h, [2, 2], padding='SAME')
143
+
144
+ conv2_h = slim.conv2d(pool1_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_1')
145
+ conv2_h = slim.conv2d(conv2_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_2')
146
+ pool2_h = slim.max_pool2d(conv2_h, [2, 2], padding='SAME')
147
+
148
+ conv3_h = slim.conv2d(pool2_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_1')
149
+ conv3_h = slim.conv2d(conv3_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_2')
150
+ pool3_h = slim.max_pool2d(conv3_h, [2, 2], padding='SAME')
151
+
152
+ conv4_h = slim.conv2d(pool3_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_1')
153
+ conv4_h = slim.conv2d(conv4_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_2')
154
+ conv6_h = slim.conv2d(conv4_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_1')
155
+ conv6_h = slim.conv2d(conv6_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_2')
156
+
157
+ up7_h = upsample_and_concat_3(conv6_h, conv3_h, cha1*4,cha1*8, name='up7')
158
+ conv7_h = slim.conv2d(up7_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_1')
159
+ conv7_h = slim.conv2d(conv7_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_2')
160
+
161
+ up8_h = upsample_and_concat_3(conv7_h, conv2_h, cha1*2,cha1*4, name='up8')
162
+ conv8_h = slim.conv2d(up8_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_1')
163
+ conv8_h = slim.conv2d(conv8_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_2')
164
+
165
+ up9_h = upsample_and_concat_3(conv8_h, conv1_h, cha1,cha1*2, name='up9')
166
+ conv9_h = slim.conv2d(up9_h, cha1, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_1')
167
+ conv9_h = slim.conv2d(conv9_h, cha1, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_2')
168
+
169
+ up10_h = upsample_and_concat_h_only(conv9_h, cha1,cha1, name='up10')
170
+ conv10_h = slim.conv2d(up10_h, cha1, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv10_1')
171
+ out = slim.conv2d(conv10_h, 3, [3, 3], rate=1, activation_fn=None, scope='g_conv10_2')
172
+ return out
173
+
174
+
175
+
176
+ def fix_orientation(image, orientation):
177
+ # 1 = Horizontal(normal)
178
+ # 2 = Mirror horizontal
179
+ # 3 = Rotate 180
180
+ # 4 = Mirror vertical
181
+ # 5 = Mirror horizontal and rotate 270 CW
182
+ # 6 = Rotate 90 CW
183
+ # 7 = Mirror horizontal and rotate 90 CW
184
+ # 8 = Rotate 270 CW
185
+
186
+ if type(orientation) is list:
187
+ orientation = orientation[0]
188
+
189
+ if orientation == 'Horizontal (normal)':
190
+ pass
191
+ elif orientation == 'Mirror horizontal':
192
+ image = cv2.flip(image, 0)
193
+ elif orientation == 'Rotate 180':
194
+ image = cv2.rotate(image, cv2.ROTATE_180)
195
+ elif orientation == 'Mirror vertical':
196
+ image = cv2.flip(image, 1)
197
+ elif orientation == 'Mirror horizontal and rotate 270 CW':
198
+ image = cv2.flip(image, 0)
199
+ image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
200
+ elif orientation == 'Rotate 90 CW':
201
+ image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
202
+ elif orientation == 'Mirror horizontal and rotate 90 CW':
203
+ image = cv2.flip(image, 0)
204
+ image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
205
+ elif orientation == 'Rotate 270 CW':
206
+ image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
207
+
208
+ return image
209
+
210
+ class ExposureFusion(object):
211
+ def __init__(self, sequence, best_exposedness=0.5, sigma=0.2, eps=1e-12, exponents=(1.0, 1.0, 1.0), layers=11):
212
+ self.sequence = sequence # [N, H, W, 3], (0..1), float32
213
+ self.img_num = sequence.shape[0]
214
+ self.best_exposedness = best_exposedness
215
+ self.sigma = sigma
216
+ self.eps = eps
217
+ self.exponents = exponents
218
+ self.layers = layers
219
+
220
+ @staticmethod
221
+ def cal_contrast(src):
222
+ gray = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
223
+ laplace_kernel = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32)
224
+ contrast = cv2.filter2D(gray, -1, laplace_kernel, borderType=cv2.BORDER_REPLICATE)
225
+ return np.abs(contrast)
226
+
227
+ @staticmethod
228
+ def cal_saturation(src):
229
+ mean = np.mean(src, axis=-1)
230
+ channels = [(src[:, :, c] - mean)**2 for c in range(3)]
231
+ saturation = np.sqrt(np.mean(channels, axis=0))
232
+ return saturation
233
+
234
+ @staticmethod
235
+ def cal_exposedness(src, best_exposedness, sigma):
236
+ exposedness = [gauss_curve(src[:, :, c], best_exposedness, sigma) for c in range(3)]
237
+ exposedness = np.prod(exposedness, axis=0)
238
+ return exposedness
239
+
240
+ def cal_weight_map(self):
241
+ #pdb.set_trace()
242
+ weights = []
243
+ for idx in range(self.sequence.shape[0]):
244
+ contrast = self.cal_contrast(self.sequence[idx])
245
+ saturation = self.cal_saturation(self.sequence[idx])
246
+ exposedness = self.cal_exposedness(self.sequence[idx], self.best_exposedness, self.sigma)
247
+ weight = np.power(contrast, self.exponents[0]) * np.power(saturation, self.exponents[1]) * np.power(exposedness, self.exponents[2])
248
+ # Gauss Blur
249
+ # weight = cv2.GaussianBlur(weight, (21, 21), 2.1)
250
+ weights.append(weight)
251
+ #pdb.set_trace()
252
+ weights = np.stack(weights, 0) + self.eps
253
+ # normalize
254
+ weights = weights / np.expand_dims(np.sum(weights, axis=0), axis=0)
255
+ return weights
256
+
257
+ def naive_fusion(self):
258
+ weights = self.cal_weight_map() # [N, H, W]
259
+ weights = np.stack([weights, weights, weights], axis=-1) # [N, H, W, 3]
260
+ naive_fusion = np.sum(weights * self.sequence * 255, axis=0)
261
+ naive_fusion = np.clip(naive_fusion, 0, 255).astype(np.uint8)
262
+ return naive_fusion
263
+
264
+ def build_gaussian_pyramid(self, high_res):
265
+ #pdb.set_trace()
266
+ gaussian_pyramid = [high_res]
267
+ for idx in range(1, self.layers):
268
+ kernel1=np.array([[0.0039,0.0156,0.0234,0.0156,0.0039],[0.0156,0.0625,0.0938,0.0625,0.0156],[0.0234,0.0938,0.1406,0.0938,0.0234],[0.0156,0.0625,0.0938,0.0625,0.0156],[0.0039,0.0156,0.0234,0.0156,0.0039]],dtype='float32')
269
+ gaussian_pyramid.append(cv2.filter2D(gaussian_pyramid[-1], -1,kernel=kernel1)[::2, ::2])
270
+ #gaussian_pyramid.append(cv2.GaussianBlur(gaussian_pyramid[-1], (5, 5), 0.83)[::2, ::2])
271
+ return gaussian_pyramid
272
+
273
+ def build_laplace_pyramid(self, gaussian_pyramid):
274
+ laplace_pyramid = [gaussian_pyramid[-1]]
275
+ for idx in range(1, self.layers):
276
+ size = (gaussian_pyramid[self.layers - idx - 1].shape[1], gaussian_pyramid[self.layers - idx - 1].shape[0])
277
+ upsampled = cv2.resize(gaussian_pyramid[self.layers - idx], size, interpolation=cv2.INTER_LINEAR)
278
+ laplace_pyramid.append(gaussian_pyramid[self.layers - idx - 1] - upsampled)
279
+ laplace_pyramid.reverse()
280
+ return laplace_pyramid
281
+
282
+ def multi_resolution_fusion(self):
283
+ #pdb.set_trace()
284
+ weights = self.cal_weight_map() # [N, H, W]
285
+ weights = np.stack([weights, weights, weights], axis=-1) # [N, H, W, 3]
286
+
287
+ image_gaussian_pyramid = [self.build_gaussian_pyramid(self.sequence[i] * 255) for i in range(self.img_num)]
288
+ image_laplace_pyramid = [self.build_laplace_pyramid(image_gaussian_pyramid[i]) for i in range(self.img_num)]
289
+ weights_gaussian_pyramid = [self.build_gaussian_pyramid(weights[i]) for i in range(self.img_num)]
290
+
291
+ fused_laplace_pyramid = [np.sum([image_laplace_pyramid[n][l] *
292
+ weights_gaussian_pyramid[n][l] for n in range(self.img_num)], axis=0) for l in range(self.layers)]
293
+
294
+ result = fused_laplace_pyramid[-1]
295
+ for k in range(1, self.layers):
296
+ size = (fused_laplace_pyramid[self.layers - k - 1].shape[1], fused_laplace_pyramid[self.layers - k - 1].shape[0])
297
+ upsampled = cv2.resize(result, size, interpolation=cv2.INTER_LINEAR)
298
+ result = upsampled + fused_laplace_pyramid[self.layers - k - 1]
299
+ #pdb.set_trace()
300
+ #result = np.clip(result, 0, 255).astype(np.uint8)
301
+
302
+
303
+ return result
304
+
305
+ h_pre1, w_pre1 = 6144, 8192
306
+ pad_1 = 0
307
+ pad_2 = 0
308
+ h_exp1, w_exp1 = h_pre1 // 2, w_pre1 // 2
309
+
310
+ sess = tf.Session()
311
+ in_image = tf.placeholder(tf.float32, [None, h_exp1, w_exp1, 4])
312
+
313
+ in_image1 = tf.nn.avg_pool(in_image,ksize=[1,4,4,1],strides=[1,4,4,1],padding='SAME')
314
+ in_image2 = tf.nn.avg_pool(in_image,ksize=[1,8,8,1],strides=[1,8,8,1],padding='SAME')
315
+
316
+ out_image1 = network(in_image1)
317
+ out_image2 = network(in_image2, reuse=True)
318
+
319
+ t_vars = tf.trainable_variables()
320
+ for ele1 in t_vars:
321
+ print("variable: ", ele1)
322
+
323
+ saver = tf.train.Saver()
324
+ sess.run(tf.global_variables_initializer())
325
+
326
+ ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
327
+ if ckpt:
328
+ print('loaded ' + ckpt.model_checkpoint_path)
329
+ saver.restore(sess, ckpt.model_checkpoint_path)
330
+
331
+ in_pic4 = np.zeros([h_exp1, w_exp1, 4])
332
+ for k in range(len(train_ids)):
333
+
334
+ print(k)
335
+ train_id = train_ids[k]
336
+ in_path = input_dir + train_id[:-4] + '.png'
337
+ #raw_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
338
+ raw_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
339
+ #meta = np.load(input_dir1 + train_id[:-4] + '.npy').astype(np.float32)
340
+ #meta = scipy.io.loadmat(input_dir2 + train_id[:-4] + '.mat')
341
+ metadata = json_read(in_path[:-4] + '.json', object_hook=fraction_from_json)
342
+
343
+ white_level = float(metadata['white_level'])
344
+ black_level = float(metadata['black_level'][0].numerator)
345
+
346
+ orientation = metadata['orientation']
347
+
348
+ in_pic2 = np.clip((raw_image - black_level) /(white_level-black_level),0,1)
349
+
350
+ mean = np.mean(np.mean(in_pic2))
351
+ var = np.var(in_pic2)
352
+
353
+ bining = 4
354
+
355
+ if (mean < 0.01):
356
+ ratio = 6
357
+ elif (mean < 0.02):
358
+ ratio = 4
359
+ elif (mean < 0.037):
360
+ ratio = 3
361
+ else:
362
+ ratio = 2
363
+
364
+ if (var > 0.015):
365
+ ratio = ratio + 1
366
+
367
+ noise_profile = float(metadata['noise_profile'][0]) * ratio
368
+ if (noise_profile > 0.02):
369
+ bining = 8
370
+ ratio = np.clip(ratio - 1,2,4)
371
+
372
+ #r_gain, b_gain = awb_v(in_pic2, bayer='RGGB', eps=1)
373
+ r_gain1 = 1./metadata['as_shot_neutral'][0]
374
+ b_gain1 = 1./metadata['as_shot_neutral'][2]
375
+
376
+ #in_pic3 = np.pad(in_pic2, ((top_pad, btm_pad), (left_pad, right_pad)), mode='reflect') # GBRG to RGGB + reflect padding
377
+ h_pre,w_pre = in_pic2.shape
378
+
379
+ if (metadata['cfa_pattern'][0].numerator == 2):
380
+ in_pic2[0:h_pre-1,0:w_pre-1] = in_pic2[1:h_pre,1:w_pre]
381
+
382
+ r_gain, b_gain = awb_v(in_pic2 * (ratio**2), bayer='RGGB', eps=1)
383
+ in_pic3 = in_pic2
384
+
385
+ in_pic4[0:h_pre//2, 0:w_pre//2, 0] = in_pic3[0::2, 0::2] * r_gain
386
+ in_pic4[0:h_pre//2, 0:w_pre//2, 1] = in_pic3[0::2, 1::2]
387
+ in_pic4[0:h_pre//2, 0:w_pre//2, 2] = in_pic3[1::2, 1::2] * b_gain
388
+ in_pic4[0:h_pre//2, 0:w_pre//2, 3] = in_pic3[1::2, 0::2]
389
+
390
+ im1=np.clip(in_pic4*1,0,1)
391
+ in_np1 = np.expand_dims(im1,axis = 0)
392
+ if (bining == 4):
393
+ out_np1 =sess.run(out_image1,feed_dict={in_image: in_np1})
394
+ else:
395
+ out_np1 =sess.run(out_image2,feed_dict={in_image: in_np1})
396
+
397
+ out_np2 = fix_orientation(out_np1[0,0:h_pre//bining,0:w_pre//bining,:], orientation)
398
+ h_pre2,w_pre2,cc = out_np2.shape
399
+
400
+ if h_pre2 > w_pre2:
401
+ out_np_1 = cv2.resize(out_np2, (768, 1024), cv2.INTER_CUBIC)
402
+ if w_pre2 > h_pre2:
403
+ out_np_1 = cv2.resize(out_np2, (1024, 768), cv2.INTER_CUBIC)
404
+
405
+ im1=np.clip(in_pic4*ratio,0,1)
406
+ in_np1 = np.expand_dims(im1,axis = 0)
407
+ if (bining == 4):
408
+ out_np1 =sess.run(out_image1,feed_dict={in_image: in_np1})
409
+ else:
410
+ out_np1 =sess.run(out_image2,feed_dict={in_image: in_np1})
411
+
412
+ out_np2 = fix_orientation(out_np1[0,0:h_pre//bining,0:w_pre//bining,:], orientation)
413
+ h_pre2,w_pre2,cc = out_np2.shape
414
+
415
+ if h_pre2 > w_pre2:
416
+ out_np_2 = cv2.resize(out_np2, (768, 1024), cv2.INTER_CUBIC)
417
+ if w_pre2 > h_pre2:
418
+ out_np_2 = cv2.resize(out_np2, (1024, 768), cv2.INTER_CUBIC)
419
+
420
+
421
+ im1=np.clip(in_pic4*(ratio**2),0,1)
422
+ in_np1 = np.expand_dims(im1,axis = 0)
423
+
424
+ if (bining == 4):
425
+ out_np1 =sess.run(out_image1,feed_dict={in_image: in_np1})
426
+ else:
427
+ out_np1 =sess.run(out_image2,feed_dict={in_image: in_np1})
428
+
429
+ out_np2 = fix_orientation(out_np1[0,0:h_pre//bining,0:w_pre//bining,:], orientation)
430
+ h_pre2,w_pre2,cc = out_np2.shape
431
+
432
+ if h_pre2 > w_pre2:
433
+ out_np_3 = cv2.resize(out_np2, (768, 1024), cv2.INTER_CUBIC)
434
+ if w_pre2 > h_pre2:
435
+ out_np_3 = cv2.resize(out_np2, (1024, 768), cv2.INTER_CUBIC)
436
+
437
+ #pdb.set_trace()
438
+ '''sequence = np.stack([out_np_1, out_np_2, out_np_3], axis=0)
439
+ #sequence0 = sequence[0]
440
+ mef = ExposureFusion(sequence.astype(np.float32))
441
+ multi_res_fusion = mef.multi_resolution_fusion()
442
+ #pdb.set_trace()
443
+ result = reprocessing(multi_res_fusion)'''
444
+
445
+ #out_crop = multi_res_fusion
446
+
447
+ #np.save(result_dir + train_id[0:-4] + '_gray_{:}.npy'.format(gain), out_crop)
448
+ cv2.imwrite(result_dir + train_id[0:-4] + '_1.png', out_np_1[:,:,::-1]*255)
449
+ cv2.imwrite(result_dir + train_id[0:-4] + '_2.png', out_np_2[:,:,::-1]*255)
450
+ cv2.imwrite(result_dir + train_id[0:-4] + '_3.png', out_np_3[:,:,::-1]*255)
451
+
DH-AISP/2/__pycache__/model_convnext2_hdr.cpython-37.pyc ADDED
Binary file (18.5 kB). View file
 
DH-AISP/2/__pycache__/myFFCResblock0.cpython-37.pyc ADDED
Binary file (1.55 kB). View file
 
DH-AISP/2/__pycache__/test_dataset_for_testing.cpython-37.pyc ADDED
Binary file (1.99 kB). View file
 
DH-AISP/2/focal_frequency_loss/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .focal_frequency_loss import FocalFrequencyLoss
2
+
3
+ __all__ = ['FocalFrequencyLoss']
DH-AISP/2/focal_frequency_loss/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (263 Bytes). View file
 
DH-AISP/2/focal_frequency_loss/__pycache__/focal_frequency_loss.cpython-37.pyc ADDED
Binary file (4.01 kB). View file
 
DH-AISP/2/focal_frequency_loss/focal_frequency_loss.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # version adaptation for PyTorch > 1.7.1
5
+ IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.'))) > (1, 7, 1)
6
+ if IS_HIGH_VERSION:
7
+ import torch.fft
8
+
9
+
10
+ class FocalFrequencyLoss(nn.Module):
11
+ """The torch.nn.Module class that implements focal frequency loss - a
12
+ frequency domain loss function for optimizing generative models.
13
+
14
+ Ref:
15
+ Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021.
16
+ <https://arxiv.org/pdf/2012.12821.pdf>
17
+
18
+ Args:
19
+ loss_weight (float): weight for focal frequency loss. Default: 1.0
20
+ alpha (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0
21
+ patch_factor (int): the factor to crop image patches for patch-based focal frequency loss. Default: 1
22
+ ave_spectrum (bool): whether to use minibatch average spectrum. Default: False
23
+ log_matrix (bool): whether to adjust the spectrum weight matrix by logarithm. Default: False
24
+ batch_matrix (bool): whether to calculate the spectrum weight matrix using batch-based statistics. Default: False
25
+ """
26
+
27
+ def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False):
28
+ super(FocalFrequencyLoss, self).__init__()
29
+ self.loss_weight = loss_weight
30
+ self.alpha = alpha
31
+ self.patch_factor = patch_factor
32
+ self.ave_spectrum = ave_spectrum
33
+ self.log_matrix = log_matrix
34
+ self.batch_matrix = batch_matrix
35
+
36
+ def tensor2freq(self, x):
37
+ # crop image patches
38
+ patch_factor = self.patch_factor
39
+ _, _, h, w = x.shape
40
+ assert h % patch_factor == 0 and w % patch_factor == 0, (
41
+ 'Patch factor should be divisible by image height and width')
42
+ patch_list = []
43
+ patch_h = h // patch_factor
44
+ patch_w = w // patch_factor
45
+ for i in range(patch_factor):
46
+ for j in range(patch_factor):
47
+ patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
48
+
49
+ # stack to patch tensor
50
+ y = torch.stack(patch_list, 1)
51
+
52
+ # perform 2D DFT (real-to-complex, orthonormalization)
53
+ if IS_HIGH_VERSION:
54
+ freq = torch.fft.fft2(y, norm='ortho')
55
+ freq = torch.stack([freq.real, freq.imag], -1)
56
+ else:
57
+ freq = torch.rfft(y, 2, onesided=False, normalized=True)
58
+ return freq
59
+
60
+ def loss_formulation(self, recon_freq, real_freq, matrix=None):
61
+ # spectrum weight matrix
62
+ if matrix is not None:
63
+ # if the matrix is predefined
64
+ weight_matrix = matrix.detach()
65
+ else:
66
+ # if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance
67
+ matrix_tmp = (recon_freq - real_freq) ** 2
68
+ matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
69
+
70
+ # whether to adjust the spectrum weight matrix by logarithm
71
+ if self.log_matrix:
72
+ matrix_tmp = torch.log(matrix_tmp + 1.0)
73
+
74
+ # whether to calculate the spectrum weight matrix using batch-based statistics
75
+ if self.batch_matrix:
76
+ matrix_tmp = matrix_tmp / matrix_tmp.max()
77
+ else:
78
+ matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
79
+
80
+ matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
81
+ matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
82
+ weight_matrix = matrix_tmp.clone().detach()
83
+
84
+ assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
85
+ 'The values of spectrum weight matrix should be in the range [0, 1], '
86
+ 'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
87
+
88
+ # frequency distance using (squared) Euclidean distance
89
+ tmp = (recon_freq - real_freq) ** 2
90
+ freq_distance = tmp[..., 0] + tmp[..., 1]
91
+
92
+ # dynamic spectrum weighting (Hadamard product)
93
+ loss = weight_matrix * freq_distance
94
+ return torch.mean(loss)
95
+
96
+ def forward(self, pred, target, matrix=None, **kwargs):
97
+ """Forward function to calculate focal frequency loss.
98
+
99
+ Args:
100
+ pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor.
101
+ target (torch.Tensor): of shape (N, C, H, W). Target tensor.
102
+ matrix (torch.Tensor, optional): Element-wise spectrum weight matrix.
103
+ Default: None (If set to None: calculated online, dynamic).
104
+ """
105
+ pred_freq = self.tensor2freq(pred)
106
+ target_freq = self.tensor2freq(target)
107
+
108
+ # whether to use minibatch average spectrum
109
+ if self.ave_spectrum:
110
+ pred_freq = torch.mean(pred_freq, 0, keepdim=True)
111
+ target_freq = torch.mean(target_freq, 0, keepdim=True)
112
+
113
+ # calculate focal frequency loss
114
+ return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight
DH-AISP/2/model_convnext2_hdr.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from timm.models.layers import trunc_normal_, DropPath
5
+ from timm.models.registry import register_model
6
+
7
+ #import Convnext as PreConv
8
+ from myFFCResblock0 import myFFCResblock
9
+
10
+
11
+ # A ConvNet for the 2020s
12
+ # original implementation https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
13
+ # paper https://arxiv.org/pdf/2201.03545.pdf
14
+
15
+ class ConvNeXt0(nn.Module):
16
+ r""" ConvNeXt
17
+ A PyTorch impl of : `A ConvNet for the 2020s` -
18
+ https://arxiv.org/pdf/2201.03545.pdf
19
+ Args:
20
+ in_chans (int): Number of input image channels. Default: 3
21
+ num_classes (int): Number of classes for classification head. Default: 1000
22
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
23
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
24
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
25
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
26
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
27
+ """
28
+ def __init__(self, block, in_chans=3, num_classes=1000,
29
+ depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], drop_path_rate=0.,
30
+ layer_scale_init_value=1e-6, head_init_scale=1.,
31
+ ):
32
+ super().__init__()
33
+
34
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
35
+ stem = nn.Sequential(
36
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
37
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
38
+ )
39
+ self.downsample_layers.append(stem)
40
+ for i in range(3):
41
+ downsample_layer = nn.Sequential(
42
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
43
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
44
+ )
45
+ self.downsample_layers.append(downsample_layer)
46
+
47
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
48
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
49
+ cur = 0
50
+ for i in range(4):
51
+ stage = nn.Sequential(
52
+ *[block(dim=dims[i], drop_path=dp_rates[cur + j],
53
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
54
+ )
55
+ self.stages.append(stage)
56
+ cur += depths[i]
57
+
58
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
59
+ self.head = nn.Linear(dims[-1], num_classes)
60
+
61
+ self.apply(self._init_weights)
62
+ self.head.weight.data.mul_(head_init_scale)
63
+ self.head.bias.data.mul_(head_init_scale)
64
+
65
+ def _init_weights(self, m):
66
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
67
+ trunc_normal_(m.weight, std=.02)
68
+ nn.init.constant_(m.bias, 0)
69
+
70
+ def forward_features(self, x):
71
+ for i in range(4):
72
+ x = self.downsample_layers[i](x)
73
+ x = self.stages[i](x)
74
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
75
+
76
+ def forward(self, x):
77
+ x = self.forward_features(x)
78
+ x = self.head(x)
79
+ return x
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+ def dwt_init(x):
89
+ x01 = x[:, :, 0::2, :] / 2 #x01.shape=[4,3,128,256]
90
+ x02 = x[:, :, 1::2, :] / 2 #x02.shape=[4,3,128,256]
91
+ x1 = x01[:, :, :, 0::2] #x1.shape=[4,3,128,128]
92
+ x2 = x02[:, :, :, 0::2] #x2.shape=[4,3,128,128]
93
+ x3 = x01[:, :, :, 1::2] #x3.shape=[4,3,128,128]
94
+ x4 = x02[:, :, :, 1::2] #x4.shape=[4,3,128,128]
95
+ x_LL = x1 + x2 + x3 + x4
96
+ x_HL = -x1 - x2 + x3 + x4
97
+ x_LH = -x1 + x2 - x3 + x4
98
+ x_HH = x1 - x2 - x3 + x4
99
+ return x_LL, torch.cat((x_HL, x_LH, x_HH), 1)
100
+
101
+ class DWT(nn.Module):
102
+ def __init__(self):
103
+ super(DWT, self).__init__()
104
+ self.requires_grad = False
105
+ def forward(self, x):
106
+ return dwt_init(x)
107
+
108
+ class DWT_transform(nn.Module):
109
+ def __init__(self, in_channels,out_channels):
110
+ super().__init__()
111
+ self.dwt = DWT()
112
+ self.conv1x1_low = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
113
+ self.conv1x1_high = nn.Conv2d(in_channels*3, out_channels, kernel_size=1, padding=0)
114
+ def forward(self, x):
115
+ dwt_low_frequency,dwt_high_frequency = self.dwt(x)
116
+ dwt_low_frequency = self.conv1x1_low(dwt_low_frequency)
117
+ dwt_high_frequency = self.conv1x1_high(dwt_high_frequency)
118
+ return dwt_low_frequency,dwt_high_frequency
119
+
120
+ def blockUNet(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False):
121
+ block = nn.Sequential()
122
+ if relu:
123
+ block.add_module('%s_relu' % name, nn.ReLU(inplace=True))
124
+ else:
125
+ block.add_module('%s_leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True))
126
+ if not transposed:
127
+ block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False))
128
+ else:
129
+ block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1))
130
+ block.add_module('%s_bili' % name, nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
131
+ if bn:
132
+ block.add_module('%s_bn' % name, nn.BatchNorm2d(out_c))
133
+ if dropout:
134
+ block.add_module('%s_dropout' % name, nn.Dropout2d(0.5, inplace=True))
135
+ return block
136
+
137
+ # DW-GAN: A Discrete Wavelet Transform GAN for NonHomogeneous Dehazing 2021
138
+ # original implementation https://github.com/liuh127/NTIRE-2021-Dehazing-DWGAN/blob/main/model.py
139
+ # paper https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/papers/Fu_DW-GAN_A_Discrete_Wavelet_Transform_GAN_for_NonHomogeneous_Dehazing_CVPRW_2021_paper.pdf
140
+ class dwt_ffc_UNet2(nn.Module):
141
+ def __init__(self,output_nc=3, nf=16):
142
+ super(dwt_ffc_UNet2, self).__init__()
143
+ layer_idx = 1
144
+ name = 'layer%d' % layer_idx
145
+ layer1 = nn.Sequential()
146
+ layer1.add_module(name, nn.Conv2d(16, nf-1, 4, 2, 1, bias=False))
147
+ layer_idx += 1
148
+ name = 'layer%d' % layer_idx
149
+ layer2 = blockUNet(nf, nf*2-2, name, transposed=False, bn=True, relu=False, dropout=False)
150
+ layer_idx += 1
151
+ name = 'layer%d' % layer_idx
152
+ layer3 = blockUNet(nf*2, nf*4-4, name, transposed=False, bn=True, relu=False, dropout=False)
153
+ layer_idx += 1
154
+ name = 'layer%d' % layer_idx
155
+ layer4 = blockUNet(nf*4, nf*8-8, name, transposed=False, bn=True, relu=False, dropout=False)
156
+ layer_idx += 1
157
+ name = 'layer%d' % layer_idx
158
+ layer5 = blockUNet(nf*8, nf*8-16, name, transposed=False, bn=True, relu=False, dropout=False)
159
+ layer_idx += 1
160
+ name = 'layer%d' % layer_idx
161
+ layer6 = blockUNet(nf*4, nf*4, name, transposed=False, bn=False, relu=False, dropout=False)
162
+
163
+ layer_idx -= 1
164
+ name = 'dlayer%d' % layer_idx
165
+ dlayer6 = blockUNet(nf * 4, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)
166
+ layer_idx -= 1
167
+ name = 'dlayer%d' % layer_idx
168
+ dlayer5 = blockUNet(nf * 16+16, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False)
169
+ layer_idx -= 1
170
+ name = 'dlayer%d' % layer_idx
171
+ dlayer4 = blockUNet(nf * 16+8, nf * 4, name, transposed=True, bn=True, relu=True, dropout=False)
172
+ layer_idx -= 1
173
+ name = 'dlayer%d' % layer_idx
174
+ dlayer3 = blockUNet(nf * 8+4, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)
175
+ layer_idx -= 1
176
+ name = 'dlayer%d' % layer_idx
177
+ dlayer2 = blockUNet(nf * 4+2, nf, name, transposed=True, bn=True, relu=True, dropout=False)
178
+ layer_idx -= 1
179
+ name = 'dlayer%d' % layer_idx
180
+ dlayer1 = blockUNet(nf * 2+1, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)
181
+
182
+ self.initial_conv=nn.Conv2d(9,16,3,padding=1)
183
+ self.bn1=nn.BatchNorm2d(16)
184
+ self.layer1 = layer1
185
+ self.DWT_down_0= DWT_transform(9,1)
186
+ self.layer2 = layer2
187
+ self.DWT_down_1 = DWT_transform(16, 2)
188
+ self.layer3 = layer3
189
+ self.DWT_down_2 = DWT_transform(32, 4)
190
+ self.layer4 = layer4
191
+ self.DWT_down_3 = DWT_transform(64, 8)
192
+ self.layer5 = layer5
193
+ self.DWT_down_4 = DWT_transform(128, 16)
194
+ self.layer6 = layer6
195
+ self.dlayer6 = dlayer6
196
+ self.dlayer5 = dlayer5
197
+ self.dlayer4 = dlayer4
198
+ self.dlayer3 = dlayer3
199
+ self.dlayer2 = dlayer2
200
+ self.dlayer1 = dlayer1
201
+ self.tail_conv1 = nn.Conv2d(48, 32, 3, padding=1, bias=True)
202
+ self.bn2=nn.BatchNorm2d(32)
203
+ self.tail_conv2 = nn.Conv2d(nf*2, output_nc, 3,padding=1, bias=True)
204
+
205
+
206
+ self.FFCResNet = myFFCResblock(input_nc=64, output_nc=64)
207
+
208
+ def forward(self, x):
209
+ conv_start=self.initial_conv(x)
210
+ conv_start=self.bn1(conv_start)
211
+ conv_out1 = self.layer1(conv_start)
212
+ dwt_low_0,dwt_high_0=self.DWT_down_0(x)
213
+ out1=torch.cat([conv_out1, dwt_low_0], 1)
214
+ conv_out2 = self.layer2(out1)
215
+ dwt_low_1,dwt_high_1= self.DWT_down_1(out1)
216
+ out2 = torch.cat([conv_out2, dwt_low_1], 1)
217
+ conv_out3 = self.layer3(out2)
218
+
219
+ dwt_low_2,dwt_high_2 = self.DWT_down_2(out2)
220
+ out3 = torch.cat([conv_out3, dwt_low_2], 1)
221
+
222
+ # conv_out4 = self.layer4(out3)
223
+ # dwt_low_3,dwt_high_3 = self.DWT_down_3(out3)
224
+ # out4 = torch.cat([conv_out4, dwt_low_3], 1)
225
+
226
+ # conv_out5 = self.layer5(out4)
227
+ # dwt_low_4,dwt_high_4 = self.DWT_down_4(out4)
228
+ # out5 = torch.cat([conv_out5, dwt_low_4], 1)
229
+
230
+ # out6 = self.layer6(out5)
231
+
232
+
233
+ out3_ffc= self.FFCResNet(out3)
234
+
235
+
236
+ dout3 = self.dlayer6(out3_ffc)
237
+
238
+
239
+ Tout3_out2 = torch.cat([dout3, out2,dwt_high_1], 1)
240
+ Tout2 = self.dlayer2(Tout3_out2)
241
+ Tout2_out1 = torch.cat([Tout2, out1,dwt_high_0], 1)
242
+ Tout1 = self.dlayer1(Tout2_out1)
243
+
244
+ Tout1_outinit = torch.cat([Tout1, conv_start], 1)
245
+ tail1=self.tail_conv1(Tout1_outinit)
246
+ tail2=self.bn2(tail1)
247
+ dout1 = self.tail_conv2(tail2)
248
+
249
+
250
+ return dout1
251
+
252
+
253
+
254
+
255
+
256
+
257
+ class Block(nn.Module):
258
+ r""" ConvNeXt Block. There are two equivalent implementations:
259
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
260
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
261
+ We use (2) as we find it slightly faster in PyTorch
262
+
263
+ Args:
264
+ dim (int): Number of input channels.
265
+ drop_path (float): Stochastic depth rate. Default: 0.0
266
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
267
+ """
268
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
269
+ super().__init__()
270
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
271
+ self.norm = LayerNorm(dim, eps=1e-6)
272
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
273
+ self.act = nn.GELU()
274
+ self.pwconv2 = nn.Linear(4 * dim, dim)
275
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
276
+ requires_grad=True) if layer_scale_init_value > 0 else None
277
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
278
+
279
+ def forward(self, x):
280
+ input = x
281
+ x = self.dwconv(x)
282
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
283
+ x = self.norm(x)
284
+ x = self.pwconv1(x)
285
+ x = self.act(x)
286
+ x = self.pwconv2(x)
287
+ if self.gamma is not None:
288
+ x = self.gamma * x
289
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
290
+
291
+ x = input + self.drop_path(x)
292
+ return x
293
+
294
+
295
+ class ConvNeXt(nn.Module):
296
+ def __init__(self, block, in_chans=3, num_classes=1000,
297
+ depths=[3, 3, 27, 3], dims=[256, 512, 1024,2048], drop_path_rate=0.,
298
+ layer_scale_init_value=1e-6, head_init_scale=1.,
299
+ ):
300
+ super().__init__()
301
+
302
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
303
+ stem = nn.Sequential(
304
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
305
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
306
+ )
307
+ self.downsample_layers.append(stem)
308
+ for i in range(3):
309
+ downsample_layer = nn.Sequential(
310
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
311
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
312
+ )
313
+ self.downsample_layers.append(downsample_layer)
314
+
315
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
316
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
317
+ cur = 0
318
+ for i in range(4):
319
+ stage = nn.Sequential(
320
+ *[block(dim=dims[i], drop_path=dp_rates[cur + j],
321
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
322
+ )
323
+ self.stages.append(stage)
324
+ cur += depths[i]
325
+
326
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
327
+ self.head = nn.Linear(dims[-1], num_classes)
328
+
329
+ self.head.weight.data.mul_(head_init_scale)
330
+ self.head.bias.data.mul_(head_init_scale)
331
+
332
+
333
+ def forward(self, x):
334
+ x_layer1 = self.downsample_layers[0](x)
335
+ x_layer1 = self.stages[0](x_layer1)
336
+
337
+
338
+
339
+ x_layer2 = self.downsample_layers[1](x_layer1)
340
+ x_layer2 = self.stages[1](x_layer2)
341
+
342
+
343
+ x_layer3 = self.downsample_layers[2](x_layer2)
344
+ out = self.stages[2](x_layer3)
345
+
346
+
347
+ return x_layer1, x_layer2, out
348
+
349
+ class LayerNorm(nn.Module):
350
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
351
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
352
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
353
+ with shape (batch_size, channels, height, width).
354
+ """
355
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
356
+ super().__init__()
357
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
358
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
359
+ self.eps = eps
360
+ self.data_format = data_format
361
+ if self.data_format not in ["channels_last", "channels_first"]:
362
+ raise NotImplementedError
363
+ self.normalized_shape = (normalized_shape, )
364
+
365
+ def forward(self, x):
366
+ if self.data_format == "channels_last":
367
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
368
+ elif self.data_format == "channels_first":
369
+ u = x.mean(1, keepdim=True)
370
+ s = (x - u).pow(2).mean(1, keepdim=True)
371
+ x = (x - u) / torch.sqrt(s + self.eps)
372
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
373
+ return x
374
+
375
+
376
+ class PALayer(nn.Module):
377
+ def __init__(self, channel):
378
+ super(PALayer, self).__init__()
379
+ self.pa = nn.Sequential(
380
+ nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
381
+ nn.ReLU(inplace=True),
382
+ nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
383
+ nn.Sigmoid()
384
+ )
385
+ def forward(self, x):
386
+ y = self.pa(x)
387
+ return x * y
388
+
389
+ class CALayer(nn.Module):
390
+ def __init__(self, channel):
391
+ super(CALayer, self).__init__()
392
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
393
+ self.ca = nn.Sequential(
394
+ nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
395
+ nn.ReLU(inplace=True),
396
+ nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
397
+ nn.Sigmoid()
398
+ )
399
+ def forward(self, x):
400
+ y = self.avg_pool(x)
401
+ y = self.ca(y)
402
+ return x * y
403
+
404
+ class CP_Attention_block(nn.Module):
405
+ def __init__(self, conv, dim, kernel_size):
406
+ super(CP_Attention_block, self).__init__()
407
+ self.conv1 = conv(dim, dim, kernel_size, bias=True)
408
+ self.act1 = nn.ReLU(inplace=True)
409
+ self.conv2 = conv(dim, dim, kernel_size, bias=True)
410
+ self.calayer = CALayer(dim)
411
+ self.palayer = PALayer(dim)
412
+ def forward(self, x):
413
+ res = self.act1(self.conv1(x))
414
+ res = res + x
415
+ res = self.conv2(res)
416
+ res = self.calayer(res)
417
+ res = self.palayer(res)
418
+ res += x
419
+ return res
420
+
421
+ def default_conv(in_channels, out_channels, kernel_size, bias=True):
422
+ return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)
423
+
424
+ class knowledge_adaptation_convnext(nn.Module):
425
+ def __init__(self):
426
+ super(knowledge_adaptation_convnext, self).__init__()
427
+ self.encoder = ConvNeXt(Block, in_chans=9,num_classes=1000, depths=[3, 3, 27, 3], dims=[256, 512, 1024,2048], drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1.)
428
+ '''pretrained_model = ConvNeXt0(Block, in_chans=3,num_classes=1000, depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1.)
429
+ #pretrained_model=nn.DataParallel(pretrained_model)
430
+ checkpoint=torch.load('./weights/convnext_xlarge_22k_1k_384_ema.pth')
431
+ #for k,v in checkpoint["model"].items():
432
+ #print(k)
433
+ #url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth"
434
+
435
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cuda:0")
436
+ pretrained_model.load_state_dict(checkpoint["model"])
437
+
438
+ pretrained_dict = pretrained_model.state_dict()
439
+ model_dict = self.encoder.state_dict()
440
+ key_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
441
+ model_dict.update(key_dict)
442
+ self.encoder.load_state_dict(model_dict)'''
443
+
444
+
445
+ self.up_block= nn.PixelShuffle(2)
446
+ self.attention0 = CP_Attention_block(default_conv, 1024, 3)
447
+ self.attention1 = CP_Attention_block(default_conv, 256, 3)
448
+ self.attention2 = CP_Attention_block(default_conv, 192, 3)
449
+ self.attention3 = CP_Attention_block(default_conv, 112, 3)
450
+ self.attention4 = CP_Attention_block(default_conv, 28, 3)
451
+ self.conv_process_1 = nn.Conv2d(28, 28, kernel_size=3,padding=1)
452
+ self.conv_process_2 = nn.Conv2d(28, 28, kernel_size=3,padding=1)
453
+ self.tail = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(28, 3, kernel_size=7, padding=0), nn.Tanh())
454
+ def forward(self, input):
455
+ x_layer1, x_layer2, x_output = self.encoder(input)
456
+
457
+ x_mid = self.attention0(x_output) #[1024,24,24]
458
+
459
+ x = self.up_block(x_mid) #[256,48,48]
460
+ x = self.attention1(x)
461
+
462
+ x = torch.cat((x, x_layer2), 1) #[768,48,48]
463
+
464
+ x = self.up_block(x) #[192,96,96]
465
+ x = self.attention2(x)
466
+ x = torch.cat((x, x_layer1), 1) #[448,96,96]
467
+ x = self.up_block(x) #[112,192,192]
468
+ x = self.attention3(x)
469
+
470
+ x = self.up_block(x) #[28,384,384]
471
+ x = self.attention4(x)
472
+
473
+ x=self.conv_process_1(x)
474
+ out=self.conv_process_2(x)
475
+ return out
476
+
477
+
478
+ class fusion_net(nn.Module):
479
+ def __init__(self):
480
+ super(fusion_net, self).__init__()
481
+ self.dwt_branch=dwt_ffc_UNet2()
482
+ self.knowledge_adaptation_branch=knowledge_adaptation_convnext()
483
+ self.fusion = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(31, 3, kernel_size=7, padding=0), nn.Tanh())
484
+ def forward(self, input):
485
+ dwt_branch=self.dwt_branch(input)
486
+ knowledge_adaptation_branch=self.knowledge_adaptation_branch(input)
487
+ x = torch.cat([dwt_branch, knowledge_adaptation_branch], 1)
488
+ x = self.fusion(x)
489
+ return x
490
+
491
+
492
+
493
+ class Discriminator(nn.Module):
494
+ def __init__(self):
495
+ super(Discriminator, self).__init__()
496
+ self.net = nn.Sequential(
497
+ nn.Conv2d(3, 64, kernel_size=3, padding=1),
498
+ nn.LeakyReLU(0.2),
499
+
500
+ nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
501
+ nn.BatchNorm2d(64),
502
+ nn.LeakyReLU(0.2),
503
+
504
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
505
+ nn.BatchNorm2d(128),
506
+ nn.LeakyReLU(0.2),
507
+
508
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
509
+ nn.BatchNorm2d(128),
510
+ nn.LeakyReLU(0.2),
511
+
512
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
513
+ nn.BatchNorm2d(256),
514
+ nn.LeakyReLU(0.2),
515
+
516
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
517
+ nn.BatchNorm2d(256),
518
+ nn.LeakyReLU(0.2),
519
+
520
+ nn.Conv2d(256, 512, kernel_size=3, padding=1),
521
+ nn.BatchNorm2d(512),
522
+ nn.LeakyReLU(0.2),
523
+
524
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
525
+ nn.BatchNorm2d(512),
526
+ nn.LeakyReLU(0.2),
527
+
528
+ nn.AdaptiveAvgPool2d(1),
529
+ nn.Conv2d(512, 1024, kernel_size=1),
530
+ nn.LeakyReLU(0.2),
531
+ nn.Conv2d(1024, 1, kernel_size=1)
532
+ )
533
+
534
+ def forward(self, x):
535
+ batch_size = x.size(0)
536
+ return torch.sigmoid(self.net(x).view(batch_size))
537
+
538
+
539
+ class Discriminator2(nn.Module):
540
+ def __init__(self):
541
+ super(Discriminator2, self).__init__()
542
+ self.net = nn.Sequential(
543
+ nn.Conv2d(3, 64, kernel_size=3, padding=1),
544
+ nn.LeakyReLU(0.2),
545
+
546
+ nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
547
+ nn.BatchNorm2d(64),
548
+ nn.LeakyReLU(0.2),
549
+
550
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
551
+ nn.BatchNorm2d(128),
552
+ nn.LeakyReLU(0.2),
553
+
554
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
555
+ nn.BatchNorm2d(128),
556
+ nn.LeakyReLU(0.2),
557
+
558
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
559
+ nn.BatchNorm2d(256),
560
+ nn.LeakyReLU(0.2),
561
+
562
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
563
+ nn.BatchNorm2d(256),
564
+ nn.LeakyReLU(0.2),
565
+
566
+ nn.Conv2d(256, 512, kernel_size=3, padding=1),
567
+ nn.BatchNorm2d(512),
568
+ nn.LeakyReLU(0.2),
569
+
570
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
571
+ nn.BatchNorm2d(512),
572
+ nn.LeakyReLU(0.2),
573
+
574
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
575
+ nn.BatchNorm2d(512),
576
+ nn.LeakyReLU(0.2),
577
+
578
+ nn.Conv2d(512, 1, kernel_size=3, padding=1),
579
+ )
580
+
581
+ def forward(self, x):
582
+ return self.net(x)
583
+
584
+ if __name__ == '__main__':
585
+
586
+ device = torch.device("cuda:0")
587
+
588
+ # Create model
589
+ im = torch.rand(1, 3, 640, 640).to(device)
590
+ model_g = fusion_net().to(device)
591
+
592
+ out_data = model_g(im)
DH-AISP/2/myFFCResblock0.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ from saicinpainting.training.modules.ffc0 import FFCResnetBlock
9
+ from saicinpainting.training.modules.ffc0 import FFC_BN_ACT
10
+
11
+
12
+
13
+
14
+ class myFFCResblock(nn.Module):
15
+ def __init__(self, input_nc, output_nc, n_blocks=2, norm_layer=nn.BatchNorm2d, #128--->64
16
+ padding_type='reflect', activation_layer=nn.ReLU,
17
+ resnet_conv_kwargs={},
18
+ spatial_transform_layers=None, spatial_transform_kwargs={},
19
+ add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
20
+ assert (n_blocks >= 0)
21
+
22
+ super().__init__()
23
+ self.initial = FFC_BN_ACT(input_nc, input_nc, kernel_size=3, padding=1, dilation=1,
24
+ norm_layer=norm_layer, activation_layer=activation_layer,
25
+ padding_type=padding_type,
26
+ **resnet_conv_kwargs)
27
+
28
+ self.ffcresblock = FFCResnetBlock(input_nc, padding_type=padding_type, activation_layer=activation_layer,
29
+ norm_layer=norm_layer, **resnet_conv_kwargs)
30
+
31
+
32
+
33
+ self.final = FFC_BN_ACT(input_nc, output_nc, kernel_size=3, padding=1, dilation=1,
34
+ norm_layer=norm_layer,
35
+ activation_layer=activation_layer,
36
+ padding_type=padding_type,
37
+ **resnet_conv_kwargs)
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+ def forward(self, x):
46
+
47
+ x_l, x_g = self.initial(x)
48
+
49
+ x_l, x_g = self.ffcresblock(x_l, x_g)
50
+ x_l, x_g = self.ffcresblock(x_l, x_g)
51
+
52
+ out_ = torch.cat([x_l, x_g], 1)
53
+
54
+ x_lout, x_gout =self.final(out_)
55
+
56
+ out = torch.cat([x_lout, x_gout], 1)
57
+ return out
58
+
59
+
60
+
DH-AISP/2/perceptual.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Imports --- #
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ # --- Perceptual loss network --- #
6
+ class LossNetwork(torch.nn.Module):
7
+ def __init__(self, vgg_model):
8
+ super(LossNetwork, self).__init__()
9
+ self.vgg_layers = vgg_model
10
+ self.layer_name_mapping = {
11
+ '3': "relu1_2",
12
+ '8': "relu2_2",
13
+ '15': "relu3_3"
14
+ }
15
+
16
+ def output_features(self, x):
17
+ output = {}
18
+ for name, module in self.vgg_layers._modules.items():
19
+ x = module(x)
20
+ if name in self.layer_name_mapping:
21
+ output[self.layer_name_mapping[name]] = x
22
+ return list(output.values())
23
+
24
+ def forward(self, dehaze, gt):
25
+ loss = []
26
+ dehaze_features = self.output_features(dehaze)
27
+ gt_features = self.output_features(gt)
28
+ for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
29
+ loss.append(F.mse_loss(dehaze_feature, gt_feature))
30
+ return sum(loss)/len(loss)
DH-AISP/2/pytorch_msssim/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+ import numpy as np
5
+
6
+
7
+ def gaussian(window_size, sigma):
8
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9
+ return gauss/gauss.sum()
10
+
11
+
12
+ def create_window(window_size, channel=1):
13
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
16
+ return window
17
+
18
+
19
+ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
20
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
21
+ if val_range is None:
22
+ if torch.max(img1) > 128:
23
+ max_val = 255
24
+ else:
25
+ max_val = 1
26
+
27
+ if torch.min(img1) < -0.5:
28
+ min_val = -1
29
+ else:
30
+ min_val = 0
31
+ L = max_val - min_val
32
+ else:
33
+ L = val_range
34
+
35
+ padd = 0
36
+ (_, channel, height, width) = img1.size()
37
+ if window is None:
38
+ real_size = min(window_size, height, width)
39
+ window = create_window(real_size, channel=channel).to(img1.device)
40
+
41
+ mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
42
+ mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
43
+
44
+ mu1_sq = mu1.pow(2)
45
+ mu2_sq = mu2.pow(2)
46
+ mu1_mu2 = mu1 * mu2
47
+
48
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
49
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
50
+ sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
51
+
52
+ C1 = (0.01 * L) ** 2
53
+ C2 = (0.03 * L) ** 2
54
+
55
+ v1 = 2.0 * sigma12 + C2
56
+ v2 = sigma1_sq + sigma2_sq + C2
57
+ cs = torch.mean(v1 / v2) # contrast sensitivity
58
+
59
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
60
+
61
+ if size_average:
62
+ ret = ssim_map.mean()
63
+ else:
64
+ ret = ssim_map.mean(1).mean(1).mean(1)
65
+
66
+ if full:
67
+ return ret, cs
68
+ return ret
69
+
70
+
71
+ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
72
+ device = img1.device
73
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
74
+ levels = weights.size()[0]
75
+ mssim = []
76
+ mcs = []
77
+ for _ in range(levels):
78
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
79
+ mssim.append(sim)
80
+ mcs.append(cs)
81
+
82
+ img1 = F.avg_pool2d(img1, (2, 2))
83
+ img2 = F.avg_pool2d(img2, (2, 2))
84
+
85
+ mssim = torch.stack(mssim)
86
+ mcs = torch.stack(mcs)
87
+
88
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
89
+ if normalize:
90
+ mssim = (mssim + 1) / 2
91
+ mcs = (mcs + 1) / 2
92
+
93
+ pow1 = mcs ** weights
94
+ pow2 = mssim ** weights
95
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
96
+ output = torch.prod(pow1[:-1] * pow2[-1])
97
+ return output
98
+
99
+
100
+ # Classes to re-use window
101
+ class SSIM(torch.nn.Module):
102
+ def __init__(self, window_size=11, size_average=True, val_range=None):
103
+ super(SSIM, self).__init__()
104
+ self.window_size = window_size
105
+ self.size_average = size_average
106
+ self.val_range = val_range
107
+
108
+ # Assume 1 channel for SSIM
109
+ self.channel = 1
110
+ self.window = create_window(window_size)
111
+
112
+ def forward(self, img1, img2):
113
+ (_, channel, _, _) = img1.size()
114
+
115
+ if channel == self.channel and self.window.dtype == img1.dtype:
116
+ window = self.window
117
+ else:
118
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
119
+ self.window = window
120
+ self.channel = channel
121
+
122
+ return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
123
+
124
+ class MSSSIM(torch.nn.Module):
125
+ def __init__(self, window_size=11, size_average=True, channel=3):
126
+ super(MSSSIM, self).__init__()
127
+ self.window_size = window_size
128
+ self.size_average = size_average
129
+ self.channel = channel
130
+
131
+ def forward(self, img1, img2):
132
+ # TODO: store window between calls if possible
133
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (3.9 kB). View file
 
DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (3.88 kB). View file
 
DH-AISP/2/result_low_light_hdr/checkpoint_gen.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5952db983eb66b04c6a39348a0916164d9148ec99c4a3b8a77bf4e240657022
3
+ size 1491472482
DH-AISP/2/saicinpainting/__init__.py ADDED
File without changes
DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (168 Bytes). View file
 
DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (155 Bytes). View file
 
DH-AISP/2/saicinpainting/__pycache__/utils.cpython-36.pyc ADDED
Binary file (6.1 kB). View file
 
DH-AISP/2/saicinpainting/__pycache__/utils.cpython-37.pyc ADDED
Binary file (6.08 kB). View file
 
DH-AISP/2/saicinpainting/evaluation/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+
5
+ from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1
6
+ from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
7
+
8
+
9
+ def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs):
10
+ logging.info(f'Make evaluator {kind}')
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ metrics = {}
13
+ if ssim:
14
+ metrics['ssim'] = SSIMScore()
15
+ if lpips:
16
+ metrics['lpips'] = LPIPSScore()
17
+ if fid:
18
+ metrics['fid'] = FIDScore().to(device)
19
+
20
+ if integral_kind is None:
21
+ integral_func = None
22
+ elif integral_kind == 'ssim_fid100_f1':
23
+ integral_func = ssim_fid100_f1
24
+ elif integral_kind == 'lpips_fid100_f1':
25
+ integral_func = lpips_fid100_f1
26
+ else:
27
+ raise ValueError(f'Unexpected integral_kind={integral_kind}')
28
+
29
+ if kind == 'default':
30
+ return InpaintingEvaluatorOnline(scores=metrics,
31
+ integral_func=integral_func,
32
+ integral_title=integral_kind,
33
+ **kwargs)
DH-AISP/2/saicinpainting/evaluation/data.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import cv2
5
+ import PIL.Image as Image
6
+ import numpy as np
7
+
8
+ from torch.utils.data import Dataset
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def load_image(fname, mode='RGB', return_orig=False):
13
+ img = np.array(Image.open(fname).convert(mode))
14
+ if img.ndim == 3:
15
+ img = np.transpose(img, (2, 0, 1))
16
+ out_img = img.astype('float32') / 255
17
+ if return_orig:
18
+ return out_img, img
19
+ else:
20
+ return out_img
21
+
22
+
23
+ def ceil_modulo(x, mod):
24
+ if x % mod == 0:
25
+ return x
26
+ return (x // mod + 1) * mod
27
+
28
+
29
+ def pad_img_to_modulo(img, mod):
30
+ channels, height, width = img.shape
31
+ out_height = ceil_modulo(height, mod)
32
+ out_width = ceil_modulo(width, mod)
33
+ return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
34
+
35
+
36
+ def pad_tensor_to_modulo(img, mod):
37
+ batch_size, channels, height, width = img.shape
38
+ out_height = ceil_modulo(height, mod)
39
+ out_width = ceil_modulo(width, mod)
40
+ return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect')
41
+
42
+
43
+ def scale_image(img, factor, interpolation=cv2.INTER_AREA):
44
+ if img.shape[0] == 1:
45
+ img = img[0]
46
+ else:
47
+ img = np.transpose(img, (1, 2, 0))
48
+
49
+ img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
50
+
51
+ if img.ndim == 2:
52
+ img = img[None, ...]
53
+ else:
54
+ img = np.transpose(img, (2, 0, 1))
55
+ return img
56
+
57
+
58
+ class InpaintingDataset(Dataset):
59
+ def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
60
+ self.datadir = datadir
61
+ self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, '**', '*mask*.png'), recursive=True)))
62
+ self.img_filenames = [fname.rsplit('_mask', 1)[0] + img_suffix for fname in self.mask_filenames]
63
+ self.pad_out_to_modulo = pad_out_to_modulo
64
+ self.scale_factor = scale_factor
65
+
66
+ def __len__(self):
67
+ return len(self.mask_filenames)
68
+
69
+ def __getitem__(self, i):
70
+ image = load_image(self.img_filenames[i], mode='RGB')
71
+ mask = load_image(self.mask_filenames[i], mode='L')
72
+ result = dict(image=image, mask=mask[None, ...])
73
+
74
+ if self.scale_factor is not None:
75
+ result['image'] = scale_image(result['image'], self.scale_factor)
76
+ result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
77
+
78
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
79
+ result['unpad_to_size'] = result['image'].shape[1:]
80
+ result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
81
+ result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
82
+
83
+ return result
84
+
85
+ class OurInpaintingDataset(Dataset):
86
+ def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
87
+ self.datadir = datadir
88
+ self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, 'mask', '**', '*mask*.png'), recursive=True)))
89
+ self.img_filenames = [os.path.join(self.datadir, 'img', os.path.basename(fname.rsplit('-', 1)[0].rsplit('_', 1)[0]) + '.png') for fname in self.mask_filenames]
90
+ self.pad_out_to_modulo = pad_out_to_modulo
91
+ self.scale_factor = scale_factor
92
+
93
+ def __len__(self):
94
+ return len(self.mask_filenames)
95
+
96
+ def __getitem__(self, i):
97
+ result = dict(image=load_image(self.img_filenames[i], mode='RGB'),
98
+ mask=load_image(self.mask_filenames[i], mode='L')[None, ...])
99
+
100
+ if self.scale_factor is not None:
101
+ result['image'] = scale_image(result['image'], self.scale_factor)
102
+ result['mask'] = scale_image(result['mask'], self.scale_factor)
103
+
104
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
105
+ result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
106
+ result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
107
+
108
+ return result
109
+
110
+ class PrecomputedInpaintingResultsDataset(InpaintingDataset):
111
+ def __init__(self, datadir, predictdir, inpainted_suffix='_inpainted.jpg', **kwargs):
112
+ super().__init__(datadir, **kwargs)
113
+ if not datadir.endswith('/'):
114
+ datadir += '/'
115
+ self.predictdir = predictdir
116
+ self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
117
+ for fname in self.mask_filenames]
118
+
119
+ def __getitem__(self, i):
120
+ result = super().__getitem__(i)
121
+ result['inpainted'] = load_image(self.pred_filenames[i])
122
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
123
+ result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
124
+ return result
125
+
126
+ class OurPrecomputedInpaintingResultsDataset(OurInpaintingDataset):
127
+ def __init__(self, datadir, predictdir, inpainted_suffix="png", **kwargs):
128
+ super().__init__(datadir, **kwargs)
129
+ if not datadir.endswith('/'):
130
+ datadir += '/'
131
+ self.predictdir = predictdir
132
+ self.pred_filenames = [os.path.join(predictdir, os.path.basename(os.path.splitext(fname)[0]) + f'_inpainted.{inpainted_suffix}')
133
+ for fname in self.mask_filenames]
134
+ # self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
135
+ # for fname in self.mask_filenames]
136
+
137
+ def __getitem__(self, i):
138
+ result = super().__getitem__(i)
139
+ result['inpainted'] = self.file_loader(self.pred_filenames[i])
140
+
141
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
142
+ result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
143
+ return result
144
+
145
+ class InpaintingEvalOnlineDataset(Dataset):
146
+ def __init__(self, indir, mask_generator, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None, **kwargs):
147
+ self.indir = indir
148
+ self.mask_generator = mask_generator
149
+ self.img_filenames = sorted(list(glob.glob(os.path.join(self.indir, '**', f'*{img_suffix}' ), recursive=True)))
150
+ self.pad_out_to_modulo = pad_out_to_modulo
151
+ self.scale_factor = scale_factor
152
+
153
+ def __len__(self):
154
+ return len(self.img_filenames)
155
+
156
+ def __getitem__(self, i):
157
+ img, raw_image = load_image(self.img_filenames[i], mode='RGB', return_orig=True)
158
+ mask = self.mask_generator(img, raw_image=raw_image)
159
+ result = dict(image=img, mask=mask)
160
+
161
+ if self.scale_factor is not None:
162
+ result['image'] = scale_image(result['image'], self.scale_factor)
163
+ result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
164
+
165
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
166
+ result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
167
+ result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
168
+ return result
DH-AISP/2/saicinpainting/evaluation/evaluator.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Dict
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import tqdm
9
+ from torch.utils.data import DataLoader
10
+
11
+ from saicinpainting.evaluation.utils import move_to_device
12
+
13
+ LOGGER = logging.getLogger(__name__)
14
+
15
+
16
+ class InpaintingEvaluator():
17
+ def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda',
18
+ integral_func=None, integral_title=None, clamp_image_range=None):
19
+ """
20
+ :param dataset: torch.utils.data.Dataset which contains images and masks
21
+ :param scores: dict {score_name: EvaluatorScore object}
22
+ :param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples
23
+ which are defined by share of area occluded by mask
24
+ :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
25
+ :param batch_size: batch_size for the dataloader
26
+ :param device: device to use
27
+ """
28
+ self.scores = scores
29
+ self.dataset = dataset
30
+
31
+ self.area_grouping = area_grouping
32
+ self.bins = bins
33
+
34
+ self.device = torch.device(device)
35
+
36
+ self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size)
37
+
38
+ self.integral_func = integral_func
39
+ self.integral_title = integral_title
40
+ self.clamp_image_range = clamp_image_range
41
+
42
+ def _get_bin_edges(self):
43
+ bin_edges = np.linspace(0, 1, self.bins + 1)
44
+
45
+ num_digits = max(0, math.ceil(math.log10(self.bins)) - 1)
46
+ interval_names = []
47
+ for idx_bin in range(self.bins):
48
+ start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \
49
+ round(100 * bin_edges[idx_bin + 1], num_digits)
50
+ start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
51
+ end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
52
+ interval_names.append("{0}-{1}%".format(start_percent, end_percent))
53
+
54
+ groups = []
55
+ for batch in self.dataloader:
56
+ mask = batch['mask']
57
+ batch_size = mask.shape[0]
58
+ area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1)
59
+ bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1
60
+ # corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element
61
+ bin_indices[bin_indices == self.bins] = self.bins - 1
62
+ groups.append(bin_indices)
63
+ groups = np.hstack(groups)
64
+
65
+ return groups, interval_names
66
+
67
+ def evaluate(self, model=None):
68
+ """
69
+ :param model: callable with signature (image_batch, mask_batch); should return inpainted_batch
70
+ :return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
71
+ name of the particular group arranged by area of mask (e.g. '10-20%')
72
+ and score statistics for the group as values.
73
+ """
74
+ results = dict()
75
+ if self.area_grouping:
76
+ groups, interval_names = self._get_bin_edges()
77
+ else:
78
+ groups = None
79
+
80
+ for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'):
81
+ score.to(self.device)
82
+ with torch.no_grad():
83
+ score.reset()
84
+ for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False):
85
+ batch = move_to_device(batch, self.device)
86
+ image_batch, mask_batch = batch['image'], batch['mask']
87
+ if self.clamp_image_range is not None:
88
+ image_batch = torch.clamp(image_batch,
89
+ min=self.clamp_image_range[0],
90
+ max=self.clamp_image_range[1])
91
+ if model is None:
92
+ assert 'inpainted' in batch, \
93
+ 'Model is None, so we expected precomputed inpainting results at key "inpainted"'
94
+ inpainted_batch = batch['inpainted']
95
+ else:
96
+ inpainted_batch = model(image_batch, mask_batch)
97
+ score(inpainted_batch, image_batch, mask_batch)
98
+ total_results, group_results = score.get_value(groups=groups)
99
+
100
+ results[(score_name, 'total')] = total_results
101
+ if groups is not None:
102
+ for group_index, group_values in group_results.items():
103
+ group_name = interval_names[group_index]
104
+ results[(score_name, group_name)] = group_values
105
+
106
+ if self.integral_func is not None:
107
+ results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
108
+
109
+ return results
110
+
111
+
112
+ def ssim_fid100_f1(metrics, fid_scale=100):
113
+ ssim = metrics[('ssim', 'total')]['mean']
114
+ fid = metrics[('fid', 'total')]['mean']
115
+ fid_rel = max(0, fid_scale - fid) / fid_scale
116
+ f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
117
+ return f1
118
+
119
+
120
+ def lpips_fid100_f1(metrics, fid_scale=100):
121
+ neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better
122
+ fid = metrics[('fid', 'total')]['mean']
123
+ fid_rel = max(0, fid_scale - fid) / fid_scale
124
+ f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3)
125
+ return f1
126
+
127
+
128
+
129
+ class InpaintingEvaluatorOnline(nn.Module):
130
+ def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted',
131
+ integral_func=None, integral_title=None, clamp_image_range=None):
132
+ """
133
+ :param scores: dict {score_name: EvaluatorScore object}
134
+ :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
135
+ :param device: device to use
136
+ """
137
+ super().__init__()
138
+ LOGGER.info(f'{type(self)} init called')
139
+ self.scores = nn.ModuleDict(scores)
140
+ self.image_key = image_key
141
+ self.inpainted_key = inpainted_key
142
+ self.bins_num = bins
143
+ self.bin_edges = np.linspace(0, 1, self.bins_num + 1)
144
+
145
+ num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1)
146
+ self.interval_names = []
147
+ for idx_bin in range(self.bins_num):
148
+ start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \
149
+ round(100 * self.bin_edges[idx_bin + 1], num_digits)
150
+ start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
151
+ end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
152
+ self.interval_names.append("{0}-{1}%".format(start_percent, end_percent))
153
+
154
+ self.groups = []
155
+
156
+ self.integral_func = integral_func
157
+ self.integral_title = integral_title
158
+ self.clamp_image_range = clamp_image_range
159
+
160
+ LOGGER.info(f'{type(self)} init done')
161
+
162
+ def _get_bins(self, mask_batch):
163
+ batch_size = mask_batch.shape[0]
164
+ area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy()
165
+ bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1)
166
+ return bin_indices
167
+
168
+ def forward(self, batch: Dict[str, torch.Tensor]):
169
+ """
170
+ Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end
171
+ :param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key)
172
+ """
173
+ result = {}
174
+ with torch.no_grad():
175
+ image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key]
176
+ if self.clamp_image_range is not None:
177
+ image_batch = torch.clamp(image_batch,
178
+ min=self.clamp_image_range[0],
179
+ max=self.clamp_image_range[1])
180
+ self.groups.extend(self._get_bins(mask_batch))
181
+
182
+ for score_name, score in self.scores.items():
183
+ result[score_name] = score(inpainted_batch, image_batch, mask_batch)
184
+ return result
185
+
186
+ def process_batch(self, batch: Dict[str, torch.Tensor]):
187
+ return self(batch)
188
+
189
+ def evaluation_end(self, states=None):
190
+ """:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
191
+ name of the particular group arranged by area of mask (e.g. '10-20%')
192
+ and score statistics for the group as values.
193
+ """
194
+ LOGGER.info(f'{type(self)}: evaluation_end called')
195
+
196
+ self.groups = np.array(self.groups)
197
+
198
+ results = {}
199
+ for score_name, score in self.scores.items():
200
+ LOGGER.info(f'Getting value of {score_name}')
201
+ cur_states = [s[score_name] for s in states] if states is not None else None
202
+ total_results, group_results = score.get_value(groups=self.groups, states=cur_states)
203
+ LOGGER.info(f'Getting value of {score_name} done')
204
+ results[(score_name, 'total')] = total_results
205
+
206
+ for group_index, group_values in group_results.items():
207
+ group_name = self.interval_names[group_index]
208
+ results[(score_name, group_name)] = group_values
209
+
210
+ if self.integral_func is not None:
211
+ results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
212
+
213
+ LOGGER.info(f'{type(self)}: reset scores')
214
+ self.groups = []
215
+ for sc in self.scores.values():
216
+ sc.reset()
217
+ LOGGER.info(f'{type(self)}: reset scores done')
218
+
219
+ LOGGER.info(f'{type(self)}: evaluation_end done')
220
+ return results
DH-AISP/2/saicinpainting/evaluation/losses/__init__.py ADDED
File without changes
DH-AISP/2/saicinpainting/evaluation/losses/base_loss.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import abstractmethod, ABC
3
+
4
+ import numpy as np
5
+ import sklearn
6
+ import sklearn.svm
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from joblib import Parallel, delayed
11
+ from scipy import linalg
12
+
13
+ from models.ade20k import SegmentationModule, NUM_CLASS, segm_options
14
+ from .fid.inception import InceptionV3
15
+ from .lpips import PerceptualLoss
16
+ from .ssim import SSIM
17
+
18
+ LOGGER = logging.getLogger(__name__)
19
+
20
+
21
+ def get_groupings(groups):
22
+ """
23
+ :param groups: group numbers for respective elements
24
+ :return: dict of kind {group_idx: indices of the corresponding group elements}
25
+ """
26
+ label_groups, count_groups = np.unique(groups, return_counts=True)
27
+
28
+ indices = np.argsort(groups)
29
+
30
+ grouping = dict()
31
+ cur_start = 0
32
+ for label, count in zip(label_groups, count_groups):
33
+ cur_end = cur_start + count
34
+ cur_indices = indices[cur_start:cur_end]
35
+ grouping[label] = cur_indices
36
+ cur_start = cur_end
37
+ return grouping
38
+
39
+
40
+ class EvaluatorScore(nn.Module):
41
+ @abstractmethod
42
+ def forward(self, pred_batch, target_batch, mask):
43
+ pass
44
+
45
+ @abstractmethod
46
+ def get_value(self, groups=None, states=None):
47
+ pass
48
+
49
+ @abstractmethod
50
+ def reset(self):
51
+ pass
52
+
53
+
54
+ class PairwiseScore(EvaluatorScore, ABC):
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.individual_values = None
58
+
59
+ def get_value(self, groups=None, states=None):
60
+ """
61
+ :param groups:
62
+ :return:
63
+ total_results: dict of kind {'mean': score mean, 'std': score std}
64
+ group_results: None, if groups is None;
65
+ else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
66
+ """
67
+ individual_values = torch.cat(states, dim=-1).reshape(-1).cpu().numpy() if states is not None \
68
+ else self.individual_values
69
+
70
+ total_results = {
71
+ 'mean': individual_values.mean(),
72
+ 'std': individual_values.std()
73
+ }
74
+
75
+ if groups is None:
76
+ return total_results, None
77
+
78
+ group_results = dict()
79
+ grouping = get_groupings(groups)
80
+ for label, index in grouping.items():
81
+ group_scores = individual_values[index]
82
+ group_results[label] = {
83
+ 'mean': group_scores.mean(),
84
+ 'std': group_scores.std()
85
+ }
86
+ return total_results, group_results
87
+
88
+ def reset(self):
89
+ self.individual_values = []
90
+
91
+
92
+ class SSIMScore(PairwiseScore):
93
+ def __init__(self, window_size=11):
94
+ super().__init__()
95
+ self.score = SSIM(window_size=window_size, size_average=False).eval()
96
+ self.reset()
97
+
98
+ def forward(self, pred_batch, target_batch, mask=None):
99
+ batch_values = self.score(pred_batch, target_batch)
100
+ self.individual_values = np.hstack([
101
+ self.individual_values, batch_values.detach().cpu().numpy()
102
+ ])
103
+ return batch_values
104
+
105
+
106
+ class LPIPSScore(PairwiseScore):
107
+ def __init__(self, model='net-lin', net='vgg', model_path=None, use_gpu=True):
108
+ super().__init__()
109
+ self.score = PerceptualLoss(model=model, net=net, model_path=model_path,
110
+ use_gpu=use_gpu, spatial=False).eval()
111
+ self.reset()
112
+
113
+ def forward(self, pred_batch, target_batch, mask=None):
114
+ batch_values = self.score(pred_batch, target_batch).flatten()
115
+ self.individual_values = np.hstack([
116
+ self.individual_values, batch_values.detach().cpu().numpy()
117
+ ])
118
+ return batch_values
119
+
120
+
121
+ def fid_calculate_activation_statistics(act):
122
+ mu = np.mean(act, axis=0)
123
+ sigma = np.cov(act, rowvar=False)
124
+ return mu, sigma
125
+
126
+
127
+ def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6):
128
+ mu1, sigma1 = fid_calculate_activation_statistics(activations_pred)
129
+ mu2, sigma2 = fid_calculate_activation_statistics(activations_target)
130
+
131
+ diff = mu1 - mu2
132
+
133
+ # Product might be almost singular
134
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
135
+ if not np.isfinite(covmean).all():
136
+ msg = ('fid calculation produces singular product; '
137
+ 'adding %s to diagonal of cov estimates') % eps
138
+ LOGGER.warning(msg)
139
+ offset = np.eye(sigma1.shape[0]) * eps
140
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
141
+
142
+ # Numerical error might give slight imaginary component
143
+ if np.iscomplexobj(covmean):
144
+ # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
145
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2):
146
+ m = np.max(np.abs(covmean.imag))
147
+ raise ValueError('Imaginary component {}'.format(m))
148
+ covmean = covmean.real
149
+
150
+ tr_covmean = np.trace(covmean)
151
+
152
+ return (diff.dot(diff) + np.trace(sigma1) +
153
+ np.trace(sigma2) - 2 * tr_covmean)
154
+
155
+
156
+ class FIDScore(EvaluatorScore):
157
+ def __init__(self, dims=2048, eps=1e-6):
158
+ LOGGER.info("FIDscore init called")
159
+ super().__init__()
160
+ if getattr(FIDScore, '_MODEL', None) is None:
161
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
162
+ FIDScore._MODEL = InceptionV3([block_idx]).eval()
163
+ self.model = FIDScore._MODEL
164
+ self.eps = eps
165
+ self.reset()
166
+ LOGGER.info("FIDscore init done")
167
+
168
+ def forward(self, pred_batch, target_batch, mask=None):
169
+ activations_pred = self._get_activations(pred_batch)
170
+ activations_target = self._get_activations(target_batch)
171
+
172
+ self.activations_pred.append(activations_pred.detach().cpu())
173
+ self.activations_target.append(activations_target.detach().cpu())
174
+
175
+ return activations_pred, activations_target
176
+
177
+ def get_value(self, groups=None, states=None):
178
+ LOGGER.info("FIDscore get_value called")
179
+ activations_pred, activations_target = zip(*states) if states is not None \
180
+ else (self.activations_pred, self.activations_target)
181
+ activations_pred = torch.cat(activations_pred).cpu().numpy()
182
+ activations_target = torch.cat(activations_target).cpu().numpy()
183
+
184
+ total_distance = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps)
185
+ total_results = dict(mean=total_distance)
186
+
187
+ if groups is None:
188
+ group_results = None
189
+ else:
190
+ group_results = dict()
191
+ grouping = get_groupings(groups)
192
+ for label, index in grouping.items():
193
+ if len(index) > 1:
194
+ group_distance = calculate_frechet_distance(activations_pred[index], activations_target[index],
195
+ eps=self.eps)
196
+ group_results[label] = dict(mean=group_distance)
197
+
198
+ else:
199
+ group_results[label] = dict(mean=float('nan'))
200
+
201
+ self.reset()
202
+
203
+ LOGGER.info("FIDscore get_value done")
204
+
205
+ return total_results, group_results
206
+
207
+ def reset(self):
208
+ self.activations_pred = []
209
+ self.activations_target = []
210
+
211
+ def _get_activations(self, batch):
212
+ activations = self.model(batch)[0]
213
+ if activations.shape[2] != 1 or activations.shape[3] != 1:
214
+ assert False, \
215
+ 'We should not have got here, because Inception always scales inputs to 299x299'
216
+ # activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1))
217
+ activations = activations.squeeze(-1).squeeze(-1)
218
+ return activations
219
+
220
+
221
+ class SegmentationAwareScore(EvaluatorScore):
222
+ def __init__(self, weights_path):
223
+ super().__init__()
224
+ self.segm_network = SegmentationModule(weights_path=weights_path, use_default_normalization=True).eval()
225
+ self.target_class_freq_by_image_total = []
226
+ self.target_class_freq_by_image_mask = []
227
+ self.pred_class_freq_by_image_mask = []
228
+
229
+ def forward(self, pred_batch, target_batch, mask):
230
+ pred_segm_flat = self.segm_network.predict(pred_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy()
231
+ target_segm_flat = self.segm_network.predict(target_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy()
232
+ mask_flat = (mask.view(mask.shape[0], -1) > 0.5).detach().cpu().numpy()
233
+
234
+ batch_target_class_freq_total = []
235
+ batch_target_class_freq_mask = []
236
+ batch_pred_class_freq_mask = []
237
+
238
+ for cur_pred_segm, cur_target_segm, cur_mask in zip(pred_segm_flat, target_segm_flat, mask_flat):
239
+ cur_target_class_freq_total = np.bincount(cur_target_segm, minlength=NUM_CLASS)[None, ...]
240
+ cur_target_class_freq_mask = np.bincount(cur_target_segm[cur_mask], minlength=NUM_CLASS)[None, ...]
241
+ cur_pred_class_freq_mask = np.bincount(cur_pred_segm[cur_mask], minlength=NUM_CLASS)[None, ...]
242
+
243
+ self.target_class_freq_by_image_total.append(cur_target_class_freq_total)
244
+ self.target_class_freq_by_image_mask.append(cur_target_class_freq_mask)
245
+ self.pred_class_freq_by_image_mask.append(cur_pred_class_freq_mask)
246
+
247
+ batch_target_class_freq_total.append(cur_target_class_freq_total)
248
+ batch_target_class_freq_mask.append(cur_target_class_freq_mask)
249
+ batch_pred_class_freq_mask.append(cur_pred_class_freq_mask)
250
+
251
+ batch_target_class_freq_total = np.concatenate(batch_target_class_freq_total, axis=0)
252
+ batch_target_class_freq_mask = np.concatenate(batch_target_class_freq_mask, axis=0)
253
+ batch_pred_class_freq_mask = np.concatenate(batch_pred_class_freq_mask, axis=0)
254
+ return batch_target_class_freq_total, batch_target_class_freq_mask, batch_pred_class_freq_mask
255
+
256
+ def reset(self):
257
+ super().reset()
258
+ self.target_class_freq_by_image_total = []
259
+ self.target_class_freq_by_image_mask = []
260
+ self.pred_class_freq_by_image_mask = []
261
+
262
+
263
+ def distribute_values_to_classes(target_class_freq_by_image_mask, values, idx2name):
264
+ assert target_class_freq_by_image_mask.ndim == 2 and target_class_freq_by_image_mask.shape[0] == values.shape[0]
265
+ total_class_freq = target_class_freq_by_image_mask.sum(0)
266
+ distr_values = (target_class_freq_by_image_mask * values[..., None]).sum(0)
267
+ result = distr_values / (total_class_freq + 1e-3)
268
+ return {idx2name[i]: val for i, val in enumerate(result) if total_class_freq[i] > 0}
269
+
270
+
271
+ def get_segmentation_idx2name():
272
+ return {i - 1: name for i, name in segm_options['classes'].set_index('Idx', drop=True)['Name'].to_dict().items()}
273
+
274
+
275
+ class SegmentationAwarePairwiseScore(SegmentationAwareScore):
276
+ def __init__(self, *args, **kwargs):
277
+ super().__init__(*args, **kwargs)
278
+ self.individual_values = []
279
+ self.segm_idx2name = get_segmentation_idx2name()
280
+
281
+ def forward(self, pred_batch, target_batch, mask):
282
+ cur_class_stats = super().forward(pred_batch, target_batch, mask)
283
+ score_values = self.calc_score(pred_batch, target_batch, mask)
284
+ self.individual_values.append(score_values)
285
+ return cur_class_stats + (score_values,)
286
+
287
+ @abstractmethod
288
+ def calc_score(self, pred_batch, target_batch, mask):
289
+ raise NotImplementedError()
290
+
291
+ def get_value(self, groups=None, states=None):
292
+ """
293
+ :param groups:
294
+ :return:
295
+ total_results: dict of kind {'mean': score mean, 'std': score std}
296
+ group_results: None, if groups is None;
297
+ else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
298
+ """
299
+ if states is not None:
300
+ (target_class_freq_by_image_total,
301
+ target_class_freq_by_image_mask,
302
+ pred_class_freq_by_image_mask,
303
+ individual_values) = states
304
+ else:
305
+ target_class_freq_by_image_total = self.target_class_freq_by_image_total
306
+ target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
307
+ pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
308
+ individual_values = self.individual_values
309
+
310
+ target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
311
+ target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
312
+ pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
313
+ individual_values = np.concatenate(individual_values, axis=0)
314
+
315
+ total_results = {
316
+ 'mean': individual_values.mean(),
317
+ 'std': individual_values.std(),
318
+ **distribute_values_to_classes(target_class_freq_by_image_mask, individual_values, self.segm_idx2name)
319
+ }
320
+
321
+ if groups is None:
322
+ return total_results, None
323
+
324
+ group_results = dict()
325
+ grouping = get_groupings(groups)
326
+ for label, index in grouping.items():
327
+ group_class_freq = target_class_freq_by_image_mask[index]
328
+ group_scores = individual_values[index]
329
+ group_results[label] = {
330
+ 'mean': group_scores.mean(),
331
+ 'std': group_scores.std(),
332
+ ** distribute_values_to_classes(group_class_freq, group_scores, self.segm_idx2name)
333
+ }
334
+ return total_results, group_results
335
+
336
+ def reset(self):
337
+ super().reset()
338
+ self.individual_values = []
339
+
340
+
341
+ class SegmentationClassStats(SegmentationAwarePairwiseScore):
342
+ def calc_score(self, pred_batch, target_batch, mask):
343
+ return 0
344
+
345
+ def get_value(self, groups=None, states=None):
346
+ """
347
+ :param groups:
348
+ :return:
349
+ total_results: dict of kind {'mean': score mean, 'std': score std}
350
+ group_results: None, if groups is None;
351
+ else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
352
+ """
353
+ if states is not None:
354
+ (target_class_freq_by_image_total,
355
+ target_class_freq_by_image_mask,
356
+ pred_class_freq_by_image_mask,
357
+ _) = states
358
+ else:
359
+ target_class_freq_by_image_total = self.target_class_freq_by_image_total
360
+ target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
361
+ pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
362
+
363
+ target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
364
+ target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
365
+ pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
366
+
367
+ target_class_freq_by_image_total_marginal = target_class_freq_by_image_total.sum(0).astype('float32')
368
+ target_class_freq_by_image_total_marginal /= target_class_freq_by_image_total_marginal.sum()
369
+
370
+ target_class_freq_by_image_mask_marginal = target_class_freq_by_image_mask.sum(0).astype('float32')
371
+ target_class_freq_by_image_mask_marginal /= target_class_freq_by_image_mask_marginal.sum()
372
+
373
+ pred_class_freq_diff = (pred_class_freq_by_image_mask - target_class_freq_by_image_mask).sum(0) / (target_class_freq_by_image_mask.sum(0) + 1e-3)
374
+
375
+ total_results = dict()
376
+ total_results.update({f'total_freq/{self.segm_idx2name[i]}': v
377
+ for i, v in enumerate(target_class_freq_by_image_total_marginal)
378
+ if v > 0})
379
+ total_results.update({f'mask_freq/{self.segm_idx2name[i]}': v
380
+ for i, v in enumerate(target_class_freq_by_image_mask_marginal)
381
+ if v > 0})
382
+ total_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v
383
+ for i, v in enumerate(pred_class_freq_diff)
384
+ if target_class_freq_by_image_total_marginal[i] > 0})
385
+
386
+ if groups is None:
387
+ return total_results, None
388
+
389
+ group_results = dict()
390
+ grouping = get_groupings(groups)
391
+ for label, index in grouping.items():
392
+ group_target_class_freq_by_image_total = target_class_freq_by_image_total[index]
393
+ group_target_class_freq_by_image_mask = target_class_freq_by_image_mask[index]
394
+ group_pred_class_freq_by_image_mask = pred_class_freq_by_image_mask[index]
395
+
396
+ group_target_class_freq_by_image_total_marginal = group_target_class_freq_by_image_total.sum(0).astype('float32')
397
+ group_target_class_freq_by_image_total_marginal /= group_target_class_freq_by_image_total_marginal.sum()
398
+
399
+ group_target_class_freq_by_image_mask_marginal = group_target_class_freq_by_image_mask.sum(0).astype('float32')
400
+ group_target_class_freq_by_image_mask_marginal /= group_target_class_freq_by_image_mask_marginal.sum()
401
+
402
+ group_pred_class_freq_diff = (group_pred_class_freq_by_image_mask - group_target_class_freq_by_image_mask).sum(0) / (
403
+ group_target_class_freq_by_image_mask.sum(0) + 1e-3)
404
+
405
+ cur_group_results = dict()
406
+ cur_group_results.update({f'total_freq/{self.segm_idx2name[i]}': v
407
+ for i, v in enumerate(group_target_class_freq_by_image_total_marginal)
408
+ if v > 0})
409
+ cur_group_results.update({f'mask_freq/{self.segm_idx2name[i]}': v
410
+ for i, v in enumerate(group_target_class_freq_by_image_mask_marginal)
411
+ if v > 0})
412
+ cur_group_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v
413
+ for i, v in enumerate(group_pred_class_freq_diff)
414
+ if group_target_class_freq_by_image_total_marginal[i] > 0})
415
+
416
+ group_results[label] = cur_group_results
417
+ return total_results, group_results
418
+
419
+
420
+ class SegmentationAwareSSIM(SegmentationAwarePairwiseScore):
421
+ def __init__(self, *args, window_size=11, **kwargs):
422
+ super().__init__(*args, **kwargs)
423
+ self.score_impl = SSIM(window_size=window_size, size_average=False).eval()
424
+
425
+ def calc_score(self, pred_batch, target_batch, mask):
426
+ return self.score_impl(pred_batch, target_batch).detach().cpu().numpy()
427
+
428
+
429
+ class SegmentationAwareLPIPS(SegmentationAwarePairwiseScore):
430
+ def __init__(self, *args, model='net-lin', net='vgg', model_path=None, use_gpu=True, **kwargs):
431
+ super().__init__(*args, **kwargs)
432
+ self.score_impl = PerceptualLoss(model=model, net=net, model_path=model_path,
433
+ use_gpu=use_gpu, spatial=False).eval()
434
+
435
+ def calc_score(self, pred_batch, target_batch, mask):
436
+ return self.score_impl(pred_batch, target_batch).flatten().detach().cpu().numpy()
437
+
438
+
439
+ def calculade_fid_no_img(img_i, activations_pred, activations_target, eps=1e-6):
440
+ activations_pred = activations_pred.copy()
441
+ activations_pred[img_i] = activations_target[img_i]
442
+ return calculate_frechet_distance(activations_pred, activations_target, eps=eps)
443
+
444
+
445
+ class SegmentationAwareFID(SegmentationAwarePairwiseScore):
446
+ def __init__(self, *args, dims=2048, eps=1e-6, n_jobs=-1, **kwargs):
447
+ super().__init__(*args, **kwargs)
448
+ if getattr(FIDScore, '_MODEL', None) is None:
449
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
450
+ FIDScore._MODEL = InceptionV3([block_idx]).eval()
451
+ self.model = FIDScore._MODEL
452
+ self.eps = eps
453
+ self.n_jobs = n_jobs
454
+
455
+ def calc_score(self, pred_batch, target_batch, mask):
456
+ activations_pred = self._get_activations(pred_batch)
457
+ activations_target = self._get_activations(target_batch)
458
+ return activations_pred, activations_target
459
+
460
+ def get_value(self, groups=None, states=None):
461
+ """
462
+ :param groups:
463
+ :return:
464
+ total_results: dict of kind {'mean': score mean, 'std': score std}
465
+ group_results: None, if groups is None;
466
+ else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
467
+ """
468
+ if states is not None:
469
+ (target_class_freq_by_image_total,
470
+ target_class_freq_by_image_mask,
471
+ pred_class_freq_by_image_mask,
472
+ activation_pairs) = states
473
+ else:
474
+ target_class_freq_by_image_total = self.target_class_freq_by_image_total
475
+ target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
476
+ pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
477
+ activation_pairs = self.individual_values
478
+
479
+ target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
480
+ target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
481
+ pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
482
+ activations_pred, activations_target = zip(*activation_pairs)
483
+ activations_pred = np.concatenate(activations_pred, axis=0)
484
+ activations_target = np.concatenate(activations_target, axis=0)
485
+
486
+ total_results = {
487
+ 'mean': calculate_frechet_distance(activations_pred, activations_target, eps=self.eps),
488
+ 'std': 0,
489
+ **self.distribute_fid_to_classes(target_class_freq_by_image_mask, activations_pred, activations_target)
490
+ }
491
+
492
+ if groups is None:
493
+ return total_results, None
494
+
495
+ group_results = dict()
496
+ grouping = get_groupings(groups)
497
+ for label, index in grouping.items():
498
+ if len(index) > 1:
499
+ group_activations_pred = activations_pred[index]
500
+ group_activations_target = activations_target[index]
501
+ group_class_freq = target_class_freq_by_image_mask[index]
502
+ group_results[label] = {
503
+ 'mean': calculate_frechet_distance(group_activations_pred, group_activations_target, eps=self.eps),
504
+ 'std': 0,
505
+ **self.distribute_fid_to_classes(group_class_freq,
506
+ group_activations_pred,
507
+ group_activations_target)
508
+ }
509
+ else:
510
+ group_results[label] = dict(mean=float('nan'), std=0)
511
+ return total_results, group_results
512
+
513
+ def distribute_fid_to_classes(self, class_freq, activations_pred, activations_target):
514
+ real_fid = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps)
515
+
516
+ fid_no_images = Parallel(n_jobs=self.n_jobs)(
517
+ delayed(calculade_fid_no_img)(img_i, activations_pred, activations_target, eps=self.eps)
518
+ for img_i in range(activations_pred.shape[0])
519
+ )
520
+ errors = real_fid - fid_no_images
521
+ return distribute_values_to_classes(class_freq, errors, self.segm_idx2name)
522
+
523
+ def _get_activations(self, batch):
524
+ activations = self.model(batch)[0]
525
+ if activations.shape[2] != 1 or activations.shape[3] != 1:
526
+ activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1))
527
+ activations = activations.squeeze(-1).squeeze(-1).detach().cpu().numpy()
528
+ return activations
DH-AISP/2/saicinpainting/evaluation/losses/fid/__init__.py ADDED
File without changes
DH-AISP/2/saicinpainting/evaluation/losses/fid/fid_score.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
3
+
4
+ The FID metric calculates the distance between two distributions of images.
5
+ Typically, we have summary statistics (mean & covariance matrix) of one
6
+ of these distributions, while the 2nd distribution is given by a GAN.
7
+
8
+ When run as a stand-alone program, it compares the distribution of
9
+ images that are stored as PNG/JPEG at a specified location with a
10
+ distribution given by summary statistics (in pickle format).
11
+
12
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
13
+ the pool_3 layer of the inception net for generated samples and real world
14
+ samples respectively.
15
+
16
+ See --help to see further details.
17
+
18
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
19
+ of Tensorflow
20
+
21
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
22
+
23
+ Licensed under the Apache License, Version 2.0 (the "License");
24
+ you may not use this file except in compliance with the License.
25
+ You may obtain a copy of the License at
26
+
27
+ http://www.apache.org/licenses/LICENSE-2.0
28
+
29
+ Unless required by applicable law or agreed to in writing, software
30
+ distributed under the License is distributed on an "AS IS" BASIS,
31
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32
+ See the License for the specific language governing permissions and
33
+ limitations under the License.
34
+ """
35
+ import os
36
+ import pathlib
37
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
38
+
39
+ import numpy as np
40
+ import torch
41
+ # from scipy.misc import imread
42
+ from imageio import imread
43
+ from PIL import Image, JpegImagePlugin
44
+ from scipy import linalg
45
+ from torch.nn.functional import adaptive_avg_pool2d
46
+ from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor
47
+
48
+ try:
49
+ from tqdm import tqdm
50
+ except ImportError:
51
+ # If not tqdm is not available, provide a mock version of it
52
+ def tqdm(x): return x
53
+
54
+ try:
55
+ from .inception import InceptionV3
56
+ except ModuleNotFoundError:
57
+ from inception import InceptionV3
58
+
59
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
60
+ parser.add_argument('path', type=str, nargs=2,
61
+ help=('Path to the generated images or '
62
+ 'to .npz statistic files'))
63
+ parser.add_argument('--batch-size', type=int, default=50,
64
+ help='Batch size to use')
65
+ parser.add_argument('--dims', type=int, default=2048,
66
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
67
+ help=('Dimensionality of Inception features to use. '
68
+ 'By default, uses pool3 features'))
69
+ parser.add_argument('-c', '--gpu', default='', type=str,
70
+ help='GPU to use (leave blank for CPU only)')
71
+ parser.add_argument('--resize', default=256)
72
+
73
+ transform = Compose([Resize(256), CenterCrop(256), ToTensor()])
74
+
75
+
76
+ def get_activations(files, model, batch_size=50, dims=2048,
77
+ cuda=False, verbose=False, keep_size=False):
78
+ """Calculates the activations of the pool_3 layer for all images.
79
+
80
+ Params:
81
+ -- files : List of image files paths
82
+ -- model : Instance of inception model
83
+ -- batch_size : Batch size of images for the model to process at once.
84
+ Make sure that the number of samples is a multiple of
85
+ the batch size, otherwise some samples are ignored. This
86
+ behavior is retained to match the original FID score
87
+ implementation.
88
+ -- dims : Dimensionality of features returned by Inception
89
+ -- cuda : If set to True, use GPU
90
+ -- verbose : If set to True and parameter out_step is given, the number
91
+ of calculated batches is reported.
92
+ Returns:
93
+ -- A numpy array of dimension (num images, dims) that contains the
94
+ activations of the given tensor when feeding inception with the
95
+ query tensor.
96
+ """
97
+ model.eval()
98
+
99
+ if len(files) % batch_size != 0:
100
+ print(('Warning: number of images is not a multiple of the '
101
+ 'batch size. Some samples are going to be ignored.'))
102
+ if batch_size > len(files):
103
+ print(('Warning: batch size is bigger than the data size. '
104
+ 'Setting batch size to data size'))
105
+ batch_size = len(files)
106
+
107
+ n_batches = len(files) // batch_size
108
+ n_used_imgs = n_batches * batch_size
109
+
110
+ pred_arr = np.empty((n_used_imgs, dims))
111
+
112
+ for i in tqdm(range(n_batches)):
113
+ if verbose:
114
+ print('\rPropagating batch %d/%d' % (i + 1, n_batches),
115
+ end='', flush=True)
116
+ start = i * batch_size
117
+ end = start + batch_size
118
+
119
+ # # Official code goes below
120
+ # images = np.array([imread(str(f)).astype(np.float32)
121
+ # for f in files[start:end]])
122
+
123
+ # # Reshape to (n_images, 3, height, width)
124
+ # images = images.transpose((0, 3, 1, 2))
125
+ # images /= 255
126
+ # batch = torch.from_numpy(images).type(torch.FloatTensor)
127
+ # #
128
+
129
+ t = transform if not keep_size else ToTensor()
130
+
131
+ if isinstance(files[0], pathlib.PosixPath):
132
+ images = [t(Image.open(str(f))) for f in files[start:end]]
133
+
134
+ elif isinstance(files[0], Image.Image):
135
+ images = [t(f) for f in files[start:end]]
136
+
137
+ else:
138
+ raise ValueError(f"Unknown data type for image: {type(files[0])}")
139
+
140
+ batch = torch.stack(images)
141
+
142
+ if cuda:
143
+ batch = batch.cuda()
144
+
145
+ pred = model(batch)[0]
146
+
147
+ # If model output is not scalar, apply global spatial average pooling.
148
+ # This happens if you choose a dimensionality not equal 2048.
149
+ if pred.shape[2] != 1 or pred.shape[3] != 1:
150
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
151
+
152
+ pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
153
+
154
+ if verbose:
155
+ print(' done')
156
+
157
+ return pred_arr
158
+
159
+
160
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
161
+ """Numpy implementation of the Frechet Distance.
162
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
163
+ and X_2 ~ N(mu_2, C_2) is
164
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
165
+
166
+ Stable version by Dougal J. Sutherland.
167
+
168
+ Params:
169
+ -- mu1 : Numpy array containing the activations of a layer of the
170
+ inception net (like returned by the function 'get_predictions')
171
+ for generated samples.
172
+ -- mu2 : The sample mean over activations, precalculated on an
173
+ representative data set.
174
+ -- sigma1: The covariance matrix over activations for generated samples.
175
+ -- sigma2: The covariance matrix over activations, precalculated on an
176
+ representative data set.
177
+
178
+ Returns:
179
+ -- : The Frechet Distance.
180
+ """
181
+
182
+ mu1 = np.atleast_1d(mu1)
183
+ mu2 = np.atleast_1d(mu2)
184
+
185
+ sigma1 = np.atleast_2d(sigma1)
186
+ sigma2 = np.atleast_2d(sigma2)
187
+
188
+ assert mu1.shape == mu2.shape, \
189
+ 'Training and test mean vectors have different lengths'
190
+ assert sigma1.shape == sigma2.shape, \
191
+ 'Training and test covariances have different dimensions'
192
+
193
+ diff = mu1 - mu2
194
+
195
+ # Product might be almost singular
196
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
197
+ if not np.isfinite(covmean).all():
198
+ msg = ('fid calculation produces singular product; '
199
+ 'adding %s to diagonal of cov estimates') % eps
200
+ print(msg)
201
+ offset = np.eye(sigma1.shape[0]) * eps
202
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
203
+
204
+ # Numerical error might give slight imaginary component
205
+ if np.iscomplexobj(covmean):
206
+ # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
207
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2):
208
+ m = np.max(np.abs(covmean.imag))
209
+ raise ValueError('Imaginary component {}'.format(m))
210
+ covmean = covmean.real
211
+
212
+ tr_covmean = np.trace(covmean)
213
+
214
+ return (diff.dot(diff) + np.trace(sigma1) +
215
+ np.trace(sigma2) - 2 * tr_covmean)
216
+
217
+
218
+ def calculate_activation_statistics(files, model, batch_size=50,
219
+ dims=2048, cuda=False, verbose=False, keep_size=False):
220
+ """Calculation of the statistics used by the FID.
221
+ Params:
222
+ -- files : List of image files paths
223
+ -- model : Instance of inception model
224
+ -- batch_size : The images numpy array is split into batches with
225
+ batch size batch_size. A reasonable batch size
226
+ depends on the hardware.
227
+ -- dims : Dimensionality of features returned by Inception
228
+ -- cuda : If set to True, use GPU
229
+ -- verbose : If set to True and parameter out_step is given, the
230
+ number of calculated batches is reported.
231
+ Returns:
232
+ -- mu : The mean over samples of the activations of the pool_3 layer of
233
+ the inception model.
234
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
235
+ the inception model.
236
+ """
237
+ act = get_activations(files, model, batch_size, dims, cuda, verbose, keep_size=keep_size)
238
+ mu = np.mean(act, axis=0)
239
+ sigma = np.cov(act, rowvar=False)
240
+ return mu, sigma
241
+
242
+
243
+ def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
244
+ if path.endswith('.npz'):
245
+ f = np.load(path)
246
+ m, s = f['mu'][:], f['sigma'][:]
247
+ f.close()
248
+ else:
249
+ path = pathlib.Path(path)
250
+ files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
251
+ m, s = calculate_activation_statistics(files, model, batch_size,
252
+ dims, cuda)
253
+
254
+ return m, s
255
+
256
+
257
+ def _compute_statistics_of_images(images, model, batch_size, dims, cuda, keep_size=False):
258
+ if isinstance(images, list): # exact paths to files are provided
259
+ m, s = calculate_activation_statistics(images, model, batch_size,
260
+ dims, cuda, keep_size=keep_size)
261
+
262
+ return m, s
263
+
264
+ else:
265
+ raise ValueError
266
+
267
+
268
+ def calculate_fid_given_paths(paths, batch_size, cuda, dims):
269
+ """Calculates the FID of two paths"""
270
+ for p in paths:
271
+ if not os.path.exists(p):
272
+ raise RuntimeError('Invalid path: %s' % p)
273
+
274
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
275
+
276
+ model = InceptionV3([block_idx])
277
+ if cuda:
278
+ model.cuda()
279
+
280
+ m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
281
+ dims, cuda)
282
+ m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
283
+ dims, cuda)
284
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
285
+
286
+ return fid_value
287
+
288
+
289
+ def calculate_fid_given_images(images, batch_size, cuda, dims, use_globals=False, keep_size=False):
290
+ if use_globals:
291
+ global FID_MODEL # for multiprocessing
292
+
293
+ for imgs in images:
294
+ if isinstance(imgs, list) and isinstance(imgs[0], (Image.Image, JpegImagePlugin.JpegImageFile)):
295
+ pass
296
+ else:
297
+ raise RuntimeError('Invalid images')
298
+
299
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
300
+
301
+ if 'FID_MODEL' not in globals() or not use_globals:
302
+ model = InceptionV3([block_idx])
303
+ if cuda:
304
+ model.cuda()
305
+
306
+ if use_globals:
307
+ FID_MODEL = model
308
+
309
+ else:
310
+ model = FID_MODEL
311
+
312
+ m1, s1 = _compute_statistics_of_images(images[0], model, batch_size,
313
+ dims, cuda, keep_size=False)
314
+ m2, s2 = _compute_statistics_of_images(images[1], model, batch_size,
315
+ dims, cuda, keep_size=False)
316
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
317
+ return fid_value
318
+
319
+
320
+ if __name__ == '__main__':
321
+ args = parser.parse_args()
322
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
323
+
324
+ fid_value = calculate_fid_given_paths(args.path,
325
+ args.batch_size,
326
+ args.gpu != '',
327
+ args.dims)
328
+ print('FID: ', fid_value)
DH-AISP/2/saicinpainting/evaluation/losses/fid/inception.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchvision import models
7
+
8
+ try:
9
+ from torchvision.models.utils import load_state_dict_from_url
10
+ except ImportError:
11
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
12
+
13
+ # Inception weights ported to Pytorch from
14
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
15
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
16
+
17
+
18
+ LOGGER = logging.getLogger(__name__)
19
+
20
+
21
+ class InceptionV3(nn.Module):
22
+ """Pretrained InceptionV3 network returning feature maps"""
23
+
24
+ # Index of default block of inception to return,
25
+ # corresponds to output of final average pooling
26
+ DEFAULT_BLOCK_INDEX = 3
27
+
28
+ # Maps feature dimensionality to their output blocks indices
29
+ BLOCK_INDEX_BY_DIM = {
30
+ 64: 0, # First max pooling features
31
+ 192: 1, # Second max pooling featurs
32
+ 768: 2, # Pre-aux classifier features
33
+ 2048: 3 # Final average pooling features
34
+ }
35
+
36
+ def __init__(self,
37
+ output_blocks=[DEFAULT_BLOCK_INDEX],
38
+ resize_input=True,
39
+ normalize_input=True,
40
+ requires_grad=False,
41
+ use_fid_inception=True):
42
+ """Build pretrained InceptionV3
43
+
44
+ Parameters
45
+ ----------
46
+ output_blocks : list of int
47
+ Indices of blocks to return features of. Possible values are:
48
+ - 0: corresponds to output of first max pooling
49
+ - 1: corresponds to output of second max pooling
50
+ - 2: corresponds to output which is fed to aux classifier
51
+ - 3: corresponds to output of final average pooling
52
+ resize_input : bool
53
+ If true, bilinearly resizes input to width and height 299 before
54
+ feeding input to model. As the network without fully connected
55
+ layers is fully convolutional, it should be able to handle inputs
56
+ of arbitrary size, so resizing might not be strictly needed
57
+ normalize_input : bool
58
+ If true, scales the input from range (0, 1) to the range the
59
+ pretrained Inception network expects, namely (-1, 1)
60
+ requires_grad : bool
61
+ If true, parameters of the model require gradients. Possibly useful
62
+ for finetuning the network
63
+ use_fid_inception : bool
64
+ If true, uses the pretrained Inception model used in Tensorflow's
65
+ FID implementation. If false, uses the pretrained Inception model
66
+ available in torchvision. The FID Inception model has different
67
+ weights and a slightly different structure from torchvision's
68
+ Inception model. If you want to compute FID scores, you are
69
+ strongly advised to set this parameter to true to get comparable
70
+ results.
71
+ """
72
+ super(InceptionV3, self).__init__()
73
+
74
+ self.resize_input = resize_input
75
+ self.normalize_input = normalize_input
76
+ self.output_blocks = sorted(output_blocks)
77
+ self.last_needed_block = max(output_blocks)
78
+
79
+ assert self.last_needed_block <= 3, \
80
+ 'Last possible output block index is 3'
81
+
82
+ self.blocks = nn.ModuleList()
83
+
84
+ if use_fid_inception:
85
+ inception = fid_inception_v3()
86
+ else:
87
+ inception = models.inception_v3(pretrained=True)
88
+
89
+ # Block 0: input to maxpool1
90
+ block0 = [
91
+ inception.Conv2d_1a_3x3,
92
+ inception.Conv2d_2a_3x3,
93
+ inception.Conv2d_2b_3x3,
94
+ nn.MaxPool2d(kernel_size=3, stride=2)
95
+ ]
96
+ self.blocks.append(nn.Sequential(*block0))
97
+
98
+ # Block 1: maxpool1 to maxpool2
99
+ if self.last_needed_block >= 1:
100
+ block1 = [
101
+ inception.Conv2d_3b_1x1,
102
+ inception.Conv2d_4a_3x3,
103
+ nn.MaxPool2d(kernel_size=3, stride=2)
104
+ ]
105
+ self.blocks.append(nn.Sequential(*block1))
106
+
107
+ # Block 2: maxpool2 to aux classifier
108
+ if self.last_needed_block >= 2:
109
+ block2 = [
110
+ inception.Mixed_5b,
111
+ inception.Mixed_5c,
112
+ inception.Mixed_5d,
113
+ inception.Mixed_6a,
114
+ inception.Mixed_6b,
115
+ inception.Mixed_6c,
116
+ inception.Mixed_6d,
117
+ inception.Mixed_6e,
118
+ ]
119
+ self.blocks.append(nn.Sequential(*block2))
120
+
121
+ # Block 3: aux classifier to final avgpool
122
+ if self.last_needed_block >= 3:
123
+ block3 = [
124
+ inception.Mixed_7a,
125
+ inception.Mixed_7b,
126
+ inception.Mixed_7c,
127
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
128
+ ]
129
+ self.blocks.append(nn.Sequential(*block3))
130
+
131
+ for param in self.parameters():
132
+ param.requires_grad = requires_grad
133
+
134
+ def forward(self, inp):
135
+ """Get Inception feature maps
136
+
137
+ Parameters
138
+ ----------
139
+ inp : torch.autograd.Variable
140
+ Input tensor of shape Bx3xHxW. Values are expected to be in
141
+ range (0, 1)
142
+
143
+ Returns
144
+ -------
145
+ List of torch.autograd.Variable, corresponding to the selected output
146
+ block, sorted ascending by index
147
+ """
148
+ outp = []
149
+ x = inp
150
+
151
+ if self.resize_input:
152
+ x = F.interpolate(x,
153
+ size=(299, 299),
154
+ mode='bilinear',
155
+ align_corners=False)
156
+
157
+ if self.normalize_input:
158
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
159
+
160
+ for idx, block in enumerate(self.blocks):
161
+ x = block(x)
162
+ if idx in self.output_blocks:
163
+ outp.append(x)
164
+
165
+ if idx == self.last_needed_block:
166
+ break
167
+
168
+ return outp
169
+
170
+
171
+ def fid_inception_v3():
172
+ """Build pretrained Inception model for FID computation
173
+
174
+ The Inception model for FID computation uses a different set of weights
175
+ and has a slightly different structure than torchvision's Inception.
176
+
177
+ This method first constructs torchvision's Inception and then patches the
178
+ necessary parts that are different in the FID Inception model.
179
+ """
180
+ LOGGER.info('fid_inception_v3 called')
181
+ inception = models.inception_v3(num_classes=1008,
182
+ aux_logits=False,
183
+ pretrained=False)
184
+ LOGGER.info('models.inception_v3 done')
185
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
186
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
187
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
188
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
189
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
190
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
191
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
192
+ inception.Mixed_7b = FIDInceptionE_1(1280)
193
+ inception.Mixed_7c = FIDInceptionE_2(2048)
194
+
195
+ LOGGER.info('fid_inception_v3 patching done')
196
+
197
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
198
+ LOGGER.info('fid_inception_v3 weights downloaded')
199
+
200
+ inception.load_state_dict(state_dict)
201
+ LOGGER.info('fid_inception_v3 weights loaded into model')
202
+
203
+ return inception
204
+
205
+
206
+ class FIDInceptionA(models.inception.InceptionA):
207
+ """InceptionA block patched for FID computation"""
208
+ def __init__(self, in_channels, pool_features):
209
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
210
+
211
+ def forward(self, x):
212
+ branch1x1 = self.branch1x1(x)
213
+
214
+ branch5x5 = self.branch5x5_1(x)
215
+ branch5x5 = self.branch5x5_2(branch5x5)
216
+
217
+ branch3x3dbl = self.branch3x3dbl_1(x)
218
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
219
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
220
+
221
+ # Patch: Tensorflow's average pool does not use the padded zero's in
222
+ # its average calculation
223
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
224
+ count_include_pad=False)
225
+ branch_pool = self.branch_pool(branch_pool)
226
+
227
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
228
+ return torch.cat(outputs, 1)
229
+
230
+
231
+ class FIDInceptionC(models.inception.InceptionC):
232
+ """InceptionC block patched for FID computation"""
233
+ def __init__(self, in_channels, channels_7x7):
234
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
235
+
236
+ def forward(self, x):
237
+ branch1x1 = self.branch1x1(x)
238
+
239
+ branch7x7 = self.branch7x7_1(x)
240
+ branch7x7 = self.branch7x7_2(branch7x7)
241
+ branch7x7 = self.branch7x7_3(branch7x7)
242
+
243
+ branch7x7dbl = self.branch7x7dbl_1(x)
244
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
245
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
246
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
247
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
248
+
249
+ # Patch: Tensorflow's average pool does not use the padded zero's in
250
+ # its average calculation
251
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
252
+ count_include_pad=False)
253
+ branch_pool = self.branch_pool(branch_pool)
254
+
255
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
256
+ return torch.cat(outputs, 1)
257
+
258
+
259
+ class FIDInceptionE_1(models.inception.InceptionE):
260
+ """First InceptionE block patched for FID computation"""
261
+ def __init__(self, in_channels):
262
+ super(FIDInceptionE_1, self).__init__(in_channels)
263
+
264
+ def forward(self, x):
265
+ branch1x1 = self.branch1x1(x)
266
+
267
+ branch3x3 = self.branch3x3_1(x)
268
+ branch3x3 = [
269
+ self.branch3x3_2a(branch3x3),
270
+ self.branch3x3_2b(branch3x3),
271
+ ]
272
+ branch3x3 = torch.cat(branch3x3, 1)
273
+
274
+ branch3x3dbl = self.branch3x3dbl_1(x)
275
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
276
+ branch3x3dbl = [
277
+ self.branch3x3dbl_3a(branch3x3dbl),
278
+ self.branch3x3dbl_3b(branch3x3dbl),
279
+ ]
280
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
281
+
282
+ # Patch: Tensorflow's average pool does not use the padded zero's in
283
+ # its average calculation
284
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
285
+ count_include_pad=False)
286
+ branch_pool = self.branch_pool(branch_pool)
287
+
288
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
289
+ return torch.cat(outputs, 1)
290
+
291
+
292
+ class FIDInceptionE_2(models.inception.InceptionE):
293
+ """Second InceptionE block patched for FID computation"""
294
+ def __init__(self, in_channels):
295
+ super(FIDInceptionE_2, self).__init__(in_channels)
296
+
297
+ def forward(self, x):
298
+ branch1x1 = self.branch1x1(x)
299
+
300
+ branch3x3 = self.branch3x3_1(x)
301
+ branch3x3 = [
302
+ self.branch3x3_2a(branch3x3),
303
+ self.branch3x3_2b(branch3x3),
304
+ ]
305
+ branch3x3 = torch.cat(branch3x3, 1)
306
+
307
+ branch3x3dbl = self.branch3x3dbl_1(x)
308
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
309
+ branch3x3dbl = [
310
+ self.branch3x3dbl_3a(branch3x3dbl),
311
+ self.branch3x3dbl_3b(branch3x3dbl),
312
+ ]
313
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
314
+
315
+ # Patch: The FID Inception model uses max pooling instead of average
316
+ # pooling. This is likely an error in this specific Inception
317
+ # implementation, as other Inception models use average pooling here
318
+ # (which matches the description in the paper).
319
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
320
+ branch_pool = self.branch_pool(branch_pool)
321
+
322
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
323
+ return torch.cat(outputs, 1)
DH-AISP/2/saicinpainting/evaluation/losses/lpips.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ############################################################
2
+ # The contents below have been combined using files in the #
3
+ # following repository: #
4
+ # https://github.com/richzhang/PerceptualSimilarity #
5
+ ############################################################
6
+
7
+ ############################################################
8
+ # __init__.py #
9
+ ############################################################
10
+
11
+ import numpy as np
12
+ from skimage.metrics import structural_similarity
13
+ import torch
14
+
15
+ from saicinpainting.utils import get_shape
16
+
17
+
18
+ class PerceptualLoss(torch.nn.Module):
19
+ def __init__(self, model='net-lin', net='alex', colorspace='rgb', model_path=None, spatial=False, use_gpu=True):
20
+ # VGG using our perceptually-learned weights (LPIPS metric)
21
+ # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
22
+ super(PerceptualLoss, self).__init__()
23
+ self.use_gpu = use_gpu
24
+ self.spatial = spatial
25
+ self.model = DistModel()
26
+ self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace,
27
+ model_path=model_path, spatial=self.spatial)
28
+
29
+ def forward(self, pred, target, normalize=True):
30
+ """
31
+ Pred and target are Variables.
32
+ If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
33
+ If normalize is False, assumes the images are already between [-1,+1]
34
+ Inputs pred and target are Nx3xHxW
35
+ Output pytorch Variable N long
36
+ """
37
+
38
+ if normalize:
39
+ target = 2 * target - 1
40
+ pred = 2 * pred - 1
41
+
42
+ return self.model(target, pred)
43
+
44
+
45
+ def normalize_tensor(in_feat, eps=1e-10):
46
+ norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
47
+ return in_feat / (norm_factor + eps)
48
+
49
+
50
+ def l2(p0, p1, range=255.):
51
+ return .5 * np.mean((p0 / range - p1 / range) ** 2)
52
+
53
+
54
+ def psnr(p0, p1, peak=255.):
55
+ return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2))
56
+
57
+
58
+ def dssim(p0, p1, range=255.):
59
+ return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
60
+
61
+
62
+ def rgb2lab(in_img, mean_cent=False):
63
+ from skimage import color
64
+ img_lab = color.rgb2lab(in_img)
65
+ if (mean_cent):
66
+ img_lab[:, :, 0] = img_lab[:, :, 0] - 50
67
+ return img_lab
68
+
69
+
70
+ def tensor2np(tensor_obj):
71
+ # change dimension of a tensor object into a numpy array
72
+ return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
73
+
74
+
75
+ def np2tensor(np_obj):
76
+ # change dimenion of np array into tensor array
77
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
78
+
79
+
80
+ def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
81
+ # image tensor to lab tensor
82
+ from skimage import color
83
+
84
+ img = tensor2im(image_tensor)
85
+ img_lab = color.rgb2lab(img)
86
+ if (mc_only):
87
+ img_lab[:, :, 0] = img_lab[:, :, 0] - 50
88
+ if (to_norm and not mc_only):
89
+ img_lab[:, :, 0] = img_lab[:, :, 0] - 50
90
+ img_lab = img_lab / 100.
91
+
92
+ return np2tensor(img_lab)
93
+
94
+
95
+ def tensorlab2tensor(lab_tensor, return_inbnd=False):
96
+ from skimage import color
97
+ import warnings
98
+ warnings.filterwarnings("ignore")
99
+
100
+ lab = tensor2np(lab_tensor) * 100.
101
+ lab[:, :, 0] = lab[:, :, 0] + 50
102
+
103
+ rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1)
104
+ if (return_inbnd):
105
+ # convert back to lab, see if we match
106
+ lab_back = color.rgb2lab(rgb_back.astype('uint8'))
107
+ mask = 1. * np.isclose(lab_back, lab, atol=2.)
108
+ mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
109
+ return (im2tensor(rgb_back), mask)
110
+ else:
111
+ return im2tensor(rgb_back)
112
+
113
+
114
+ def rgb2lab(input):
115
+ from skimage import color
116
+ return color.rgb2lab(input / 255.)
117
+
118
+
119
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
120
+ image_numpy = image_tensor[0].cpu().float().numpy()
121
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
122
+ return image_numpy.astype(imtype)
123
+
124
+
125
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
126
+ return torch.Tensor((image / factor - cent)
127
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
128
+
129
+
130
+ def tensor2vec(vector_tensor):
131
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
132
+
133
+
134
+ def voc_ap(rec, prec, use_07_metric=False):
135
+ """ ap = voc_ap(rec, prec, [use_07_metric])
136
+ Compute VOC AP given precision and recall.
137
+ If use_07_metric is true, uses the
138
+ VOC 07 11 point method (default:False).
139
+ """
140
+ if use_07_metric:
141
+ # 11 point metric
142
+ ap = 0.
143
+ for t in np.arange(0., 1.1, 0.1):
144
+ if np.sum(rec >= t) == 0:
145
+ p = 0
146
+ else:
147
+ p = np.max(prec[rec >= t])
148
+ ap = ap + p / 11.
149
+ else:
150
+ # correct AP calculation
151
+ # first append sentinel values at the end
152
+ mrec = np.concatenate(([0.], rec, [1.]))
153
+ mpre = np.concatenate(([0.], prec, [0.]))
154
+
155
+ # compute the precision envelope
156
+ for i in range(mpre.size - 1, 0, -1):
157
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
158
+
159
+ # to calculate area under PR curve, look for points
160
+ # where X axis (recall) changes value
161
+ i = np.where(mrec[1:] != mrec[:-1])[0]
162
+
163
+ # and sum (\Delta recall) * prec
164
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
165
+ return ap
166
+
167
+
168
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
169
+ # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
170
+ image_numpy = image_tensor[0].cpu().float().numpy()
171
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
172
+ return image_numpy.astype(imtype)
173
+
174
+
175
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
176
+ # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
177
+ return torch.Tensor((image / factor - cent)
178
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
179
+
180
+
181
+ ############################################################
182
+ # base_model.py #
183
+ ############################################################
184
+
185
+
186
+ class BaseModel(torch.nn.Module):
187
+ def __init__(self):
188
+ super().__init__()
189
+
190
+ def name(self):
191
+ return 'BaseModel'
192
+
193
+ def initialize(self, use_gpu=True):
194
+ self.use_gpu = use_gpu
195
+
196
+ def forward(self):
197
+ pass
198
+
199
+ def get_image_paths(self):
200
+ pass
201
+
202
+ def optimize_parameters(self):
203
+ pass
204
+
205
+ def get_current_visuals(self):
206
+ return self.input
207
+
208
+ def get_current_errors(self):
209
+ return {}
210
+
211
+ def save(self, label):
212
+ pass
213
+
214
+ # helper saving function that can be used by subclasses
215
+ def save_network(self, network, path, network_label, epoch_label):
216
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
217
+ save_path = os.path.join(path, save_filename)
218
+ torch.save(network.state_dict(), save_path)
219
+
220
+ # helper loading function that can be used by subclasses
221
+ def load_network(self, network, network_label, epoch_label):
222
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
223
+ save_path = os.path.join(self.save_dir, save_filename)
224
+ print('Loading network from %s' % save_path)
225
+ network.load_state_dict(torch.load(save_path, map_location='cpu'))
226
+
227
+ def update_learning_rate():
228
+ pass
229
+
230
+ def get_image_paths(self):
231
+ return self.image_paths
232
+
233
+ def save_done(self, flag=False):
234
+ np.save(os.path.join(self.save_dir, 'done_flag'), flag)
235
+ np.savetxt(os.path.join(self.save_dir, 'done_flag'), [flag, ], fmt='%i')
236
+
237
+
238
+ ############################################################
239
+ # dist_model.py #
240
+ ############################################################
241
+
242
+ import os
243
+ from collections import OrderedDict
244
+ from scipy.ndimage import zoom
245
+ from tqdm import tqdm
246
+
247
+
248
+ class DistModel(BaseModel):
249
+ def name(self):
250
+ return self.model_name
251
+
252
+ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False,
253
+ model_path=None,
254
+ use_gpu=True, printNet=False, spatial=False,
255
+ is_train=False, lr=.0001, beta1=0.5, version='0.1'):
256
+ '''
257
+ INPUTS
258
+ model - ['net-lin'] for linearly calibrated network
259
+ ['net'] for off-the-shelf network
260
+ ['L2'] for L2 distance in Lab colorspace
261
+ ['SSIM'] for ssim in RGB colorspace
262
+ net - ['squeeze','alex','vgg']
263
+ model_path - if None, will look in weights/[NET_NAME].pth
264
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
265
+ use_gpu - bool - whether or not to use a GPU
266
+ printNet - bool - whether or not to print network architecture out
267
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
268
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
269
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
270
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
271
+ is_train - bool - [True] for training mode
272
+ lr - float - initial learning rate
273
+ beta1 - float - initial momentum term for adam
274
+ version - 0.1 for latest, 0.0 was original (with a bug)
275
+ '''
276
+ BaseModel.initialize(self, use_gpu=use_gpu)
277
+
278
+ self.model = model
279
+ self.net = net
280
+ self.is_train = is_train
281
+ self.spatial = spatial
282
+ self.model_name = '%s [%s]' % (model, net)
283
+
284
+ if (self.model == 'net-lin'): # pretrained net + linear layer
285
+ self.net = PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
286
+ use_dropout=True, spatial=spatial, version=version, lpips=True)
287
+ kw = dict(map_location='cpu')
288
+ if (model_path is None):
289
+ import inspect
290
+ model_path = os.path.abspath(
291
+ os.path.join(os.path.dirname(__file__), '..', '..', '..', 'models', 'lpips_models', f'{net}.pth'))
292
+
293
+ if (not is_train):
294
+ self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
295
+
296
+ elif (self.model == 'net'): # pretrained network
297
+ self.net = PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
298
+ elif (self.model in ['L2', 'l2']):
299
+ self.net = L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing
300
+ self.model_name = 'L2'
301
+ elif (self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
302
+ self.net = DSSIM(use_gpu=use_gpu, colorspace=colorspace)
303
+ self.model_name = 'SSIM'
304
+ else:
305
+ raise ValueError("Model [%s] not recognized." % self.model)
306
+
307
+ self.trainable_parameters = list(self.net.parameters())
308
+
309
+ if self.is_train: # training mode
310
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
311
+ self.rankLoss = BCERankingLoss()
312
+ self.trainable_parameters += list(self.rankLoss.net.parameters())
313
+ self.lr = lr
314
+ self.old_lr = lr
315
+ self.optimizer_net = torch.optim.Adam(self.trainable_parameters, lr=lr, betas=(beta1, 0.999))
316
+ else: # test mode
317
+ self.net.eval()
318
+
319
+ # if (use_gpu):
320
+ # self.net.to(gpu_ids[0])
321
+ # self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
322
+ # if (self.is_train):
323
+ # self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
324
+
325
+ if (printNet):
326
+ print('---------- Networks initialized -------------')
327
+ print_network(self.net)
328
+ print('-----------------------------------------------')
329
+
330
+ def forward(self, in0, in1, retPerLayer=False):
331
+ ''' Function computes the distance between image patches in0 and in1
332
+ INPUTS
333
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
334
+ OUTPUT
335
+ computed distances between in0 and in1
336
+ '''
337
+
338
+ return self.net(in0, in1, retPerLayer=retPerLayer)
339
+
340
+ # ***** TRAINING FUNCTIONS *****
341
+ def optimize_parameters(self):
342
+ self.forward_train()
343
+ self.optimizer_net.zero_grad()
344
+ self.backward_train()
345
+ self.optimizer_net.step()
346
+ self.clamp_weights()
347
+
348
+ def clamp_weights(self):
349
+ for module in self.net.modules():
350
+ if (hasattr(module, 'weight') and module.kernel_size == (1, 1)):
351
+ module.weight.data = torch.clamp(module.weight.data, min=0)
352
+
353
+ def set_input(self, data):
354
+ self.input_ref = data['ref']
355
+ self.input_p0 = data['p0']
356
+ self.input_p1 = data['p1']
357
+ self.input_judge = data['judge']
358
+
359
+ # if (self.use_gpu):
360
+ # self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
361
+ # self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
362
+ # self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
363
+ # self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
364
+
365
+ # self.var_ref = Variable(self.input_ref, requires_grad=True)
366
+ # self.var_p0 = Variable(self.input_p0, requires_grad=True)
367
+ # self.var_p1 = Variable(self.input_p1, requires_grad=True)
368
+
369
+ def forward_train(self): # run forward pass
370
+ # print(self.net.module.scaling_layer.shift)
371
+ # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
372
+
373
+ assert False, "We shoud've not get here when using LPIPS as a metric"
374
+
375
+ self.d0 = self(self.var_ref, self.var_p0)
376
+ self.d1 = self(self.var_ref, self.var_p1)
377
+ self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
378
+
379
+ self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
380
+
381
+ self.loss_total = self.rankLoss(self.d0, self.d1, self.var_judge * 2. - 1.)
382
+
383
+ return self.loss_total
384
+
385
+ def backward_train(self):
386
+ torch.mean(self.loss_total).backward()
387
+
388
+ def compute_accuracy(self, d0, d1, judge):
389
+ ''' d0, d1 are Variables, judge is a Tensor '''
390
+ d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
391
+ judge_per = judge.cpu().numpy().flatten()
392
+ return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
393
+
394
+ def get_current_errors(self):
395
+ retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
396
+ ('acc_r', self.acc_r)])
397
+
398
+ for key in retDict.keys():
399
+ retDict[key] = np.mean(retDict[key])
400
+
401
+ return retDict
402
+
403
+ def get_current_visuals(self):
404
+ zoom_factor = 256 / self.var_ref.data.size()[2]
405
+
406
+ ref_img = tensor2im(self.var_ref.data)
407
+ p0_img = tensor2im(self.var_p0.data)
408
+ p1_img = tensor2im(self.var_p1.data)
409
+
410
+ ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
411
+ p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
412
+ p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
413
+
414
+ return OrderedDict([('ref', ref_img_vis),
415
+ ('p0', p0_img_vis),
416
+ ('p1', p1_img_vis)])
417
+
418
+ def save(self, path, label):
419
+ if (self.use_gpu):
420
+ self.save_network(self.net.module, path, '', label)
421
+ else:
422
+ self.save_network(self.net, path, '', label)
423
+ self.save_network(self.rankLoss.net, path, 'rank', label)
424
+
425
+ def update_learning_rate(self, nepoch_decay):
426
+ lrd = self.lr / nepoch_decay
427
+ lr = self.old_lr - lrd
428
+
429
+ for param_group in self.optimizer_net.param_groups:
430
+ param_group['lr'] = lr
431
+
432
+ print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
433
+ self.old_lr = lr
434
+
435
+
436
+ def score_2afc_dataset(data_loader, func, name=''):
437
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
438
+ distance function 'func' in dataset 'data_loader'
439
+ INPUTS
440
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
441
+ func - callable distance function - calling d=func(in0,in1) should take 2
442
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
443
+ OUTPUTS
444
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
445
+ [1] - dictionary with following elements
446
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
447
+ gts - N array in [0,1], preferred patch selected by human evaluators
448
+ (closer to "0" for left patch p0, "1" for right patch p1,
449
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
450
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
451
+ CONSTS
452
+ N - number of test triplets in data_loader
453
+ '''
454
+
455
+ d0s = []
456
+ d1s = []
457
+ gts = []
458
+
459
+ for data in tqdm(data_loader.load_data(), desc=name):
460
+ d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
461
+ d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
462
+ gts += data['judge'].cpu().numpy().flatten().tolist()
463
+
464
+ d0s = np.array(d0s)
465
+ d1s = np.array(d1s)
466
+ gts = np.array(gts)
467
+ scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
468
+
469
+ return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
470
+
471
+
472
+ def score_jnd_dataset(data_loader, func, name=''):
473
+ ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
474
+ INPUTS
475
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
476
+ func - callable distance function - calling d=func(in0,in1) should take 2
477
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
478
+ OUTPUTS
479
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
480
+ [1] - dictionary with following elements
481
+ ds - N array containing distances between two patches shown to human evaluator
482
+ sames - N array containing fraction of people who thought the two patches were identical
483
+ CONSTS
484
+ N - number of test triplets in data_loader
485
+ '''
486
+
487
+ ds = []
488
+ gts = []
489
+
490
+ for data in tqdm(data_loader.load_data(), desc=name):
491
+ ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
492
+ gts += data['same'].cpu().numpy().flatten().tolist()
493
+
494
+ sames = np.array(gts)
495
+ ds = np.array(ds)
496
+
497
+ sorted_inds = np.argsort(ds)
498
+ ds_sorted = ds[sorted_inds]
499
+ sames_sorted = sames[sorted_inds]
500
+
501
+ TPs = np.cumsum(sames_sorted)
502
+ FPs = np.cumsum(1 - sames_sorted)
503
+ FNs = np.sum(sames_sorted) - TPs
504
+
505
+ precs = TPs / (TPs + FPs)
506
+ recs = TPs / (TPs + FNs)
507
+ score = voc_ap(recs, precs)
508
+
509
+ return (score, dict(ds=ds, sames=sames))
510
+
511
+
512
+ ############################################################
513
+ # networks_basic.py #
514
+ ############################################################
515
+
516
+ import torch.nn as nn
517
+ from torch.autograd import Variable
518
+ import numpy as np
519
+
520
+
521
+ def spatial_average(in_tens, keepdim=True):
522
+ return in_tens.mean([2, 3], keepdim=keepdim)
523
+
524
+
525
+ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
526
+ in_H = in_tens.shape[2]
527
+ scale_factor = 1. * out_H / in_H
528
+
529
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
530
+
531
+
532
+ # Learned perceptual metric
533
+ class PNetLin(nn.Module):
534
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False,
535
+ version='0.1', lpips=True):
536
+ super(PNetLin, self).__init__()
537
+
538
+ self.pnet_type = pnet_type
539
+ self.pnet_tune = pnet_tune
540
+ self.pnet_rand = pnet_rand
541
+ self.spatial = spatial
542
+ self.lpips = lpips
543
+ self.version = version
544
+ self.scaling_layer = ScalingLayer()
545
+
546
+ if (self.pnet_type in ['vgg', 'vgg16']):
547
+ net_type = vgg16
548
+ self.chns = [64, 128, 256, 512, 512]
549
+ elif (self.pnet_type == 'alex'):
550
+ net_type = alexnet
551
+ self.chns = [64, 192, 384, 256, 256]
552
+ elif (self.pnet_type == 'squeeze'):
553
+ net_type = squeezenet
554
+ self.chns = [64, 128, 256, 384, 384, 512, 512]
555
+ self.L = len(self.chns)
556
+
557
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
558
+
559
+ if (lpips):
560
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
561
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
562
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
563
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
564
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
565
+ self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
566
+ if (self.pnet_type == 'squeeze'): # 7 layers for squeezenet
567
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
568
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
569
+ self.lins += [self.lin5, self.lin6]
570
+
571
+ def forward(self, in0, in1, retPerLayer=False):
572
+ # v0.0 - original release had a bug, where input was not scaled
573
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else (
574
+ in0, in1)
575
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
576
+ feats0, feats1, diffs = {}, {}, {}
577
+
578
+ for kk in range(self.L):
579
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
580
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
581
+
582
+ if (self.lpips):
583
+ if (self.spatial):
584
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
585
+ else:
586
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
587
+ else:
588
+ if (self.spatial):
589
+ res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
590
+ else:
591
+ res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
592
+
593
+ val = res[0]
594
+ for l in range(1, self.L):
595
+ val += res[l]
596
+
597
+ if (retPerLayer):
598
+ return (val, res)
599
+ else:
600
+ return val
601
+
602
+
603
+ class ScalingLayer(nn.Module):
604
+ def __init__(self):
605
+ super(ScalingLayer, self).__init__()
606
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
607
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
608
+
609
+ def forward(self, inp):
610
+ return (inp - self.shift) / self.scale
611
+
612
+
613
+ class NetLinLayer(nn.Module):
614
+ ''' A single linear layer which does a 1x1 conv '''
615
+
616
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
617
+ super(NetLinLayer, self).__init__()
618
+
619
+ layers = [nn.Dropout(), ] if (use_dropout) else []
620
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
621
+ self.model = nn.Sequential(*layers)
622
+
623
+
624
+ class Dist2LogitLayer(nn.Module):
625
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
626
+
627
+ def __init__(self, chn_mid=32, use_sigmoid=True):
628
+ super(Dist2LogitLayer, self).__init__()
629
+
630
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ]
631
+ layers += [nn.LeakyReLU(0.2, True), ]
632
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ]
633
+ layers += [nn.LeakyReLU(0.2, True), ]
634
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ]
635
+ if (use_sigmoid):
636
+ layers += [nn.Sigmoid(), ]
637
+ self.model = nn.Sequential(*layers)
638
+
639
+ def forward(self, d0, d1, eps=0.1):
640
+ return self.model(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1))
641
+
642
+
643
+ class BCERankingLoss(nn.Module):
644
+ def __init__(self, chn_mid=32):
645
+ super(BCERankingLoss, self).__init__()
646
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
647
+ # self.parameters = list(self.net.parameters())
648
+ self.loss = torch.nn.BCELoss()
649
+
650
+ def forward(self, d0, d1, judge):
651
+ per = (judge + 1.) / 2.
652
+ self.logit = self.net(d0, d1)
653
+ return self.loss(self.logit, per)
654
+
655
+
656
+ # L2, DSSIM metrics
657
+ class FakeNet(nn.Module):
658
+ def __init__(self, use_gpu=True, colorspace='Lab'):
659
+ super(FakeNet, self).__init__()
660
+ self.use_gpu = use_gpu
661
+ self.colorspace = colorspace
662
+
663
+
664
+ class L2(FakeNet):
665
+
666
+ def forward(self, in0, in1, retPerLayer=None):
667
+ assert (in0.size()[0] == 1) # currently only supports batchSize 1
668
+
669
+ if (self.colorspace == 'RGB'):
670
+ (N, C, X, Y) = in0.size()
671
+ value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y),
672
+ dim=3).view(N)
673
+ return value
674
+ elif (self.colorspace == 'Lab'):
675
+ value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
676
+ tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float')
677
+ ret_var = Variable(torch.Tensor((value,)))
678
+ # if (self.use_gpu):
679
+ # ret_var = ret_var.cuda()
680
+ return ret_var
681
+
682
+
683
+ class DSSIM(FakeNet):
684
+
685
+ def forward(self, in0, in1, retPerLayer=None):
686
+ assert (in0.size()[0] == 1) # currently only supports batchSize 1
687
+
688
+ if (self.colorspace == 'RGB'):
689
+ value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype('float')
690
+ elif (self.colorspace == 'Lab'):
691
+ value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
692
+ tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float')
693
+ ret_var = Variable(torch.Tensor((value,)))
694
+ # if (self.use_gpu):
695
+ # ret_var = ret_var.cuda()
696
+ return ret_var
697
+
698
+
699
+ def print_network(net):
700
+ num_params = 0
701
+ for param in net.parameters():
702
+ num_params += param.numel()
703
+ print('Network', net)
704
+ print('Total number of parameters: %d' % num_params)
705
+
706
+
707
+ ############################################################
708
+ # pretrained_networks.py #
709
+ ############################################################
710
+
711
+ from collections import namedtuple
712
+ import torch
713
+ from torchvision import models as tv
714
+
715
+
716
+ class squeezenet(torch.nn.Module):
717
+ def __init__(self, requires_grad=False, pretrained=True):
718
+ super(squeezenet, self).__init__()
719
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
720
+ self.slice1 = torch.nn.Sequential()
721
+ self.slice2 = torch.nn.Sequential()
722
+ self.slice3 = torch.nn.Sequential()
723
+ self.slice4 = torch.nn.Sequential()
724
+ self.slice5 = torch.nn.Sequential()
725
+ self.slice6 = torch.nn.Sequential()
726
+ self.slice7 = torch.nn.Sequential()
727
+ self.N_slices = 7
728
+ for x in range(2):
729
+ self.slice1.add_module(str(x), pretrained_features[x])
730
+ for x in range(2, 5):
731
+ self.slice2.add_module(str(x), pretrained_features[x])
732
+ for x in range(5, 8):
733
+ self.slice3.add_module(str(x), pretrained_features[x])
734
+ for x in range(8, 10):
735
+ self.slice4.add_module(str(x), pretrained_features[x])
736
+ for x in range(10, 11):
737
+ self.slice5.add_module(str(x), pretrained_features[x])
738
+ for x in range(11, 12):
739
+ self.slice6.add_module(str(x), pretrained_features[x])
740
+ for x in range(12, 13):
741
+ self.slice7.add_module(str(x), pretrained_features[x])
742
+ if not requires_grad:
743
+ for param in self.parameters():
744
+ param.requires_grad = False
745
+
746
+ def forward(self, X):
747
+ h = self.slice1(X)
748
+ h_relu1 = h
749
+ h = self.slice2(h)
750
+ h_relu2 = h
751
+ h = self.slice3(h)
752
+ h_relu3 = h
753
+ h = self.slice4(h)
754
+ h_relu4 = h
755
+ h = self.slice5(h)
756
+ h_relu5 = h
757
+ h = self.slice6(h)
758
+ h_relu6 = h
759
+ h = self.slice7(h)
760
+ h_relu7 = h
761
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])
762
+ out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
763
+
764
+ return out
765
+
766
+
767
+ class alexnet(torch.nn.Module):
768
+ def __init__(self, requires_grad=False, pretrained=True):
769
+ super(alexnet, self).__init__()
770
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
771
+ self.slice1 = torch.nn.Sequential()
772
+ self.slice2 = torch.nn.Sequential()
773
+ self.slice3 = torch.nn.Sequential()
774
+ self.slice4 = torch.nn.Sequential()
775
+ self.slice5 = torch.nn.Sequential()
776
+ self.N_slices = 5
777
+ for x in range(2):
778
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
779
+ for x in range(2, 5):
780
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
781
+ for x in range(5, 8):
782
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
783
+ for x in range(8, 10):
784
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
785
+ for x in range(10, 12):
786
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
787
+ if not requires_grad:
788
+ for param in self.parameters():
789
+ param.requires_grad = False
790
+
791
+ def forward(self, X):
792
+ h = self.slice1(X)
793
+ h_relu1 = h
794
+ h = self.slice2(h)
795
+ h_relu2 = h
796
+ h = self.slice3(h)
797
+ h_relu3 = h
798
+ h = self.slice4(h)
799
+ h_relu4 = h
800
+ h = self.slice5(h)
801
+ h_relu5 = h
802
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
803
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
804
+
805
+ return out
806
+
807
+
808
+ class vgg16(torch.nn.Module):
809
+ def __init__(self, requires_grad=False, pretrained=True):
810
+ super(vgg16, self).__init__()
811
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
812
+ self.slice1 = torch.nn.Sequential()
813
+ self.slice2 = torch.nn.Sequential()
814
+ self.slice3 = torch.nn.Sequential()
815
+ self.slice4 = torch.nn.Sequential()
816
+ self.slice5 = torch.nn.Sequential()
817
+ self.N_slices = 5
818
+ for x in range(4):
819
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
820
+ for x in range(4, 9):
821
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
822
+ for x in range(9, 16):
823
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
824
+ for x in range(16, 23):
825
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
826
+ for x in range(23, 30):
827
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
828
+ if not requires_grad:
829
+ for param in self.parameters():
830
+ param.requires_grad = False
831
+
832
+ def forward(self, X):
833
+ h = self.slice1(X)
834
+ h_relu1_2 = h
835
+ h = self.slice2(h)
836
+ h_relu2_2 = h
837
+ h = self.slice3(h)
838
+ h_relu3_3 = h
839
+ h = self.slice4(h)
840
+ h_relu4_3 = h
841
+ h = self.slice5(h)
842
+ h_relu5_3 = h
843
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
844
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
845
+
846
+ return out
847
+
848
+
849
+ class resnet(torch.nn.Module):
850
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
851
+ super(resnet, self).__init__()
852
+ if (num == 18):
853
+ self.net = tv.resnet18(pretrained=pretrained)
854
+ elif (num == 34):
855
+ self.net = tv.resnet34(pretrained=pretrained)
856
+ elif (num == 50):
857
+ self.net = tv.resnet50(pretrained=pretrained)
858
+ elif (num == 101):
859
+ self.net = tv.resnet101(pretrained=pretrained)
860
+ elif (num == 152):
861
+ self.net = tv.resnet152(pretrained=pretrained)
862
+ self.N_slices = 5
863
+
864
+ self.conv1 = self.net.conv1
865
+ self.bn1 = self.net.bn1
866
+ self.relu = self.net.relu
867
+ self.maxpool = self.net.maxpool
868
+ self.layer1 = self.net.layer1
869
+ self.layer2 = self.net.layer2
870
+ self.layer3 = self.net.layer3
871
+ self.layer4 = self.net.layer4
872
+
873
+ def forward(self, X):
874
+ h = self.conv1(X)
875
+ h = self.bn1(h)
876
+ h = self.relu(h)
877
+ h_relu1 = h
878
+ h = self.maxpool(h)
879
+ h = self.layer1(h)
880
+ h_conv2 = h
881
+ h = self.layer2(h)
882
+ h_conv3 = h
883
+ h = self.layer3(h)
884
+ h_conv4 = h
885
+ h = self.layer4(h)
886
+ h_conv5 = h
887
+
888
+ outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
889
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
890
+
891
+ return out
DH-AISP/2/saicinpainting/evaluation/losses/ssim.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SSIM(torch.nn.Module):
7
+ """SSIM. Modified from:
8
+ https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
9
+ """
10
+
11
+ def __init__(self, window_size=11, size_average=True):
12
+ super().__init__()
13
+ self.window_size = window_size
14
+ self.size_average = size_average
15
+ self.channel = 1
16
+ self.register_buffer('window', self._create_window(window_size, self.channel))
17
+
18
+ def forward(self, img1, img2):
19
+ assert len(img1.shape) == 4
20
+
21
+ channel = img1.size()[1]
22
+
23
+ if channel == self.channel and self.window.data.type() == img1.data.type():
24
+ window = self.window
25
+ else:
26
+ window = self._create_window(self.window_size, channel)
27
+
28
+ # window = window.to(img1.get_device())
29
+ window = window.type_as(img1)
30
+
31
+ self.window = window
32
+ self.channel = channel
33
+
34
+ return self._ssim(img1, img2, window, self.window_size, channel, self.size_average)
35
+
36
+ def _gaussian(self, window_size, sigma):
37
+ gauss = torch.Tensor([
38
+ np.exp(-(x - (window_size // 2)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)
39
+ ])
40
+ return gauss / gauss.sum()
41
+
42
+ def _create_window(self, window_size, channel):
43
+ _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1)
44
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
45
+ return _2D_window.expand(channel, 1, window_size, window_size).contiguous()
46
+
47
+ def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
48
+ mu1 = F.conv2d(img1, window, padding=(window_size // 2), groups=channel)
49
+ mu2 = F.conv2d(img2, window, padding=(window_size // 2), groups=channel)
50
+
51
+ mu1_sq = mu1.pow(2)
52
+ mu2_sq = mu2.pow(2)
53
+ mu1_mu2 = mu1 * mu2
54
+
55
+ sigma1_sq = F.conv2d(
56
+ img1 * img1, window, padding=(window_size // 2), groups=channel) - mu1_sq
57
+ sigma2_sq = F.conv2d(
58
+ img2 * img2, window, padding=(window_size // 2), groups=channel) - mu2_sq
59
+ sigma12 = F.conv2d(
60
+ img1 * img2, window, padding=(window_size // 2), groups=channel) - mu1_mu2
61
+
62
+ C1 = 0.01 ** 2
63
+ C2 = 0.03 ** 2
64
+
65
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
66
+ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
67
+
68
+ if size_average:
69
+ return ssim_map.mean()
70
+
71
+ return ssim_map.mean(1).mean(1).mean(1)
72
+
73
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
74
+ return
DH-AISP/2/saicinpainting/evaluation/masks/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Current algorithm
2
+
3
+ ## Choice of mask objects
4
+
5
+ For identification of the objects which are suitable for mask obtaining, panoptic segmentation model
6
+ from [detectron2](https://github.com/facebookresearch/detectron2) trained on COCO. Categories of the detected instances
7
+ belong either to "stuff" or "things" types. We consider that instances of objects should have category belong
8
+ to "things". Besides, we set upper bound on area which is taken by the object &mdash; we consider that too big
9
+ area indicates either of the instance being a background or a main object which should not be removed.
10
+
11
+ ## Choice of position for mask
12
+
13
+ We consider that input image has size 2^n x 2^m. We downsample it using
14
+ [COUNTLESS](https://github.com/william-silversmith/countless) algorithm so the width is equal to
15
+ 64 = 2^8 = 2^{downsample_levels}.
16
+
17
+ ### Augmentation
18
+
19
+ There are several parameters for augmentation:
20
+ - Scaling factor. We limit scaling to the case when a mask after scaling with pivot point in its center fits inside the
21
+ image completely.
22
+ -
23
+
24
+ ### Shift
25
+
26
+
27
+ ## Select
DH-AISP/2/saicinpainting/evaluation/masks/__init__.py ADDED
File without changes
DH-AISP/2/saicinpainting/evaluation/masks/countless/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![Build Status](https://travis-ci.org/william-silversmith/countless.svg?branch=master)](https://travis-ci.org/william-silversmith/countless)
2
+
3
+ Python COUNTLESS Downsampling
4
+ =============================
5
+
6
+ To install:
7
+
8
+ `pip install -r requirements.txt`
9
+
10
+ To test:
11
+
12
+ `python test.py`
13
+
14
+ To benchmark countless2d:
15
+
16
+ `python python/countless2d.py python/images/gray_segmentation.png`
17
+
18
+ To benchmark countless3d:
19
+
20
+ `python python/countless3d.py`
21
+
22
+ Adjust N and the list of algorithms inside each script to modify the run parameters.
23
+
24
+
25
+ Python3 is slightly faster than Python2.
DH-AISP/2/saicinpainting/evaluation/masks/countless/__init__.py ADDED
File without changes
DH-AISP/2/saicinpainting/evaluation/masks/countless/countless2d.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+
3
+ """
4
+ COUNTLESS performance test in Python.
5
+
6
+ python countless2d.py ./images/NAMEOFIMAGE
7
+ """
8
+
9
+ import six
10
+ from six.moves import range
11
+ from collections import defaultdict
12
+ from functools import reduce
13
+ import operator
14
+ import io
15
+ import os
16
+ from PIL import Image
17
+ import math
18
+ import numpy as np
19
+ import random
20
+ import sys
21
+ import time
22
+ from tqdm import tqdm
23
+ from scipy import ndimage
24
+
25
+ def simplest_countless(data):
26
+ """
27
+ Vectorized implementation of downsampling a 2D
28
+ image by 2 on each side using the COUNTLESS algorithm.
29
+
30
+ data is a 2D numpy array with even dimensions.
31
+ """
32
+ sections = []
33
+
34
+ # This loop splits the 2D array apart into four arrays that are
35
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
36
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
37
+ factor = (2,2)
38
+ for offset in np.ndindex(factor):
39
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
40
+ sections.append(part)
41
+
42
+ a, b, c, d = sections
43
+
44
+ ab = a * (a == b) # PICK(A,B)
45
+ ac = a * (a == c) # PICK(A,C)
46
+ bc = b * (b == c) # PICK(B,C)
47
+
48
+ a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed
49
+
50
+ return a + (a == 0) * d # AB || AC || BC || D
51
+
52
+ def quick_countless(data):
53
+ """
54
+ Vectorized implementation of downsampling a 2D
55
+ image by 2 on each side using the COUNTLESS algorithm.
56
+
57
+ data is a 2D numpy array with even dimensions.
58
+ """
59
+ sections = []
60
+
61
+ # This loop splits the 2D array apart into four arrays that are
62
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
63
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
64
+ factor = (2,2)
65
+ for offset in np.ndindex(factor):
66
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
67
+ sections.append(part)
68
+
69
+ a, b, c, d = sections
70
+
71
+ ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
72
+ bc = b * (b == c) # PICK(B,C)
73
+
74
+ a = ab_ac | bc # (PICK(A,B) || PICK(A,C)) or PICK(B,C)
75
+ return a + (a == 0) * d # AB || AC || BC || D
76
+
77
+ def quickest_countless(data):
78
+ """
79
+ Vectorized implementation of downsampling a 2D
80
+ image by 2 on each side using the COUNTLESS algorithm.
81
+
82
+ data is a 2D numpy array with even dimensions.
83
+ """
84
+ sections = []
85
+
86
+ # This loop splits the 2D array apart into four arrays that are
87
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
88
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
89
+ factor = (2,2)
90
+ for offset in np.ndindex(factor):
91
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
92
+ sections.append(part)
93
+
94
+ a, b, c, d = sections
95
+
96
+ ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
97
+ ab_ac |= b * (b == c) # PICK(B,C)
98
+ return ab_ac + (ab_ac == 0) * d # AB || AC || BC || D
99
+
100
+ def quick_countless_xor(data):
101
+ """
102
+ Vectorized implementation of downsampling a 2D
103
+ image by 2 on each side using the COUNTLESS algorithm.
104
+
105
+ data is a 2D numpy array with even dimensions.
106
+ """
107
+ sections = []
108
+
109
+ # This loop splits the 2D array apart into four arrays that are
110
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
111
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
112
+ factor = (2,2)
113
+ for offset in np.ndindex(factor):
114
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
115
+ sections.append(part)
116
+
117
+ a, b, c, d = sections
118
+
119
+ ab = a ^ (a ^ b) # a or b
120
+ ab += (ab != a) * ((ab ^ (ab ^ c)) - b) # b or c
121
+ ab += (ab == c) * ((ab ^ (ab ^ d)) - c) # c or d
122
+ return ab
123
+
124
+ def stippled_countless(data):
125
+ """
126
+ Vectorized implementation of downsampling a 2D
127
+ image by 2 on each side using the COUNTLESS algorithm
128
+ that treats zero as "background" and inflates lone
129
+ pixels.
130
+
131
+ data is a 2D numpy array with even dimensions.
132
+ """
133
+ sections = []
134
+
135
+ # This loop splits the 2D array apart into four arrays that are
136
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
137
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
138
+ factor = (2,2)
139
+ for offset in np.ndindex(factor):
140
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
141
+ sections.append(part)
142
+
143
+ a, b, c, d = sections
144
+
145
+ ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
146
+ ab_ac |= b * (b == c) # PICK(B,C)
147
+
148
+ nonzero = a + (a == 0) * (b + (b == 0) * c)
149
+ return ab_ac + (ab_ac == 0) * (d + (d == 0) * nonzero) # AB || AC || BC || D
150
+
151
+ def zero_corrected_countless(data):
152
+ """
153
+ Vectorized implementation of downsampling a 2D
154
+ image by 2 on each side using the COUNTLESS algorithm.
155
+
156
+ data is a 2D numpy array with even dimensions.
157
+ """
158
+ # allows us to prevent losing 1/2 a bit of information
159
+ # at the top end by using a bigger type. Without this 255 is handled incorrectly.
160
+ data, upgraded = upgrade_type(data)
161
+
162
+ # offset from zero, raw countless doesn't handle 0 correctly
163
+ # we'll remove the extra 1 at the end.
164
+ data += 1
165
+
166
+ sections = []
167
+
168
+ # This loop splits the 2D array apart into four arrays that are
169
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
170
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
171
+ factor = (2,2)
172
+ for offset in np.ndindex(factor):
173
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
174
+ sections.append(part)
175
+
176
+ a, b, c, d = sections
177
+
178
+ ab = a * (a == b) # PICK(A,B)
179
+ ac = a * (a == c) # PICK(A,C)
180
+ bc = b * (b == c) # PICK(B,C)
181
+
182
+ a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed
183
+
184
+ result = a + (a == 0) * d - 1 # a or d - 1
185
+
186
+ if upgraded:
187
+ return downgrade_type(result)
188
+
189
+ # only need to reset data if we weren't upgraded
190
+ # b/c no copy was made in that case
191
+ data -= 1
192
+
193
+ return result
194
+
195
+ def countless_extreme(data):
196
+ nonzeros = np.count_nonzero(data)
197
+ # print("nonzeros", nonzeros)
198
+
199
+ N = reduce(operator.mul, data.shape)
200
+
201
+ if nonzeros == N:
202
+ print("quick")
203
+ return quick_countless(data)
204
+ elif np.count_nonzero(data + 1) == N:
205
+ print("quick")
206
+ # print("upper", nonzeros)
207
+ return quick_countless(data)
208
+ else:
209
+ return countless(data)
210
+
211
+
212
+ def countless(data):
213
+ """
214
+ Vectorized implementation of downsampling a 2D
215
+ image by 2 on each side using the COUNTLESS algorithm.
216
+
217
+ data is a 2D numpy array with even dimensions.
218
+ """
219
+ # allows us to prevent losing 1/2 a bit of information
220
+ # at the top end by using a bigger type. Without this 255 is handled incorrectly.
221
+ data, upgraded = upgrade_type(data)
222
+
223
+ # offset from zero, raw countless doesn't handle 0 correctly
224
+ # we'll remove the extra 1 at the end.
225
+ data += 1
226
+
227
+ sections = []
228
+
229
+ # This loop splits the 2D array apart into four arrays that are
230
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
231
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
232
+ factor = (2,2)
233
+ for offset in np.ndindex(factor):
234
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
235
+ sections.append(part)
236
+
237
+ a, b, c, d = sections
238
+
239
+ ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
240
+ ab_ac |= b * (b == c) # PICK(B,C)
241
+ result = ab_ac + (ab_ac == 0) * d - 1 # (matches or d) - 1
242
+
243
+ if upgraded:
244
+ return downgrade_type(result)
245
+
246
+ # only need to reset data if we weren't upgraded
247
+ # b/c no copy was made in that case
248
+ data -= 1
249
+
250
+ return result
251
+
252
+ def upgrade_type(arr):
253
+ dtype = arr.dtype
254
+
255
+ if dtype == np.uint8:
256
+ return arr.astype(np.uint16), True
257
+ elif dtype == np.uint16:
258
+ return arr.astype(np.uint32), True
259
+ elif dtype == np.uint32:
260
+ return arr.astype(np.uint64), True
261
+
262
+ return arr, False
263
+
264
+ def downgrade_type(arr):
265
+ dtype = arr.dtype
266
+
267
+ if dtype == np.uint64:
268
+ return arr.astype(np.uint32)
269
+ elif dtype == np.uint32:
270
+ return arr.astype(np.uint16)
271
+ elif dtype == np.uint16:
272
+ return arr.astype(np.uint8)
273
+
274
+ return arr
275
+
276
+ def odd_to_even(image):
277
+ """
278
+ To facilitate 2x2 downsampling segmentation, change an odd sized image into an even sized one.
279
+ Works by mirroring the starting 1 pixel edge of the image on odd shaped sides.
280
+
281
+ e.g. turn a 3x3x5 image into a 4x4x5 (the x and y are what are getting downsampled)
282
+
283
+ For example: [ 3, 2, 4 ] => [ 3, 3, 2, 4 ] which is now easy to downsample.
284
+
285
+ """
286
+ shape = np.array(image.shape)
287
+
288
+ offset = (shape % 2)[:2] # x,y offset
289
+
290
+ # detect if we're dealing with an even
291
+ # image. if so it's fine, just return.
292
+ if not np.any(offset):
293
+ return image
294
+
295
+ oddshape = image.shape[:2] + offset
296
+ oddshape = np.append(oddshape, shape[2:])
297
+ oddshape = oddshape.astype(int)
298
+
299
+ newimg = np.empty(shape=oddshape, dtype=image.dtype)
300
+
301
+ ox,oy = offset
302
+ sx,sy = oddshape
303
+
304
+ newimg[0,0] = image[0,0] # corner
305
+ newimg[ox:sx,0] = image[:,0] # x axis line
306
+ newimg[0,oy:sy] = image[0,:] # y axis line
307
+
308
+ return newimg
309
+
310
+ def counting(array):
311
+ factor = (2, 2, 1)
312
+ shape = array.shape
313
+
314
+ while len(shape) < 4:
315
+ array = np.expand_dims(array, axis=-1)
316
+ shape = array.shape
317
+
318
+ output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor))
319
+ output = np.zeros(output_shape, dtype=array.dtype)
320
+
321
+ for chan in range(0, shape[3]):
322
+ for z in range(0, shape[2]):
323
+ for x in range(0, shape[0], 2):
324
+ for y in range(0, shape[1], 2):
325
+ block = array[ x:x+2, y:y+2, z, chan ] # 2x2 block
326
+
327
+ hashtable = defaultdict(int)
328
+ for subx, suby in np.ndindex(block.shape[0], block.shape[1]):
329
+ hashtable[block[subx, suby]] += 1
330
+
331
+ best = (0, 0)
332
+ for segid, val in six.iteritems(hashtable):
333
+ if best[1] < val:
334
+ best = (segid, val)
335
+
336
+ output[ x // 2, y // 2, chan ] = best[0]
337
+
338
+ return output
339
+
340
+ def ndzoom(array):
341
+ if len(array.shape) == 3:
342
+ ratio = ( 1 / 2.0, 1 / 2.0, 1.0 )
343
+ else:
344
+ ratio = ( 1 / 2.0, 1 / 2.0)
345
+ return ndimage.interpolation.zoom(array, ratio, order=1)
346
+
347
+ def countless_if(array):
348
+ factor = (2, 2, 1)
349
+ shape = array.shape
350
+
351
+ if len(shape) < 3:
352
+ array = array[ :,:, np.newaxis ]
353
+ shape = array.shape
354
+
355
+ output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor))
356
+ output = np.zeros(output_shape, dtype=array.dtype)
357
+
358
+ for chan in range(0, shape[2]):
359
+ for x in range(0, shape[0], 2):
360
+ for y in range(0, shape[1], 2):
361
+ block = array[ x:x+2, y:y+2, chan ] # 2x2 block
362
+
363
+ if block[0,0] == block[1,0]:
364
+ pick = block[0,0]
365
+ elif block[0,0] == block[0,1]:
366
+ pick = block[0,0]
367
+ elif block[1,0] == block[0,1]:
368
+ pick = block[1,0]
369
+ else:
370
+ pick = block[1,1]
371
+
372
+ output[ x // 2, y // 2, chan ] = pick
373
+
374
+ return np.squeeze(output)
375
+
376
+ def downsample_with_averaging(array):
377
+ """
378
+ Downsample x by factor using averaging.
379
+
380
+ @return: The downsampled array, of the same type as x.
381
+ """
382
+
383
+ if len(array.shape) == 3:
384
+ factor = (2,2,1)
385
+ else:
386
+ factor = (2,2)
387
+
388
+ if np.array_equal(factor[:3], np.array([1,1,1])):
389
+ return array
390
+
391
+ output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor))
392
+ temp = np.zeros(output_shape, float)
393
+ counts = np.zeros(output_shape, np.int)
394
+ for offset in np.ndindex(factor):
395
+ part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
396
+ indexing_expr = tuple(np.s_[:s] for s in part.shape)
397
+ temp[indexing_expr] += part
398
+ counts[indexing_expr] += 1
399
+ return np.cast[array.dtype](temp / counts)
400
+
401
+ def downsample_with_max_pooling(array):
402
+
403
+ factor = (2,2)
404
+
405
+ if np.all(np.array(factor, int) == 1):
406
+ return array
407
+
408
+ sections = []
409
+
410
+ for offset in np.ndindex(factor):
411
+ part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
412
+ sections.append(part)
413
+
414
+ output = sections[0].copy()
415
+
416
+ for section in sections[1:]:
417
+ np.maximum(output, section, output)
418
+
419
+ return output
420
+
421
+ def striding(array):
422
+ """Downsample x by factor using striding.
423
+
424
+ @return: The downsampled array, of the same type as x.
425
+ """
426
+ factor = (2,2)
427
+ if np.all(np.array(factor, int) == 1):
428
+ return array
429
+ return array[tuple(np.s_[::f] for f in factor)]
430
+
431
+ def benchmark():
432
+ filename = sys.argv[1]
433
+ img = Image.open(filename)
434
+ data = np.array(img.getdata(), dtype=np.uint8)
435
+
436
+ if len(data.shape) == 1:
437
+ n_channels = 1
438
+ reshape = (img.height, img.width)
439
+ else:
440
+ n_channels = min(data.shape[1], 3)
441
+ data = data[:, :n_channels]
442
+ reshape = (img.height, img.width, n_channels)
443
+
444
+ data = data.reshape(reshape).astype(np.uint8)
445
+
446
+ methods = [
447
+ simplest_countless,
448
+ quick_countless,
449
+ quick_countless_xor,
450
+ quickest_countless,
451
+ stippled_countless,
452
+ zero_corrected_countless,
453
+ countless,
454
+ downsample_with_averaging,
455
+ downsample_with_max_pooling,
456
+ ndzoom,
457
+ striding,
458
+ # countless_if,
459
+ # counting,
460
+ ]
461
+
462
+ formats = {
463
+ 1: 'L',
464
+ 3: 'RGB',
465
+ 4: 'RGBA'
466
+ }
467
+
468
+ if not os.path.exists('./results'):
469
+ os.mkdir('./results')
470
+
471
+ N = 500
472
+ img_size = float(img.width * img.height) / 1024.0 / 1024.0
473
+ print("N = %d, %dx%d (%.2f MPx) %d chan, %s" % (N, img.width, img.height, img_size, n_channels, filename))
474
+ print("Algorithm\tMPx/sec\tMB/sec\tSec")
475
+ for fn in methods:
476
+ print(fn.__name__, end='')
477
+ sys.stdout.flush()
478
+
479
+ start = time.time()
480
+ # tqdm is here to show you what's going on the first time you run it.
481
+ # Feel free to remove it to get slightly more accurate timing results.
482
+ for _ in tqdm(range(N), desc=fn.__name__, disable=True):
483
+ result = fn(data)
484
+ end = time.time()
485
+ print("\r", end='')
486
+
487
+ total_time = (end - start)
488
+ mpx = N * img_size / total_time
489
+ mbytes = N * img_size * n_channels / total_time
490
+ # Output in tab separated format to enable copy-paste into excel/numbers
491
+ print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time))
492
+ outimg = Image.fromarray(np.squeeze(result), formats[n_channels])
493
+ outimg.save('./results/{}.png'.format(fn.__name__, "PNG"))
494
+
495
+ if __name__ == '__main__':
496
+ benchmark()
497
+
498
+
499
+ # Example results:
500
+ # N = 5, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png
501
+ # Function MPx/sec MB/sec Sec
502
+ # simplest_countless 752.855 752.855 0.01
503
+ # quick_countless 920.328 920.328 0.01
504
+ # zero_corrected_countless 534.143 534.143 0.01
505
+ # countless 644.247 644.247 0.01
506
+ # downsample_with_averaging 372.575 372.575 0.01
507
+ # downsample_with_max_pooling 974.060 974.060 0.01
508
+ # ndzoom 137.517 137.517 0.04
509
+ # striding 38550.588 38550.588 0.00
510
+ # countless_if 4.377 4.377 1.14
511
+ # counting 0.117 0.117 42.85
512
+
513
+ # Run without non-numpy implementations:
514
+ # N = 2000, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png
515
+ # Algorithm MPx/sec MB/sec Sec
516
+ # simplest_countless 800.522 800.522 2.50
517
+ # quick_countless 945.420 945.420 2.12
518
+ # quickest_countless 947.256 947.256 2.11
519
+ # stippled_countless 544.049 544.049 3.68
520
+ # zero_corrected_countless 575.310 575.310 3.48
521
+ # countless 646.684 646.684 3.09
522
+ # downsample_with_averaging 385.132 385.132 5.19
523
+ # downsample_with_max_poolin 988.361 988.361 2.02
524
+ # ndzoom 163.104 163.104 12.26
525
+ # striding 81589.340 81589.340 0.02
526
+
527
+
528
+
529
+
DH-AISP/2/saicinpainting/evaluation/masks/countless/countless3d.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from six.moves import range
2
+ from PIL import Image
3
+ import numpy as np
4
+ import io
5
+ import time
6
+ import math
7
+ import random
8
+ import sys
9
+ from collections import defaultdict
10
+ from copy import deepcopy
11
+ from itertools import combinations
12
+ from functools import reduce
13
+ from tqdm import tqdm
14
+
15
+ from memory_profiler import profile
16
+
17
+ def countless5(a,b,c,d,e):
18
+ """First stage of generalizing from countless2d.
19
+
20
+ You have five slots: A, B, C, D, E
21
+
22
+ You can decide if something is the winner by first checking for
23
+ matches of three, then matches of two, then picking just one if
24
+ the other two tries fail. In countless2d, you just check for matches
25
+ of two and then pick one of them otherwise.
26
+
27
+ Unfortunately, you need to check ABC, ABD, ABE, BCD, BDE, & CDE.
28
+ Then you need to check AB, AC, AD, BC, BD
29
+ We skip checking E because if none of these match, we pick E. We can
30
+ skip checking AE, BE, CE, DE since if any of those match, E is our boy
31
+ so it's redundant.
32
+
33
+ So countless grows cominatorially in complexity.
34
+ """
35
+ sections = [ a,b,c,d,e ]
36
+
37
+ p2 = lambda q,r: q * (q == r) # q if p == q else 0
38
+ p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) # q if q == r == s else 0
39
+
40
+ lor = lambda x,y: x + (x == 0) * y
41
+
42
+ results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) )
43
+ results3 = reduce(lor, results3)
44
+
45
+ results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) )
46
+ results2 = reduce(lor, results2)
47
+
48
+ return reduce(lor, (results3, results2, e))
49
+
50
+ def countless8(a,b,c,d,e,f,g,h):
51
+ """Extend countless5 to countless8. Same deal, except we also
52
+ need to check for matches of length 4."""
53
+ sections = [ a, b, c, d, e, f, g, h ]
54
+
55
+ p2 = lambda q,r: q * (q == r)
56
+ p3 = lambda q,r,s: q * ( (q == r) & (r == s) )
57
+ p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) )
58
+
59
+ lor = lambda x,y: x + (x == 0) * y
60
+
61
+ results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) )
62
+ results4 = reduce(lor, results4)
63
+
64
+ results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) )
65
+ results3 = reduce(lor, results3)
66
+
67
+ # We can always use our shortcut of omitting the last element
68
+ # for N choose 2
69
+ results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) )
70
+ results2 = reduce(lor, results2)
71
+
72
+ return reduce(lor, [ results4, results3, results2, h ])
73
+
74
+ def dynamic_countless3d(data):
75
+ """countless8 + dynamic programming. ~2x faster"""
76
+ sections = []
77
+
78
+ # shift zeros up one so they don't interfere with bitwise operators
79
+ # we'll shift down at the end
80
+ data += 1
81
+
82
+ # This loop splits the 2D array apart into four arrays that are
83
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
84
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
85
+ factor = (2,2,2)
86
+ for offset in np.ndindex(factor):
87
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
88
+ sections.append(part)
89
+
90
+ pick = lambda a,b: a * (a == b)
91
+ lor = lambda x,y: x + (x == 0) * y
92
+
93
+ subproblems2 = {}
94
+
95
+ results2 = None
96
+ for x,y in combinations(range(7), 2):
97
+ res = pick(sections[x], sections[y])
98
+ subproblems2[(x,y)] = res
99
+ if results2 is not None:
100
+ results2 += (results2 == 0) * res
101
+ else:
102
+ results2 = res
103
+
104
+ subproblems3 = {}
105
+
106
+ results3 = None
107
+ for x,y,z in combinations(range(8), 3):
108
+ res = pick(subproblems2[(x,y)], sections[z])
109
+
110
+ if z != 7:
111
+ subproblems3[(x,y,z)] = res
112
+
113
+ if results3 is not None:
114
+ results3 += (results3 == 0) * res
115
+ else:
116
+ results3 = res
117
+
118
+ results3 = reduce(lor, (results3, results2, sections[-1]))
119
+
120
+ # free memory
121
+ results2 = None
122
+ subproblems2 = None
123
+ res = None
124
+
125
+ results4 = ( pick(subproblems3[(x,y,z)], sections[w]) for x,y,z,w in combinations(range(8), 4) )
126
+ results4 = reduce(lor, results4)
127
+ subproblems3 = None # free memory
128
+
129
+ final_result = lor(results4, results3) - 1
130
+ data -= 1
131
+ return final_result
132
+
133
+ def countless3d(data):
134
+ """Now write countless8 in such a way that it could be used
135
+ to process an image."""
136
+ sections = []
137
+
138
+ # shift zeros up one so they don't interfere with bitwise operators
139
+ # we'll shift down at the end
140
+ data += 1
141
+
142
+ # This loop splits the 2D array apart into four arrays that are
143
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
144
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
145
+ factor = (2,2,2)
146
+ for offset in np.ndindex(factor):
147
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
148
+ sections.append(part)
149
+
150
+ p2 = lambda q,r: q * (q == r)
151
+ p3 = lambda q,r,s: q * ( (q == r) & (r == s) )
152
+ p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) )
153
+
154
+ lor = lambda x,y: x + (x == 0) * y
155
+
156
+ results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) )
157
+ results4 = reduce(lor, results4)
158
+
159
+ results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) )
160
+ results3 = reduce(lor, results3)
161
+
162
+ results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) )
163
+ results2 = reduce(lor, results2)
164
+
165
+ final_result = reduce(lor, (results4, results3, results2, sections[-1])) - 1
166
+ data -= 1
167
+ return final_result
168
+
169
+ def countless_generalized(data, factor):
170
+ assert len(data.shape) == len(factor)
171
+
172
+ sections = []
173
+
174
+ mode_of = reduce(lambda x,y: x * y, factor)
175
+ majority = int(math.ceil(float(mode_of) / 2))
176
+
177
+ data += 1
178
+
179
+ # This loop splits the 2D array apart into four arrays that are
180
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
181
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
182
+ for offset in np.ndindex(factor):
183
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
184
+ sections.append(part)
185
+
186
+ def pick(elements):
187
+ eq = ( elements[i] == elements[i+1] for i in range(len(elements) - 1) )
188
+ anded = reduce(lambda p,q: p & q, eq)
189
+ return elements[0] * anded
190
+
191
+ def logical_or(x,y):
192
+ return x + (x == 0) * y
193
+
194
+ result = ( pick(combo) for combo in combinations(sections, majority) )
195
+ result = reduce(logical_or, result)
196
+ for i in range(majority - 1, 3-1, -1): # 3-1 b/c of exclusive bounds
197
+ partial_result = ( pick(combo) for combo in combinations(sections, i) )
198
+ partial_result = reduce(logical_or, partial_result)
199
+ result = logical_or(result, partial_result)
200
+
201
+ partial_result = ( pick(combo) for combo in combinations(sections[:-1], 2) )
202
+ partial_result = reduce(logical_or, partial_result)
203
+ result = logical_or(result, partial_result)
204
+
205
+ result = logical_or(result, sections[-1]) - 1
206
+ data -= 1
207
+ return result
208
+
209
+ def dynamic_countless_generalized(data, factor):
210
+ assert len(data.shape) == len(factor)
211
+
212
+ sections = []
213
+
214
+ mode_of = reduce(lambda x,y: x * y, factor)
215
+ majority = int(math.ceil(float(mode_of) / 2))
216
+
217
+ data += 1 # offset from zero
218
+
219
+ # This loop splits the 2D array apart into four arrays that are
220
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
221
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
222
+ for offset in np.ndindex(factor):
223
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
224
+ sections.append(part)
225
+
226
+ pick = lambda a,b: a * (a == b)
227
+ lor = lambda x,y: x + (x == 0) * y # logical or
228
+
229
+ subproblems = [ {}, {} ]
230
+ results2 = None
231
+ for x,y in combinations(range(len(sections) - 1), 2):
232
+ res = pick(sections[x], sections[y])
233
+ subproblems[0][(x,y)] = res
234
+ if results2 is not None:
235
+ results2 = lor(results2, res)
236
+ else:
237
+ results2 = res
238
+
239
+ results = [ results2 ]
240
+ for r in range(3, majority+1):
241
+ r_results = None
242
+ for combo in combinations(range(len(sections)), r):
243
+ res = pick(subproblems[0][combo[:-1]], sections[combo[-1]])
244
+
245
+ if combo[-1] != len(sections) - 1:
246
+ subproblems[1][combo] = res
247
+
248
+ if r_results is not None:
249
+ r_results = lor(r_results, res)
250
+ else:
251
+ r_results = res
252
+ results.append(r_results)
253
+ subproblems[0] = subproblems[1]
254
+ subproblems[1] = {}
255
+
256
+ results.reverse()
257
+ final_result = lor(reduce(lor, results), sections[-1]) - 1
258
+ data -= 1
259
+ return final_result
260
+
261
+ def downsample_with_averaging(array):
262
+ """
263
+ Downsample x by factor using averaging.
264
+
265
+ @return: The downsampled array, of the same type as x.
266
+ """
267
+ factor = (2,2,2)
268
+
269
+ if np.array_equal(factor[:3], np.array([1,1,1])):
270
+ return array
271
+
272
+ output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor))
273
+ temp = np.zeros(output_shape, float)
274
+ counts = np.zeros(output_shape, np.int)
275
+ for offset in np.ndindex(factor):
276
+ part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
277
+ indexing_expr = tuple(np.s_[:s] for s in part.shape)
278
+ temp[indexing_expr] += part
279
+ counts[indexing_expr] += 1
280
+ return np.cast[array.dtype](temp / counts)
281
+
282
+ def downsample_with_max_pooling(array):
283
+
284
+ factor = (2,2,2)
285
+
286
+ sections = []
287
+
288
+ for offset in np.ndindex(factor):
289
+ part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
290
+ sections.append(part)
291
+
292
+ output = sections[0].copy()
293
+
294
+ for section in sections[1:]:
295
+ np.maximum(output, section, output)
296
+
297
+ return output
298
+
299
+ def striding(array):
300
+ """Downsample x by factor using striding.
301
+
302
+ @return: The downsampled array, of the same type as x.
303
+ """
304
+ factor = (2,2,2)
305
+ if np.all(np.array(factor, int) == 1):
306
+ return array
307
+ return array[tuple(np.s_[::f] for f in factor)]
308
+
309
+ def benchmark():
310
+ def countless3d_generalized(img):
311
+ return countless_generalized(img, (2,8,1))
312
+ def countless3d_dynamic_generalized(img):
313
+ return dynamic_countless_generalized(img, (8,8,1))
314
+
315
+ methods = [
316
+ # countless3d,
317
+ # dynamic_countless3d,
318
+ countless3d_generalized,
319
+ # countless3d_dynamic_generalized,
320
+ # striding,
321
+ # downsample_with_averaging,
322
+ # downsample_with_max_pooling
323
+ ]
324
+
325
+ data = np.zeros(shape=(16**2, 16**2, 16**2), dtype=np.uint8) + 1
326
+
327
+ N = 5
328
+
329
+ print('Algorithm\tMPx\tMB/sec\tSec\tN=%d' % N)
330
+
331
+ for fn in methods:
332
+ start = time.time()
333
+ for _ in range(N):
334
+ result = fn(data)
335
+ end = time.time()
336
+
337
+ total_time = (end - start)
338
+ mpx = N * float(data.shape[0] * data.shape[1] * data.shape[2]) / total_time / 1024.0 / 1024.0
339
+ mbytes = mpx * np.dtype(data.dtype).itemsize
340
+ # Output in tab separated format to enable copy-paste into excel/numbers
341
+ print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time))
342
+
343
+ if __name__ == '__main__':
344
+ benchmark()
345
+
346
+ # Algorithm MPx MB/sec Sec N=5
347
+ # countless3d 10.564 10.564 60.58
348
+ # dynamic_countless3d 22.717 22.717 28.17
349
+ # countless3d_generalized 9.702 9.702 65.96
350
+ # countless3d_dynamic_generalized 22.720 22.720 28.17
351
+ # striding 253360.506 253360.506 0.00
352
+ # downsample_with_averaging 224.098 224.098 2.86
353
+ # downsample_with_max_pooling 690.474 690.474 0.93
354
+
355
+
356
+
DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gcim.jpg ADDED

Git LFS Details

  • SHA256: 2b1ade0a290a0a79aceb49a170d085e28e5d2ea1face4fcd522d39a279d3fb4d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.58 MB
DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png ADDED
DH-AISP/2/saicinpainting/evaluation/masks/countless/images/segmentation.png ADDED
DH-AISP/2/saicinpainting/evaluation/masks/countless/images/sparse.png ADDED
DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png ADDED
DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png ADDED
DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d.png ADDED