diff --git a/.gitattributes b/.gitattributes index 98f55a07ac1f0e30d980e43f8c33bc85ce22d59f..a158fd953eb8cdf867295c1fc0033fdfff23a9a7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -39,3 +39,6 @@ SCBC/Input/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text SCBC/Output/IMG_20240215_213330.png filter=lfs diff=lfs merge=lfs -text SCBC/Output/IMG_20240215_214449.png filter=lfs diff=lfs merge=lfs -text PolyuColor/resources/average_shading.png filter=lfs diff=lfs merge=lfs -text +DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text +DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.meta filter=lfs diff=lfs merge=lfs -text +DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gcim.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/DH-AISP/1/__pycache__/awb.cpython-36.pyc b/DH-AISP/1/__pycache__/awb.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f1385b3e54ceb205e9bdda458fdee2c3934d7ad Binary files /dev/null and b/DH-AISP/1/__pycache__/awb.cpython-36.pyc differ diff --git a/DH-AISP/1/awb.py b/DH-AISP/1/awb.py new file mode 100644 index 0000000000000000000000000000000000000000..60380cc6903e36e2cd3a075bafdb18202f1bd402 --- /dev/null +++ b/DH-AISP/1/awb.py @@ -0,0 +1,184 @@ +import os +import cv2 +import numpy as np +from glob import glob + + +def dynamic(rgb): + + rgb = rgb[:-1, :-1, :] # 删去一行一列 + h, w, _ = rgb.shape + col = 4 + row = 3 + h1 = h // row + w1 = w // col + + r, g, b = cv2.split(rgb) + r_mask = r < 0.95 + g_mask = g < 0.95 + b_mask = b < 0.95 + mask = r_mask * g_mask * b_mask + r *= mask + g *= mask + b *= mask + rgb = np.stack((r, g, b), axis=2) + + y, cr, cb = cv2.split(cv2.cvtColor(rgb, cv2.COLOR_RGB2YCrCb)) + cr -= 0.5 + cb -= 0.5 + + mr, mb, dr, db = 0, 0, 0, 0 + for r in range(row): + for c in range(col): + cr_1 = cr[r * h1:(r + 1) * h1, c * w1:(c + 1) * w1] + cb_1 = cb[r * h1:(r + 1) * h1, c * w1:(c + 1) * w1] + mr_1 = np.mean(cr_1) + mb_1 = np.mean(cb_1) + dr_1 = np.mean(np.abs(cr_1 - mr)) + db_1 = np.mean(np.abs(cb_1 - mb)) + + mr += mr_1 + mb += mb_1 + dr += dr_1 + db += db_1 + + mr /= col * row + mb /= col * row + dr /= col * row + db /= col * row + + cb_mask = np.abs(cb - (mb + db * np.sign(mb))) < 1.5 * db + cr_mask = np.abs(cr - (1.5 * mr + dr * np.sign(mr))) < 1.5 * dr + + mask = cb_mask * cr_mask + y_white = y * mask + + hist_y = np.zeros(256, dtype=np.int) + y_white_uint8 = (y_white * 255).astype(np.int) + + for v in range(255): + hist_y[v] = np.sum(y_white_uint8 == v) + + thr_sum = 0.05 * np.sum(mask) + sum_v = 0 + thr = 0 + for v in range(255, -1, -1): + sum_v = sum_v + hist_y[v] + if sum_v > thr_sum: + thr = v + break + + white_mask = y_white_uint8 > thr + cv2.imwrite(r'V:\Project\3_MEWDR\data\2nd_awb\t.png', (white_mask + 0) * 255) + + r, g, b = cv2.split(rgb) + r_ave = np.sum(r[white_mask]) / np.sum(white_mask) + g_ave = np.sum(g[white_mask]) / np.sum(white_mask) + b_ave = np.sum(b[white_mask]) / np.sum(white_mask) + + return 1 / r_ave, 1 / g_ave, 1 / b_ave + + +def perf_ref(rgb, eps): + h, w, _ = rgb.shape + + r, g, b = cv2.split(rgb) + r_mask = r < 0.95 + g_mask = g < 0.95 + b_mask = b < 0.95 + mask = r_mask * g_mask * b_mask + r *= mask + g *= mask + b *= mask + rgb = np.stack((r, g, b), axis=2) + rgb = np.clip(rgb * 255, 0, 255).astype(np.int) + + hist_rgb = np.zeros(255 * 3, dtype=np.int) + rgb_sum = np.sum(rgb, axis=2) + + for v in range(255 * 3): + hist_rgb[v] = np.sum(rgb_sum == v) + + thr_sum = eps * h * w + sum_v = 0 + thr = 0 + for v in range(255 * 3 - 1, -1, -1): + sum_v = sum_v + hist_rgb[v] + if sum_v > thr_sum: + thr = v + break + + thr_mask = rgb_sum > thr + r_ave = np.sum(r[thr_mask]) / np.sum(thr_mask) + g_ave = np.sum(g[thr_mask]) / np.sum(thr_mask) + b_ave = np.sum(b[thr_mask]) / np.sum(thr_mask) + + # k = (r_ave + g_ave + b_ave) / 3. + # k = 255 + + # print(k) + + # r = np.clip(r * k / r_ave, 0, 255) + # g = np.clip(g * k / g_ave, 0, 255) + # b = np.clip(b * k / b_ave, 0, 255) + + return 1 / r_ave, 1 / g_ave, 1 / b_ave + + +def awb_v(in_image, bayer, eps): + + assert bayer in ['GBRG', 'RGGB'] + + if bayer == 'GBRG': + g = in_image[0::2, 0::2] # [0,0] + b = in_image[0::2, 1::2] # [0,1] + r = in_image[1::2, 0::2] # [1,0] + else: + r = in_image[0::2, 0::2] # [0,0] + g = in_image[0::2, 1::2] # [0,1] + b = in_image[1::2, 1::2] # [1,1] + + rgb = cv2.merge((r, g, b)) * 1 + + r_gain, g_gain, b_gain = perf_ref(rgb, eps) + + return r_gain / g_gain, b_gain / g_gain + + +def main(): + path = r'V:\Project\3_MEWDR\data\2nd_raw' + # out_path = r'V:\Project\3_MEWDR\data\2nd_awb' + + files = glob(os.path.join(path, '*.png')) + + for f in files: + img = cv2.imread(f, cv2.CV_16UC1) + img = (img.astype(np.float) - 2048) / (15400 - 2048) * 4 + + g = img[0::2, 0::2] # [0,0] + b = img[0::2, 1::2] # [0,1] + r = img[1::2, 0::2] # [1,0] + # g_ = img[1::2, 1::2] + + rgb = cv2.merge((r, g, b)) + + # save_name = f.replace('.png', '_rgb.png').replace('2nd_raw', '2nd_awb') + + r_gain, g_gain, b_gain = perf_ref(rgb, eps=0.1) + # r_gain, g_gain, b_gain = dynamic(rgb.astype(np.float32)) + + r *= r_gain / g_gain + b *= b_gain / g_gain + print(r_gain / g_gain, b_gain / g_gain) + + out_rgb = np.clip(cv2.merge((r, g, b)) * 255, 0, 255) + + save_name = f.replace('.png', '_awb4_dyn.png').replace('2nd_raw', '2nd_awb') + + cv2.imwrite(save_name, cv2.cvtColor(out_rgb.astype(np.uint8), cv2.COLOR_RGB2BGR)) + + # break + + +if __name__ == '__main__': + main() diff --git a/DH-AISP/1/daylight_isp_03_3_unet_sid_5/checkpoint b/DH-AISP/1/daylight_isp_03_3_unet_sid_5/checkpoint new file mode 100644 index 0000000000000000000000000000000000000000..febd7d546081c498d644978b31e8c836a8931736 --- /dev/null +++ b/DH-AISP/1/daylight_isp_03_3_unet_sid_5/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "model.ckpt" +all_model_checkpoint_paths: "model.ckpt" diff --git a/DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.data-00000-of-00001 b/DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..90df3a08d37de06e5f7420cdb25aea7a1bd6e802 --- /dev/null +++ b/DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.data-00000-of-00001 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6997bfa5624aba66e2497088cc8f379db63bac343a0a648e08f6a5840a48259f +size 175070404 diff --git a/DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.index b/DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.index new file mode 100644 index 0000000000000000000000000000000000000000..7fdc35323b7bc7e049b0966ea7a847e85d351b27 Binary files /dev/null and b/DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.index differ diff --git a/DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.meta b/DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.meta new file mode 100644 index 0000000000000000000000000000000000000000..31bfcf1dbbe064c5c140178cb9d60dc02025a083 --- /dev/null +++ b/DH-AISP/1/daylight_isp_03_3_unet_sid_5/model.ckpt.meta @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79f84947bf3a5a9e851539308b85b43ecc6a8e93ed2c7ab9adb23f0fd6796286 +size 124053471 diff --git a/DH-AISP/1/tensorflow2to1_3_unet_bining3_7.py b/DH-AISP/1/tensorflow2to1_3_unet_bining3_7.py new file mode 100644 index 0000000000000000000000000000000000000000..c875a7d191d17c6e9b281c22aee3d0de0fda1855 --- /dev/null +++ b/DH-AISP/1/tensorflow2to1_3_unet_bining3_7.py @@ -0,0 +1,451 @@ +# uniform content loss + adaptive threshold + per_class_input + recursive G +# improvement upon cqf37 +from __future__ import division +import os +import tensorflow.compat.v1 as tf +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +import tf_slim as slim +import tensorflow as tf2 +tf2.test.is_gpu_available() +import numpy as np +import glob +# import scipy.io as sio +import cv2 +import json +from fractions import Fraction +import pdb +import sys +import argparse + +from awb import awb_v + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +input_dir = '../data/' +cha1 = 32 + +# get train IDs +train_fns = glob.glob(input_dir + '*.png') +train_ids = [os.path.basename(train_fn) for train_fn in train_fns] + +result_dir = './mid/' +checkpoint_dir = './daylight_isp_03_3_unet_sid_5/' + +if not os.path.exists(result_dir): + os.mkdir(result_dir) + +#run python tensorflow2to1_1214_5202x3464_01_unetpp3.py ./data/ ./result/ ./daylight_isp_03/ + +# DEBUG = 0 +# if DEBUG == 1: +# save_freq = 2 +# test_ids = test_ids[0:5] + +def json_read(fname, **kwargs): + with open(fname) as j: + data = json.load(j, **kwargs) + return data + +def fraction_from_json(json_object): + if 'Fraction' in json_object: + return Fraction(*json_object['Fraction']) + return json_object + +def fractions2floats(fractions): + floats = [] + for fraction in fractions: + floats.append(float(fraction.numerator) / fraction.denominator) + return floats + +def tv_loss(input_, output): + I = tf.image.rgb_to_grayscale(input_) + L = tf.log(I+0.0001) + dx = L[:, :-1, :-1, :] - L[:, :-1, 1:, :] + dy = L[:, :-1, :-1, :] - L[:, 1:, :-1, :] + + alpha = tf.constant(1.2) + lamda = tf.constant(1.5) + dx = tf.divide(lamda, tf.pow(tf.abs(dx),alpha)+ tf.constant(0.0001)) + dy = tf.divide(lamda, tf.pow(tf.abs(dy),alpha)+ tf.constant(0.0001)) + shape = output.get_shape() + x_loss = dx *((output[:, :-1, :-1, :] - output[:, :-1, 1:, :])**2) + y_loss = dy *((output[:, :-1, :-1, :] - output[:, 1:, :-1, :])**2) + tvloss = tf.reduce_mean(x_loss + y_loss)/2.0 + return tvloss + +def lrelu(x): + return tf.maximum(x * 0.2, x) + + +def upsample_and_concat_3(x1, x2, output_channels, in_channels, name): + with tf.variable_scope(name): + x1 = slim.conv2d(x1, output_channels, [3, 3], rate=1, activation_fn=lrelu, scope='conv_2to1') + deconv = tf.image.resize_images(x1, [x1.shape[1] * 2, x1.shape[2] * 2]) + deconv_output = tf.concat([deconv, x2], 3) + deconv_output.set_shape([None, None, None, output_channels * 2]) + return deconv_output + + +def upsample_and_concat_h(x1, x2, output_channels, in_channels, name): + with tf.variable_scope(name): + #deconv = tf.image.resize_images(x1, [x1.shape[1].value*2, x1.shape[2].value*2]) + pool_size = 2 + deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02)) + deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1]) + + deconv_output = tf.concat([deconv, x2], 3) + deconv_output.set_shape([None, None, None, output_channels * 2]) + + return deconv_output + +def upsample_and_concat_h_only(x1, output_channels, in_channels, name): + with tf.variable_scope(name): + x1 = tf.image.resize_images(x1, [x1.shape[1] * 2, x1.shape[2] * 2]) + x1.set_shape([None, None, None, output_channels]) + return x1 + + +def conv_block(input, output_channels, name): + with tf.variable_scope(name): + conv = slim.conv2d(input, output_channels, [3, 3], activation_fn=lrelu, scope='conv1') + conv = slim.conv2d(conv, output_channels, [3, 3], activation_fn=lrelu, scope='conv2') + return conv + + +def conv_block_up(input, output_channels, name): + with tf.variable_scope(name): + conv = slim.conv2d(input, output_channels, [1, 1], scope='conv0') + conv = slim.conv2d(conv, output_channels, [3, 3], activation_fn=lrelu, scope='conv1') + conv = slim.conv2d(conv, output_channels, [3, 3], activation_fn=lrelu, scope='conv2') + + return conv + + +def upsample_and_concat(x1, x2, output_channels, in_channels, p, name): + with tf.variable_scope(name): + pool_size = 2 + + deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02)) + deconv_filter = tf.cast(deconv_filter, x1.dtype) + deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2[0]), strides=[1, pool_size, pool_size, 1]) + # x2.append(deconv) + x2 = tf.concat(x2, axis=3) + deconv_output = tf.concat([x2, deconv], axis=3) + deconv_output.set_shape([None, None, None, output_channels * (p + 1)]) + + return deconv_output + + +def network(input): + with tf.variable_scope("generator_h"): + conv1_h = slim.conv2d(input, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_1') + conv1_h = slim.conv2d(conv1_h, 32, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv1_2') + pool1_h = slim.max_pool2d(conv1_h, [2, 2], padding='SAME') + + conv2_h = slim.conv2d(pool1_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_1') + conv2_h = slim.conv2d(conv2_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv2_2') + pool2_h = slim.max_pool2d(conv2_h, [2, 2], padding='SAME') + + conv3_h = slim.conv2d(pool2_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_1') + conv3_h = slim.conv2d(conv3_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv3_2') + pool3_h = slim.max_pool2d(conv3_h, [2, 2], padding='SAME') + + conv4_h = slim.conv2d(pool3_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_1') + conv4_h = slim.conv2d(conv4_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv4_2') + conv6_h = slim.conv2d(conv4_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_1') + conv6_h = slim.conv2d(conv6_h, cha1*8, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv6_2') + + up7_h = upsample_and_concat_3(conv6_h, conv3_h, cha1*4,cha1*8, name='up7') + conv7_h = slim.conv2d(up7_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_1') + conv7_h = slim.conv2d(conv7_h, cha1*4, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv7_2') + + up8_h = upsample_and_concat_3(conv7_h, conv2_h, cha1*2,cha1*4, name='up8') + conv8_h = slim.conv2d(up8_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_1') + conv8_h = slim.conv2d(conv8_h, cha1*2, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv8_2') + + up9_h = upsample_and_concat_3(conv8_h, conv1_h, cha1,cha1*2, name='up9') + conv9_h = slim.conv2d(up9_h, cha1, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_1') + conv9_h = slim.conv2d(conv9_h, cha1, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv9_2') + + up10_h = upsample_and_concat_h_only(conv9_h, cha1,cha1, name='up10') + conv10_h = slim.conv2d(up10_h, cha1, [3, 3], rate=1, activation_fn=lrelu, scope='g_conv10_1') + out = slim.conv2d(conv10_h, 3, [3, 3], rate=1, activation_fn=None, scope='g_conv10_2') + return out + + + +def fix_orientation(image, orientation): + # 1 = Horizontal(normal) + # 2 = Mirror horizontal + # 3 = Rotate 180 + # 4 = Mirror vertical + # 5 = Mirror horizontal and rotate 270 CW + # 6 = Rotate 90 CW + # 7 = Mirror horizontal and rotate 90 CW + # 8 = Rotate 270 CW + + if type(orientation) is list: + orientation = orientation[0] + + if orientation == 'Horizontal (normal)': + pass + elif orientation == 'Mirror horizontal': + image = cv2.flip(image, 0) + elif orientation == 'Rotate 180': + image = cv2.rotate(image, cv2.ROTATE_180) + elif orientation == 'Mirror vertical': + image = cv2.flip(image, 1) + elif orientation == 'Mirror horizontal and rotate 270 CW': + image = cv2.flip(image, 0) + image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) + elif orientation == 'Rotate 90 CW': + image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) + elif orientation == 'Mirror horizontal and rotate 90 CW': + image = cv2.flip(image, 0) + image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) + elif orientation == 'Rotate 270 CW': + image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) + + return image + +class ExposureFusion(object): + def __init__(self, sequence, best_exposedness=0.5, sigma=0.2, eps=1e-12, exponents=(1.0, 1.0, 1.0), layers=11): + self.sequence = sequence # [N, H, W, 3], (0..1), float32 + self.img_num = sequence.shape[0] + self.best_exposedness = best_exposedness + self.sigma = sigma + self.eps = eps + self.exponents = exponents + self.layers = layers + + @staticmethod + def cal_contrast(src): + gray = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY) + laplace_kernel = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32) + contrast = cv2.filter2D(gray, -1, laplace_kernel, borderType=cv2.BORDER_REPLICATE) + return np.abs(contrast) + + @staticmethod + def cal_saturation(src): + mean = np.mean(src, axis=-1) + channels = [(src[:, :, c] - mean)**2 for c in range(3)] + saturation = np.sqrt(np.mean(channels, axis=0)) + return saturation + + @staticmethod + def cal_exposedness(src, best_exposedness, sigma): + exposedness = [gauss_curve(src[:, :, c], best_exposedness, sigma) for c in range(3)] + exposedness = np.prod(exposedness, axis=0) + return exposedness + + def cal_weight_map(self): + #pdb.set_trace() + weights = [] + for idx in range(self.sequence.shape[0]): + contrast = self.cal_contrast(self.sequence[idx]) + saturation = self.cal_saturation(self.sequence[idx]) + exposedness = self.cal_exposedness(self.sequence[idx], self.best_exposedness, self.sigma) + weight = np.power(contrast, self.exponents[0]) * np.power(saturation, self.exponents[1]) * np.power(exposedness, self.exponents[2]) + # Gauss Blur + # weight = cv2.GaussianBlur(weight, (21, 21), 2.1) + weights.append(weight) + #pdb.set_trace() + weights = np.stack(weights, 0) + self.eps + # normalize + weights = weights / np.expand_dims(np.sum(weights, axis=0), axis=0) + return weights + + def naive_fusion(self): + weights = self.cal_weight_map() # [N, H, W] + weights = np.stack([weights, weights, weights], axis=-1) # [N, H, W, 3] + naive_fusion = np.sum(weights * self.sequence * 255, axis=0) + naive_fusion = np.clip(naive_fusion, 0, 255).astype(np.uint8) + return naive_fusion + + def build_gaussian_pyramid(self, high_res): + #pdb.set_trace() + gaussian_pyramid = [high_res] + for idx in range(1, self.layers): + kernel1=np.array([[0.0039,0.0156,0.0234,0.0156,0.0039],[0.0156,0.0625,0.0938,0.0625,0.0156],[0.0234,0.0938,0.1406,0.0938,0.0234],[0.0156,0.0625,0.0938,0.0625,0.0156],[0.0039,0.0156,0.0234,0.0156,0.0039]],dtype='float32') + gaussian_pyramid.append(cv2.filter2D(gaussian_pyramid[-1], -1,kernel=kernel1)[::2, ::2]) + #gaussian_pyramid.append(cv2.GaussianBlur(gaussian_pyramid[-1], (5, 5), 0.83)[::2, ::2]) + return gaussian_pyramid + + def build_laplace_pyramid(self, gaussian_pyramid): + laplace_pyramid = [gaussian_pyramid[-1]] + for idx in range(1, self.layers): + size = (gaussian_pyramid[self.layers - idx - 1].shape[1], gaussian_pyramid[self.layers - idx - 1].shape[0]) + upsampled = cv2.resize(gaussian_pyramid[self.layers - idx], size, interpolation=cv2.INTER_LINEAR) + laplace_pyramid.append(gaussian_pyramid[self.layers - idx - 1] - upsampled) + laplace_pyramid.reverse() + return laplace_pyramid + + def multi_resolution_fusion(self): + #pdb.set_trace() + weights = self.cal_weight_map() # [N, H, W] + weights = np.stack([weights, weights, weights], axis=-1) # [N, H, W, 3] + + image_gaussian_pyramid = [self.build_gaussian_pyramid(self.sequence[i] * 255) for i in range(self.img_num)] + image_laplace_pyramid = [self.build_laplace_pyramid(image_gaussian_pyramid[i]) for i in range(self.img_num)] + weights_gaussian_pyramid = [self.build_gaussian_pyramid(weights[i]) for i in range(self.img_num)] + + fused_laplace_pyramid = [np.sum([image_laplace_pyramid[n][l] * + weights_gaussian_pyramid[n][l] for n in range(self.img_num)], axis=0) for l in range(self.layers)] + + result = fused_laplace_pyramid[-1] + for k in range(1, self.layers): + size = (fused_laplace_pyramid[self.layers - k - 1].shape[1], fused_laplace_pyramid[self.layers - k - 1].shape[0]) + upsampled = cv2.resize(result, size, interpolation=cv2.INTER_LINEAR) + result = upsampled + fused_laplace_pyramid[self.layers - k - 1] + #pdb.set_trace() + #result = np.clip(result, 0, 255).astype(np.uint8) + + + return result + +h_pre1, w_pre1 = 6144, 8192 +pad_1 = 0 +pad_2 = 0 +h_exp1, w_exp1 = h_pre1 // 2, w_pre1 // 2 + +sess = tf.Session() +in_image = tf.placeholder(tf.float32, [None, h_exp1, w_exp1, 4]) + +in_image1 = tf.nn.avg_pool(in_image,ksize=[1,4,4,1],strides=[1,4,4,1],padding='SAME') +in_image2 = tf.nn.avg_pool(in_image,ksize=[1,8,8,1],strides=[1,8,8,1],padding='SAME') + +out_image1 = network(in_image1) +out_image2 = network(in_image2, reuse=True) + +t_vars = tf.trainable_variables() +for ele1 in t_vars: + print("variable: ", ele1) + +saver = tf.train.Saver() +sess.run(tf.global_variables_initializer()) + +ckpt = tf.train.get_checkpoint_state(checkpoint_dir) +if ckpt: + print('loaded ' + ckpt.model_checkpoint_path) + saver.restore(sess, ckpt.model_checkpoint_path) + +in_pic4 = np.zeros([h_exp1, w_exp1, 4]) +for k in range(len(train_ids)): + + print(k) + train_id = train_ids[k] + in_path = input_dir + train_id[:-4] + '.png' + #raw_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED).astype(np.float32) + raw_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED).astype(np.float32) + #meta = np.load(input_dir1 + train_id[:-4] + '.npy').astype(np.float32) + #meta = scipy.io.loadmat(input_dir2 + train_id[:-4] + '.mat') + metadata = json_read(in_path[:-4] + '.json', object_hook=fraction_from_json) + + white_level = float(metadata['white_level']) + black_level = float(metadata['black_level'][0].numerator) + + orientation = metadata['orientation'] + + in_pic2 = np.clip((raw_image - black_level) /(white_level-black_level),0,1) + + mean = np.mean(np.mean(in_pic2)) + var = np.var(in_pic2) + + bining = 4 + + if (mean < 0.01): + ratio = 6 + elif (mean < 0.02): + ratio = 4 + elif (mean < 0.037): + ratio = 3 + else: + ratio = 2 + + if (var > 0.015): + ratio = ratio + 1 + + noise_profile = float(metadata['noise_profile'][0]) * ratio + if (noise_profile > 0.02): + bining = 8 + ratio = np.clip(ratio - 1,2,4) + + #r_gain, b_gain = awb_v(in_pic2, bayer='RGGB', eps=1) + r_gain1 = 1./metadata['as_shot_neutral'][0] + b_gain1 = 1./metadata['as_shot_neutral'][2] + + #in_pic3 = np.pad(in_pic2, ((top_pad, btm_pad), (left_pad, right_pad)), mode='reflect') # GBRG to RGGB + reflect padding + h_pre,w_pre = in_pic2.shape + + if (metadata['cfa_pattern'][0].numerator == 2): + in_pic2[0:h_pre-1,0:w_pre-1] = in_pic2[1:h_pre,1:w_pre] + + r_gain, b_gain = awb_v(in_pic2 * (ratio**2), bayer='RGGB', eps=1) + in_pic3 = in_pic2 + + in_pic4[0:h_pre//2, 0:w_pre//2, 0] = in_pic3[0::2, 0::2] * r_gain + in_pic4[0:h_pre//2, 0:w_pre//2, 1] = in_pic3[0::2, 1::2] + in_pic4[0:h_pre//2, 0:w_pre//2, 2] = in_pic3[1::2, 1::2] * b_gain + in_pic4[0:h_pre//2, 0:w_pre//2, 3] = in_pic3[1::2, 0::2] + + im1=np.clip(in_pic4*1,0,1) + in_np1 = np.expand_dims(im1,axis = 0) + if (bining == 4): + out_np1 =sess.run(out_image1,feed_dict={in_image: in_np1}) + else: + out_np1 =sess.run(out_image2,feed_dict={in_image: in_np1}) + + out_np2 = fix_orientation(out_np1[0,0:h_pre//bining,0:w_pre//bining,:], orientation) + h_pre2,w_pre2,cc = out_np2.shape + + if h_pre2 > w_pre2: + out_np_1 = cv2.resize(out_np2, (768, 1024), cv2.INTER_CUBIC) + if w_pre2 > h_pre2: + out_np_1 = cv2.resize(out_np2, (1024, 768), cv2.INTER_CUBIC) + + im1=np.clip(in_pic4*ratio,0,1) + in_np1 = np.expand_dims(im1,axis = 0) + if (bining == 4): + out_np1 =sess.run(out_image1,feed_dict={in_image: in_np1}) + else: + out_np1 =sess.run(out_image2,feed_dict={in_image: in_np1}) + + out_np2 = fix_orientation(out_np1[0,0:h_pre//bining,0:w_pre//bining,:], orientation) + h_pre2,w_pre2,cc = out_np2.shape + + if h_pre2 > w_pre2: + out_np_2 = cv2.resize(out_np2, (768, 1024), cv2.INTER_CUBIC) + if w_pre2 > h_pre2: + out_np_2 = cv2.resize(out_np2, (1024, 768), cv2.INTER_CUBIC) + + + im1=np.clip(in_pic4*(ratio**2),0,1) + in_np1 = np.expand_dims(im1,axis = 0) + + if (bining == 4): + out_np1 =sess.run(out_image1,feed_dict={in_image: in_np1}) + else: + out_np1 =sess.run(out_image2,feed_dict={in_image: in_np1}) + + out_np2 = fix_orientation(out_np1[0,0:h_pre//bining,0:w_pre//bining,:], orientation) + h_pre2,w_pre2,cc = out_np2.shape + + if h_pre2 > w_pre2: + out_np_3 = cv2.resize(out_np2, (768, 1024), cv2.INTER_CUBIC) + if w_pre2 > h_pre2: + out_np_3 = cv2.resize(out_np2, (1024, 768), cv2.INTER_CUBIC) + + #pdb.set_trace() + '''sequence = np.stack([out_np_1, out_np_2, out_np_3], axis=0) + #sequence0 = sequence[0] + mef = ExposureFusion(sequence.astype(np.float32)) + multi_res_fusion = mef.multi_resolution_fusion() + #pdb.set_trace() + result = reprocessing(multi_res_fusion)''' + + #out_crop = multi_res_fusion + + #np.save(result_dir + train_id[0:-4] + '_gray_{:}.npy'.format(gain), out_crop) + cv2.imwrite(result_dir + train_id[0:-4] + '_1.png', out_np_1[:,:,::-1]*255) + cv2.imwrite(result_dir + train_id[0:-4] + '_2.png', out_np_2[:,:,::-1]*255) + cv2.imwrite(result_dir + train_id[0:-4] + '_3.png', out_np_3[:,:,::-1]*255) + diff --git a/DH-AISP/2/__pycache__/model_convnext2_hdr.cpython-37.pyc b/DH-AISP/2/__pycache__/model_convnext2_hdr.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cfa67106d587231ab73a96684227ef91190bcf9 Binary files /dev/null and b/DH-AISP/2/__pycache__/model_convnext2_hdr.cpython-37.pyc differ diff --git a/DH-AISP/2/__pycache__/myFFCResblock0.cpython-37.pyc b/DH-AISP/2/__pycache__/myFFCResblock0.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c6a145921121f59338bdebaf2ded8b4bca34725 Binary files /dev/null and b/DH-AISP/2/__pycache__/myFFCResblock0.cpython-37.pyc differ diff --git a/DH-AISP/2/__pycache__/test_dataset_for_testing.cpython-37.pyc b/DH-AISP/2/__pycache__/test_dataset_for_testing.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa8b30f63b527bf17f6dea87ce2f6a8124c89f06 Binary files /dev/null and b/DH-AISP/2/__pycache__/test_dataset_for_testing.cpython-37.pyc differ diff --git a/DH-AISP/2/focal_frequency_loss/__init__.py b/DH-AISP/2/focal_frequency_loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7c0fbb1f469cc985bd5e545ae71fd52e03072e --- /dev/null +++ b/DH-AISP/2/focal_frequency_loss/__init__.py @@ -0,0 +1,3 @@ +from .focal_frequency_loss import FocalFrequencyLoss + +__all__ = ['FocalFrequencyLoss'] diff --git a/DH-AISP/2/focal_frequency_loss/__pycache__/__init__.cpython-37.pyc b/DH-AISP/2/focal_frequency_loss/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e497ad4b53916565e56cc63661062e32a4c21ef Binary files /dev/null and b/DH-AISP/2/focal_frequency_loss/__pycache__/__init__.cpython-37.pyc differ diff --git a/DH-AISP/2/focal_frequency_loss/__pycache__/focal_frequency_loss.cpython-37.pyc b/DH-AISP/2/focal_frequency_loss/__pycache__/focal_frequency_loss.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47b31c58144d2b6431850775ad039f79018d218d Binary files /dev/null and b/DH-AISP/2/focal_frequency_loss/__pycache__/focal_frequency_loss.cpython-37.pyc differ diff --git a/DH-AISP/2/focal_frequency_loss/focal_frequency_loss.py b/DH-AISP/2/focal_frequency_loss/focal_frequency_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..34ff0078915376f775ce37bf6825012c8c9a5a55 --- /dev/null +++ b/DH-AISP/2/focal_frequency_loss/focal_frequency_loss.py @@ -0,0 +1,114 @@ +import torch +import torch.nn as nn + +# version adaptation for PyTorch > 1.7.1 +IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.'))) > (1, 7, 1) +if IS_HIGH_VERSION: + import torch.fft + + +class FocalFrequencyLoss(nn.Module): + """The torch.nn.Module class that implements focal frequency loss - a + frequency domain loss function for optimizing generative models. + + Ref: + Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021. + + + Args: + loss_weight (float): weight for focal frequency loss. Default: 1.0 + alpha (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0 + patch_factor (int): the factor to crop image patches for patch-based focal frequency loss. Default: 1 + ave_spectrum (bool): whether to use minibatch average spectrum. Default: False + log_matrix (bool): whether to adjust the spectrum weight matrix by logarithm. Default: False + batch_matrix (bool): whether to calculate the spectrum weight matrix using batch-based statistics. Default: False + """ + + def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False): + super(FocalFrequencyLoss, self).__init__() + self.loss_weight = loss_weight + self.alpha = alpha + self.patch_factor = patch_factor + self.ave_spectrum = ave_spectrum + self.log_matrix = log_matrix + self.batch_matrix = batch_matrix + + def tensor2freq(self, x): + # crop image patches + patch_factor = self.patch_factor + _, _, h, w = x.shape + assert h % patch_factor == 0 and w % patch_factor == 0, ( + 'Patch factor should be divisible by image height and width') + patch_list = [] + patch_h = h // patch_factor + patch_w = w // patch_factor + for i in range(patch_factor): + for j in range(patch_factor): + patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w]) + + # stack to patch tensor + y = torch.stack(patch_list, 1) + + # perform 2D DFT (real-to-complex, orthonormalization) + if IS_HIGH_VERSION: + freq = torch.fft.fft2(y, norm='ortho') + freq = torch.stack([freq.real, freq.imag], -1) + else: + freq = torch.rfft(y, 2, onesided=False, normalized=True) + return freq + + def loss_formulation(self, recon_freq, real_freq, matrix=None): + # spectrum weight matrix + if matrix is not None: + # if the matrix is predefined + weight_matrix = matrix.detach() + else: + # if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance + matrix_tmp = (recon_freq - real_freq) ** 2 + matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha + + # whether to adjust the spectrum weight matrix by logarithm + if self.log_matrix: + matrix_tmp = torch.log(matrix_tmp + 1.0) + + # whether to calculate the spectrum weight matrix using batch-based statistics + if self.batch_matrix: + matrix_tmp = matrix_tmp / matrix_tmp.max() + else: + matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None] + + matrix_tmp[torch.isnan(matrix_tmp)] = 0.0 + matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0) + weight_matrix = matrix_tmp.clone().detach() + + assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, ( + 'The values of spectrum weight matrix should be in the range [0, 1], ' + 'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item())) + + # frequency distance using (squared) Euclidean distance + tmp = (recon_freq - real_freq) ** 2 + freq_distance = tmp[..., 0] + tmp[..., 1] + + # dynamic spectrum weighting (Hadamard product) + loss = weight_matrix * freq_distance + return torch.mean(loss) + + def forward(self, pred, target, matrix=None, **kwargs): + """Forward function to calculate focal frequency loss. + + Args: + pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor. + target (torch.Tensor): of shape (N, C, H, W). Target tensor. + matrix (torch.Tensor, optional): Element-wise spectrum weight matrix. + Default: None (If set to None: calculated online, dynamic). + """ + pred_freq = self.tensor2freq(pred) + target_freq = self.tensor2freq(target) + + # whether to use minibatch average spectrum + if self.ave_spectrum: + pred_freq = torch.mean(pred_freq, 0, keepdim=True) + target_freq = torch.mean(target_freq, 0, keepdim=True) + + # calculate focal frequency loss + return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight diff --git a/DH-AISP/2/model_convnext2_hdr.py b/DH-AISP/2/model_convnext2_hdr.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3b7414b9baf052ffd6026ea6edb65889bd5d16 --- /dev/null +++ b/DH-AISP/2/model_convnext2_hdr.py @@ -0,0 +1,592 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_, DropPath +from timm.models.registry import register_model + +#import Convnext as PreConv +from myFFCResblock0 import myFFCResblock + + +# A ConvNet for the 2020s +# original implementation https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py +# paper https://arxiv.org/pdf/2201.03545.pdf + +class ConvNeXt0(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + def __init__(self, block, in_chans=3, num_classes=1000, + depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], drop_path_rate=0., + layer_scale_init_value=1e-6, head_init_scale=1., + ): + super().__init__() + + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first") + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.Sequential( + *[block(dim=dims[i], drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + self.head = nn.Linear(dims[-1], num_classes) + + self.apply(self._init_weights) + self.head.weight.data.mul_(head_init_scale) + self.head.bias.data.mul_(head_init_scale) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + + + + + + + +def dwt_init(x): + x01 = x[:, :, 0::2, :] / 2 #x01.shape=[4,3,128,256] + x02 = x[:, :, 1::2, :] / 2 #x02.shape=[4,3,128,256] + x1 = x01[:, :, :, 0::2] #x1.shape=[4,3,128,128] + x2 = x02[:, :, :, 0::2] #x2.shape=[4,3,128,128] + x3 = x01[:, :, :, 1::2] #x3.shape=[4,3,128,128] + x4 = x02[:, :, :, 1::2] #x4.shape=[4,3,128,128] + x_LL = x1 + x2 + x3 + x4 + x_HL = -x1 - x2 + x3 + x4 + x_LH = -x1 + x2 - x3 + x4 + x_HH = x1 - x2 - x3 + x4 + return x_LL, torch.cat((x_HL, x_LH, x_HH), 1) + +class DWT(nn.Module): + def __init__(self): + super(DWT, self).__init__() + self.requires_grad = False + def forward(self, x): + return dwt_init(x) + +class DWT_transform(nn.Module): + def __init__(self, in_channels,out_channels): + super().__init__() + self.dwt = DWT() + self.conv1x1_low = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + self.conv1x1_high = nn.Conv2d(in_channels*3, out_channels, kernel_size=1, padding=0) + def forward(self, x): + dwt_low_frequency,dwt_high_frequency = self.dwt(x) + dwt_low_frequency = self.conv1x1_low(dwt_low_frequency) + dwt_high_frequency = self.conv1x1_high(dwt_high_frequency) + return dwt_low_frequency,dwt_high_frequency + +def blockUNet(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False): + block = nn.Sequential() + if relu: + block.add_module('%s_relu' % name, nn.ReLU(inplace=True)) + else: + block.add_module('%s_leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) + if not transposed: + block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False)) + else: + block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1)) + block.add_module('%s_bili' % name, nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) + if bn: + block.add_module('%s_bn' % name, nn.BatchNorm2d(out_c)) + if dropout: + block.add_module('%s_dropout' % name, nn.Dropout2d(0.5, inplace=True)) + return block + +# DW-GAN: A Discrete Wavelet Transform GAN for NonHomogeneous Dehazing 2021 +# original implementation https://github.com/liuh127/NTIRE-2021-Dehazing-DWGAN/blob/main/model.py +# paper https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/papers/Fu_DW-GAN_A_Discrete_Wavelet_Transform_GAN_for_NonHomogeneous_Dehazing_CVPRW_2021_paper.pdf +class dwt_ffc_UNet2(nn.Module): + def __init__(self,output_nc=3, nf=16): + super(dwt_ffc_UNet2, self).__init__() + layer_idx = 1 + name = 'layer%d' % layer_idx + layer1 = nn.Sequential() + layer1.add_module(name, nn.Conv2d(16, nf-1, 4, 2, 1, bias=False)) + layer_idx += 1 + name = 'layer%d' % layer_idx + layer2 = blockUNet(nf, nf*2-2, name, transposed=False, bn=True, relu=False, dropout=False) + layer_idx += 1 + name = 'layer%d' % layer_idx + layer3 = blockUNet(nf*2, nf*4-4, name, transposed=False, bn=True, relu=False, dropout=False) + layer_idx += 1 + name = 'layer%d' % layer_idx + layer4 = blockUNet(nf*4, nf*8-8, name, transposed=False, bn=True, relu=False, dropout=False) + layer_idx += 1 + name = 'layer%d' % layer_idx + layer5 = blockUNet(nf*8, nf*8-16, name, transposed=False, bn=True, relu=False, dropout=False) + layer_idx += 1 + name = 'layer%d' % layer_idx + layer6 = blockUNet(nf*4, nf*4, name, transposed=False, bn=False, relu=False, dropout=False) + + layer_idx -= 1 + name = 'dlayer%d' % layer_idx + dlayer6 = blockUNet(nf * 4, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False) + layer_idx -= 1 + name = 'dlayer%d' % layer_idx + dlayer5 = blockUNet(nf * 16+16, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False) + layer_idx -= 1 + name = 'dlayer%d' % layer_idx + dlayer4 = blockUNet(nf * 16+8, nf * 4, name, transposed=True, bn=True, relu=True, dropout=False) + layer_idx -= 1 + name = 'dlayer%d' % layer_idx + dlayer3 = blockUNet(nf * 8+4, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False) + layer_idx -= 1 + name = 'dlayer%d' % layer_idx + dlayer2 = blockUNet(nf * 4+2, nf, name, transposed=True, bn=True, relu=True, dropout=False) + layer_idx -= 1 + name = 'dlayer%d' % layer_idx + dlayer1 = blockUNet(nf * 2+1, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False) + + self.initial_conv=nn.Conv2d(9,16,3,padding=1) + self.bn1=nn.BatchNorm2d(16) + self.layer1 = layer1 + self.DWT_down_0= DWT_transform(9,1) + self.layer2 = layer2 + self.DWT_down_1 = DWT_transform(16, 2) + self.layer3 = layer3 + self.DWT_down_2 = DWT_transform(32, 4) + self.layer4 = layer4 + self.DWT_down_3 = DWT_transform(64, 8) + self.layer5 = layer5 + self.DWT_down_4 = DWT_transform(128, 16) + self.layer6 = layer6 + self.dlayer6 = dlayer6 + self.dlayer5 = dlayer5 + self.dlayer4 = dlayer4 + self.dlayer3 = dlayer3 + self.dlayer2 = dlayer2 + self.dlayer1 = dlayer1 + self.tail_conv1 = nn.Conv2d(48, 32, 3, padding=1, bias=True) + self.bn2=nn.BatchNorm2d(32) + self.tail_conv2 = nn.Conv2d(nf*2, output_nc, 3,padding=1, bias=True) + + + self.FFCResNet = myFFCResblock(input_nc=64, output_nc=64) + + def forward(self, x): + conv_start=self.initial_conv(x) + conv_start=self.bn1(conv_start) + conv_out1 = self.layer1(conv_start) + dwt_low_0,dwt_high_0=self.DWT_down_0(x) + out1=torch.cat([conv_out1, dwt_low_0], 1) + conv_out2 = self.layer2(out1) + dwt_low_1,dwt_high_1= self.DWT_down_1(out1) + out2 = torch.cat([conv_out2, dwt_low_1], 1) + conv_out3 = self.layer3(out2) + + dwt_low_2,dwt_high_2 = self.DWT_down_2(out2) + out3 = torch.cat([conv_out3, dwt_low_2], 1) + + # conv_out4 = self.layer4(out3) + # dwt_low_3,dwt_high_3 = self.DWT_down_3(out3) + # out4 = torch.cat([conv_out4, dwt_low_3], 1) + + # conv_out5 = self.layer5(out4) + # dwt_low_4,dwt_high_4 = self.DWT_down_4(out4) + # out5 = torch.cat([conv_out5, dwt_low_4], 1) + + # out6 = self.layer6(out5) + + + out3_ffc= self.FFCResNet(out3) + + + dout3 = self.dlayer6(out3_ffc) + + + Tout3_out2 = torch.cat([dout3, out2,dwt_high_1], 1) + Tout2 = self.dlayer2(Tout3_out2) + Tout2_out1 = torch.cat([Tout2, out1,dwt_high_0], 1) + Tout1 = self.dlayer1(Tout2_out1) + + Tout1_outinit = torch.cat([Tout1, conv_start], 1) + tail1=self.tail_conv1(Tout1_outinit) + tail2=self.bn2(tail1) + dout1 = self.tail_conv2(tail2) + + + return dout1 + + + + + + +class Block(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXt(nn.Module): + def __init__(self, block, in_chans=3, num_classes=1000, + depths=[3, 3, 27, 3], dims=[256, 512, 1024,2048], drop_path_rate=0., + layer_scale_init_value=1e-6, head_init_scale=1., + ): + super().__init__() + + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first") + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.Sequential( + *[block(dim=dims[i], drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + self.head = nn.Linear(dims[-1], num_classes) + + self.head.weight.data.mul_(head_init_scale) + self.head.bias.data.mul_(head_init_scale) + + + def forward(self, x): + x_layer1 = self.downsample_layers[0](x) + x_layer1 = self.stages[0](x_layer1) + + + + x_layer2 = self.downsample_layers[1](x_layer1) + x_layer2 = self.stages[1](x_layer2) + + + x_layer3 = self.downsample_layers[2](x_layer2) + out = self.stages[2](x_layer3) + + + return x_layer1, x_layer2, out + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class PALayer(nn.Module): + def __init__(self, channel): + super(PALayer, self).__init__() + self.pa = nn.Sequential( + nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True), + nn.Sigmoid() + ) + def forward(self, x): + y = self.pa(x) + return x * y + +class CALayer(nn.Module): + def __init__(self, channel): + super(CALayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.ca = nn.Sequential( + nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + def forward(self, x): + y = self.avg_pool(x) + y = self.ca(y) + return x * y + +class CP_Attention_block(nn.Module): + def __init__(self, conv, dim, kernel_size): + super(CP_Attention_block, self).__init__() + self.conv1 = conv(dim, dim, kernel_size, bias=True) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = conv(dim, dim, kernel_size, bias=True) + self.calayer = CALayer(dim) + self.palayer = PALayer(dim) + def forward(self, x): + res = self.act1(self.conv1(x)) + res = res + x + res = self.conv2(res) + res = self.calayer(res) + res = self.palayer(res) + res += x + return res + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) + +class knowledge_adaptation_convnext(nn.Module): + def __init__(self): + super(knowledge_adaptation_convnext, self).__init__() + self.encoder = ConvNeXt(Block, in_chans=9,num_classes=1000, depths=[3, 3, 27, 3], dims=[256, 512, 1024,2048], drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1.) + '''pretrained_model = ConvNeXt0(Block, in_chans=3,num_classes=1000, depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1.) + #pretrained_model=nn.DataParallel(pretrained_model) + checkpoint=torch.load('./weights/convnext_xlarge_22k_1k_384_ema.pth') + #for k,v in checkpoint["model"].items(): + #print(k) + #url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth" + + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cuda:0") + pretrained_model.load_state_dict(checkpoint["model"]) + + pretrained_dict = pretrained_model.state_dict() + model_dict = self.encoder.state_dict() + key_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + model_dict.update(key_dict) + self.encoder.load_state_dict(model_dict)''' + + + self.up_block= nn.PixelShuffle(2) + self.attention0 = CP_Attention_block(default_conv, 1024, 3) + self.attention1 = CP_Attention_block(default_conv, 256, 3) + self.attention2 = CP_Attention_block(default_conv, 192, 3) + self.attention3 = CP_Attention_block(default_conv, 112, 3) + self.attention4 = CP_Attention_block(default_conv, 28, 3) + self.conv_process_1 = nn.Conv2d(28, 28, kernel_size=3,padding=1) + self.conv_process_2 = nn.Conv2d(28, 28, kernel_size=3,padding=1) + self.tail = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(28, 3, kernel_size=7, padding=0), nn.Tanh()) + def forward(self, input): + x_layer1, x_layer2, x_output = self.encoder(input) + + x_mid = self.attention0(x_output) #[1024,24,24] + + x = self.up_block(x_mid) #[256,48,48] + x = self.attention1(x) + + x = torch.cat((x, x_layer2), 1) #[768,48,48] + + x = self.up_block(x) #[192,96,96] + x = self.attention2(x) + x = torch.cat((x, x_layer1), 1) #[448,96,96] + x = self.up_block(x) #[112,192,192] + x = self.attention3(x) + + x = self.up_block(x) #[28,384,384] + x = self.attention4(x) + + x=self.conv_process_1(x) + out=self.conv_process_2(x) + return out + + +class fusion_net(nn.Module): + def __init__(self): + super(fusion_net, self).__init__() + self.dwt_branch=dwt_ffc_UNet2() + self.knowledge_adaptation_branch=knowledge_adaptation_convnext() + self.fusion = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(31, 3, kernel_size=7, padding=0), nn.Tanh()) + def forward(self, input): + dwt_branch=self.dwt_branch(input) + knowledge_adaptation_branch=self.knowledge_adaptation_branch(input) + x = torch.cat([dwt_branch, knowledge_adaptation_branch], 1) + x = self.fusion(x) + return x + + + +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + self.net = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2), + + nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.2), + + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + + nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + + nn.Conv2d(128, 256, kernel_size=3, padding=1), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2), + + nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2), + + nn.Conv2d(256, 512, kernel_size=3, padding=1), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2), + + nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2), + + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(512, 1024, kernel_size=1), + nn.LeakyReLU(0.2), + nn.Conv2d(1024, 1, kernel_size=1) + ) + + def forward(self, x): + batch_size = x.size(0) + return torch.sigmoid(self.net(x).view(batch_size)) + + +class Discriminator2(nn.Module): + def __init__(self): + super(Discriminator2, self).__init__() + self.net = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2), + + nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.2), + + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + + nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + + nn.Conv2d(128, 256, kernel_size=3, padding=1), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2), + + nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2), + + nn.Conv2d(256, 512, kernel_size=3, padding=1), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2), + + nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2), + + nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2), + + nn.Conv2d(512, 1, kernel_size=3, padding=1), + ) + + def forward(self, x): + return self.net(x) + +if __name__ == '__main__': + + device = torch.device("cuda:0") + + # Create model + im = torch.rand(1, 3, 640, 640).to(device) + model_g = fusion_net().to(device) + + out_data = model_g(im) diff --git a/DH-AISP/2/myFFCResblock0.py b/DH-AISP/2/myFFCResblock0.py new file mode 100644 index 0000000000000000000000000000000000000000..09fc2bc5c69aa06d23c49c17b95c44626c677308 --- /dev/null +++ b/DH-AISP/2/myFFCResblock0.py @@ -0,0 +1,60 @@ + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from saicinpainting.training.modules.ffc0 import FFCResnetBlock +from saicinpainting.training.modules.ffc0 import FFC_BN_ACT + + + + +class myFFCResblock(nn.Module): + def __init__(self, input_nc, output_nc, n_blocks=2, norm_layer=nn.BatchNorm2d, #128--->64 + padding_type='reflect', activation_layer=nn.ReLU, + resnet_conv_kwargs={}, + spatial_transform_layers=None, spatial_transform_kwargs={}, + add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}): + assert (n_blocks >= 0) + + super().__init__() + self.initial = FFC_BN_ACT(input_nc, input_nc, kernel_size=3, padding=1, dilation=1, + norm_layer=norm_layer, activation_layer=activation_layer, + padding_type=padding_type, + **resnet_conv_kwargs) + + self.ffcresblock = FFCResnetBlock(input_nc, padding_type=padding_type, activation_layer=activation_layer, + norm_layer=norm_layer, **resnet_conv_kwargs) + + + + self.final = FFC_BN_ACT(input_nc, output_nc, kernel_size=3, padding=1, dilation=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **resnet_conv_kwargs) + + + + + + + + def forward(self, x): + + x_l, x_g = self.initial(x) + + x_l, x_g = self.ffcresblock(x_l, x_g) + x_l, x_g = self.ffcresblock(x_l, x_g) + + out_ = torch.cat([x_l, x_g], 1) + + x_lout, x_gout =self.final(out_) + + out = torch.cat([x_lout, x_gout], 1) + return out + + + diff --git a/DH-AISP/2/perceptual.py b/DH-AISP/2/perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..790aaf4ad87a1b3d6110aa1152c9fa874d1b4c13 --- /dev/null +++ b/DH-AISP/2/perceptual.py @@ -0,0 +1,30 @@ +# --- Imports --- # +import torch +import torch.nn.functional as F + +# --- Perceptual loss network --- # +class LossNetwork(torch.nn.Module): + def __init__(self, vgg_model): + super(LossNetwork, self).__init__() + self.vgg_layers = vgg_model + self.layer_name_mapping = { + '3': "relu1_2", + '8': "relu2_2", + '15': "relu3_3" + } + + def output_features(self, x): + output = {} + for name, module in self.vgg_layers._modules.items(): + x = module(x) + if name in self.layer_name_mapping: + output[self.layer_name_mapping[name]] = x + return list(output.values()) + + def forward(self, dehaze, gt): + loss = [] + dehaze_features = self.output_features(dehaze) + gt_features = self.output_features(gt) + for dehaze_feature, gt_feature in zip(dehaze_features, gt_features): + loss.append(F.mse_loss(dehaze_feature, gt_feature)) + return sum(loss)/len(loss) \ No newline at end of file diff --git a/DH-AISP/2/pytorch_msssim/__init__.py b/DH-AISP/2/pytorch_msssim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3032b466d4dcae7bad974b1b71932e16292e25d1 --- /dev/null +++ b/DH-AISP/2/pytorch_msssim/__init__.py @@ -0,0 +1,133 @@ +import torch +import torch.nn.functional as F +from math import exp +import numpy as np + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + + +def create_window(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + + +def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, channel, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window(real_size, channel=channel).to(img1.device) + + mu1 = F.conv2d(img1, window, padding=padd, groups=channel) + mu2 = F.conv2d(img2, window, padding=padd, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): + device = img1.device + weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) + levels = weights.size()[0] + mssim = [] + mcs = [] + for _ in range(levels): + sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) + mssim.append(sim) + mcs.append(cs) + + img1 = F.avg_pool2d(img1, (2, 2)) + img2 = F.avg_pool2d(img2, (2, 2)) + + mssim = torch.stack(mssim) + mcs = torch.stack(mcs) + + # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) + if normalize: + mssim = (mssim + 1) / 2 + mcs = (mcs + 1) / 2 + + pow1 = mcs ** weights + pow2 = mssim ** weights + # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ + output = torch.prod(pow1[:-1] * pow2[-1]) + return output + + +# Classes to re-use window +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, val_range=None): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.val_range = val_range + + # Assume 1 channel for SSIM + self.channel = 1 + self.window = create_window(window_size) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) + self.window = window + self.channel = channel + + return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) + +class MSSSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, channel=3): + super(MSSSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = channel + + def forward(self, img1, img2): + # TODO: store window between calls if possible + return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) diff --git a/DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-36.pyc b/DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea03e038793b95503f0ad00c1cd0e1d0c53595ea Binary files /dev/null and b/DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-36.pyc differ diff --git a/DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-37.pyc b/DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a55e0429936ae83ba1bfe37fadaa646319617d6 Binary files /dev/null and b/DH-AISP/2/pytorch_msssim/__pycache__/__init__.cpython-37.pyc differ diff --git a/DH-AISP/2/result_low_light_hdr/checkpoint_gen.pth b/DH-AISP/2/result_low_light_hdr/checkpoint_gen.pth new file mode 100644 index 0000000000000000000000000000000000000000..d26ecf9a380c9e12f41bb64077aba5af7f233028 --- /dev/null +++ b/DH-AISP/2/result_low_light_hdr/checkpoint_gen.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5952db983eb66b04c6a39348a0916164d9148ec99c4a3b8a77bf4e240657022 +size 1491472482 diff --git a/DH-AISP/2/saicinpainting/__init__.py b/DH-AISP/2/saicinpainting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-36.pyc b/DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbb179b167a653991f019453c75f484e290824b5 Binary files /dev/null and b/DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-37.pyc b/DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8769c3b942ae399de3bb608db6508a35d0572e01 Binary files /dev/null and b/DH-AISP/2/saicinpainting/__pycache__/__init__.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/__pycache__/utils.cpython-36.pyc b/DH-AISP/2/saicinpainting/__pycache__/utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b63b45b56a668c67a4c8a8bda7d15285634d8135 Binary files /dev/null and b/DH-AISP/2/saicinpainting/__pycache__/utils.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/__pycache__/utils.cpython-37.pyc b/DH-AISP/2/saicinpainting/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab0fd0f31b39085e8b478ff008a1fe0e800bdc8b Binary files /dev/null and b/DH-AISP/2/saicinpainting/__pycache__/utils.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/evaluation/__init__.py b/DH-AISP/2/saicinpainting/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c8117565b252ca069a808b31b8c52aaddd2289 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/__init__.py @@ -0,0 +1,33 @@ +import logging + +import torch + +from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1 +from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore + + +def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs): + logging.info(f'Make evaluator {kind}') + device = "cuda" if torch.cuda.is_available() else "cpu" + metrics = {} + if ssim: + metrics['ssim'] = SSIMScore() + if lpips: + metrics['lpips'] = LPIPSScore() + if fid: + metrics['fid'] = FIDScore().to(device) + + if integral_kind is None: + integral_func = None + elif integral_kind == 'ssim_fid100_f1': + integral_func = ssim_fid100_f1 + elif integral_kind == 'lpips_fid100_f1': + integral_func = lpips_fid100_f1 + else: + raise ValueError(f'Unexpected integral_kind={integral_kind}') + + if kind == 'default': + return InpaintingEvaluatorOnline(scores=metrics, + integral_func=integral_func, + integral_title=integral_kind, + **kwargs) diff --git a/DH-AISP/2/saicinpainting/evaluation/data.py b/DH-AISP/2/saicinpainting/evaluation/data.py new file mode 100644 index 0000000000000000000000000000000000000000..89a4ea4c9577e6131731444f149eec76978ec260 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/data.py @@ -0,0 +1,168 @@ +import glob +import os + +import cv2 +import PIL.Image as Image +import numpy as np + +from torch.utils.data import Dataset +import torch.nn.functional as F + + +def load_image(fname, mode='RGB', return_orig=False): + img = np.array(Image.open(fname).convert(mode)) + if img.ndim == 3: + img = np.transpose(img, (2, 0, 1)) + out_img = img.astype('float32') / 255 + if return_orig: + return out_img, img + else: + return out_img + + +def ceil_modulo(x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + +def pad_img_to_modulo(img, mod): + channels, height, width = img.shape + out_height = ceil_modulo(height, mod) + out_width = ceil_modulo(width, mod) + return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric') + + +def pad_tensor_to_modulo(img, mod): + batch_size, channels, height, width = img.shape + out_height = ceil_modulo(height, mod) + out_width = ceil_modulo(width, mod) + return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect') + + +def scale_image(img, factor, interpolation=cv2.INTER_AREA): + if img.shape[0] == 1: + img = img[0] + else: + img = np.transpose(img, (1, 2, 0)) + + img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation) + + if img.ndim == 2: + img = img[None, ...] + else: + img = np.transpose(img, (2, 0, 1)) + return img + + +class InpaintingDataset(Dataset): + def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None): + self.datadir = datadir + self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, '**', '*mask*.png'), recursive=True))) + self.img_filenames = [fname.rsplit('_mask', 1)[0] + img_suffix for fname in self.mask_filenames] + self.pad_out_to_modulo = pad_out_to_modulo + self.scale_factor = scale_factor + + def __len__(self): + return len(self.mask_filenames) + + def __getitem__(self, i): + image = load_image(self.img_filenames[i], mode='RGB') + mask = load_image(self.mask_filenames[i], mode='L') + result = dict(image=image, mask=mask[None, ...]) + + if self.scale_factor is not None: + result['image'] = scale_image(result['image'], self.scale_factor) + result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['unpad_to_size'] = result['image'].shape[1:] + result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo) + result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo) + + return result + +class OurInpaintingDataset(Dataset): + def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None): + self.datadir = datadir + self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, 'mask', '**', '*mask*.png'), recursive=True))) + self.img_filenames = [os.path.join(self.datadir, 'img', os.path.basename(fname.rsplit('-', 1)[0].rsplit('_', 1)[0]) + '.png') for fname in self.mask_filenames] + self.pad_out_to_modulo = pad_out_to_modulo + self.scale_factor = scale_factor + + def __len__(self): + return len(self.mask_filenames) + + def __getitem__(self, i): + result = dict(image=load_image(self.img_filenames[i], mode='RGB'), + mask=load_image(self.mask_filenames[i], mode='L')[None, ...]) + + if self.scale_factor is not None: + result['image'] = scale_image(result['image'], self.scale_factor) + result['mask'] = scale_image(result['mask'], self.scale_factor) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo) + result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo) + + return result + +class PrecomputedInpaintingResultsDataset(InpaintingDataset): + def __init__(self, datadir, predictdir, inpainted_suffix='_inpainted.jpg', **kwargs): + super().__init__(datadir, **kwargs) + if not datadir.endswith('/'): + datadir += '/' + self.predictdir = predictdir + self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix) + for fname in self.mask_filenames] + + def __getitem__(self, i): + result = super().__getitem__(i) + result['inpainted'] = load_image(self.pred_filenames[i]) + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo) + return result + +class OurPrecomputedInpaintingResultsDataset(OurInpaintingDataset): + def __init__(self, datadir, predictdir, inpainted_suffix="png", **kwargs): + super().__init__(datadir, **kwargs) + if not datadir.endswith('/'): + datadir += '/' + self.predictdir = predictdir + self.pred_filenames = [os.path.join(predictdir, os.path.basename(os.path.splitext(fname)[0]) + f'_inpainted.{inpainted_suffix}') + for fname in self.mask_filenames] + # self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix) + # for fname in self.mask_filenames] + + def __getitem__(self, i): + result = super().__getitem__(i) + result['inpainted'] = self.file_loader(self.pred_filenames[i]) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo) + return result + +class InpaintingEvalOnlineDataset(Dataset): + def __init__(self, indir, mask_generator, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None, **kwargs): + self.indir = indir + self.mask_generator = mask_generator + self.img_filenames = sorted(list(glob.glob(os.path.join(self.indir, '**', f'*{img_suffix}' ), recursive=True))) + self.pad_out_to_modulo = pad_out_to_modulo + self.scale_factor = scale_factor + + def __len__(self): + return len(self.img_filenames) + + def __getitem__(self, i): + img, raw_image = load_image(self.img_filenames[i], mode='RGB', return_orig=True) + mask = self.mask_generator(img, raw_image=raw_image) + result = dict(image=img, mask=mask) + + if self.scale_factor is not None: + result['image'] = scale_image(result['image'], self.scale_factor) + result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo) + result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo) + return result \ No newline at end of file diff --git a/DH-AISP/2/saicinpainting/evaluation/evaluator.py b/DH-AISP/2/saicinpainting/evaluation/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..aa9e80402633c08a580929b38a5cb695cb7171d8 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/evaluator.py @@ -0,0 +1,220 @@ +import logging +import math +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import tqdm +from torch.utils.data import DataLoader + +from saicinpainting.evaluation.utils import move_to_device + +LOGGER = logging.getLogger(__name__) + + +class InpaintingEvaluator(): + def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda', + integral_func=None, integral_title=None, clamp_image_range=None): + """ + :param dataset: torch.utils.data.Dataset which contains images and masks + :param scores: dict {score_name: EvaluatorScore object} + :param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples + which are defined by share of area occluded by mask + :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1) + :param batch_size: batch_size for the dataloader + :param device: device to use + """ + self.scores = scores + self.dataset = dataset + + self.area_grouping = area_grouping + self.bins = bins + + self.device = torch.device(device) + + self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size) + + self.integral_func = integral_func + self.integral_title = integral_title + self.clamp_image_range = clamp_image_range + + def _get_bin_edges(self): + bin_edges = np.linspace(0, 1, self.bins + 1) + + num_digits = max(0, math.ceil(math.log10(self.bins)) - 1) + interval_names = [] + for idx_bin in range(self.bins): + start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \ + round(100 * bin_edges[idx_bin + 1], num_digits) + start_percent = '{:.{n}f}'.format(start_percent, n=num_digits) + end_percent = '{:.{n}f}'.format(end_percent, n=num_digits) + interval_names.append("{0}-{1}%".format(start_percent, end_percent)) + + groups = [] + for batch in self.dataloader: + mask = batch['mask'] + batch_size = mask.shape[0] + area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1) + bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1 + # corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element + bin_indices[bin_indices == self.bins] = self.bins - 1 + groups.append(bin_indices) + groups = np.hstack(groups) + + return groups, interval_names + + def evaluate(self, model=None): + """ + :param model: callable with signature (image_batch, mask_batch); should return inpainted_batch + :return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or + name of the particular group arranged by area of mask (e.g. '10-20%') + and score statistics for the group as values. + """ + results = dict() + if self.area_grouping: + groups, interval_names = self._get_bin_edges() + else: + groups = None + + for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'): + score.to(self.device) + with torch.no_grad(): + score.reset() + for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False): + batch = move_to_device(batch, self.device) + image_batch, mask_batch = batch['image'], batch['mask'] + if self.clamp_image_range is not None: + image_batch = torch.clamp(image_batch, + min=self.clamp_image_range[0], + max=self.clamp_image_range[1]) + if model is None: + assert 'inpainted' in batch, \ + 'Model is None, so we expected precomputed inpainting results at key "inpainted"' + inpainted_batch = batch['inpainted'] + else: + inpainted_batch = model(image_batch, mask_batch) + score(inpainted_batch, image_batch, mask_batch) + total_results, group_results = score.get_value(groups=groups) + + results[(score_name, 'total')] = total_results + if groups is not None: + for group_index, group_values in group_results.items(): + group_name = interval_names[group_index] + results[(score_name, group_name)] = group_values + + if self.integral_func is not None: + results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results)) + + return results + + +def ssim_fid100_f1(metrics, fid_scale=100): + ssim = metrics[('ssim', 'total')]['mean'] + fid = metrics[('fid', 'total')]['mean'] + fid_rel = max(0, fid_scale - fid) / fid_scale + f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3) + return f1 + + +def lpips_fid100_f1(metrics, fid_scale=100): + neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better + fid = metrics[('fid', 'total')]['mean'] + fid_rel = max(0, fid_scale - fid) / fid_scale + f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3) + return f1 + + + +class InpaintingEvaluatorOnline(nn.Module): + def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted', + integral_func=None, integral_title=None, clamp_image_range=None): + """ + :param scores: dict {score_name: EvaluatorScore object} + :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1) + :param device: device to use + """ + super().__init__() + LOGGER.info(f'{type(self)} init called') + self.scores = nn.ModuleDict(scores) + self.image_key = image_key + self.inpainted_key = inpainted_key + self.bins_num = bins + self.bin_edges = np.linspace(0, 1, self.bins_num + 1) + + num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1) + self.interval_names = [] + for idx_bin in range(self.bins_num): + start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \ + round(100 * self.bin_edges[idx_bin + 1], num_digits) + start_percent = '{:.{n}f}'.format(start_percent, n=num_digits) + end_percent = '{:.{n}f}'.format(end_percent, n=num_digits) + self.interval_names.append("{0}-{1}%".format(start_percent, end_percent)) + + self.groups = [] + + self.integral_func = integral_func + self.integral_title = integral_title + self.clamp_image_range = clamp_image_range + + LOGGER.info(f'{type(self)} init done') + + def _get_bins(self, mask_batch): + batch_size = mask_batch.shape[0] + area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy() + bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1) + return bin_indices + + def forward(self, batch: Dict[str, torch.Tensor]): + """ + Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end + :param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key) + """ + result = {} + with torch.no_grad(): + image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key] + if self.clamp_image_range is not None: + image_batch = torch.clamp(image_batch, + min=self.clamp_image_range[0], + max=self.clamp_image_range[1]) + self.groups.extend(self._get_bins(mask_batch)) + + for score_name, score in self.scores.items(): + result[score_name] = score(inpainted_batch, image_batch, mask_batch) + return result + + def process_batch(self, batch: Dict[str, torch.Tensor]): + return self(batch) + + def evaluation_end(self, states=None): + """:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or + name of the particular group arranged by area of mask (e.g. '10-20%') + and score statistics for the group as values. + """ + LOGGER.info(f'{type(self)}: evaluation_end called') + + self.groups = np.array(self.groups) + + results = {} + for score_name, score in self.scores.items(): + LOGGER.info(f'Getting value of {score_name}') + cur_states = [s[score_name] for s in states] if states is not None else None + total_results, group_results = score.get_value(groups=self.groups, states=cur_states) + LOGGER.info(f'Getting value of {score_name} done') + results[(score_name, 'total')] = total_results + + for group_index, group_values in group_results.items(): + group_name = self.interval_names[group_index] + results[(score_name, group_name)] = group_values + + if self.integral_func is not None: + results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results)) + + LOGGER.info(f'{type(self)}: reset scores') + self.groups = [] + for sc in self.scores.values(): + sc.reset() + LOGGER.info(f'{type(self)}: reset scores done') + + LOGGER.info(f'{type(self)}: evaluation_end done') + return results diff --git a/DH-AISP/2/saicinpainting/evaluation/losses/__init__.py b/DH-AISP/2/saicinpainting/evaluation/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DH-AISP/2/saicinpainting/evaluation/losses/base_loss.py b/DH-AISP/2/saicinpainting/evaluation/losses/base_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e5cd5fa8d571b2da829b87f0784bd38978158ce7 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/losses/base_loss.py @@ -0,0 +1,528 @@ +import logging +from abc import abstractmethod, ABC + +import numpy as np +import sklearn +import sklearn.svm +import torch +import torch.nn as nn +import torch.nn.functional as F +from joblib import Parallel, delayed +from scipy import linalg + +from models.ade20k import SegmentationModule, NUM_CLASS, segm_options +from .fid.inception import InceptionV3 +from .lpips import PerceptualLoss +from .ssim import SSIM + +LOGGER = logging.getLogger(__name__) + + +def get_groupings(groups): + """ + :param groups: group numbers for respective elements + :return: dict of kind {group_idx: indices of the corresponding group elements} + """ + label_groups, count_groups = np.unique(groups, return_counts=True) + + indices = np.argsort(groups) + + grouping = dict() + cur_start = 0 + for label, count in zip(label_groups, count_groups): + cur_end = cur_start + count + cur_indices = indices[cur_start:cur_end] + grouping[label] = cur_indices + cur_start = cur_end + return grouping + + +class EvaluatorScore(nn.Module): + @abstractmethod + def forward(self, pred_batch, target_batch, mask): + pass + + @abstractmethod + def get_value(self, groups=None, states=None): + pass + + @abstractmethod + def reset(self): + pass + + +class PairwiseScore(EvaluatorScore, ABC): + def __init__(self): + super().__init__() + self.individual_values = None + + def get_value(self, groups=None, states=None): + """ + :param groups: + :return: + total_results: dict of kind {'mean': score mean, 'std': score std} + group_results: None, if groups is None; + else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} + """ + individual_values = torch.cat(states, dim=-1).reshape(-1).cpu().numpy() if states is not None \ + else self.individual_values + + total_results = { + 'mean': individual_values.mean(), + 'std': individual_values.std() + } + + if groups is None: + return total_results, None + + group_results = dict() + grouping = get_groupings(groups) + for label, index in grouping.items(): + group_scores = individual_values[index] + group_results[label] = { + 'mean': group_scores.mean(), + 'std': group_scores.std() + } + return total_results, group_results + + def reset(self): + self.individual_values = [] + + +class SSIMScore(PairwiseScore): + def __init__(self, window_size=11): + super().__init__() + self.score = SSIM(window_size=window_size, size_average=False).eval() + self.reset() + + def forward(self, pred_batch, target_batch, mask=None): + batch_values = self.score(pred_batch, target_batch) + self.individual_values = np.hstack([ + self.individual_values, batch_values.detach().cpu().numpy() + ]) + return batch_values + + +class LPIPSScore(PairwiseScore): + def __init__(self, model='net-lin', net='vgg', model_path=None, use_gpu=True): + super().__init__() + self.score = PerceptualLoss(model=model, net=net, model_path=model_path, + use_gpu=use_gpu, spatial=False).eval() + self.reset() + + def forward(self, pred_batch, target_batch, mask=None): + batch_values = self.score(pred_batch, target_batch).flatten() + self.individual_values = np.hstack([ + self.individual_values, batch_values.detach().cpu().numpy() + ]) + return batch_values + + +def fid_calculate_activation_statistics(act): + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): + mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) + mu2, sigma2 = fid_calculate_activation_statistics(activations_target) + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + LOGGER.warning(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +class FIDScore(EvaluatorScore): + def __init__(self, dims=2048, eps=1e-6): + LOGGER.info("FIDscore init called") + super().__init__() + if getattr(FIDScore, '_MODEL', None) is None: + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + FIDScore._MODEL = InceptionV3([block_idx]).eval() + self.model = FIDScore._MODEL + self.eps = eps + self.reset() + LOGGER.info("FIDscore init done") + + def forward(self, pred_batch, target_batch, mask=None): + activations_pred = self._get_activations(pred_batch) + activations_target = self._get_activations(target_batch) + + self.activations_pred.append(activations_pred.detach().cpu()) + self.activations_target.append(activations_target.detach().cpu()) + + return activations_pred, activations_target + + def get_value(self, groups=None, states=None): + LOGGER.info("FIDscore get_value called") + activations_pred, activations_target = zip(*states) if states is not None \ + else (self.activations_pred, self.activations_target) + activations_pred = torch.cat(activations_pred).cpu().numpy() + activations_target = torch.cat(activations_target).cpu().numpy() + + total_distance = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps) + total_results = dict(mean=total_distance) + + if groups is None: + group_results = None + else: + group_results = dict() + grouping = get_groupings(groups) + for label, index in grouping.items(): + if len(index) > 1: + group_distance = calculate_frechet_distance(activations_pred[index], activations_target[index], + eps=self.eps) + group_results[label] = dict(mean=group_distance) + + else: + group_results[label] = dict(mean=float('nan')) + + self.reset() + + LOGGER.info("FIDscore get_value done") + + return total_results, group_results + + def reset(self): + self.activations_pred = [] + self.activations_target = [] + + def _get_activations(self, batch): + activations = self.model(batch)[0] + if activations.shape[2] != 1 or activations.shape[3] != 1: + assert False, \ + 'We should not have got here, because Inception always scales inputs to 299x299' + # activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1)) + activations = activations.squeeze(-1).squeeze(-1) + return activations + + +class SegmentationAwareScore(EvaluatorScore): + def __init__(self, weights_path): + super().__init__() + self.segm_network = SegmentationModule(weights_path=weights_path, use_default_normalization=True).eval() + self.target_class_freq_by_image_total = [] + self.target_class_freq_by_image_mask = [] + self.pred_class_freq_by_image_mask = [] + + def forward(self, pred_batch, target_batch, mask): + pred_segm_flat = self.segm_network.predict(pred_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy() + target_segm_flat = self.segm_network.predict(target_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy() + mask_flat = (mask.view(mask.shape[0], -1) > 0.5).detach().cpu().numpy() + + batch_target_class_freq_total = [] + batch_target_class_freq_mask = [] + batch_pred_class_freq_mask = [] + + for cur_pred_segm, cur_target_segm, cur_mask in zip(pred_segm_flat, target_segm_flat, mask_flat): + cur_target_class_freq_total = np.bincount(cur_target_segm, minlength=NUM_CLASS)[None, ...] + cur_target_class_freq_mask = np.bincount(cur_target_segm[cur_mask], minlength=NUM_CLASS)[None, ...] + cur_pred_class_freq_mask = np.bincount(cur_pred_segm[cur_mask], minlength=NUM_CLASS)[None, ...] + + self.target_class_freq_by_image_total.append(cur_target_class_freq_total) + self.target_class_freq_by_image_mask.append(cur_target_class_freq_mask) + self.pred_class_freq_by_image_mask.append(cur_pred_class_freq_mask) + + batch_target_class_freq_total.append(cur_target_class_freq_total) + batch_target_class_freq_mask.append(cur_target_class_freq_mask) + batch_pred_class_freq_mask.append(cur_pred_class_freq_mask) + + batch_target_class_freq_total = np.concatenate(batch_target_class_freq_total, axis=0) + batch_target_class_freq_mask = np.concatenate(batch_target_class_freq_mask, axis=0) + batch_pred_class_freq_mask = np.concatenate(batch_pred_class_freq_mask, axis=0) + return batch_target_class_freq_total, batch_target_class_freq_mask, batch_pred_class_freq_mask + + def reset(self): + super().reset() + self.target_class_freq_by_image_total = [] + self.target_class_freq_by_image_mask = [] + self.pred_class_freq_by_image_mask = [] + + +def distribute_values_to_classes(target_class_freq_by_image_mask, values, idx2name): + assert target_class_freq_by_image_mask.ndim == 2 and target_class_freq_by_image_mask.shape[0] == values.shape[0] + total_class_freq = target_class_freq_by_image_mask.sum(0) + distr_values = (target_class_freq_by_image_mask * values[..., None]).sum(0) + result = distr_values / (total_class_freq + 1e-3) + return {idx2name[i]: val for i, val in enumerate(result) if total_class_freq[i] > 0} + + +def get_segmentation_idx2name(): + return {i - 1: name for i, name in segm_options['classes'].set_index('Idx', drop=True)['Name'].to_dict().items()} + + +class SegmentationAwarePairwiseScore(SegmentationAwareScore): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.individual_values = [] + self.segm_idx2name = get_segmentation_idx2name() + + def forward(self, pred_batch, target_batch, mask): + cur_class_stats = super().forward(pred_batch, target_batch, mask) + score_values = self.calc_score(pred_batch, target_batch, mask) + self.individual_values.append(score_values) + return cur_class_stats + (score_values,) + + @abstractmethod + def calc_score(self, pred_batch, target_batch, mask): + raise NotImplementedError() + + def get_value(self, groups=None, states=None): + """ + :param groups: + :return: + total_results: dict of kind {'mean': score mean, 'std': score std} + group_results: None, if groups is None; + else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} + """ + if states is not None: + (target_class_freq_by_image_total, + target_class_freq_by_image_mask, + pred_class_freq_by_image_mask, + individual_values) = states + else: + target_class_freq_by_image_total = self.target_class_freq_by_image_total + target_class_freq_by_image_mask = self.target_class_freq_by_image_mask + pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask + individual_values = self.individual_values + + target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0) + target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0) + pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0) + individual_values = np.concatenate(individual_values, axis=0) + + total_results = { + 'mean': individual_values.mean(), + 'std': individual_values.std(), + **distribute_values_to_classes(target_class_freq_by_image_mask, individual_values, self.segm_idx2name) + } + + if groups is None: + return total_results, None + + group_results = dict() + grouping = get_groupings(groups) + for label, index in grouping.items(): + group_class_freq = target_class_freq_by_image_mask[index] + group_scores = individual_values[index] + group_results[label] = { + 'mean': group_scores.mean(), + 'std': group_scores.std(), + ** distribute_values_to_classes(group_class_freq, group_scores, self.segm_idx2name) + } + return total_results, group_results + + def reset(self): + super().reset() + self.individual_values = [] + + +class SegmentationClassStats(SegmentationAwarePairwiseScore): + def calc_score(self, pred_batch, target_batch, mask): + return 0 + + def get_value(self, groups=None, states=None): + """ + :param groups: + :return: + total_results: dict of kind {'mean': score mean, 'std': score std} + group_results: None, if groups is None; + else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} + """ + if states is not None: + (target_class_freq_by_image_total, + target_class_freq_by_image_mask, + pred_class_freq_by_image_mask, + _) = states + else: + target_class_freq_by_image_total = self.target_class_freq_by_image_total + target_class_freq_by_image_mask = self.target_class_freq_by_image_mask + pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask + + target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0) + target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0) + pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0) + + target_class_freq_by_image_total_marginal = target_class_freq_by_image_total.sum(0).astype('float32') + target_class_freq_by_image_total_marginal /= target_class_freq_by_image_total_marginal.sum() + + target_class_freq_by_image_mask_marginal = target_class_freq_by_image_mask.sum(0).astype('float32') + target_class_freq_by_image_mask_marginal /= target_class_freq_by_image_mask_marginal.sum() + + pred_class_freq_diff = (pred_class_freq_by_image_mask - target_class_freq_by_image_mask).sum(0) / (target_class_freq_by_image_mask.sum(0) + 1e-3) + + total_results = dict() + total_results.update({f'total_freq/{self.segm_idx2name[i]}': v + for i, v in enumerate(target_class_freq_by_image_total_marginal) + if v > 0}) + total_results.update({f'mask_freq/{self.segm_idx2name[i]}': v + for i, v in enumerate(target_class_freq_by_image_mask_marginal) + if v > 0}) + total_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v + for i, v in enumerate(pred_class_freq_diff) + if target_class_freq_by_image_total_marginal[i] > 0}) + + if groups is None: + return total_results, None + + group_results = dict() + grouping = get_groupings(groups) + for label, index in grouping.items(): + group_target_class_freq_by_image_total = target_class_freq_by_image_total[index] + group_target_class_freq_by_image_mask = target_class_freq_by_image_mask[index] + group_pred_class_freq_by_image_mask = pred_class_freq_by_image_mask[index] + + group_target_class_freq_by_image_total_marginal = group_target_class_freq_by_image_total.sum(0).astype('float32') + group_target_class_freq_by_image_total_marginal /= group_target_class_freq_by_image_total_marginal.sum() + + group_target_class_freq_by_image_mask_marginal = group_target_class_freq_by_image_mask.sum(0).astype('float32') + group_target_class_freq_by_image_mask_marginal /= group_target_class_freq_by_image_mask_marginal.sum() + + group_pred_class_freq_diff = (group_pred_class_freq_by_image_mask - group_target_class_freq_by_image_mask).sum(0) / ( + group_target_class_freq_by_image_mask.sum(0) + 1e-3) + + cur_group_results = dict() + cur_group_results.update({f'total_freq/{self.segm_idx2name[i]}': v + for i, v in enumerate(group_target_class_freq_by_image_total_marginal) + if v > 0}) + cur_group_results.update({f'mask_freq/{self.segm_idx2name[i]}': v + for i, v in enumerate(group_target_class_freq_by_image_mask_marginal) + if v > 0}) + cur_group_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v + for i, v in enumerate(group_pred_class_freq_diff) + if group_target_class_freq_by_image_total_marginal[i] > 0}) + + group_results[label] = cur_group_results + return total_results, group_results + + +class SegmentationAwareSSIM(SegmentationAwarePairwiseScore): + def __init__(self, *args, window_size=11, **kwargs): + super().__init__(*args, **kwargs) + self.score_impl = SSIM(window_size=window_size, size_average=False).eval() + + def calc_score(self, pred_batch, target_batch, mask): + return self.score_impl(pred_batch, target_batch).detach().cpu().numpy() + + +class SegmentationAwareLPIPS(SegmentationAwarePairwiseScore): + def __init__(self, *args, model='net-lin', net='vgg', model_path=None, use_gpu=True, **kwargs): + super().__init__(*args, **kwargs) + self.score_impl = PerceptualLoss(model=model, net=net, model_path=model_path, + use_gpu=use_gpu, spatial=False).eval() + + def calc_score(self, pred_batch, target_batch, mask): + return self.score_impl(pred_batch, target_batch).flatten().detach().cpu().numpy() + + +def calculade_fid_no_img(img_i, activations_pred, activations_target, eps=1e-6): + activations_pred = activations_pred.copy() + activations_pred[img_i] = activations_target[img_i] + return calculate_frechet_distance(activations_pred, activations_target, eps=eps) + + +class SegmentationAwareFID(SegmentationAwarePairwiseScore): + def __init__(self, *args, dims=2048, eps=1e-6, n_jobs=-1, **kwargs): + super().__init__(*args, **kwargs) + if getattr(FIDScore, '_MODEL', None) is None: + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + FIDScore._MODEL = InceptionV3([block_idx]).eval() + self.model = FIDScore._MODEL + self.eps = eps + self.n_jobs = n_jobs + + def calc_score(self, pred_batch, target_batch, mask): + activations_pred = self._get_activations(pred_batch) + activations_target = self._get_activations(target_batch) + return activations_pred, activations_target + + def get_value(self, groups=None, states=None): + """ + :param groups: + :return: + total_results: dict of kind {'mean': score mean, 'std': score std} + group_results: None, if groups is None; + else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} + """ + if states is not None: + (target_class_freq_by_image_total, + target_class_freq_by_image_mask, + pred_class_freq_by_image_mask, + activation_pairs) = states + else: + target_class_freq_by_image_total = self.target_class_freq_by_image_total + target_class_freq_by_image_mask = self.target_class_freq_by_image_mask + pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask + activation_pairs = self.individual_values + + target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0) + target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0) + pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0) + activations_pred, activations_target = zip(*activation_pairs) + activations_pred = np.concatenate(activations_pred, axis=0) + activations_target = np.concatenate(activations_target, axis=0) + + total_results = { + 'mean': calculate_frechet_distance(activations_pred, activations_target, eps=self.eps), + 'std': 0, + **self.distribute_fid_to_classes(target_class_freq_by_image_mask, activations_pred, activations_target) + } + + if groups is None: + return total_results, None + + group_results = dict() + grouping = get_groupings(groups) + for label, index in grouping.items(): + if len(index) > 1: + group_activations_pred = activations_pred[index] + group_activations_target = activations_target[index] + group_class_freq = target_class_freq_by_image_mask[index] + group_results[label] = { + 'mean': calculate_frechet_distance(group_activations_pred, group_activations_target, eps=self.eps), + 'std': 0, + **self.distribute_fid_to_classes(group_class_freq, + group_activations_pred, + group_activations_target) + } + else: + group_results[label] = dict(mean=float('nan'), std=0) + return total_results, group_results + + def distribute_fid_to_classes(self, class_freq, activations_pred, activations_target): + real_fid = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps) + + fid_no_images = Parallel(n_jobs=self.n_jobs)( + delayed(calculade_fid_no_img)(img_i, activations_pred, activations_target, eps=self.eps) + for img_i in range(activations_pred.shape[0]) + ) + errors = real_fid - fid_no_images + return distribute_values_to_classes(class_freq, errors, self.segm_idx2name) + + def _get_activations(self, batch): + activations = self.model(batch)[0] + if activations.shape[2] != 1 or activations.shape[3] != 1: + activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1)) + activations = activations.squeeze(-1).squeeze(-1).detach().cpu().numpy() + return activations diff --git a/DH-AISP/2/saicinpainting/evaluation/losses/fid/__init__.py b/DH-AISP/2/saicinpainting/evaluation/losses/fid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DH-AISP/2/saicinpainting/evaluation/losses/fid/fid_score.py b/DH-AISP/2/saicinpainting/evaluation/losses/fid/fid_score.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca8e602c21bb6a624d646da3f6479aea033b0ac --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/losses/fid/fid_score.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +"""Calculates the Frechet Inception Distance (FID) to evalulate GANs + +The FID metric calculates the distance between two distributions of images. +Typically, we have summary statistics (mean & covariance matrix) of one +of these distributions, while the 2nd distribution is given by a GAN. + +When run as a stand-alone program, it compares the distribution of +images that are stored as PNG/JPEG at a specified location with a +distribution given by summary statistics (in pickle format). + +The FID is calculated by assuming that X_1 and X_2 are the activations of +the pool_3 layer of the inception net for generated samples and real world +samples respectively. + +See --help to see further details. + +Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead +of Tensorflow + +Copyright 2018 Institute of Bioinformatics, JKU Linz + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import pathlib +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +import numpy as np +import torch +# from scipy.misc import imread +from imageio import imread +from PIL import Image, JpegImagePlugin +from scipy import linalg +from torch.nn.functional import adaptive_avg_pool2d +from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor + +try: + from tqdm import tqdm +except ImportError: + # If not tqdm is not available, provide a mock version of it + def tqdm(x): return x + +try: + from .inception import InceptionV3 +except ModuleNotFoundError: + from inception import InceptionV3 + +parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) +parser.add_argument('path', type=str, nargs=2, + help=('Path to the generated images or ' + 'to .npz statistic files')) +parser.add_argument('--batch-size', type=int, default=50, + help='Batch size to use') +parser.add_argument('--dims', type=int, default=2048, + choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), + help=('Dimensionality of Inception features to use. ' + 'By default, uses pool3 features')) +parser.add_argument('-c', '--gpu', default='', type=str, + help='GPU to use (leave blank for CPU only)') +parser.add_argument('--resize', default=256) + +transform = Compose([Resize(256), CenterCrop(256), ToTensor()]) + + +def get_activations(files, model, batch_size=50, dims=2048, + cuda=False, verbose=False, keep_size=False): + """Calculates the activations of the pool_3 layer for all images. + + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : Batch size of images for the model to process at once. + Make sure that the number of samples is a multiple of + the batch size, otherwise some samples are ignored. This + behavior is retained to match the original FID score + implementation. + -- dims : Dimensionality of features returned by Inception + -- cuda : If set to True, use GPU + -- verbose : If set to True and parameter out_step is given, the number + of calculated batches is reported. + Returns: + -- A numpy array of dimension (num images, dims) that contains the + activations of the given tensor when feeding inception with the + query tensor. + """ + model.eval() + + if len(files) % batch_size != 0: + print(('Warning: number of images is not a multiple of the ' + 'batch size. Some samples are going to be ignored.')) + if batch_size > len(files): + print(('Warning: batch size is bigger than the data size. ' + 'Setting batch size to data size')) + batch_size = len(files) + + n_batches = len(files) // batch_size + n_used_imgs = n_batches * batch_size + + pred_arr = np.empty((n_used_imgs, dims)) + + for i in tqdm(range(n_batches)): + if verbose: + print('\rPropagating batch %d/%d' % (i + 1, n_batches), + end='', flush=True) + start = i * batch_size + end = start + batch_size + + # # Official code goes below + # images = np.array([imread(str(f)).astype(np.float32) + # for f in files[start:end]]) + + # # Reshape to (n_images, 3, height, width) + # images = images.transpose((0, 3, 1, 2)) + # images /= 255 + # batch = torch.from_numpy(images).type(torch.FloatTensor) + # # + + t = transform if not keep_size else ToTensor() + + if isinstance(files[0], pathlib.PosixPath): + images = [t(Image.open(str(f))) for f in files[start:end]] + + elif isinstance(files[0], Image.Image): + images = [t(f) for f in files[start:end]] + + else: + raise ValueError(f"Unknown data type for image: {type(files[0])}") + + batch = torch.stack(images) + + if cuda: + batch = batch.cuda() + + pred = model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.shape[2] != 1 or pred.shape[3] != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) + + if verbose: + print(' done') + + return pred_arr + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +def calculate_activation_statistics(files, model, batch_size=50, + dims=2048, cuda=False, verbose=False, keep_size=False): + """Calculation of the statistics used by the FID. + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : The images numpy array is split into batches with + batch size batch_size. A reasonable batch size + depends on the hardware. + -- dims : Dimensionality of features returned by Inception + -- cuda : If set to True, use GPU + -- verbose : If set to True and parameter out_step is given, the + number of calculated batches is reported. + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the inception model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the inception model. + """ + act = get_activations(files, model, batch_size, dims, cuda, verbose, keep_size=keep_size) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def _compute_statistics_of_path(path, model, batch_size, dims, cuda): + if path.endswith('.npz'): + f = np.load(path) + m, s = f['mu'][:], f['sigma'][:] + f.close() + else: + path = pathlib.Path(path) + files = list(path.glob('*.jpg')) + list(path.glob('*.png')) + m, s = calculate_activation_statistics(files, model, batch_size, + dims, cuda) + + return m, s + + +def _compute_statistics_of_images(images, model, batch_size, dims, cuda, keep_size=False): + if isinstance(images, list): # exact paths to files are provided + m, s = calculate_activation_statistics(images, model, batch_size, + dims, cuda, keep_size=keep_size) + + return m, s + + else: + raise ValueError + + +def calculate_fid_given_paths(paths, batch_size, cuda, dims): + """Calculates the FID of two paths""" + for p in paths: + if not os.path.exists(p): + raise RuntimeError('Invalid path: %s' % p) + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + + model = InceptionV3([block_idx]) + if cuda: + model.cuda() + + m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, + dims, cuda) + m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, + dims, cuda) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + + return fid_value + + +def calculate_fid_given_images(images, batch_size, cuda, dims, use_globals=False, keep_size=False): + if use_globals: + global FID_MODEL # for multiprocessing + + for imgs in images: + if isinstance(imgs, list) and isinstance(imgs[0], (Image.Image, JpegImagePlugin.JpegImageFile)): + pass + else: + raise RuntimeError('Invalid images') + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + + if 'FID_MODEL' not in globals() or not use_globals: + model = InceptionV3([block_idx]) + if cuda: + model.cuda() + + if use_globals: + FID_MODEL = model + + else: + model = FID_MODEL + + m1, s1 = _compute_statistics_of_images(images[0], model, batch_size, + dims, cuda, keep_size=False) + m2, s2 = _compute_statistics_of_images(images[1], model, batch_size, + dims, cuda, keep_size=False) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + return fid_value + + +if __name__ == '__main__': + args = parser.parse_args() + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + fid_value = calculate_fid_given_paths(args.path, + args.batch_size, + args.gpu != '', + args.dims) + print('FID: ', fid_value) diff --git a/DH-AISP/2/saicinpainting/evaluation/losses/fid/inception.py b/DH-AISP/2/saicinpainting/evaluation/losses/fid/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..e9bd0863b457aaa40c770eaa4acbb142b18fc18b --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/losses/fid/inception.py @@ -0,0 +1,323 @@ +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' + + +LOGGER = logging.getLogger(__name__) + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate(x, + size=(299, 299), + mode='bilinear', + align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + LOGGER.info('fid_inception_v3 called') + inception = models.inception_v3(num_classes=1008, + aux_logits=False, + pretrained=False) + LOGGER.info('models.inception_v3 done') + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + LOGGER.info('fid_inception_v3 patching done') + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + LOGGER.info('fid_inception_v3 weights downloaded') + + inception.load_state_dict(state_dict) + LOGGER.info('fid_inception_v3 weights loaded into model') + + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/DH-AISP/2/saicinpainting/evaluation/losses/lpips.py b/DH-AISP/2/saicinpainting/evaluation/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f19b747f2457902695213f7efcde4fdc306c1f --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/losses/lpips.py @@ -0,0 +1,891 @@ +############################################################ +# The contents below have been combined using files in the # +# following repository: # +# https://github.com/richzhang/PerceptualSimilarity # +############################################################ + +############################################################ +# __init__.py # +############################################################ + +import numpy as np +from skimage.metrics import structural_similarity +import torch + +from saicinpainting.utils import get_shape + + +class PerceptualLoss(torch.nn.Module): + def __init__(self, model='net-lin', net='alex', colorspace='rgb', model_path=None, spatial=False, use_gpu=True): + # VGG using our perceptually-learned weights (LPIPS metric) + # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss + super(PerceptualLoss, self).__init__() + self.use_gpu = use_gpu + self.spatial = spatial + self.model = DistModel() + self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, + model_path=model_path, spatial=self.spatial) + + def forward(self, pred, target, normalize=True): + """ + Pred and target are Variables. + If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] + If normalize is False, assumes the images are already between [-1,+1] + Inputs pred and target are Nx3xHxW + Output pytorch Variable N long + """ + + if normalize: + target = 2 * target - 1 + pred = 2 * pred - 1 + + return self.model(target, pred) + + +def normalize_tensor(in_feat, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) + return in_feat / (norm_factor + eps) + + +def l2(p0, p1, range=255.): + return .5 * np.mean((p0 / range - p1 / range) ** 2) + + +def psnr(p0, p1, peak=255.): + return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2)) + + +def dssim(p0, p1, range=255.): + return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. + + +def rgb2lab(in_img, mean_cent=False): + from skimage import color + img_lab = color.rgb2lab(in_img) + if (mean_cent): + img_lab[:, :, 0] = img_lab[:, :, 0] - 50 + return img_lab + + +def tensor2np(tensor_obj): + # change dimension of a tensor object into a numpy array + return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) + + +def np2tensor(np_obj): + # change dimenion of np array into tensor array + return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + + +def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): + # image tensor to lab tensor + from skimage import color + + img = tensor2im(image_tensor) + img_lab = color.rgb2lab(img) + if (mc_only): + img_lab[:, :, 0] = img_lab[:, :, 0] - 50 + if (to_norm and not mc_only): + img_lab[:, :, 0] = img_lab[:, :, 0] - 50 + img_lab = img_lab / 100. + + return np2tensor(img_lab) + + +def tensorlab2tensor(lab_tensor, return_inbnd=False): + from skimage import color + import warnings + warnings.filterwarnings("ignore") + + lab = tensor2np(lab_tensor) * 100. + lab[:, :, 0] = lab[:, :, 0] + 50 + + rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1) + if (return_inbnd): + # convert back to lab, see if we match + lab_back = color.rgb2lab(rgb_back.astype('uint8')) + mask = 1. * np.isclose(lab_back, lab, atol=2.) + mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) + return (im2tensor(rgb_back), mask) + else: + return im2tensor(rgb_back) + + +def rgb2lab(input): + from skimage import color + return color.rgb2lab(input / 255.) + + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + + +def tensor2vec(vector_tensor): + return vector_tensor.data.cpu().numpy()[:, :, 0, 0] + + +def voc_ap(rec, prec, use_07_metric=False): + """ ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): + # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): + # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + + +############################################################ +# base_model.py # +############################################################ + + +class BaseModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def name(self): + return 'BaseModel' + + def initialize(self, use_gpu=True): + self.use_gpu = use_gpu + + def forward(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, path, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(path, save_filename) + torch.save(network.state_dict(), save_path) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + print('Loading network from %s' % save_path) + network.load_state_dict(torch.load(save_path, map_location='cpu')) + + def update_learning_rate(): + pass + + def get_image_paths(self): + return self.image_paths + + def save_done(self, flag=False): + np.save(os.path.join(self.save_dir, 'done_flag'), flag) + np.savetxt(os.path.join(self.save_dir, 'done_flag'), [flag, ], fmt='%i') + + +############################################################ +# dist_model.py # +############################################################ + +import os +from collections import OrderedDict +from scipy.ndimage import zoom +from tqdm import tqdm + + +class DistModel(BaseModel): + def name(self): + return self.model_name + + def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, + model_path=None, + use_gpu=True, printNet=False, spatial=False, + is_train=False, lr=.0001, beta1=0.5, version='0.1'): + ''' + INPUTS + model - ['net-lin'] for linearly calibrated network + ['net'] for off-the-shelf network + ['L2'] for L2 distance in Lab colorspace + ['SSIM'] for ssim in RGB colorspace + net - ['squeeze','alex','vgg'] + model_path - if None, will look in weights/[NET_NAME].pth + colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM + use_gpu - bool - whether or not to use a GPU + printNet - bool - whether or not to print network architecture out + spatial - bool - whether to output an array containing varying distances across spatial dimensions + spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). + spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. + spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). + is_train - bool - [True] for training mode + lr - float - initial learning rate + beta1 - float - initial momentum term for adam + version - 0.1 for latest, 0.0 was original (with a bug) + ''' + BaseModel.initialize(self, use_gpu=use_gpu) + + self.model = model + self.net = net + self.is_train = is_train + self.spatial = spatial + self.model_name = '%s [%s]' % (model, net) + + if (self.model == 'net-lin'): # pretrained net + linear layer + self.net = PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, + use_dropout=True, spatial=spatial, version=version, lpips=True) + kw = dict(map_location='cpu') + if (model_path is None): + import inspect + model_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), '..', '..', '..', 'models', 'lpips_models', f'{net}.pth')) + + if (not is_train): + self.net.load_state_dict(torch.load(model_path, **kw), strict=False) + + elif (self.model == 'net'): # pretrained network + self.net = PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) + elif (self.model in ['L2', 'l2']): + self.net = L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing + self.model_name = 'L2' + elif (self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']): + self.net = DSSIM(use_gpu=use_gpu, colorspace=colorspace) + self.model_name = 'SSIM' + else: + raise ValueError("Model [%s] not recognized." % self.model) + + self.trainable_parameters = list(self.net.parameters()) + + if self.is_train: # training mode + # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) + self.rankLoss = BCERankingLoss() + self.trainable_parameters += list(self.rankLoss.net.parameters()) + self.lr = lr + self.old_lr = lr + self.optimizer_net = torch.optim.Adam(self.trainable_parameters, lr=lr, betas=(beta1, 0.999)) + else: # test mode + self.net.eval() + + # if (use_gpu): + # self.net.to(gpu_ids[0]) + # self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) + # if (self.is_train): + # self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 + + if (printNet): + print('---------- Networks initialized -------------') + print_network(self.net) + print('-----------------------------------------------') + + def forward(self, in0, in1, retPerLayer=False): + ''' Function computes the distance between image patches in0 and in1 + INPUTS + in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] + OUTPUT + computed distances between in0 and in1 + ''' + + return self.net(in0, in1, retPerLayer=retPerLayer) + + # ***** TRAINING FUNCTIONS ***** + def optimize_parameters(self): + self.forward_train() + self.optimizer_net.zero_grad() + self.backward_train() + self.optimizer_net.step() + self.clamp_weights() + + def clamp_weights(self): + for module in self.net.modules(): + if (hasattr(module, 'weight') and module.kernel_size == (1, 1)): + module.weight.data = torch.clamp(module.weight.data, min=0) + + def set_input(self, data): + self.input_ref = data['ref'] + self.input_p0 = data['p0'] + self.input_p1 = data['p1'] + self.input_judge = data['judge'] + + # if (self.use_gpu): + # self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) + # self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) + # self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) + # self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) + + # self.var_ref = Variable(self.input_ref, requires_grad=True) + # self.var_p0 = Variable(self.input_p0, requires_grad=True) + # self.var_p1 = Variable(self.input_p1, requires_grad=True) + + def forward_train(self): # run forward pass + # print(self.net.module.scaling_layer.shift) + # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) + + assert False, "We shoud've not get here when using LPIPS as a metric" + + self.d0 = self(self.var_ref, self.var_p0) + self.d1 = self(self.var_ref, self.var_p1) + self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge) + + self.var_judge = Variable(1. * self.input_judge).view(self.d0.size()) + + self.loss_total = self.rankLoss(self.d0, self.d1, self.var_judge * 2. - 1.) + + return self.loss_total + + def backward_train(self): + torch.mean(self.loss_total).backward() + + def compute_accuracy(self, d0, d1, judge): + ''' d0, d1 are Variables, judge is a Tensor ''' + d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten() + judge_per = judge.cpu().numpy().flatten() + return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per) + + def get_current_errors(self): + retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()), + ('acc_r', self.acc_r)]) + + for key in retDict.keys(): + retDict[key] = np.mean(retDict[key]) + + return retDict + + def get_current_visuals(self): + zoom_factor = 256 / self.var_ref.data.size()[2] + + ref_img = tensor2im(self.var_ref.data) + p0_img = tensor2im(self.var_p0.data) + p1_img = tensor2im(self.var_p1.data) + + ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0) + p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0) + p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0) + + return OrderedDict([('ref', ref_img_vis), + ('p0', p0_img_vis), + ('p1', p1_img_vis)]) + + def save(self, path, label): + if (self.use_gpu): + self.save_network(self.net.module, path, '', label) + else: + self.save_network(self.net, path, '', label) + self.save_network(self.rankLoss.net, path, 'rank', label) + + def update_learning_rate(self, nepoch_decay): + lrd = self.lr / nepoch_decay + lr = self.old_lr - lrd + + for param_group in self.optimizer_net.param_groups: + param_group['lr'] = lr + + print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr)) + self.old_lr = lr + + +def score_2afc_dataset(data_loader, func, name=''): + ''' Function computes Two Alternative Forced Choice (2AFC) score using + distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return numpy array of length N + OUTPUTS + [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators + [1] - dictionary with following elements + d0s,d1s - N arrays containing distances between reference patch to perturbed patches + gts - N array in [0,1], preferred patch selected by human evaluators + (closer to "0" for left patch p0, "1" for right patch p1, + "0.6" means 60pct people preferred right patch, 40pct preferred left) + scores - N array in [0,1], corresponding to what percentage function agreed with humans + CONSTS + N - number of test triplets in data_loader + ''' + + d0s = [] + d1s = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist() + d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist() + gts += data['judge'].cpu().numpy().flatten().tolist() + + d0s = np.array(d0s) + d1s = np.array(d1s) + gts = np.array(gts) + scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5 + + return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores)) + + +def score_jnd_dataset(data_loader, func, name=''): + ''' Function computes JND score using distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return pytorch array of length N + OUTPUTS + [0] - JND score in [0,1], mAP score (area under precision-recall curve) + [1] - dictionary with following elements + ds - N array containing distances between two patches shown to human evaluator + sames - N array containing fraction of people who thought the two patches were identical + CONSTS + N - number of test triplets in data_loader + ''' + + ds = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist() + gts += data['same'].cpu().numpy().flatten().tolist() + + sames = np.array(gts) + ds = np.array(ds) + + sorted_inds = np.argsort(ds) + ds_sorted = ds[sorted_inds] + sames_sorted = sames[sorted_inds] + + TPs = np.cumsum(sames_sorted) + FPs = np.cumsum(1 - sames_sorted) + FNs = np.sum(sames_sorted) - TPs + + precs = TPs / (TPs + FPs) + recs = TPs / (TPs + FNs) + score = voc_ap(recs, precs) + + return (score, dict(ds=ds, sames=sames)) + + +############################################################ +# networks_basic.py # +############################################################ + +import torch.nn as nn +from torch.autograd import Variable +import numpy as np + + +def spatial_average(in_tens, keepdim=True): + return in_tens.mean([2, 3], keepdim=keepdim) + + +def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W + in_H = in_tens.shape[2] + scale_factor = 1. * out_H / in_H + + return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) + + +# Learned perceptual metric +class PNetLin(nn.Module): + def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, + version='0.1', lpips=True): + super(PNetLin, self).__init__() + + self.pnet_type = pnet_type + self.pnet_tune = pnet_tune + self.pnet_rand = pnet_rand + self.spatial = spatial + self.lpips = lpips + self.version = version + self.scaling_layer = ScalingLayer() + + if (self.pnet_type in ['vgg', 'vgg16']): + net_type = vgg16 + self.chns = [64, 128, 256, 512, 512] + elif (self.pnet_type == 'alex'): + net_type = alexnet + self.chns = [64, 192, 384, 256, 256] + elif (self.pnet_type == 'squeeze'): + net_type = squeezenet + self.chns = [64, 128, 256, 384, 384, 512, 512] + self.L = len(self.chns) + + self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) + + if (lpips): + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + if (self.pnet_type == 'squeeze'): # 7 layers for squeezenet + self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) + self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) + self.lins += [self.lin5, self.lin6] + + def forward(self, in0, in1, retPerLayer=False): + # v0.0 - original release had a bug, where input was not scaled + in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else ( + in0, in1) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + + for kk in range(self.L): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + if (self.lpips): + if (self.spatial): + res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] + else: + res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] + else: + if (self.spatial): + res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] + else: + res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] + + val = res[0] + for l in range(1, self.L): + val += res[l] + + if (retPerLayer): + return (val, res) + else: + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + ''' A single linear layer which does a 1x1 conv ''' + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class Dist2LogitLayer(nn.Module): + ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' + + def __init__(self, chn_mid=32, use_sigmoid=True): + super(Dist2LogitLayer, self).__init__() + + layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ] + layers += [nn.LeakyReLU(0.2, True), ] + layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ] + layers += [nn.LeakyReLU(0.2, True), ] + layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ] + if (use_sigmoid): + layers += [nn.Sigmoid(), ] + self.model = nn.Sequential(*layers) + + def forward(self, d0, d1, eps=0.1): + return self.model(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)) + + +class BCERankingLoss(nn.Module): + def __init__(self, chn_mid=32): + super(BCERankingLoss, self).__init__() + self.net = Dist2LogitLayer(chn_mid=chn_mid) + # self.parameters = list(self.net.parameters()) + self.loss = torch.nn.BCELoss() + + def forward(self, d0, d1, judge): + per = (judge + 1.) / 2. + self.logit = self.net(d0, d1) + return self.loss(self.logit, per) + + +# L2, DSSIM metrics +class FakeNet(nn.Module): + def __init__(self, use_gpu=True, colorspace='Lab'): + super(FakeNet, self).__init__() + self.use_gpu = use_gpu + self.colorspace = colorspace + + +class L2(FakeNet): + + def forward(self, in0, in1, retPerLayer=None): + assert (in0.size()[0] == 1) # currently only supports batchSize 1 + + if (self.colorspace == 'RGB'): + (N, C, X, Y) = in0.size() + value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), + dim=3).view(N) + return value + elif (self.colorspace == 'Lab'): + value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), + tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') + ret_var = Variable(torch.Tensor((value,))) + # if (self.use_gpu): + # ret_var = ret_var.cuda() + return ret_var + + +class DSSIM(FakeNet): + + def forward(self, in0, in1, retPerLayer=None): + assert (in0.size()[0] == 1) # currently only supports batchSize 1 + + if (self.colorspace == 'RGB'): + value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype('float') + elif (self.colorspace == 'Lab'): + value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), + tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') + ret_var = Variable(torch.Tensor((value,))) + # if (self.use_gpu): + # ret_var = ret_var.cuda() + return ret_var + + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print('Network', net) + print('Total number of parameters: %d' % num_params) + + +############################################################ +# pretrained_networks.py # +############################################################ + +from collections import namedtuple +import torch +from torchvision import models as tv + + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7']) + out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if (num == 18): + self.net = tv.resnet18(pretrained=pretrained) + elif (num == 34): + self.net = tv.resnet34(pretrained=pretrained) + elif (num == 50): + self.net = tv.resnet50(pretrained=pretrained) + elif (num == 101): + self.net = tv.resnet101(pretrained=pretrained) + elif (num == 152): + self.net = tv.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5']) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out diff --git a/DH-AISP/2/saicinpainting/evaluation/losses/ssim.py b/DH-AISP/2/saicinpainting/evaluation/losses/ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..ee43a0095408eca98e253dea194db788446f9c0a --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/losses/ssim.py @@ -0,0 +1,74 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +class SSIM(torch.nn.Module): + """SSIM. Modified from: + https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py + """ + + def __init__(self, window_size=11, size_average=True): + super().__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.register_buffer('window', self._create_window(window_size, self.channel)) + + def forward(self, img1, img2): + assert len(img1.shape) == 4 + + channel = img1.size()[1] + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = self._create_window(self.window_size, channel) + + # window = window.to(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return self._ssim(img1, img2, window, self.window_size, channel, self.size_average) + + def _gaussian(self, window_size, sigma): + gauss = torch.Tensor([ + np.exp(-(x - (window_size // 2)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size) + ]) + return gauss / gauss.sum() + + def _create_window(self, window_size, channel): + _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + return _2D_window.expand(channel, 1, window_size, window_size).contiguous() + + def _ssim(self, img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=(window_size // 2), groups=channel) + mu2 = F.conv2d(img2, window, padding=(window_size // 2), groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=(window_size // 2), groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=(window_size // 2), groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=(window_size // 2), groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ + ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + + return ssim_map.mean(1).mean(1).mean(1) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + return diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/README.md b/DH-AISP/2/saicinpainting/evaluation/masks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cf176bc10fae3b03f139727147c220f2a735c806 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/masks/README.md @@ -0,0 +1,27 @@ +# Current algorithm + +## Choice of mask objects + +For identification of the objects which are suitable for mask obtaining, panoptic segmentation model +from [detectron2](https://github.com/facebookresearch/detectron2) trained on COCO. Categories of the detected instances +belong either to "stuff" or "things" types. We consider that instances of objects should have category belong +to "things". Besides, we set upper bound on area which is taken by the object — we consider that too big +area indicates either of the instance being a background or a main object which should not be removed. + +## Choice of position for mask + +We consider that input image has size 2^n x 2^m. We downsample it using +[COUNTLESS](https://github.com/william-silversmith/countless) algorithm so the width is equal to +64 = 2^8 = 2^{downsample_levels}. + +### Augmentation + +There are several parameters for augmentation: +- Scaling factor. We limit scaling to the case when a mask after scaling with pivot point in its center fits inside the + image completely. +- + +### Shift + + +## Select diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/__init__.py b/DH-AISP/2/saicinpainting/evaluation/masks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/README.md b/DH-AISP/2/saicinpainting/evaluation/masks/countless/README.md new file mode 100644 index 0000000000000000000000000000000000000000..67335464d794776140fd0308f408608f2231309b --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/masks/countless/README.md @@ -0,0 +1,25 @@ +[![Build Status](https://travis-ci.org/william-silversmith/countless.svg?branch=master)](https://travis-ci.org/william-silversmith/countless) + +Python COUNTLESS Downsampling +============================= + +To install: + +`pip install -r requirements.txt` + +To test: + +`python test.py` + +To benchmark countless2d: + +`python python/countless2d.py python/images/gray_segmentation.png` + +To benchmark countless3d: + +`python python/countless3d.py` + +Adjust N and the list of algorithms inside each script to modify the run parameters. + + +Python3 is slightly faster than Python2. \ No newline at end of file diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/__init__.py b/DH-AISP/2/saicinpainting/evaluation/masks/countless/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/countless2d.py b/DH-AISP/2/saicinpainting/evaluation/masks/countless/countless2d.py new file mode 100644 index 0000000000000000000000000000000000000000..dc27b73affa20ab1a8a199542469a10aaf1f555a --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/masks/countless/countless2d.py @@ -0,0 +1,529 @@ +from __future__ import print_function, division + +""" +COUNTLESS performance test in Python. + +python countless2d.py ./images/NAMEOFIMAGE +""" + +import six +from six.moves import range +from collections import defaultdict +from functools import reduce +import operator +import io +import os +from PIL import Image +import math +import numpy as np +import random +import sys +import time +from tqdm import tqdm +from scipy import ndimage + +def simplest_countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab = a * (a == b) # PICK(A,B) + ac = a * (a == c) # PICK(A,C) + bc = b * (b == c) # PICK(B,C) + + a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed + + return a + (a == 0) * d # AB || AC || BC || D + +def quick_countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization + bc = b * (b == c) # PICK(B,C) + + a = ab_ac | bc # (PICK(A,B) || PICK(A,C)) or PICK(B,C) + return a + (a == 0) * d # AB || AC || BC || D + +def quickest_countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization + ab_ac |= b * (b == c) # PICK(B,C) + return ab_ac + (ab_ac == 0) * d # AB || AC || BC || D + +def quick_countless_xor(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab = a ^ (a ^ b) # a or b + ab += (ab != a) * ((ab ^ (ab ^ c)) - b) # b or c + ab += (ab == c) * ((ab ^ (ab ^ d)) - c) # c or d + return ab + +def stippled_countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm + that treats zero as "background" and inflates lone + pixels. + + data is a 2D numpy array with even dimensions. + """ + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization + ab_ac |= b * (b == c) # PICK(B,C) + + nonzero = a + (a == 0) * (b + (b == 0) * c) + return ab_ac + (ab_ac == 0) * (d + (d == 0) * nonzero) # AB || AC || BC || D + +def zero_corrected_countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + # allows us to prevent losing 1/2 a bit of information + # at the top end by using a bigger type. Without this 255 is handled incorrectly. + data, upgraded = upgrade_type(data) + + # offset from zero, raw countless doesn't handle 0 correctly + # we'll remove the extra 1 at the end. + data += 1 + + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab = a * (a == b) # PICK(A,B) + ac = a * (a == c) # PICK(A,C) + bc = b * (b == c) # PICK(B,C) + + a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed + + result = a + (a == 0) * d - 1 # a or d - 1 + + if upgraded: + return downgrade_type(result) + + # only need to reset data if we weren't upgraded + # b/c no copy was made in that case + data -= 1 + + return result + +def countless_extreme(data): + nonzeros = np.count_nonzero(data) + # print("nonzeros", nonzeros) + + N = reduce(operator.mul, data.shape) + + if nonzeros == N: + print("quick") + return quick_countless(data) + elif np.count_nonzero(data + 1) == N: + print("quick") + # print("upper", nonzeros) + return quick_countless(data) + else: + return countless(data) + + +def countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + # allows us to prevent losing 1/2 a bit of information + # at the top end by using a bigger type. Without this 255 is handled incorrectly. + data, upgraded = upgrade_type(data) + + # offset from zero, raw countless doesn't handle 0 correctly + # we'll remove the extra 1 at the end. + data += 1 + + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization + ab_ac |= b * (b == c) # PICK(B,C) + result = ab_ac + (ab_ac == 0) * d - 1 # (matches or d) - 1 + + if upgraded: + return downgrade_type(result) + + # only need to reset data if we weren't upgraded + # b/c no copy was made in that case + data -= 1 + + return result + +def upgrade_type(arr): + dtype = arr.dtype + + if dtype == np.uint8: + return arr.astype(np.uint16), True + elif dtype == np.uint16: + return arr.astype(np.uint32), True + elif dtype == np.uint32: + return arr.astype(np.uint64), True + + return arr, False + +def downgrade_type(arr): + dtype = arr.dtype + + if dtype == np.uint64: + return arr.astype(np.uint32) + elif dtype == np.uint32: + return arr.astype(np.uint16) + elif dtype == np.uint16: + return arr.astype(np.uint8) + + return arr + +def odd_to_even(image): + """ + To facilitate 2x2 downsampling segmentation, change an odd sized image into an even sized one. + Works by mirroring the starting 1 pixel edge of the image on odd shaped sides. + + e.g. turn a 3x3x5 image into a 4x4x5 (the x and y are what are getting downsampled) + + For example: [ 3, 2, 4 ] => [ 3, 3, 2, 4 ] which is now easy to downsample. + + """ + shape = np.array(image.shape) + + offset = (shape % 2)[:2] # x,y offset + + # detect if we're dealing with an even + # image. if so it's fine, just return. + if not np.any(offset): + return image + + oddshape = image.shape[:2] + offset + oddshape = np.append(oddshape, shape[2:]) + oddshape = oddshape.astype(int) + + newimg = np.empty(shape=oddshape, dtype=image.dtype) + + ox,oy = offset + sx,sy = oddshape + + newimg[0,0] = image[0,0] # corner + newimg[ox:sx,0] = image[:,0] # x axis line + newimg[0,oy:sy] = image[0,:] # y axis line + + return newimg + +def counting(array): + factor = (2, 2, 1) + shape = array.shape + + while len(shape) < 4: + array = np.expand_dims(array, axis=-1) + shape = array.shape + + output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor)) + output = np.zeros(output_shape, dtype=array.dtype) + + for chan in range(0, shape[3]): + for z in range(0, shape[2]): + for x in range(0, shape[0], 2): + for y in range(0, shape[1], 2): + block = array[ x:x+2, y:y+2, z, chan ] # 2x2 block + + hashtable = defaultdict(int) + for subx, suby in np.ndindex(block.shape[0], block.shape[1]): + hashtable[block[subx, suby]] += 1 + + best = (0, 0) + for segid, val in six.iteritems(hashtable): + if best[1] < val: + best = (segid, val) + + output[ x // 2, y // 2, chan ] = best[0] + + return output + +def ndzoom(array): + if len(array.shape) == 3: + ratio = ( 1 / 2.0, 1 / 2.0, 1.0 ) + else: + ratio = ( 1 / 2.0, 1 / 2.0) + return ndimage.interpolation.zoom(array, ratio, order=1) + +def countless_if(array): + factor = (2, 2, 1) + shape = array.shape + + if len(shape) < 3: + array = array[ :,:, np.newaxis ] + shape = array.shape + + output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor)) + output = np.zeros(output_shape, dtype=array.dtype) + + for chan in range(0, shape[2]): + for x in range(0, shape[0], 2): + for y in range(0, shape[1], 2): + block = array[ x:x+2, y:y+2, chan ] # 2x2 block + + if block[0,0] == block[1,0]: + pick = block[0,0] + elif block[0,0] == block[0,1]: + pick = block[0,0] + elif block[1,0] == block[0,1]: + pick = block[1,0] + else: + pick = block[1,1] + + output[ x // 2, y // 2, chan ] = pick + + return np.squeeze(output) + +def downsample_with_averaging(array): + """ + Downsample x by factor using averaging. + + @return: The downsampled array, of the same type as x. + """ + + if len(array.shape) == 3: + factor = (2,2,1) + else: + factor = (2,2) + + if np.array_equal(factor[:3], np.array([1,1,1])): + return array + + output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor)) + temp = np.zeros(output_shape, float) + counts = np.zeros(output_shape, np.int) + for offset in np.ndindex(factor): + part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + indexing_expr = tuple(np.s_[:s] for s in part.shape) + temp[indexing_expr] += part + counts[indexing_expr] += 1 + return np.cast[array.dtype](temp / counts) + +def downsample_with_max_pooling(array): + + factor = (2,2) + + if np.all(np.array(factor, int) == 1): + return array + + sections = [] + + for offset in np.ndindex(factor): + part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + output = sections[0].copy() + + for section in sections[1:]: + np.maximum(output, section, output) + + return output + +def striding(array): + """Downsample x by factor using striding. + + @return: The downsampled array, of the same type as x. + """ + factor = (2,2) + if np.all(np.array(factor, int) == 1): + return array + return array[tuple(np.s_[::f] for f in factor)] + +def benchmark(): + filename = sys.argv[1] + img = Image.open(filename) + data = np.array(img.getdata(), dtype=np.uint8) + + if len(data.shape) == 1: + n_channels = 1 + reshape = (img.height, img.width) + else: + n_channels = min(data.shape[1], 3) + data = data[:, :n_channels] + reshape = (img.height, img.width, n_channels) + + data = data.reshape(reshape).astype(np.uint8) + + methods = [ + simplest_countless, + quick_countless, + quick_countless_xor, + quickest_countless, + stippled_countless, + zero_corrected_countless, + countless, + downsample_with_averaging, + downsample_with_max_pooling, + ndzoom, + striding, + # countless_if, + # counting, + ] + + formats = { + 1: 'L', + 3: 'RGB', + 4: 'RGBA' + } + + if not os.path.exists('./results'): + os.mkdir('./results') + + N = 500 + img_size = float(img.width * img.height) / 1024.0 / 1024.0 + print("N = %d, %dx%d (%.2f MPx) %d chan, %s" % (N, img.width, img.height, img_size, n_channels, filename)) + print("Algorithm\tMPx/sec\tMB/sec\tSec") + for fn in methods: + print(fn.__name__, end='') + sys.stdout.flush() + + start = time.time() + # tqdm is here to show you what's going on the first time you run it. + # Feel free to remove it to get slightly more accurate timing results. + for _ in tqdm(range(N), desc=fn.__name__, disable=True): + result = fn(data) + end = time.time() + print("\r", end='') + + total_time = (end - start) + mpx = N * img_size / total_time + mbytes = N * img_size * n_channels / total_time + # Output in tab separated format to enable copy-paste into excel/numbers + print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time)) + outimg = Image.fromarray(np.squeeze(result), formats[n_channels]) + outimg.save('./results/{}.png'.format(fn.__name__, "PNG")) + +if __name__ == '__main__': + benchmark() + + +# Example results: +# N = 5, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png +# Function MPx/sec MB/sec Sec +# simplest_countless 752.855 752.855 0.01 +# quick_countless 920.328 920.328 0.01 +# zero_corrected_countless 534.143 534.143 0.01 +# countless 644.247 644.247 0.01 +# downsample_with_averaging 372.575 372.575 0.01 +# downsample_with_max_pooling 974.060 974.060 0.01 +# ndzoom 137.517 137.517 0.04 +# striding 38550.588 38550.588 0.00 +# countless_if 4.377 4.377 1.14 +# counting 0.117 0.117 42.85 + +# Run without non-numpy implementations: +# N = 2000, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png +# Algorithm MPx/sec MB/sec Sec +# simplest_countless 800.522 800.522 2.50 +# quick_countless 945.420 945.420 2.12 +# quickest_countless 947.256 947.256 2.11 +# stippled_countless 544.049 544.049 3.68 +# zero_corrected_countless 575.310 575.310 3.48 +# countless 646.684 646.684 3.09 +# downsample_with_averaging 385.132 385.132 5.19 +# downsample_with_max_poolin 988.361 988.361 2.02 +# ndzoom 163.104 163.104 12.26 +# striding 81589.340 81589.340 0.02 + + + + diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/countless3d.py b/DH-AISP/2/saicinpainting/evaluation/masks/countless/countless3d.py new file mode 100644 index 0000000000000000000000000000000000000000..810a71e4b1fa344dd2d731186516dbfa96c9cd03 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/masks/countless/countless3d.py @@ -0,0 +1,356 @@ +from six.moves import range +from PIL import Image +import numpy as np +import io +import time +import math +import random +import sys +from collections import defaultdict +from copy import deepcopy +from itertools import combinations +from functools import reduce +from tqdm import tqdm + +from memory_profiler import profile + +def countless5(a,b,c,d,e): + """First stage of generalizing from countless2d. + + You have five slots: A, B, C, D, E + + You can decide if something is the winner by first checking for + matches of three, then matches of two, then picking just one if + the other two tries fail. In countless2d, you just check for matches + of two and then pick one of them otherwise. + + Unfortunately, you need to check ABC, ABD, ABE, BCD, BDE, & CDE. + Then you need to check AB, AC, AD, BC, BD + We skip checking E because if none of these match, we pick E. We can + skip checking AE, BE, CE, DE since if any of those match, E is our boy + so it's redundant. + + So countless grows cominatorially in complexity. + """ + sections = [ a,b,c,d,e ] + + p2 = lambda q,r: q * (q == r) # q if p == q else 0 + p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) # q if q == r == s else 0 + + lor = lambda x,y: x + (x == 0) * y + + results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) ) + results3 = reduce(lor, results3) + + results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) ) + results2 = reduce(lor, results2) + + return reduce(lor, (results3, results2, e)) + +def countless8(a,b,c,d,e,f,g,h): + """Extend countless5 to countless8. Same deal, except we also + need to check for matches of length 4.""" + sections = [ a, b, c, d, e, f, g, h ] + + p2 = lambda q,r: q * (q == r) + p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) + p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) ) + + lor = lambda x,y: x + (x == 0) * y + + results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) ) + results4 = reduce(lor, results4) + + results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) ) + results3 = reduce(lor, results3) + + # We can always use our shortcut of omitting the last element + # for N choose 2 + results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) ) + results2 = reduce(lor, results2) + + return reduce(lor, [ results4, results3, results2, h ]) + +def dynamic_countless3d(data): + """countless8 + dynamic programming. ~2x faster""" + sections = [] + + # shift zeros up one so they don't interfere with bitwise operators + # we'll shift down at the end + data += 1 + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + pick = lambda a,b: a * (a == b) + lor = lambda x,y: x + (x == 0) * y + + subproblems2 = {} + + results2 = None + for x,y in combinations(range(7), 2): + res = pick(sections[x], sections[y]) + subproblems2[(x,y)] = res + if results2 is not None: + results2 += (results2 == 0) * res + else: + results2 = res + + subproblems3 = {} + + results3 = None + for x,y,z in combinations(range(8), 3): + res = pick(subproblems2[(x,y)], sections[z]) + + if z != 7: + subproblems3[(x,y,z)] = res + + if results3 is not None: + results3 += (results3 == 0) * res + else: + results3 = res + + results3 = reduce(lor, (results3, results2, sections[-1])) + + # free memory + results2 = None + subproblems2 = None + res = None + + results4 = ( pick(subproblems3[(x,y,z)], sections[w]) for x,y,z,w in combinations(range(8), 4) ) + results4 = reduce(lor, results4) + subproblems3 = None # free memory + + final_result = lor(results4, results3) - 1 + data -= 1 + return final_result + +def countless3d(data): + """Now write countless8 in such a way that it could be used + to process an image.""" + sections = [] + + # shift zeros up one so they don't interfere with bitwise operators + # we'll shift down at the end + data += 1 + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + p2 = lambda q,r: q * (q == r) + p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) + p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) ) + + lor = lambda x,y: x + (x == 0) * y + + results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) ) + results4 = reduce(lor, results4) + + results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) ) + results3 = reduce(lor, results3) + + results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) ) + results2 = reduce(lor, results2) + + final_result = reduce(lor, (results4, results3, results2, sections[-1])) - 1 + data -= 1 + return final_result + +def countless_generalized(data, factor): + assert len(data.shape) == len(factor) + + sections = [] + + mode_of = reduce(lambda x,y: x * y, factor) + majority = int(math.ceil(float(mode_of) / 2)) + + data += 1 + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + def pick(elements): + eq = ( elements[i] == elements[i+1] for i in range(len(elements) - 1) ) + anded = reduce(lambda p,q: p & q, eq) + return elements[0] * anded + + def logical_or(x,y): + return x + (x == 0) * y + + result = ( pick(combo) for combo in combinations(sections, majority) ) + result = reduce(logical_or, result) + for i in range(majority - 1, 3-1, -1): # 3-1 b/c of exclusive bounds + partial_result = ( pick(combo) for combo in combinations(sections, i) ) + partial_result = reduce(logical_or, partial_result) + result = logical_or(result, partial_result) + + partial_result = ( pick(combo) for combo in combinations(sections[:-1], 2) ) + partial_result = reduce(logical_or, partial_result) + result = logical_or(result, partial_result) + + result = logical_or(result, sections[-1]) - 1 + data -= 1 + return result + +def dynamic_countless_generalized(data, factor): + assert len(data.shape) == len(factor) + + sections = [] + + mode_of = reduce(lambda x,y: x * y, factor) + majority = int(math.ceil(float(mode_of) / 2)) + + data += 1 # offset from zero + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + pick = lambda a,b: a * (a == b) + lor = lambda x,y: x + (x == 0) * y # logical or + + subproblems = [ {}, {} ] + results2 = None + for x,y in combinations(range(len(sections) - 1), 2): + res = pick(sections[x], sections[y]) + subproblems[0][(x,y)] = res + if results2 is not None: + results2 = lor(results2, res) + else: + results2 = res + + results = [ results2 ] + for r in range(3, majority+1): + r_results = None + for combo in combinations(range(len(sections)), r): + res = pick(subproblems[0][combo[:-1]], sections[combo[-1]]) + + if combo[-1] != len(sections) - 1: + subproblems[1][combo] = res + + if r_results is not None: + r_results = lor(r_results, res) + else: + r_results = res + results.append(r_results) + subproblems[0] = subproblems[1] + subproblems[1] = {} + + results.reverse() + final_result = lor(reduce(lor, results), sections[-1]) - 1 + data -= 1 + return final_result + +def downsample_with_averaging(array): + """ + Downsample x by factor using averaging. + + @return: The downsampled array, of the same type as x. + """ + factor = (2,2,2) + + if np.array_equal(factor[:3], np.array([1,1,1])): + return array + + output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor)) + temp = np.zeros(output_shape, float) + counts = np.zeros(output_shape, np.int) + for offset in np.ndindex(factor): + part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + indexing_expr = tuple(np.s_[:s] for s in part.shape) + temp[indexing_expr] += part + counts[indexing_expr] += 1 + return np.cast[array.dtype](temp / counts) + +def downsample_with_max_pooling(array): + + factor = (2,2,2) + + sections = [] + + for offset in np.ndindex(factor): + part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + output = sections[0].copy() + + for section in sections[1:]: + np.maximum(output, section, output) + + return output + +def striding(array): + """Downsample x by factor using striding. + + @return: The downsampled array, of the same type as x. + """ + factor = (2,2,2) + if np.all(np.array(factor, int) == 1): + return array + return array[tuple(np.s_[::f] for f in factor)] + +def benchmark(): + def countless3d_generalized(img): + return countless_generalized(img, (2,8,1)) + def countless3d_dynamic_generalized(img): + return dynamic_countless_generalized(img, (8,8,1)) + + methods = [ + # countless3d, + # dynamic_countless3d, + countless3d_generalized, + # countless3d_dynamic_generalized, + # striding, + # downsample_with_averaging, + # downsample_with_max_pooling + ] + + data = np.zeros(shape=(16**2, 16**2, 16**2), dtype=np.uint8) + 1 + + N = 5 + + print('Algorithm\tMPx\tMB/sec\tSec\tN=%d' % N) + + for fn in methods: + start = time.time() + for _ in range(N): + result = fn(data) + end = time.time() + + total_time = (end - start) + mpx = N * float(data.shape[0] * data.shape[1] * data.shape[2]) / total_time / 1024.0 / 1024.0 + mbytes = mpx * np.dtype(data.dtype).itemsize + # Output in tab separated format to enable copy-paste into excel/numbers + print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time)) + +if __name__ == '__main__': + benchmark() + +# Algorithm MPx MB/sec Sec N=5 +# countless3d 10.564 10.564 60.58 +# dynamic_countless3d 22.717 22.717 28.17 +# countless3d_generalized 9.702 9.702 65.96 +# countless3d_dynamic_generalized 22.720 22.720 28.17 +# striding 253360.506 253360.506 0.00 +# downsample_with_averaging 224.098 224.098 2.86 +# downsample_with_max_pooling 690.474 690.474 0.93 + + + diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gcim.jpg b/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gcim.jpg new file mode 100644 index 0000000000000000000000000000000000000000..610d9212eb0ba1cc970ea467104dea8f68a7a839 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gcim.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b1ade0a290a0a79aceb49a170d085e28e5d2ea1face4fcd522d39a279d3fb4d +size 2582487 diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png b/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png new file mode 100644 index 0000000000000000000000000000000000000000..5995bfb41b65bb503e0b2cd99da3dfce41b619b9 Binary files /dev/null and b/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png differ diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/segmentation.png b/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/segmentation.png new file mode 100644 index 0000000000000000000000000000000000000000..b8744331d17f2085bb1d9dc73f26c6d11ccab0a0 Binary files /dev/null and b/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/segmentation.png differ diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/sparse.png b/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/sparse.png new file mode 100644 index 0000000000000000000000000000000000000000..401f043b0850a7c3fb7e289abce386b145e6fe32 Binary files /dev/null and b/DH-AISP/2/saicinpainting/evaluation/masks/countless/images/sparse.png differ diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png new file mode 100644 index 0000000000000000000000000000000000000000..557eca7295f50ac9398165b5da873eeb06d10e5c Binary files /dev/null and b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png differ diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png new file mode 100644 index 0000000000000000000000000000000000000000..2121cef5c7376a47fda376a22832d3e8b9e6ff91 Binary files /dev/null and b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png differ diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d.png b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d.png new file mode 100644 index 0000000000000000000000000000000000000000..5b4bf5d5fc400ce25388cc189fd18d61b82a5fd5 Binary files /dev/null and b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d.png differ diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png new file mode 100644 index 0000000000000000000000000000000000000000..91bcb420c88e1cad2c9a3152495211e018585aa4 Binary files /dev/null and b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png differ diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png new file mode 100644 index 0000000000000000000000000000000000000000..5c6137442d6027a99ee7e3d1ba92a7bfbd49dffc Binary files /dev/null and b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png differ diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png new file mode 100644 index 0000000000000000000000000000000000000000..9193f641f493ae085d226aa3f3468089e1f686ea Binary files /dev/null and b/DH-AISP/2/saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png differ diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/requirements.txt b/DH-AISP/2/saicinpainting/evaluation/masks/countless/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..cbf8c87bf9b4c9fe54cb39d722253c0ab59e63ad --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/masks/countless/requirements.txt @@ -0,0 +1,7 @@ +Pillow>=6.2.0 +numpy>=1.16 +scipy +tqdm +memory_profiler +six +pytest \ No newline at end of file diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/countless/test.py b/DH-AISP/2/saicinpainting/evaluation/masks/countless/test.py new file mode 100644 index 0000000000000000000000000000000000000000..7809beb7aeeb3bcb10d03093a564917b1f2b4786 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/masks/countless/test.py @@ -0,0 +1,195 @@ +from copy import deepcopy + +import numpy as np + +import countless2d +import countless3d + +def test_countless2d(): + def test_all_cases(fn, test_zero): + case1 = np.array([ [ 1, 2 ], [ 3, 4 ] ]).reshape((2,2,1,1)) # all different + case2 = np.array([ [ 1, 1 ], [ 2, 3 ] ]).reshape((2,2,1,1)) # two are same + case1z = np.array([ [ 0, 1 ], [ 2, 3 ] ]).reshape((2,2,1,1)) # all different + case2z = np.array([ [ 0, 0 ], [ 2, 3 ] ]).reshape((2,2,1,1)) # two are same + case3 = np.array([ [ 1, 1 ], [ 2, 2 ] ]).reshape((2,2,1,1)) # two groups are same + case4 = np.array([ [ 1, 2 ], [ 2, 2 ] ]).reshape((2,2,1,1)) # 3 are the same + case5 = np.array([ [ 5, 5 ], [ 5, 5 ] ]).reshape((2,2,1,1)) # all are the same + + is_255_handled = np.array([ [ 255, 255 ], [ 1, 2 ] ], dtype=np.uint8).reshape((2,2,1,1)) + + test = lambda case: fn(case) + + if test_zero: + assert test(case1z) == [[[[3]]]] # d + assert test(case2z) == [[[[0]]]] # a==b + else: + assert test(case1) == [[[[4]]]] # d + assert test(case2) == [[[[1]]]] # a==b + + assert test(case3) == [[[[1]]]] # a==b + assert test(case4) == [[[[2]]]] # b==c + assert test(case5) == [[[[5]]]] # a==b + + assert test(is_255_handled) == [[[[255]]]] + + assert fn(case1).dtype == case1.dtype + + test_all_cases(countless2d.simplest_countless, False) + test_all_cases(countless2d.quick_countless, False) + test_all_cases(countless2d.quickest_countless, False) + test_all_cases(countless2d.stippled_countless, False) + + + + methods = [ + countless2d.zero_corrected_countless, + countless2d.countless, + countless2d.countless_if, + # countless2d.counting, # counting doesn't respect order so harder to write a test + ] + + for fn in methods: + print(fn.__name__) + test_all_cases(fn, True) + +def test_stippled_countless2d(): + a = np.array([ [ 1, 2 ], [ 3, 4 ] ]).reshape((2,2,1,1)) + b = np.array([ [ 0, 2 ], [ 3, 4 ] ]).reshape((2,2,1,1)) + c = np.array([ [ 1, 0 ], [ 3, 4 ] ]).reshape((2,2,1,1)) + d = np.array([ [ 1, 2 ], [ 0, 4 ] ]).reshape((2,2,1,1)) + e = np.array([ [ 1, 2 ], [ 3, 0 ] ]).reshape((2,2,1,1)) + f = np.array([ [ 0, 0 ], [ 3, 4 ] ]).reshape((2,2,1,1)) + g = np.array([ [ 0, 2 ], [ 0, 4 ] ]).reshape((2,2,1,1)) + h = np.array([ [ 0, 2 ], [ 3, 0 ] ]).reshape((2,2,1,1)) + i = np.array([ [ 1, 0 ], [ 0, 4 ] ]).reshape((2,2,1,1)) + j = np.array([ [ 1, 2 ], [ 0, 0 ] ]).reshape((2,2,1,1)) + k = np.array([ [ 1, 0 ], [ 3, 0 ] ]).reshape((2,2,1,1)) + l = np.array([ [ 1, 0 ], [ 0, 0 ] ]).reshape((2,2,1,1)) + m = np.array([ [ 0, 2 ], [ 0, 0 ] ]).reshape((2,2,1,1)) + n = np.array([ [ 0, 0 ], [ 3, 0 ] ]).reshape((2,2,1,1)) + o = np.array([ [ 0, 0 ], [ 0, 4 ] ]).reshape((2,2,1,1)) + z = np.array([ [ 0, 0 ], [ 0, 0 ] ]).reshape((2,2,1,1)) + + test = countless2d.stippled_countless + + # Note: We only tested non-matching cases above, + # cases f,g,h,i,j,k prove their duals work as well + # b/c if two pixels are black, either one can be chosen + # if they are different or the same. + + assert test(a) == [[[[4]]]] + assert test(b) == [[[[4]]]] + assert test(c) == [[[[4]]]] + assert test(d) == [[[[4]]]] + assert test(e) == [[[[1]]]] + assert test(f) == [[[[4]]]] + assert test(g) == [[[[4]]]] + assert test(h) == [[[[2]]]] + assert test(i) == [[[[4]]]] + assert test(j) == [[[[1]]]] + assert test(k) == [[[[1]]]] + assert test(l) == [[[[1]]]] + assert test(m) == [[[[2]]]] + assert test(n) == [[[[3]]]] + assert test(o) == [[[[4]]]] + assert test(z) == [[[[0]]]] + + bc = np.array([ [ 0, 2 ], [ 2, 4 ] ]).reshape((2,2,1,1)) + bd = np.array([ [ 0, 2 ], [ 3, 2 ] ]).reshape((2,2,1,1)) + cd = np.array([ [ 0, 2 ], [ 3, 3 ] ]).reshape((2,2,1,1)) + + assert test(bc) == [[[[2]]]] + assert test(bd) == [[[[2]]]] + assert test(cd) == [[[[3]]]] + + ab = np.array([ [ 1, 1 ], [ 0, 4 ] ]).reshape((2,2,1,1)) + ac = np.array([ [ 1, 2 ], [ 1, 0 ] ]).reshape((2,2,1,1)) + ad = np.array([ [ 1, 0 ], [ 3, 1 ] ]).reshape((2,2,1,1)) + + assert test(ab) == [[[[1]]]] + assert test(ac) == [[[[1]]]] + assert test(ad) == [[[[1]]]] + +def test_countless3d(): + def test_all_cases(fn): + alldifferent = [ + [ + [1,2], + [3,4], + ], + [ + [5,6], + [7,8] + ] + ] + allsame = [ + [ + [1,1], + [1,1], + ], + [ + [1,1], + [1,1] + ] + ] + + assert fn(np.array(alldifferent)) == [[[8]]] + assert fn(np.array(allsame)) == [[[1]]] + + twosame = deepcopy(alldifferent) + twosame[1][1][0] = 2 + + assert fn(np.array(twosame)) == [[[2]]] + + threemixed = [ + [ + [3,3], + [1,2], + ], + [ + [2,4], + [4,3] + ] + ] + assert fn(np.array(threemixed)) == [[[3]]] + + foursame = [ + [ + [4,4], + [1,2], + ], + [ + [2,4], + [4,3] + ] + ] + + assert fn(np.array(foursame)) == [[[4]]] + + fivesame = [ + [ + [5,4], + [5,5], + ], + [ + [2,4], + [5,5] + ] + ] + + assert fn(np.array(fivesame)) == [[[5]]] + + def countless3d_generalized(img): + return countless3d.countless_generalized(img, (2,2,2)) + def countless3d_dynamic_generalized(img): + return countless3d.dynamic_countless_generalized(img, (2,2,2)) + + methods = [ + countless3d.countless3d, + countless3d.dynamic_countless3d, + countless3d_generalized, + countless3d_dynamic_generalized, + ] + + for fn in methods: + test_all_cases(fn) \ No newline at end of file diff --git a/DH-AISP/2/saicinpainting/evaluation/masks/mask.py b/DH-AISP/2/saicinpainting/evaluation/masks/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..3e34d0675a781fba983cb542f18390255aaf2609 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/masks/mask.py @@ -0,0 +1,429 @@ +import enum +from copy import deepcopy + +import numpy as np +from skimage import img_as_ubyte +from skimage.transform import rescale, resize +try: + from detectron2 import model_zoo + from detectron2.config import get_cfg + from detectron2.engine import DefaultPredictor + DETECTRON_INSTALLED = True +except: + print("Detectron v2 is not installed") + DETECTRON_INSTALLED = False + +from .countless.countless2d import zero_corrected_countless + + +class ObjectMask(): + def __init__(self, mask): + self.height, self.width = mask.shape + (self.up, self.down), (self.left, self.right) = self._get_limits(mask) + self.mask = mask[self.up:self.down, self.left:self.right].copy() + + @staticmethod + def _get_limits(mask): + def indicator_limits(indicator): + lower = indicator.argmax() + upper = len(indicator) - indicator[::-1].argmax() + return lower, upper + + vertical_indicator = mask.any(axis=1) + vertical_limits = indicator_limits(vertical_indicator) + + horizontal_indicator = mask.any(axis=0) + horizontal_limits = indicator_limits(horizontal_indicator) + + return vertical_limits, horizontal_limits + + def _clean(self): + self.up, self.down, self.left, self.right = 0, 0, 0, 0 + self.mask = np.empty((0, 0)) + + def horizontal_flip(self, inplace=False): + if not inplace: + flipped = deepcopy(self) + return flipped.horizontal_flip(inplace=True) + + self.mask = self.mask[:, ::-1] + return self + + def vertical_flip(self, inplace=False): + if not inplace: + flipped = deepcopy(self) + return flipped.vertical_flip(inplace=True) + + self.mask = self.mask[::-1, :] + return self + + def image_center(self): + y_center = self.up + (self.down - self.up) / 2 + x_center = self.left + (self.right - self.left) / 2 + return y_center, x_center + + def rescale(self, scaling_factor, inplace=False): + if not inplace: + scaled = deepcopy(self) + return scaled.rescale(scaling_factor, inplace=True) + + scaled_mask = rescale(self.mask.astype(float), scaling_factor, order=0) > 0.5 + (up, down), (left, right) = self._get_limits(scaled_mask) + self.mask = scaled_mask[up:down, left:right] + + y_center, x_center = self.image_center() + mask_height, mask_width = self.mask.shape + self.up = int(round(y_center - mask_height / 2)) + self.down = self.up + mask_height + self.left = int(round(x_center - mask_width / 2)) + self.right = self.left + mask_width + return self + + def crop_to_canvas(self, vertical=True, horizontal=True, inplace=False): + if not inplace: + cropped = deepcopy(self) + cropped.crop_to_canvas(vertical=vertical, horizontal=horizontal, inplace=True) + return cropped + + if vertical: + if self.up >= self.height or self.down <= 0: + self._clean() + else: + cut_up, cut_down = max(-self.up, 0), max(self.down - self.height, 0) + if cut_up != 0: + self.mask = self.mask[cut_up:] + self.up = 0 + if cut_down != 0: + self.mask = self.mask[:-cut_down] + self.down = self.height + + if horizontal: + if self.left >= self.width or self.right <= 0: + self._clean() + else: + cut_left, cut_right = max(-self.left, 0), max(self.right - self.width, 0) + if cut_left != 0: + self.mask = self.mask[:, cut_left:] + self.left = 0 + if cut_right != 0: + self.mask = self.mask[:, :-cut_right] + self.right = self.width + + return self + + def restore_full_mask(self, allow_crop=False): + cropped = self.crop_to_canvas(inplace=allow_crop) + mask = np.zeros((cropped.height, cropped.width), dtype=bool) + mask[cropped.up:cropped.down, cropped.left:cropped.right] = cropped.mask + return mask + + def shift(self, vertical=0, horizontal=0, inplace=False): + if not inplace: + shifted = deepcopy(self) + return shifted.shift(vertical=vertical, horizontal=horizontal, inplace=True) + + self.up += vertical + self.down += vertical + self.left += horizontal + self.right += horizontal + return self + + def area(self): + return self.mask.sum() + + +class RigidnessMode(enum.Enum): + soft = 0 + rigid = 1 + + +class SegmentationMask: + def __init__(self, confidence_threshold=0.5, rigidness_mode=RigidnessMode.rigid, + max_object_area=0.3, min_mask_area=0.02, downsample_levels=6, num_variants_per_mask=4, + max_mask_intersection=0.5, max_foreground_coverage=0.5, max_foreground_intersection=0.5, + max_hidden_area=0.2, max_scale_change=0.25, horizontal_flip=True, + max_vertical_shift=0.1, position_shuffle=True): + """ + :param confidence_threshold: float; threshold for confidence of the panoptic segmentator to allow for + the instance. + :param rigidness_mode: RigidnessMode object + when soft, checks intersection only with the object from which the mask_object was produced + when rigid, checks intersection with any foreground class object + :param max_object_area: float; allowed upper bound for to be considered as mask_object. + :param min_mask_area: float; lower bound for mask to be considered valid + :param downsample_levels: int; defines width of the resized segmentation to obtain shifted masks; + :param num_variants_per_mask: int; maximal number of the masks for the same object; + :param max_mask_intersection: float; maximum allowed area fraction of intersection for 2 masks + produced by horizontal shift of the same mask_object; higher value -> more diversity + :param max_foreground_coverage: float; maximum allowed area fraction of intersection for foreground object to be + covered by mask; lower value -> less the objects are covered + :param max_foreground_intersection: float; maximum allowed area of intersection for the mask with foreground + object; lower value -> mask is more on the background than on the objects + :param max_hidden_area: upper bound on part of the object hidden by shifting object outside the screen area; + :param max_scale_change: allowed scale change for the mask_object; + :param horizontal_flip: if horizontal flips are allowed; + :param max_vertical_shift: amount of vertical movement allowed; + :param position_shuffle: shuffle + """ + + assert DETECTRON_INSTALLED, 'Cannot use SegmentationMask without detectron2' + self.cfg = get_cfg() + self.cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")) + self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml") + self.cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = confidence_threshold + self.predictor = DefaultPredictor(self.cfg) + + self.rigidness_mode = RigidnessMode(rigidness_mode) + self.max_object_area = max_object_area + self.min_mask_area = min_mask_area + self.downsample_levels = downsample_levels + self.num_variants_per_mask = num_variants_per_mask + self.max_mask_intersection = max_mask_intersection + self.max_foreground_coverage = max_foreground_coverage + self.max_foreground_intersection = max_foreground_intersection + self.max_hidden_area = max_hidden_area + self.position_shuffle = position_shuffle + + self.max_scale_change = max_scale_change + self.horizontal_flip = horizontal_flip + self.max_vertical_shift = max_vertical_shift + + def get_segmentation(self, img): + im = img_as_ubyte(img) + panoptic_seg, segment_info = self.predictor(im)["panoptic_seg"] + return panoptic_seg, segment_info + + @staticmethod + def _is_power_of_two(n): + return (n != 0) and (n & (n-1) == 0) + + def identify_candidates(self, panoptic_seg, segments_info): + potential_mask_ids = [] + for segment in segments_info: + if not segment["isthing"]: + continue + mask = (panoptic_seg == segment["id"]).int().detach().cpu().numpy() + area = mask.sum().item() / np.prod(panoptic_seg.shape) + if area >= self.max_object_area: + continue + potential_mask_ids.append(segment["id"]) + return potential_mask_ids + + def downsample_mask(self, mask): + height, width = mask.shape + if not (self._is_power_of_two(height) and self._is_power_of_two(width)): + raise ValueError("Image sides are not power of 2.") + + num_iterations = width.bit_length() - 1 - self.downsample_levels + if num_iterations < 0: + raise ValueError(f"Width is lower than 2^{self.downsample_levels}.") + + if height.bit_length() - 1 < num_iterations: + raise ValueError("Height is too low to perform downsampling") + + downsampled = mask + for _ in range(num_iterations): + downsampled = zero_corrected_countless(downsampled) + + return downsampled + + def _augmentation_params(self): + scaling_factor = np.random.uniform(1 - self.max_scale_change, 1 + self.max_scale_change) + if self.horizontal_flip: + horizontal_flip = bool(np.random.choice(2)) + else: + horizontal_flip = False + vertical_shift = np.random.uniform(-self.max_vertical_shift, self.max_vertical_shift) + + return { + "scaling_factor": scaling_factor, + "horizontal_flip": horizontal_flip, + "vertical_shift": vertical_shift + } + + def _get_intersection(self, mask_array, mask_object): + intersection = mask_array[ + mask_object.up:mask_object.down, mask_object.left:mask_object.right + ] & mask_object.mask + return intersection + + def _check_masks_intersection(self, aug_mask, total_mask_area, prev_masks): + for existing_mask in prev_masks: + intersection_area = self._get_intersection(existing_mask, aug_mask).sum() + intersection_existing = intersection_area / existing_mask.sum() + intersection_current = 1 - (aug_mask.area() - intersection_area) / total_mask_area + if (intersection_existing > self.max_mask_intersection) or \ + (intersection_current > self.max_mask_intersection): + return False + return True + + def _check_foreground_intersection(self, aug_mask, foreground): + for existing_mask in foreground: + intersection_area = self._get_intersection(existing_mask, aug_mask).sum() + intersection_existing = intersection_area / existing_mask.sum() + if intersection_existing > self.max_foreground_coverage: + return False + intersection_mask = intersection_area / aug_mask.area() + if intersection_mask > self.max_foreground_intersection: + return False + return True + + def _move_mask(self, mask, foreground): + # Obtaining properties of the original mask_object: + orig_mask = ObjectMask(mask) + + chosen_masks = [] + chosen_parameters = [] + # to fix the case when resizing gives mask_object consisting only of False + scaling_factor_lower_bound = 0. + + for var_idx in range(self.num_variants_per_mask): + # Obtaining augmentation parameters and applying them to the downscaled mask_object + augmentation_params = self._augmentation_params() + augmentation_params["scaling_factor"] = min([ + augmentation_params["scaling_factor"], + 2 * min(orig_mask.up, orig_mask.height - orig_mask.down) / orig_mask.height + 1., + 2 * min(orig_mask.left, orig_mask.width - orig_mask.right) / orig_mask.width + 1. + ]) + augmentation_params["scaling_factor"] = max([ + augmentation_params["scaling_factor"], scaling_factor_lower_bound + ]) + + aug_mask = deepcopy(orig_mask) + aug_mask.rescale(augmentation_params["scaling_factor"], inplace=True) + if augmentation_params["horizontal_flip"]: + aug_mask.horizontal_flip(inplace=True) + total_aug_area = aug_mask.area() + if total_aug_area == 0: + scaling_factor_lower_bound = 1. + continue + + # Fix if the element vertical shift is too strong and shown area is too small: + vertical_area = aug_mask.mask.sum(axis=1) / total_aug_area # share of area taken by rows + # number of rows which are allowed to be hidden from upper and lower parts of image respectively + max_hidden_up = np.searchsorted(vertical_area.cumsum(), self.max_hidden_area) + max_hidden_down = np.searchsorted(vertical_area[::-1].cumsum(), self.max_hidden_area) + # correcting vertical shift, so not too much area will be hidden + augmentation_params["vertical_shift"] = np.clip( + augmentation_params["vertical_shift"], + -(aug_mask.up + max_hidden_up) / aug_mask.height, + (aug_mask.height - aug_mask.down + max_hidden_down) / aug_mask.height + ) + # Applying vertical shift: + vertical_shift = int(round(aug_mask.height * augmentation_params["vertical_shift"])) + aug_mask.shift(vertical=vertical_shift, inplace=True) + aug_mask.crop_to_canvas(vertical=True, horizontal=False, inplace=True) + + # Choosing horizontal shift: + max_hidden_area = self.max_hidden_area - (1 - aug_mask.area() / total_aug_area) + horizontal_area = aug_mask.mask.sum(axis=0) / total_aug_area + max_hidden_left = np.searchsorted(horizontal_area.cumsum(), max_hidden_area) + max_hidden_right = np.searchsorted(horizontal_area[::-1].cumsum(), max_hidden_area) + allowed_shifts = np.arange(-max_hidden_left, aug_mask.width - + (aug_mask.right - aug_mask.left) + max_hidden_right + 1) + allowed_shifts = - (aug_mask.left - allowed_shifts) + + if self.position_shuffle: + np.random.shuffle(allowed_shifts) + + mask_is_found = False + for horizontal_shift in allowed_shifts: + aug_mask_left = deepcopy(aug_mask) + aug_mask_left.shift(horizontal=horizontal_shift, inplace=True) + aug_mask_left.crop_to_canvas(inplace=True) + + prev_masks = [mask] + chosen_masks + is_mask_suitable = self._check_masks_intersection(aug_mask_left, total_aug_area, prev_masks) & \ + self._check_foreground_intersection(aug_mask_left, foreground) + if is_mask_suitable: + aug_draw = aug_mask_left.restore_full_mask() + chosen_masks.append(aug_draw) + augmentation_params["horizontal_shift"] = horizontal_shift / aug_mask_left.width + chosen_parameters.append(augmentation_params) + mask_is_found = True + break + + if not mask_is_found: + break + + return chosen_parameters + + def _prepare_mask(self, mask): + height, width = mask.shape + target_width = width if self._is_power_of_two(width) else (1 << width.bit_length()) + target_height = height if self._is_power_of_two(height) else (1 << height.bit_length()) + + return resize(mask.astype('float32'), (target_height, target_width), order=0, mode='edge').round().astype('int32') + + def get_masks(self, im, return_panoptic=False): + panoptic_seg, segments_info = self.get_segmentation(im) + potential_mask_ids = self.identify_candidates(panoptic_seg, segments_info) + + panoptic_seg_scaled = self._prepare_mask(panoptic_seg.detach().cpu().numpy()) + downsampled = self.downsample_mask(panoptic_seg_scaled) + scene_objects = [] + for segment in segments_info: + if not segment["isthing"]: + continue + mask = downsampled == segment["id"] + if not np.any(mask): + continue + scene_objects.append(mask) + + mask_set = [] + for mask_id in potential_mask_ids: + mask = downsampled == mask_id + if not np.any(mask): + continue + + if self.rigidness_mode is RigidnessMode.soft: + foreground = [mask] + elif self.rigidness_mode is RigidnessMode.rigid: + foreground = scene_objects + else: + raise ValueError(f'Unexpected rigidness_mode: {rigidness_mode}') + + masks_params = self._move_mask(mask, foreground) + + full_mask = ObjectMask((panoptic_seg == mask_id).detach().cpu().numpy()) + + for params in masks_params: + aug_mask = deepcopy(full_mask) + aug_mask.rescale(params["scaling_factor"], inplace=True) + if params["horizontal_flip"]: + aug_mask.horizontal_flip(inplace=True) + + vertical_shift = int(round(aug_mask.height * params["vertical_shift"])) + horizontal_shift = int(round(aug_mask.width * params["horizontal_shift"])) + aug_mask.shift(vertical=vertical_shift, horizontal=horizontal_shift, inplace=True) + aug_mask = aug_mask.restore_full_mask().astype('uint8') + if aug_mask.mean() <= self.min_mask_area: + continue + mask_set.append(aug_mask) + + if return_panoptic: + return mask_set, panoptic_seg.detach().cpu().numpy() + else: + return mask_set + + +def propose_random_square_crop(mask, min_overlap=0.5): + height, width = mask.shape + mask_ys, mask_xs = np.where(mask > 0.5) # mask==0 is known fragment and mask==1 is missing + + if height < width: + crop_size = height + obj_left, obj_right = mask_xs.min(), mask_xs.max() + obj_width = obj_right - obj_left + left_border = max(0, min(width - crop_size - 1, obj_left + obj_width * min_overlap - crop_size)) + right_border = max(left_border + 1, min(width - crop_size, obj_left + obj_width * min_overlap)) + start_x = np.random.randint(left_border, right_border) + return start_x, 0, start_x + crop_size, height + else: + crop_size = width + obj_top, obj_bottom = mask_ys.min(), mask_ys.max() + obj_height = obj_bottom - obj_top + top_border = max(0, min(height - crop_size - 1, obj_top + obj_height * min_overlap - crop_size)) + bottom_border = max(top_border + 1, min(height - crop_size, obj_top + obj_height * min_overlap)) + start_y = np.random.randint(top_border, bottom_border) + return 0, start_y, width, start_y + crop_size diff --git a/DH-AISP/2/saicinpainting/evaluation/refinement.py b/DH-AISP/2/saicinpainting/evaluation/refinement.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d3cbac689d99ab20e71daf6f42bb8ca3c9feb8 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/refinement.py @@ -0,0 +1,314 @@ +import torch +import torch.nn as nn +from torch.optim import Adam, SGD +from kornia.filters import gaussian_blur2d +from kornia.geometry.transform import resize +from kornia.morphology import erosion +from torch.nn import functional as F +import numpy as np +import cv2 + +from saicinpainting.evaluation.data import pad_tensor_to_modulo +from saicinpainting.evaluation.utils import move_to_device +from saicinpainting.training.modules.ffc import FFCResnetBlock +from saicinpainting.training.modules.pix2pixhd import ResnetBlock + +from tqdm import tqdm + + +def _pyrdown(im : torch.Tensor, downsize : tuple=None): + """downscale the image""" + if downsize is None: + downsize = (im.shape[2]//2, im.shape[3]//2) + assert im.shape[1] == 3, "Expected shape for the input to be (n,3,height,width)" + im = gaussian_blur2d(im, kernel_size=(5,5), sigma=(1.0,1.0)) + im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False) + return im + +def _pyrdown_mask(mask : torch.Tensor, downsize : tuple=None, eps : float=1e-8, blur_mask : bool=True, round_up : bool=True): + """downscale the mask tensor + + Parameters + ---------- + mask : torch.Tensor + mask of size (B, 1, H, W) + downsize : tuple, optional + size to downscale to. If None, image is downscaled to half, by default None + eps : float, optional + threshold value for binarizing the mask, by default 1e-8 + blur_mask : bool, optional + if True, apply gaussian filter before downscaling, by default True + round_up : bool, optional + if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True + + Returns + ------- + torch.Tensor + downscaled mask + """ + + if downsize is None: + downsize = (mask.shape[2]//2, mask.shape[3]//2) + assert mask.shape[1] == 1, "Expected shape for the input to be (n,1,height,width)" + if blur_mask == True: + mask = gaussian_blur2d(mask, kernel_size=(5,5), sigma=(1.0,1.0)) + mask = F.interpolate(mask, size=downsize, mode='bilinear', align_corners=False) + else: + mask = F.interpolate(mask, size=downsize, mode='bilinear', align_corners=False) + if round_up: + mask[mask>=eps] = 1 + mask[mask=1.0-eps] = 1 + mask[mask<1.0-eps] = 0 + return mask + +def _erode_mask(mask : torch.Tensor, ekernel : torch.Tensor=None, eps : float=1e-8): + """erode the mask, and set gray pixels to 0""" + if ekernel is not None: + mask = erosion(mask, ekernel) + mask[mask>=1.0-eps] = 1 + mask[mask<1.0-eps] = 0 + return mask + + +def _l1_loss( + pred : torch.Tensor, pred_downscaled : torch.Tensor, ref : torch.Tensor, + mask : torch.Tensor, mask_downscaled : torch.Tensor, + image : torch.Tensor, on_pred : bool=True + ): + """l1 loss on src pixels, and downscaled predictions if on_pred=True""" + loss = torch.mean(torch.abs(pred[mask<1e-8] - image[mask<1e-8])) + if on_pred: + loss += torch.mean(torch.abs(pred_downscaled[mask_downscaled>=1e-8] - ref[mask_downscaled>=1e-8])) + return loss + +def _infer( + image : torch.Tensor, mask : torch.Tensor, + forward_front : nn.Module, forward_rears : nn.Module, + ref_lower_res : torch.Tensor, orig_shape : tuple, devices : list, + scale_ind : int, n_iters : int=15, lr : float=0.002): + """Performs inference with refinement at a given scale. + + Parameters + ---------- + image : torch.Tensor + input image to be inpainted, of size (1,3,H,W) + mask : torch.Tensor + input inpainting mask, of size (1,1,H,W) + forward_front : nn.Module + the front part of the inpainting network + forward_rears : nn.Module + the rear part of the inpainting network + ref_lower_res : torch.Tensor + the inpainting at previous scale, used as reference image + orig_shape : tuple + shape of the original input image before padding + devices : list + list of available devices + scale_ind : int + the scale index + n_iters : int, optional + number of iterations of refinement, by default 15 + lr : float, optional + learning rate, by default 0.002 + + Returns + ------- + torch.Tensor + inpainted image + """ + masked_image = image * (1 - mask) + masked_image = torch.cat([masked_image, mask], dim=1) + + mask = mask.repeat(1,3,1,1) + if ref_lower_res is not None: + ref_lower_res = ref_lower_res.detach() + with torch.no_grad(): + z1,z2 = forward_front(masked_image) + # Inference + mask = mask.to(devices[-1]) + ekernel = torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)).astype(bool)).float() + ekernel = ekernel.to(devices[-1]) + image = image.to(devices[-1]) + z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0]) + z1.requires_grad, z2.requires_grad = True, True + + optimizer = Adam([z1,z2], lr=lr) + + pbar = tqdm(range(n_iters), leave=False) + for idi in pbar: + optimizer.zero_grad() + input_feat = (z1,z2) + for idd, forward_rear in enumerate(forward_rears): + output_feat = forward_rear(input_feat) + if idd < len(devices) - 1: + midz1, midz2 = output_feat + midz1, midz2 = midz1.to(devices[idd+1]), midz2.to(devices[idd+1]) + input_feat = (midz1, midz2) + else: + pred = output_feat + + if ref_lower_res is None: + break + losses = {} + ######################### multi-scale ############################# + # scaled loss with downsampler + pred_downscaled = _pyrdown(pred[:,:,:orig_shape[0],:orig_shape[1]]) + mask_downscaled = _pyrdown_mask(mask[:,:1,:orig_shape[0],:orig_shape[1]], blur_mask=False, round_up=False) + mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel) + mask_downscaled = mask_downscaled.repeat(1,3,1,1) + losses["ms_l1"] = _l1_loss(pred, pred_downscaled, ref_lower_res, mask, mask_downscaled, image, on_pred=True) + + loss = sum(losses.values()) + pbar.set_description("Refining scale {} using scale {} ...current loss: {:.4f}".format(scale_ind+1, scale_ind, loss.item())) + if idi < n_iters - 1: + loss.backward() + optimizer.step() + del pred_downscaled + del loss + del pred + # "pred" is the prediction after Plug-n-Play module + inpainted = mask * pred + (1 - mask) * image + inpainted = inpainted.detach().cpu() + return inpainted + +def _get_image_mask_pyramid(batch : dict, min_side : int, max_scales : int, px_budget : int): + """Build the image mask pyramid + + Parameters + ---------- + batch : dict + batch containing image, mask, etc + min_side : int + minimum side length to limit the number of scales of the pyramid + max_scales : int + maximum number of scales allowed + px_budget : int + the product H*W cannot exceed this budget, because of resource constraints + + Returns + ------- + tuple + image-mask pyramid in the form of list of images and list of masks + """ + + assert batch['image'].shape[0] == 1, "refiner works on only batches of size 1!" + + h, w = batch['unpad_to_size'] + h, w = h[0].item(), w[0].item() + + image = batch['image'][...,:h,:w] + mask = batch['mask'][...,:h,:w] + if h*w > px_budget: + #resize + ratio = np.sqrt(px_budget / float(h*w)) + h_orig, w_orig = h, w + h,w = int(h*ratio), int(w*ratio) + print(f"Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...") + image = resize(image, (h,w),interpolation='bilinear', align_corners=False) + mask = resize(mask, (h,w),interpolation='bilinear', align_corners=False) + mask[mask>1e-8] = 1 + breadth = min(h,w) + n_scales = min(1 + int(round(max(0,np.log2(breadth / min_side)))), max_scales) + ls_images = [] + ls_masks = [] + + ls_images.append(image) + ls_masks.append(mask) + + for _ in range(n_scales - 1): + image_p = _pyrdown(ls_images[-1]) + mask_p = _pyrdown_mask(ls_masks[-1]) + ls_images.append(image_p) + ls_masks.append(mask_p) + # reverse the lists because we want the lowest resolution image as index 0 + return ls_images[::-1], ls_masks[::-1] + +def refine_predict( + batch : dict, inpainter : nn.Module, gpu_ids : str, + modulo : int, n_iters : int, lr : float, min_side : int, + max_scales : int, px_budget : int + ): + """Refines the inpainting of the network + + Parameters + ---------- + batch : dict + image-mask batch, currently we assume the batchsize to be 1 + inpainter : nn.Module + the inpainting neural network + gpu_ids : str + the GPU ids of the machine to use. If only single GPU, use: "0," + modulo : int + pad the image to ensure dimension % modulo == 0 + n_iters : int + number of iterations of refinement for each scale + lr : float + learning rate + min_side : int + all sides of image on all scales should be >= min_side / sqrt(2) + max_scales : int + max number of downscaling scales for the image-mask pyramid + px_budget : int + pixels budget. Any image will be resized to satisfy height*width <= px_budget + + Returns + ------- + torch.Tensor + inpainted image of size (1,3,H,W) + """ + + assert not inpainter.training + assert not inpainter.add_noise_kwargs + assert inpainter.concat_mask + + gpu_ids = [f'cuda:{gpuid}' for gpuid in gpu_ids.replace(" ","").split(",") if gpuid.isdigit()] + n_resnet_blocks = 0 + first_resblock_ind = 0 + found_first_resblock = False + for idl in range(len(inpainter.generator.model)): + if isinstance(inpainter.generator.model[idl], FFCResnetBlock) or isinstance(inpainter.generator.model[idl], ResnetBlock): + n_resnet_blocks += 1 + found_first_resblock = True + elif not found_first_resblock: + first_resblock_ind += 1 + resblocks_per_gpu = n_resnet_blocks // len(gpu_ids) + + devices = [torch.device(gpu_id) for gpu_id in gpu_ids] + + # split the model into front, and rear parts + forward_front = inpainter.generator.model[0:first_resblock_ind] + forward_front.to(devices[0]) + forward_rears = [] + for idd in range(len(gpu_ids)): + if idd < len(gpu_ids) - 1: + forward_rears.append(inpainter.generator.model[first_resblock_ind + resblocks_per_gpu*(idd):first_resblock_ind+resblocks_per_gpu*(idd+1)]) + else: + forward_rears.append(inpainter.generator.model[first_resblock_ind + resblocks_per_gpu*(idd):]) + forward_rears[idd].to(devices[idd]) + + ls_images, ls_masks = _get_image_mask_pyramid( + batch, + min_side, + max_scales, + px_budget + ) + image_inpainted = None + + for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)): + orig_shape = image.shape[2:] + image = pad_tensor_to_modulo(image, modulo) + mask = pad_tensor_to_modulo(mask, modulo) + mask[mask >= 1e-8] = 1.0 + mask[mask < 1e-8] = 0.0 + image, mask = move_to_device(image, devices[0]), move_to_device(mask, devices[0]) + if image_inpainted is not None: + image_inpainted = move_to_device(image_inpainted, devices[-1]) + image_inpainted = _infer(image, mask, forward_front, forward_rears, image_inpainted, orig_shape, devices, ids, n_iters, lr) + image_inpainted = image_inpainted[:,:,:orig_shape[0], :orig_shape[1]] + # detach everything to save resources + image = image.detach().cpu() + mask = mask.detach().cpu() + + return image_inpainted diff --git a/DH-AISP/2/saicinpainting/evaluation/utils.py b/DH-AISP/2/saicinpainting/evaluation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7c15c9242ed8a9bc59fbb3b450cca394720bb8 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/utils.py @@ -0,0 +1,28 @@ +from enum import Enum + +import yaml +from easydict import EasyDict as edict +import torch.nn as nn +import torch + + +def load_yaml(path): + with open(path, 'r') as f: + return edict(yaml.safe_load(f)) + + +def move_to_device(obj, device): + if isinstance(obj, nn.Module): + return obj.to(device) + if torch.is_tensor(obj): + return obj.to(device) + if isinstance(obj, (tuple, list)): + return [move_to_device(el, device) for el in obj] + if isinstance(obj, dict): + return {name: move_to_device(val, device) for name, val in obj.items()} + raise ValueError(f'Unexpected type {type(obj)}') + + +class SmallMode(Enum): + DROP = "drop" + UPSCALE = "upscale" diff --git a/DH-AISP/2/saicinpainting/evaluation/vis.py b/DH-AISP/2/saicinpainting/evaluation/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..c2910b4ef8c61efee72dabd0531a9b669ec8bf98 --- /dev/null +++ b/DH-AISP/2/saicinpainting/evaluation/vis.py @@ -0,0 +1,37 @@ +import numpy as np +from skimage import io +from skimage.segmentation import mark_boundaries + + +def save_item_for_vis(item, out_file): + mask = item['mask'] > 0.5 + if mask.ndim == 3: + mask = mask[0] + img = mark_boundaries(np.transpose(item['image'], (1, 2, 0)), + mask, + color=(1., 0., 0.), + outline_color=(1., 1., 1.), + mode='thick') + + if 'inpainted' in item: + inp_img = mark_boundaries(np.transpose(item['inpainted'], (1, 2, 0)), + mask, + color=(1., 0., 0.), + mode='outer') + img = np.concatenate((img, inp_img), axis=1) + + img = np.clip(img * 255, 0, 255).astype('uint8') + io.imsave(out_file, img) + + +def save_mask_for_sidebyside(item, out_file): + mask = item['mask']# > 0.5 + if mask.ndim == 3: + mask = mask[0] + mask = np.clip(mask * 255, 0, 255).astype('uint8') + io.imsave(out_file, mask) + +def save_img_for_sidebyside(item, out_file): + img = np.transpose(item['image'], (1, 2, 0)) + img = np.clip(img * 255, 0, 255).astype('uint8') + io.imsave(out_file, img) \ No newline at end of file diff --git a/DH-AISP/2/saicinpainting/training/__init__.py b/DH-AISP/2/saicinpainting/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DH-AISP/2/saicinpainting/training/__pycache__/__init__.cpython-36.pyc b/DH-AISP/2/saicinpainting/training/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f1650cd7b365c2c9a710a1920597d818ba36549 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/__pycache__/__init__.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/__pycache__/__init__.cpython-37.pyc b/DH-AISP/2/saicinpainting/training/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58231f99dd0082db8028ec3428ff612c743b076d Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/__pycache__/__init__.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/data/__init__.py b/DH-AISP/2/saicinpainting/training/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DH-AISP/2/saicinpainting/training/data/aug.py b/DH-AISP/2/saicinpainting/training/data/aug.py new file mode 100644 index 0000000000000000000000000000000000000000..b1246250924e79511b58cd3d7ab79de8012f8949 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/data/aug.py @@ -0,0 +1,84 @@ +from albumentations import DualIAATransform, to_tuple +import imgaug.augmenters as iaa + +class IAAAffine2(DualIAATransform): + """Place a regular grid of points on the input and randomly move the neighbourhood of these point around + via affine transformations. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__( + self, + scale=(0.7, 1.3), + translate_percent=None, + translate_px=None, + rotate=0.0, + shear=(-0.1, 0.1), + order=1, + cval=0, + mode="reflect", + always_apply=False, + p=0.5, + ): + super(IAAAffine2, self).__init__(always_apply, p) + self.scale = dict(x=scale, y=scale) + self.translate_percent = to_tuple(translate_percent, 0) + self.translate_px = to_tuple(translate_px, 0) + self.rotate = to_tuple(rotate) + self.shear = dict(x=shear, y=shear) + self.order = order + self.cval = cval + self.mode = mode + + @property + def processor(self): + return iaa.Affine( + self.scale, + self.translate_percent, + self.translate_px, + self.rotate, + self.shear, + self.order, + self.cval, + self.mode, + ) + + def get_transform_init_args_names(self): + return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode") + + +class IAAPerspective2(DualIAATransform): + """Perform a random four point perspective transform of the input. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + scale ((float, float): standard deviation of the normal distributions. These are used to sample + the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5, + order=1, cval=0, mode="replicate"): + super(IAAPerspective2, self).__init__(always_apply, p) + self.scale = to_tuple(scale, 1.0) + self.keep_size = keep_size + self.cval = cval + self.mode = mode + + @property + def processor(self): + return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval) + + def get_transform_init_args_names(self): + return ("scale", "keep_size") diff --git a/DH-AISP/2/saicinpainting/training/data/datasets.py b/DH-AISP/2/saicinpainting/training/data/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f503dafffb970d8dbaca33934da417036d1e55 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/data/datasets.py @@ -0,0 +1,304 @@ +import glob +import logging +import os +import random + +import albumentations as A +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import webdataset +from omegaconf import open_dict, OmegaConf +from skimage.feature import canny +from skimage.transform import rescale, resize +from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler, ConcatDataset + +from saicinpainting.evaluation.data import InpaintingDataset as InpaintingEvaluationDataset, \ + OurInpaintingDataset as OurInpaintingEvaluationDataset, ceil_modulo, InpaintingEvalOnlineDataset +from saicinpainting.training.data.aug import IAAAffine2, IAAPerspective2 +from saicinpainting.training.data.masks import get_mask_generator + +LOGGER = logging.getLogger(__name__) + + +class InpaintingTrainDataset(Dataset): + def __init__(self, indir, mask_generator, transform): + self.in_files = list(glob.glob(os.path.join(indir, '**', '*.jpg'), recursive=True)) + self.mask_generator = mask_generator + self.transform = transform + self.iter_i = 0 + + def __len__(self): + return len(self.in_files) + + def __getitem__(self, item): + path = self.in_files[item] + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = self.transform(image=img)['image'] + img = np.transpose(img, (2, 0, 1)) + # TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks + mask = self.mask_generator(img, iter_i=self.iter_i) + self.iter_i += 1 + return dict(image=img, + mask=mask) + + +class InpaintingTrainWebDataset(IterableDataset): + def __init__(self, indir, mask_generator, transform, shuffle_buffer=200): + self.impl = webdataset.Dataset(indir).shuffle(shuffle_buffer).decode('rgb').to_tuple('jpg') + self.mask_generator = mask_generator + self.transform = transform + + def __iter__(self): + for iter_i, (img,) in enumerate(self.impl): + img = np.clip(img * 255, 0, 255).astype('uint8') + img = self.transform(image=img)['image'] + img = np.transpose(img, (2, 0, 1)) + mask = self.mask_generator(img, iter_i=iter_i) + yield dict(image=img, + mask=mask) + + +class ImgSegmentationDataset(Dataset): + def __init__(self, indir, mask_generator, transform, out_size, segm_indir, semantic_seg_n_classes): + self.indir = indir + self.segm_indir = segm_indir + self.mask_generator = mask_generator + self.transform = transform + self.out_size = out_size + self.semantic_seg_n_classes = semantic_seg_n_classes + self.in_files = list(glob.glob(os.path.join(indir, '**', '*.jpg'), recursive=True)) + + def __len__(self): + return len(self.in_files) + + def __getitem__(self, item): + path = self.in_files[item] + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (self.out_size, self.out_size)) + img = self.transform(image=img)['image'] + img = np.transpose(img, (2, 0, 1)) + mask = self.mask_generator(img) + segm, segm_classes= self.load_semantic_segm(path) + result = dict(image=img, + mask=mask, + segm=segm, + segm_classes=segm_classes) + return result + + def load_semantic_segm(self, img_path): + segm_path = img_path.replace(self.indir, self.segm_indir).replace(".jpg", ".png") + mask = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE) + mask = cv2.resize(mask, (self.out_size, self.out_size)) + tensor = torch.from_numpy(np.clip(mask.astype(int)-1, 0, None)) + ohe = F.one_hot(tensor.long(), num_classes=self.semantic_seg_n_classes) # w x h x n_classes + return ohe.permute(2, 0, 1).float(), tensor.unsqueeze(0) + + +def get_transforms(transform_variant, out_size): + if transform_variant == 'default': + transform = A.Compose([ + A.RandomScale(scale_limit=0.2), # +/- 20% + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'distortions': + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.7, 1.3), + rotate=(-40, 40), + shear=(-0.1, 0.1)), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'distortions_scale05_1': + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.5, 1.0), + rotate=(-40, 40), + shear=(-0.1, 0.1), + p=1), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'distortions_scale03_12': + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.3, 1.2), + rotate=(-40, 40), + shear=(-0.1, 0.1), + p=1), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'distortions_scale03_07': + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.3, 0.7), # scale 512 to 256 in average + rotate=(-40, 40), + shear=(-0.1, 0.1), + p=1), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'distortions_light': + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.02)), + IAAAffine2(scale=(0.8, 1.8), + rotate=(-20, 20), + shear=(-0.03, 0.03)), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'non_space_transform': + transform = A.Compose([ + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'no_augs': + transform = A.Compose([ + A.ToFloat() + ]) + else: + raise ValueError(f'Unexpected transform_variant {transform_variant}') + return transform + + +def make_default_train_dataloader(indir, kind='default', out_size=512, mask_gen_kwargs=None, transform_variant='default', + mask_generator_kind="mixed", dataloader_kwargs=None, ddp_kwargs=None, **kwargs): + LOGGER.info(f'Make train dataloader {kind} from {indir}. Using mask generator={mask_generator_kind}') + + mask_generator = get_mask_generator(kind=mask_generator_kind, kwargs=mask_gen_kwargs) + transform = get_transforms(transform_variant, out_size) + + if kind == 'default': + dataset = InpaintingTrainDataset(indir=indir, + mask_generator=mask_generator, + transform=transform, + **kwargs) + elif kind == 'default_web': + dataset = InpaintingTrainWebDataset(indir=indir, + mask_generator=mask_generator, + transform=transform, + **kwargs) + elif kind == 'img_with_segm': + dataset = ImgSegmentationDataset(indir=indir, + mask_generator=mask_generator, + transform=transform, + out_size=out_size, + **kwargs) + else: + raise ValueError(f'Unknown train dataset kind {kind}') + + if dataloader_kwargs is None: + dataloader_kwargs = {} + + is_dataset_only_iterable = kind in ('default_web',) + + if ddp_kwargs is not None and not is_dataset_only_iterable: + dataloader_kwargs['shuffle'] = False + dataloader_kwargs['sampler'] = DistributedSampler(dataset, **ddp_kwargs) + + if is_dataset_only_iterable and 'shuffle' in dataloader_kwargs: + with open_dict(dataloader_kwargs): + del dataloader_kwargs['shuffle'] + + dataloader = DataLoader(dataset, **dataloader_kwargs) + return dataloader + + +def make_default_val_dataset(indir, kind='default', out_size=512, transform_variant='default', **kwargs): + if OmegaConf.is_list(indir) or isinstance(indir, (tuple, list)): + return ConcatDataset([ + make_default_val_dataset(idir, kind=kind, out_size=out_size, transform_variant=transform_variant, **kwargs) for idir in indir + ]) + + LOGGER.info(f'Make val dataloader {kind} from {indir}') + mask_generator = get_mask_generator(kind=kwargs.get("mask_generator_kind"), kwargs=kwargs.get("mask_gen_kwargs")) + + if transform_variant is not None: + transform = get_transforms(transform_variant, out_size) + + if kind == 'default': + dataset = InpaintingEvaluationDataset(indir, **kwargs) + elif kind == 'our_eval': + dataset = OurInpaintingEvaluationDataset(indir, **kwargs) + elif kind == 'img_with_segm': + dataset = ImgSegmentationDataset(indir=indir, + mask_generator=mask_generator, + transform=transform, + out_size=out_size, + **kwargs) + elif kind == 'online': + dataset = InpaintingEvalOnlineDataset(indir=indir, + mask_generator=mask_generator, + transform=transform, + out_size=out_size, + **kwargs) + else: + raise ValueError(f'Unknown val dataset kind {kind}') + + return dataset + + +def make_default_val_dataloader(*args, dataloader_kwargs=None, **kwargs): + dataset = make_default_val_dataset(*args, **kwargs) + + if dataloader_kwargs is None: + dataloader_kwargs = {} + dataloader = DataLoader(dataset, **dataloader_kwargs) + return dataloader + + +def make_constant_area_crop_params(img_height, img_width, min_size=128, max_size=512, area=256*256, round_to_mod=16): + min_size = min(img_height, img_width, min_size) + max_size = min(img_height, img_width, max_size) + if random.random() < 0.5: + out_height = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod)) + out_width = min(max_size, ceil_modulo(area // out_height, round_to_mod)) + else: + out_width = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod)) + out_height = min(max_size, ceil_modulo(area // out_width, round_to_mod)) + + start_y = random.randint(0, img_height - out_height) + start_x = random.randint(0, img_width - out_width) + return (start_y, start_x, out_height, out_width) diff --git a/DH-AISP/2/saicinpainting/training/data/masks.py b/DH-AISP/2/saicinpainting/training/data/masks.py new file mode 100644 index 0000000000000000000000000000000000000000..e91fc74913356481065c5f5906acd50fb05f521c --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/data/masks.py @@ -0,0 +1,332 @@ +import math +import random +import hashlib +import logging +from enum import Enum + +import cv2 +import numpy as np + +from saicinpainting.evaluation.masks.mask import SegmentationMask +from saicinpainting.utils import LinearRamp + +LOGGER = logging.getLogger(__name__) + + +class DrawMethod(Enum): + LINE = 'line' + CIRCLE = 'circle' + SQUARE = 'square' + + +def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, + draw_method=DrawMethod.LINE): + draw_method = DrawMethod(draw_method) + + height, width = shape + mask = np.zeros((height, width), np.float32) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + start_x = np.random.randint(width) + start_y = np.random.randint(height) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_len) + brush_w = 5 + np.random.randint(max_width) + end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width) + end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height) + if draw_method == DrawMethod.LINE: + cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w) + elif draw_method == DrawMethod.CIRCLE: + cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1) + elif draw_method == DrawMethod.SQUARE: + radius = brush_w // 2 + mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1 + start_x, start_y = end_x, end_y + return mask[None, ...] + + +class RandomIrregularMaskGenerator: + def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None, + draw_method=DrawMethod.LINE): + self.max_angle = max_angle + self.max_len = max_len + self.max_width = max_width + self.min_times = min_times + self.max_times = max_times + self.draw_method = draw_method + self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None + + def __call__(self, img, iter_i=None, raw_image=None): + coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1 + cur_max_len = int(max(1, self.max_len * coef)) + cur_max_width = int(max(1, self.max_width * coef)) + cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef) + return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len, + max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times, + draw_method=self.draw_method) + + +def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3): + height, width = shape + mask = np.zeros((height, width), np.float32) + bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + box_width = np.random.randint(bbox_min_size, bbox_max_size) + box_height = np.random.randint(bbox_min_size, bbox_max_size) + start_x = np.random.randint(margin, width - margin - box_width + 1) + start_y = np.random.randint(margin, height - margin - box_height + 1) + mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1 + return mask[None, ...] + + +class RandomRectangleMaskGenerator: + def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None): + self.margin = margin + self.bbox_min_size = bbox_min_size + self.bbox_max_size = bbox_max_size + self.min_times = min_times + self.max_times = max_times + self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None + + def __call__(self, img, iter_i=None, raw_image=None): + coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1 + cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef) + cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef) + return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size, + bbox_max_size=cur_bbox_max_size, min_times=self.min_times, + max_times=cur_max_times) + + +class RandomSegmentationMaskGenerator: + def __init__(self, **kwargs): + self.impl = None # will be instantiated in first call (effectively in subprocess) + self.kwargs = kwargs + + def __call__(self, img, iter_i=None, raw_image=None): + if self.impl is None: + self.impl = SegmentationMask(**self.kwargs) + + masks = self.impl.get_masks(np.transpose(img, (1, 2, 0))) + masks = [m for m in masks if len(np.unique(m)) > 1] + return np.random.choice(masks) + + +def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3): + height, width = shape + mask = np.zeros((height, width), np.float32) + step_x = np.random.randint(min_step, max_step + 1) + width_x = np.random.randint(min_width, min(step_x, max_width + 1)) + offset_x = np.random.randint(0, step_x) + + step_y = np.random.randint(min_step, max_step + 1) + width_y = np.random.randint(min_width, min(step_y, max_width + 1)) + offset_y = np.random.randint(0, step_y) + + for dy in range(width_y): + mask[offset_y + dy::step_y] = 1 + for dx in range(width_x): + mask[:, offset_x + dx::step_x] = 1 + return mask[None, ...] + + +class RandomSuperresMaskGenerator: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __call__(self, img, iter_i=None): + return make_random_superres_mask(img.shape[1:], **self.kwargs) + + +class DumbAreaMaskGenerator: + min_ratio = 0.1 + max_ratio = 0.35 + default_ratio = 0.225 + + def __init__(self, is_training): + #Parameters: + # is_training(bool): If true - random rectangular mask, if false - central square mask + self.is_training = is_training + + def _random_vector(self, dimension): + if self.is_training: + lower_limit = math.sqrt(self.min_ratio) + upper_limit = math.sqrt(self.max_ratio) + mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension) + u = random.randint(0, dimension-mask_side-1) + v = u+mask_side + else: + margin = (math.sqrt(self.default_ratio) / 2) * dimension + u = round(dimension/2 - margin) + v = round(dimension/2 + margin) + return u, v + + def __call__(self, img, iter_i=None, raw_image=None): + c, height, width = img.shape + mask = np.zeros((height, width), np.float32) + x1, x2 = self._random_vector(width) + y1, y2 = self._random_vector(height) + mask[x1:x2, y1:y2] = 1 + return mask[None, ...] + + +class OutpaintingMaskGenerator: + def __init__(self, min_padding_percent:float=0.04, max_padding_percent:int=0.25, left_padding_prob:float=0.5, top_padding_prob:float=0.5, + right_padding_prob:float=0.5, bottom_padding_prob:float=0.5, is_fixed_randomness:bool=False): + """ + is_fixed_randomness - get identical paddings for the same image if args are the same + """ + self.min_padding_percent = min_padding_percent + self.max_padding_percent = max_padding_percent + self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob] + self.is_fixed_randomness = is_fixed_randomness + + assert self.min_padding_percent <= self.max_padding_percent + assert self.max_padding_percent > 0 + assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]" + assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}" + assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}" + if len([x for x in self.probs if x > 0]) == 1: + LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side") + + def apply_padding(self, mask, coord): + mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h), + int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1 + return mask + + def get_padding(self, size): + n1 = int(self.min_padding_percent*size) + n2 = int(self.max_padding_percent*size) + return self.rnd.randint(n1, n2) / size + + @staticmethod + def _img2rs(img): + arr = np.ascontiguousarray(img.astype(np.uint8)) + str_hash = hashlib.sha1(arr).hexdigest() + res = hash(str_hash)%(2**32) + return res + + def __call__(self, img, iter_i=None, raw_image=None): + c, self.img_h, self.img_w = img.shape + mask = np.zeros((self.img_h, self.img_w), np.float32) + at_least_one_mask_applied = False + + if self.is_fixed_randomness: + assert raw_image is not None, f"Cant calculate hash on raw_image=None" + rs = self._img2rs(raw_image) + self.rnd = np.random.RandomState(rs) + else: + self.rnd = np.random + + coords = [[ + (0,0), + (1,self.get_padding(size=self.img_h)) + ], + [ + (0,0), + (self.get_padding(size=self.img_w),1) + ], + [ + (0,1-self.get_padding(size=self.img_h)), + (1,1) + ], + [ + (1-self.get_padding(size=self.img_w),0), + (1,1) + ]] + + for pp, coord in zip(self.probs, coords): + if self.rnd.random() < pp: + at_least_one_mask_applied = True + mask = self.apply_padding(mask=mask, coord=coord) + + if not at_least_one_mask_applied: + idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs)) + mask = self.apply_padding(mask=mask, coord=coords[idx]) + return mask[None, ...] + + +class MixedMaskGenerator: + def __init__(self, irregular_proba=1/3, irregular_kwargs=None, + box_proba=1/3, box_kwargs=None, + segm_proba=1/3, segm_kwargs=None, + squares_proba=0, squares_kwargs=None, + superres_proba=0, superres_kwargs=None, + outpainting_proba=0, outpainting_kwargs=None, + invert_proba=0): + self.probas = [] + self.gens = [] + + if irregular_proba > 0: + self.probas.append(irregular_proba) + if irregular_kwargs is None: + irregular_kwargs = {} + else: + irregular_kwargs = dict(irregular_kwargs) + irregular_kwargs['draw_method'] = DrawMethod.LINE + self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs)) + + if box_proba > 0: + self.probas.append(box_proba) + if box_kwargs is None: + box_kwargs = {} + self.gens.append(RandomRectangleMaskGenerator(**box_kwargs)) + + if segm_proba > 0: + self.probas.append(segm_proba) + if segm_kwargs is None: + segm_kwargs = {} + self.gens.append(RandomSegmentationMaskGenerator(**segm_kwargs)) + + if squares_proba > 0: + self.probas.append(squares_proba) + if squares_kwargs is None: + squares_kwargs = {} + else: + squares_kwargs = dict(squares_kwargs) + squares_kwargs['draw_method'] = DrawMethod.SQUARE + self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs)) + + if superres_proba > 0: + self.probas.append(superres_proba) + if superres_kwargs is None: + superres_kwargs = {} + self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs)) + + if outpainting_proba > 0: + self.probas.append(outpainting_proba) + if outpainting_kwargs is None: + outpainting_kwargs = {} + self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs)) + + self.probas = np.array(self.probas, dtype='float32') + self.probas /= self.probas.sum() + self.invert_proba = invert_proba + + def __call__(self, img, iter_i=None, raw_image=None): + kind = np.random.choice(len(self.probas), p=self.probas) + gen = self.gens[kind] + result = gen(img, iter_i=iter_i, raw_image=raw_image) + if self.invert_proba > 0 and random.random() < self.invert_proba: + result = 1 - result + return result + + +def get_mask_generator(kind, kwargs): + if kind is None: + kind = "mixed" + if kwargs is None: + kwargs = {} + + if kind == "mixed": + cl = MixedMaskGenerator + elif kind == "outpainting": + cl = OutpaintingMaskGenerator + elif kind == "dumb": + cl = DumbAreaMaskGenerator + else: + raise NotImplementedError(f"No such generator kind = {kind}") + return cl(**kwargs) diff --git a/DH-AISP/2/saicinpainting/training/losses/__init__.py b/DH-AISP/2/saicinpainting/training/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DH-AISP/2/saicinpainting/training/losses/adversarial.py b/DH-AISP/2/saicinpainting/training/losses/adversarial.py new file mode 100644 index 0000000000000000000000000000000000000000..d6db2967ce5074d94ed3b4c51fc743ff2f7831b1 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/losses/adversarial.py @@ -0,0 +1,177 @@ +from typing import Tuple, Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BaseAdversarialLoss: + def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + generator: nn.Module, discriminator: nn.Module): + """ + Prepare for generator step + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param generator: + :param discriminator: + :return: None + """ + + def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + generator: nn.Module, discriminator: nn.Module): + """ + Prepare for discriminator step + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param generator: + :param discriminator: + :return: None + """ + + def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Calculate generator loss + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param discr_real_pred: Tensor, discriminator output for real_batch + :param discr_fake_pred: Tensor, discriminator output for fake_batch + :param mask: Tensor, actual mask, which was at input of generator when making fake_batch + :return: total generator loss along with some values that might be interesting to log + """ + raise NotImplemented() + + def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Calculate discriminator loss and call .backward() on it + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param discr_real_pred: Tensor, discriminator output for real_batch + :param discr_fake_pred: Tensor, discriminator output for fake_batch + :param mask: Tensor, actual mask, which was at input of generator when making fake_batch + :return: total discriminator loss along with some values that might be interesting to log + """ + raise NotImplemented() + + def interpolate_mask(self, mask, shape): + assert mask is not None + assert self.allow_scale_mask or shape == mask.shape[-2:] + if shape != mask.shape[-2:] and self.allow_scale_mask: + if self.mask_scale_mode == 'maxpool': + mask = F.adaptive_max_pool2d(mask, shape) + else: + mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode) + return mask + +def make_r1_gp(discr_real_pred, real_batch): + if torch.is_grad_enabled(): + grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0] + grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean() + else: + grad_penalty = 0 + real_batch.requires_grad = False + + return grad_penalty + +class NonSaturatingWithR1(BaseAdversarialLoss): + def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False, + mask_scale_mode='nearest', extra_mask_weight_for_gen=0, + use_unmasked_for_gen=True, use_unmasked_for_discr=True): + self.gp_coef = gp_coef + self.weight = weight + # use for discr => use for gen; + # otherwise we teach only the discr to pay attention to very small difference + assert use_unmasked_for_gen or (not use_unmasked_for_discr) + # mask as target => use unmasked for discr: + # if we don't care about unmasked regions at all + # then it doesn't matter if the value of mask_as_fake_target is true or false + assert use_unmasked_for_discr or (not mask_as_fake_target) + self.use_unmasked_for_gen = use_unmasked_for_gen + self.use_unmasked_for_discr = use_unmasked_for_discr + self.mask_as_fake_target = mask_as_fake_target + self.allow_scale_mask = allow_scale_mask + self.mask_scale_mode = mask_scale_mode + self.extra_mask_weight_for_gen = extra_mask_weight_for_gen + + def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask=None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + fake_loss = F.softplus(-discr_fake_pred) + if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \ + not self.use_unmasked_for_gen: # == if masked region should be treated differently + mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) + if not self.use_unmasked_for_gen: + fake_loss = fake_loss * mask + else: + pixel_weights = 1 + mask * self.extra_mask_weight_for_gen + fake_loss = fake_loss * pixel_weights + + return fake_loss.mean() * self.weight, dict() + + def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + generator: nn.Module, discriminator: nn.Module): + real_batch.requires_grad = True + + def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask=None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + + real_loss = F.softplus(-discr_real_pred) + grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef + fake_loss = F.softplus(discr_fake_pred) + + if not self.use_unmasked_for_discr or self.mask_as_fake_target: + # == if masked region should be treated differently + mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) + # use_unmasked_for_discr=False only makes sense for fakes; + # for reals there is no difference beetween two regions + fake_loss = fake_loss * mask + if self.mask_as_fake_target: + fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred) + + sum_discr_loss = real_loss + grad_penalty + fake_loss + metrics = dict(discr_real_out=discr_real_pred.mean(), + discr_fake_out=discr_fake_pred.mean(), + discr_real_gp=grad_penalty) + return sum_discr_loss.mean(), metrics + +class BCELoss(BaseAdversarialLoss): + def __init__(self, weight): + self.weight = weight + self.bce_loss = nn.BCEWithLogitsLoss() + + def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device) + fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight + return fake_loss, dict() + + def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + generator: nn.Module, discriminator: nn.Module): + real_batch.requires_grad = True + + def discriminator_loss(self, + mask: torch.Tensor, + discr_real_pred: torch.Tensor, + discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + + real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device) + sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2 + metrics = dict(discr_real_out=discr_real_pred.mean(), + discr_fake_out=discr_fake_pred.mean(), + discr_real_gp=0) + return sum_discr_loss, metrics + + +def make_discrim_loss(kind, **kwargs): + if kind == 'r1': + return NonSaturatingWithR1(**kwargs) + elif kind == 'bce': + return BCELoss(**kwargs) + raise ValueError(f'Unknown adversarial loss kind {kind}') diff --git a/DH-AISP/2/saicinpainting/training/losses/constants.py b/DH-AISP/2/saicinpainting/training/losses/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3e5e151342232be8e2c2a77fe6fd5798dc2a8c --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/losses/constants.py @@ -0,0 +1,152 @@ +weights = {"ade20k": + [6.34517766497462, + 9.328358208955224, + 11.389521640091116, + 16.10305958132045, + 20.833333333333332, + 22.22222222222222, + 25.125628140703515, + 43.29004329004329, + 50.5050505050505, + 54.6448087431694, + 55.24861878453038, + 60.24096385542168, + 62.5, + 66.2251655629139, + 84.74576271186442, + 90.90909090909092, + 91.74311926605505, + 96.15384615384616, + 96.15384615384616, + 97.08737864077669, + 102.04081632653062, + 135.13513513513513, + 149.2537313432836, + 153.84615384615384, + 163.93442622950818, + 166.66666666666666, + 188.67924528301887, + 192.30769230769232, + 217.3913043478261, + 227.27272727272725, + 227.27272727272725, + 227.27272727272725, + 303.03030303030306, + 322.5806451612903, + 333.3333333333333, + 370.3703703703703, + 384.61538461538464, + 416.6666666666667, + 416.6666666666667, + 434.7826086956522, + 434.7826086956522, + 454.5454545454545, + 454.5454545454545, + 500.0, + 526.3157894736842, + 526.3157894736842, + 555.5555555555555, + 555.5555555555555, + 555.5555555555555, + 555.5555555555555, + 555.5555555555555, + 555.5555555555555, + 555.5555555555555, + 588.2352941176471, + 588.2352941176471, + 588.2352941176471, + 588.2352941176471, + 588.2352941176471, + 666.6666666666666, + 666.6666666666666, + 666.6666666666666, + 666.6666666666666, + 714.2857142857143, + 714.2857142857143, + 714.2857142857143, + 714.2857142857143, + 714.2857142857143, + 769.2307692307693, + 769.2307692307693, + 769.2307692307693, + 833.3333333333334, + 833.3333333333334, + 833.3333333333334, + 833.3333333333334, + 909.090909090909, + 1000.0, + 1111.111111111111, + 1111.111111111111, + 1111.111111111111, + 1111.111111111111, + 1111.111111111111, + 1250.0, + 1250.0, + 1250.0, + 1250.0, + 1250.0, + 1428.5714285714287, + 1428.5714285714287, + 1428.5714285714287, + 1428.5714285714287, + 1428.5714285714287, + 1428.5714285714287, + 1428.5714285714287, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 5000.0, + 5000.0, + 5000.0] +} \ No newline at end of file diff --git a/DH-AISP/2/saicinpainting/training/losses/distance_weighting.py b/DH-AISP/2/saicinpainting/training/losses/distance_weighting.py new file mode 100644 index 0000000000000000000000000000000000000000..93052003b1e47fd663c70aedcecd144171f49204 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/losses/distance_weighting.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from saicinpainting.training.losses.perceptual import IMAGENET_STD, IMAGENET_MEAN + + +def dummy_distance_weighter(real_img, pred_img, mask): + return mask + + +def get_gauss_kernel(kernel_size, width_factor=1): + coords = torch.stack(torch.meshgrid(torch.arange(kernel_size), + torch.arange(kernel_size)), + dim=0).float() + diff = torch.exp(-((coords - kernel_size // 2) ** 2).sum(0) / kernel_size / width_factor) + diff /= diff.sum() + return diff + + +class BlurMask(nn.Module): + def __init__(self, kernel_size=5, width_factor=1): + super().__init__() + self.filter = nn.Conv2d(1, 1, kernel_size, padding=kernel_size // 2, padding_mode='replicate', bias=False) + self.filter.weight.data.copy_(get_gauss_kernel(kernel_size, width_factor=width_factor)) + + def forward(self, real_img, pred_img, mask): + with torch.no_grad(): + result = self.filter(mask) * mask + return result + + +class EmulatedEDTMask(nn.Module): + def __init__(self, dilate_kernel_size=5, blur_kernel_size=5, width_factor=1): + super().__init__() + self.dilate_filter = nn.Conv2d(1, 1, dilate_kernel_size, padding=dilate_kernel_size// 2, padding_mode='replicate', + bias=False) + self.dilate_filter.weight.data.copy_(torch.ones(1, 1, dilate_kernel_size, dilate_kernel_size, dtype=torch.float)) + self.blur_filter = nn.Conv2d(1, 1, blur_kernel_size, padding=blur_kernel_size // 2, padding_mode='replicate', bias=False) + self.blur_filter.weight.data.copy_(get_gauss_kernel(blur_kernel_size, width_factor=width_factor)) + + def forward(self, real_img, pred_img, mask): + with torch.no_grad(): + known_mask = 1 - mask + dilated_known_mask = (self.dilate_filter(known_mask) > 1).float() + result = self.blur_filter(1 - dilated_known_mask) * mask + return result + + +class PropagatePerceptualSim(nn.Module): + def __init__(self, level=2, max_iters=10, temperature=500, erode_mask_size=3): + super().__init__() + vgg = torchvision.models.vgg19(pretrained=True).features + vgg_avg_pooling = [] + + for weights in vgg.parameters(): + weights.requires_grad = False + + cur_level_i = 0 + for module in vgg.modules(): + if module.__class__.__name__ == 'Sequential': + continue + elif module.__class__.__name__ == 'MaxPool2d': + vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) + else: + vgg_avg_pooling.append(module) + if module.__class__.__name__ == 'ReLU': + cur_level_i += 1 + if cur_level_i == level: + break + + self.features = nn.Sequential(*vgg_avg_pooling) + + self.max_iters = max_iters + self.temperature = temperature + self.do_erode = erode_mask_size > 0 + if self.do_erode: + self.erode_mask = nn.Conv2d(1, 1, erode_mask_size, padding=erode_mask_size // 2, bias=False) + self.erode_mask.weight.data.fill_(1) + + def forward(self, real_img, pred_img, mask): + with torch.no_grad(): + real_img = (real_img - IMAGENET_MEAN.to(real_img)) / IMAGENET_STD.to(real_img) + real_feats = self.features(real_img) + + vertical_sim = torch.exp(-(real_feats[:, :, 1:] - real_feats[:, :, :-1]).pow(2).sum(1, keepdim=True) + / self.temperature) + horizontal_sim = torch.exp(-(real_feats[:, :, :, 1:] - real_feats[:, :, :, :-1]).pow(2).sum(1, keepdim=True) + / self.temperature) + + mask_scaled = F.interpolate(mask, size=real_feats.shape[-2:], mode='bilinear', align_corners=False) + if self.do_erode: + mask_scaled = (self.erode_mask(mask_scaled) > 1).float() + + cur_knowness = 1 - mask_scaled + + for iter_i in range(self.max_iters): + new_top_knowness = F.pad(cur_knowness[:, :, :-1] * vertical_sim, (0, 0, 1, 0), mode='replicate') + new_bottom_knowness = F.pad(cur_knowness[:, :, 1:] * vertical_sim, (0, 0, 0, 1), mode='replicate') + + new_left_knowness = F.pad(cur_knowness[:, :, :, :-1] * horizontal_sim, (1, 0, 0, 0), mode='replicate') + new_right_knowness = F.pad(cur_knowness[:, :, :, 1:] * horizontal_sim, (0, 1, 0, 0), mode='replicate') + + new_knowness = torch.stack([new_top_knowness, new_bottom_knowness, + new_left_knowness, new_right_knowness], + dim=0).max(0).values + + cur_knowness = torch.max(cur_knowness, new_knowness) + + cur_knowness = F.interpolate(cur_knowness, size=mask.shape[-2:], mode='bilinear') + result = torch.min(mask, 1 - cur_knowness) + + return result + + +def make_mask_distance_weighter(kind='none', **kwargs): + if kind == 'none': + return dummy_distance_weighter + if kind == 'blur': + return BlurMask(**kwargs) + if kind == 'edt': + return EmulatedEDTMask(**kwargs) + if kind == 'pps': + return PropagatePerceptualSim(**kwargs) + raise ValueError(f'Unknown mask distance weighter kind {kind}') diff --git a/DH-AISP/2/saicinpainting/training/losses/feature_matching.py b/DH-AISP/2/saicinpainting/training/losses/feature_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..c019895c9178817837d1a6773367b178a861dc61 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/losses/feature_matching.py @@ -0,0 +1,33 @@ +from typing import List + +import torch +import torch.nn.functional as F + + +def masked_l2_loss(pred, target, mask, weight_known, weight_missing): + per_pixel_l2 = F.mse_loss(pred, target, reduction='none') + pixel_weights = mask * weight_missing + (1 - mask) * weight_known + return (pixel_weights * per_pixel_l2).mean() + + +def masked_l1_loss(pred, target, mask, weight_known, weight_missing): + per_pixel_l1 = F.l1_loss(pred, target, reduction='none') + pixel_weights = mask * weight_missing + (1 - mask) * weight_known + return (pixel_weights * per_pixel_l1).mean() + + +def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None): + if mask is None: + res = torch.stack([F.mse_loss(fake_feat, target_feat) + for fake_feat, target_feat in zip(fake_features, target_features)]).mean() + else: + res = 0 + norm = 0 + for fake_feat, target_feat in zip(fake_features, target_features): + cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False) + error_weights = 1 - cur_mask + cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() + res = res + cur_val + norm += 1 + res = res / norm + return res diff --git a/DH-AISP/2/saicinpainting/training/losses/perceptual.py b/DH-AISP/2/saicinpainting/training/losses/perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..8c055c2b327ce7943682af5c5f9394b9fcbec506 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/losses/perceptual.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from models.ade20k import ModelBuilder +from saicinpainting.utils import check_and_warn_input_range + + +IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] +IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] + + +class PerceptualLoss(nn.Module): + def __init__(self, normalize_inputs=True): + super(PerceptualLoss, self).__init__() + + self.normalize_inputs = normalize_inputs + self.mean_ = IMAGENET_MEAN + self.std_ = IMAGENET_STD + + vgg = torchvision.models.vgg19(pretrained=True).features + vgg_avg_pooling = [] + + for weights in vgg.parameters(): + weights.requires_grad = False + + for module in vgg.modules(): + if module.__class__.__name__ == 'Sequential': + continue + elif module.__class__.__name__ == 'MaxPool2d': + vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) + else: + vgg_avg_pooling.append(module) + + self.vgg = nn.Sequential(*vgg_avg_pooling) + + def do_normalize_inputs(self, x): + return (x - self.mean_.to(x.device)) / self.std_.to(x.device) + + def partial_losses(self, input, target, mask=None): + check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses') + + # we expect input and target to be in [0, 1] range + losses = [] + + if self.normalize_inputs: + features_input = self.do_normalize_inputs(input) + features_target = self.do_normalize_inputs(target) + else: + features_input = input + features_target = target + + for layer in self.vgg[:30]: + + features_input = layer(features_input) + features_target = layer(features_target) + + if layer.__class__.__name__ == 'ReLU': + loss = F.mse_loss(features_input, features_target, reduction='none') + + if mask is not None: + cur_mask = F.interpolate(mask, size=features_input.shape[-2:], + mode='bilinear', align_corners=False) + loss = loss * (1 - cur_mask) + + loss = loss.mean(dim=tuple(range(1, len(loss.shape)))) + losses.append(loss) + + return losses + + def forward(self, input, target, mask=None): + losses = self.partial_losses(input, target, mask=mask) + return torch.stack(losses).sum(dim=0) + + def get_global_features(self, input): + check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features') + + if self.normalize_inputs: + features_input = self.do_normalize_inputs(input) + else: + features_input = input + + features_input = self.vgg(features_input) + return features_input + + +class ResNetPL(nn.Module): + def __init__(self, weight=1, + weights_path=None, arch_encoder='resnet50dilated', segmentation=True): + super().__init__() + self.impl = ModelBuilder.get_encoder(weights_path=weights_path, + arch_encoder=arch_encoder, + arch_decoder='ppm_deepsup', + fc_dim=2048, + segmentation=segmentation) + self.impl.eval() + for w in self.impl.parameters(): + w.requires_grad_(False) + + self.weight = weight + + def forward(self, pred, target): + pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) + target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) + + pred_feats = self.impl(pred, return_feature_maps=True) + target_feats = self.impl(target, return_feature_maps=True) + + result = torch.stack([F.mse_loss(cur_pred, cur_target) + for cur_pred, cur_target + in zip(pred_feats, target_feats)]).sum() * self.weight + return result diff --git a/DH-AISP/2/saicinpainting/training/losses/segmentation.py b/DH-AISP/2/saicinpainting/training/losses/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4a9f94eaae84722db584277dbbf9bc41ede357 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/losses/segmentation.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .constants import weights as constant_weights + + +class CrossEntropy2d(nn.Module): + def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs): + """ + weight (Tensor, optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size "nclasses" + """ + super(CrossEntropy2d, self).__init__() + self.reduction = reduction + self.ignore_label = ignore_label + self.weights = weights + if self.weights is not None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.weights = torch.FloatTensor(constant_weights[weights]).to(device) + + def forward(self, predict, target): + """ + Args: + predict:(n, c, h, w) + target:(n, 1, h, w) + """ + target = target.long() + assert not target.requires_grad + assert predict.dim() == 4, "{0}".format(predict.size()) + assert target.dim() == 4, "{0}".format(target.size()) + assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) + assert target.size(1) == 1, "{0}".format(target.size(1)) + assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2)) + assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3)) + target = target.squeeze(1) + n, c, h, w = predict.size() + target_mask = (target >= 0) * (target != self.ignore_label) + target = target[target_mask] + predict = predict.transpose(1, 2).transpose(2, 3).contiguous() + predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) + loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction) + return loss diff --git a/DH-AISP/2/saicinpainting/training/losses/style_loss.py b/DH-AISP/2/saicinpainting/training/losses/style_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb42d7fbc5d17a47bec7365889868505f5fdfb5 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/losses/style_loss.py @@ -0,0 +1,155 @@ +import torch +import torch.nn as nn +import torchvision.models as models + + +class PerceptualLoss(nn.Module): + r""" + Perceptual loss, VGG-based + https://arxiv.org/abs/1603.08155 + https://github.com/dxyang/StyleTransfer/blob/master/utils.py + """ + + def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): + super(PerceptualLoss, self).__init__() + self.add_module('vgg', VGG19()) + self.criterion = torch.nn.L1Loss() + self.weights = weights + + def __call__(self, x, y): + # Compute features + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + + content_loss = 0.0 + content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) + content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) + content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) + content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) + content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) + + + return content_loss + + +class VGG19(torch.nn.Module): + def __init__(self): + super(VGG19, self).__init__() + features = models.vgg19(pretrained=True).features + self.relu1_1 = torch.nn.Sequential() + self.relu1_2 = torch.nn.Sequential() + + self.relu2_1 = torch.nn.Sequential() + self.relu2_2 = torch.nn.Sequential() + + self.relu3_1 = torch.nn.Sequential() + self.relu3_2 = torch.nn.Sequential() + self.relu3_3 = torch.nn.Sequential() + self.relu3_4 = torch.nn.Sequential() + + self.relu4_1 = torch.nn.Sequential() + self.relu4_2 = torch.nn.Sequential() + self.relu4_3 = torch.nn.Sequential() + self.relu4_4 = torch.nn.Sequential() + + self.relu5_1 = torch.nn.Sequential() + self.relu5_2 = torch.nn.Sequential() + self.relu5_3 = torch.nn.Sequential() + self.relu5_4 = torch.nn.Sequential() + + for x in range(2): + self.relu1_1.add_module(str(x), features[x]) + + for x in range(2, 4): + self.relu1_2.add_module(str(x), features[x]) + + for x in range(4, 7): + self.relu2_1.add_module(str(x), features[x]) + + for x in range(7, 9): + self.relu2_2.add_module(str(x), features[x]) + + for x in range(9, 12): + self.relu3_1.add_module(str(x), features[x]) + + for x in range(12, 14): + self.relu3_2.add_module(str(x), features[x]) + + for x in range(14, 16): + self.relu3_2.add_module(str(x), features[x]) + + for x in range(16, 18): + self.relu3_4.add_module(str(x), features[x]) + + for x in range(18, 21): + self.relu4_1.add_module(str(x), features[x]) + + for x in range(21, 23): + self.relu4_2.add_module(str(x), features[x]) + + for x in range(23, 25): + self.relu4_3.add_module(str(x), features[x]) + + for x in range(25, 27): + self.relu4_4.add_module(str(x), features[x]) + + for x in range(27, 30): + self.relu5_1.add_module(str(x), features[x]) + + for x in range(30, 32): + self.relu5_2.add_module(str(x), features[x]) + + for x in range(32, 34): + self.relu5_3.add_module(str(x), features[x]) + + for x in range(34, 36): + self.relu5_4.add_module(str(x), features[x]) + + # don't need the gradients, just want the features + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + relu1_1 = self.relu1_1(x) + relu1_2 = self.relu1_2(relu1_1) + + relu2_1 = self.relu2_1(relu1_2) + relu2_2 = self.relu2_2(relu2_1) + + relu3_1 = self.relu3_1(relu2_2) + relu3_2 = self.relu3_2(relu3_1) + relu3_3 = self.relu3_3(relu3_2) + relu3_4 = self.relu3_4(relu3_3) + + relu4_1 = self.relu4_1(relu3_4) + relu4_2 = self.relu4_2(relu4_1) + relu4_3 = self.relu4_3(relu4_2) + relu4_4 = self.relu4_4(relu4_3) + + relu5_1 = self.relu5_1(relu4_4) + relu5_2 = self.relu5_2(relu5_1) + relu5_3 = self.relu5_3(relu5_2) + relu5_4 = self.relu5_4(relu5_3) + + out = { + 'relu1_1': relu1_1, + 'relu1_2': relu1_2, + + 'relu2_1': relu2_1, + 'relu2_2': relu2_2, + + 'relu3_1': relu3_1, + 'relu3_2': relu3_2, + 'relu3_3': relu3_3, + 'relu3_4': relu3_4, + + 'relu4_1': relu4_1, + 'relu4_2': relu4_2, + 'relu4_3': relu4_3, + 'relu4_4': relu4_4, + + 'relu5_1': relu5_1, + 'relu5_2': relu5_2, + 'relu5_3': relu5_3, + 'relu5_4': relu5_4, + } + return out diff --git a/DH-AISP/2/saicinpainting/training/modules/__init__.py b/DH-AISP/2/saicinpainting/training/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82e1a9096a5bd8f3fb00e899d0239b078246cad4 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/__init__.py @@ -0,0 +1,31 @@ +import logging + +from saicinpainting.training.modules.ffc import FFCResNetGenerator +from saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \ + NLayerDiscriminator, MultidilatedNLayerDiscriminator + +def make_generator(config, kind, **kwargs): + logging.info(f'Make generator {kind}') + + if kind == 'pix2pixhd_multidilated': + return MultiDilatedGlobalGenerator(**kwargs) + + if kind == 'pix2pixhd_global': + return GlobalGenerator(**kwargs) + + if kind == 'ffc_resnet': + return FFCResNetGenerator(**kwargs) + + raise ValueError(f'Unknown generator kind {kind}') + + +def make_discriminator(kind, **kwargs): + logging.info(f'Make discriminator {kind}') + + if kind == 'pix2pixhd_nlayer_multidilated': + return MultidilatedNLayerDiscriminator(**kwargs) + + if kind == 'pix2pixhd_nlayer': + return NLayerDiscriminator(**kwargs) + + raise ValueError(f'Unknown discriminator kind {kind}') diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/__init__.cpython-36.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70af05c150b8484a542d27a1b7e63b328ef25549 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/__init__.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/__init__.cpython-37.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfc3bc0b3325733a3c817de6b22ac43d97ea9f66 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/__init__.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/base.cpython-36.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/base.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21b3639da553fde2f2d1e3700e7f719148d06132 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/base.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/base.cpython-37.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/base.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5796aa60e8558beedc4b2346b050de5d6bae33e Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/base.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/depthwise_sep_conv.cpython-36.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/depthwise_sep_conv.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf3847bf1992b4ff3c3c29eea8b5c0c1895eda9d Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/depthwise_sep_conv.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/depthwise_sep_conv.cpython-37.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/depthwise_sep_conv.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d59bbad97630b204983c77e9da6fb8941db11a34 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/depthwise_sep_conv.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/ffc.cpython-36.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/ffc.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25344277b06fe90e8a3e80e17242398f0f198b3f Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/ffc.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/ffc.cpython-37.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/ffc.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..847f20226bdec4321d189e7e4bc41d4cd3bb8de0 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/ffc.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/ffc0.cpython-37.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/ffc0.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9274f2a068ad543fdd6a266326dc40b43322f9c0 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/ffc0.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/multidilated_conv.cpython-36.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/multidilated_conv.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ce4c245a7ebdeb9e22a6ab1a7f776849c468674 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/multidilated_conv.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/multidilated_conv.cpython-37.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/multidilated_conv.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a950f6bac89d40293d7423aca96c4b529992e98 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/multidilated_conv.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/pix2pixhd.cpython-36.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/pix2pixhd.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..583026bc4530b4422990036e3fd7689f9a818827 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/pix2pixhd.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/pix2pixhd.cpython-37.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/pix2pixhd.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1c247cef2f21b8e3238e6970f8bef5b19d6b6f2 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/pix2pixhd.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/spatial_transform.cpython-36.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/spatial_transform.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02470dc4537021d4cc99b4c06fefae357067f6e6 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/spatial_transform.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/spatial_transform.cpython-37.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/spatial_transform.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..638fb7b38dfe8ae079d75973acfa97bbf5d8a579 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/spatial_transform.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/squeeze_excitation.cpython-36.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/squeeze_excitation.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7adfb3f7bd56ec750700a1dcd08d96c964e2997f Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/squeeze_excitation.cpython-36.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/__pycache__/squeeze_excitation.cpython-37.pyc b/DH-AISP/2/saicinpainting/training/modules/__pycache__/squeeze_excitation.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd30cb61336155a0c3f5c537a0d41b8b6f1dbbf9 Binary files /dev/null and b/DH-AISP/2/saicinpainting/training/modules/__pycache__/squeeze_excitation.cpython-37.pyc differ diff --git a/DH-AISP/2/saicinpainting/training/modules/base.py b/DH-AISP/2/saicinpainting/training/modules/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a50c3fc7753a0bba64a5ab8c1ed64ff97e62313f --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/base.py @@ -0,0 +1,80 @@ +import abc +from typing import Tuple, List + +import torch +import torch.nn as nn + +from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv +from saicinpainting.training.modules.multidilated_conv import MultidilatedConv + + +class BaseDiscriminator(nn.Module): + @abc.abstractmethod + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Predict scores and get intermediate activations. Useful for feature matching loss + :return tuple (scores, list of intermediate activations) + """ + raise NotImplemented() + + +def get_conv_block_ctor(kind='default'): + if not isinstance(kind, str): + return kind + if kind == 'default': + return nn.Conv2d + if kind == 'depthwise': + return DepthWiseSeperableConv + if kind == 'multidilated': + return MultidilatedConv + raise ValueError(f'Unknown convolutional block kind {kind}') + + +def get_norm_layer(kind='bn'): + if not isinstance(kind, str): + return kind + if kind == 'bn': + return nn.BatchNorm2d + if kind == 'in': + return nn.InstanceNorm2d + raise ValueError(f'Unknown norm block kind {kind}') + + +def get_activation(kind='tanh'): + if kind == 'tanh': + return nn.Tanh() + if kind == 'sigmoid': + return nn.Sigmoid() + if kind is False: + return nn.Identity() + raise ValueError(f'Unknown activation kind {kind}') + + +class SimpleMultiStepGenerator(nn.Module): + def __init__(self, steps: List[nn.Module]): + super().__init__() + self.steps = nn.ModuleList(steps) + + def forward(self, x): + cur_in = x + outs = [] + for step in self.steps: + cur_out = step(cur_in) + outs.append(cur_out) + cur_in = torch.cat((cur_in, cur_out), dim=1) + return torch.cat(outs[::-1], dim=1) + +def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features): + if kind == 'convtranspose': + return [nn.ConvTranspose2d(min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, stride=2, padding=1, output_padding=1), + norm_layer(min(max_features, int(ngf * mult / 2))), activation] + elif kind == 'bilinear': + return [nn.Upsample(scale_factor=2, mode='bilinear'), + DepthWiseSeperableConv(min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, stride=1, padding=1), + norm_layer(min(max_features, int(ngf * mult / 2))), activation] + else: + raise Exception(f"Invalid deconv kind: {kind}") \ No newline at end of file diff --git a/DH-AISP/2/saicinpainting/training/modules/depthwise_sep_conv.py b/DH-AISP/2/saicinpainting/training/modules/depthwise_sep_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..83dd15c3df1d9f40baf0091a373fa224532c9ddd --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/depthwise_sep_conv.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + +class DepthWiseSeperableConv(nn.Module): + def __init__(self, in_dim, out_dim, *args, **kwargs): + super().__init__() + if 'groups' in kwargs: + # ignoring groups for Depthwise Sep Conv + del kwargs['groups'] + + self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs) + self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out \ No newline at end of file diff --git a/DH-AISP/2/saicinpainting/training/modules/fake_fakes.py b/DH-AISP/2/saicinpainting/training/modules/fake_fakes.py new file mode 100644 index 0000000000000000000000000000000000000000..45c4ad559cef2730b771a709197e00ae1c87683c --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/fake_fakes.py @@ -0,0 +1,47 @@ +import torch +from kornia import SamplePadding +from kornia.augmentation import RandomAffine, CenterCrop + + +class FakeFakesGenerator: + def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2): + self.grad_aug = RandomAffine(degrees=360, + translate=0.2, + padding_mode=SamplePadding.REFLECTION, + keepdim=False, + p=1) + self.img_aug = RandomAffine(degrees=img_aug_degree, + translate=img_aug_translate, + padding_mode=SamplePadding.REFLECTION, + keepdim=True, + p=1) + self.aug_proba = aug_proba + + def __call__(self, input_images, masks): + blend_masks = self._fill_masks_with_gradient(masks) + blend_target = self._make_blend_target(input_images) + result = input_images * (1 - blend_masks) + blend_target * blend_masks + return result, blend_masks + + def _make_blend_target(self, input_images): + batch_size = input_images.shape[0] + permuted = input_images[torch.randperm(batch_size)] + augmented = self.img_aug(input_images) + is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float() + result = augmented * is_aug + permuted * (1 - is_aug) + return result + + def _fill_masks_with_gradient(self, masks): + batch_size, _, height, width = masks.shape + grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \ + .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2) + grad = self.grad_aug(grad) + grad = CenterCrop((height, width))(grad) + grad *= masks + + grad_for_min = grad + (1 - masks) * 10 + grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None] + grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6 + grad.clamp_(min=0, max=1) + + return grad diff --git a/DH-AISP/2/saicinpainting/training/modules/ffc.py b/DH-AISP/2/saicinpainting/training/modules/ffc.py new file mode 100644 index 0000000000000000000000000000000000000000..85c91e20fdcf8217d128bc2c79be392a00b9dfc2 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/ffc.py @@ -0,0 +1,462 @@ +# Fast Fourier Convolution NeurIPS 2020 +# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py +# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from saicinpainting.training.modules.base import get_activation, BaseDiscriminator +from saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper +from saicinpainting.training.modules.squeeze_excitation import SELayer +from saicinpainting.utils import get_shape + + +class FFCSE_block(nn.Module): + + def __init__(self, channels, ratio_g): + super(FFCSE_block, self).__init__() + in_cg = int(channels * ratio_g) + in_cl = channels - in_cg + r = 16 + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv1 = nn.Conv2d(channels, channels // r, + kernel_size=1, bias=True) + self.relu1 = nn.ReLU(inplace=True) + self.conv_a2l = None if in_cl == 0 else nn.Conv2d( + channels // r, in_cl, kernel_size=1, bias=True) + self.conv_a2g = None if in_cg == 0 else nn.Conv2d( + channels // r, in_cg, kernel_size=1, bias=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = x if type(x) is tuple else (x, 0) + id_l, id_g = x + + x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1) + x = self.avgpool(x) + x = self.relu1(self.conv1(x)) + + x_l = 0 if self.conv_a2l is None else id_l * \ + self.sigmoid(self.conv_a2l(x)) + x_g = 0 if self.conv_a2g is None else id_g * \ + self.sigmoid(self.conv_a2g(x)) + return x_l, x_g + + +class FourierUnit(nn.Module): + + def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', + spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'): + # bn_layer not used + super(FourierUnit, self).__init__() + self.groups = groups + + self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) + self.bn = torch.nn.BatchNorm2d(out_channels * 2) + self.relu = torch.nn.ReLU(inplace=True) + + # squeeze and excitation block + self.use_se = use_se + if use_se: + if se_kwargs is None: + se_kwargs = {} + self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False) + + r_size = x.size() + # (batch, c, h, w/2+1, 2) + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view((batch, -1,) + ffted.size()[3:]) + + if self.spectral_pos_encoding: + height, width = ffted.shape[-2:] + coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) + coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) + ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) + + if self.use_se: + ffted = self.se(ffted) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(self.bn(ffted)) + + ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( + 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + if self.spatial_scale_factor is not None: + output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) + + return output + + +class SpectralTransform(nn.Module): + + def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.enable_lfu = enable_lfu + if stride == 2: + self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) + else: + self.downsample = nn.Identity() + + self.stride = stride + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels // + 2, kernel_size=1, groups=groups, bias=False), + nn.BatchNorm2d(out_channels // 2), + nn.ReLU(inplace=True) + ) + self.fu = FourierUnit( + out_channels // 2, out_channels // 2, groups, **fu_kwargs) + if self.enable_lfu: + self.lfu = FourierUnit( + out_channels // 2, out_channels // 2, groups) + self.conv2 = torch.nn.Conv2d( + out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) + + def forward(self, x): + + x = self.downsample(x) + x = self.conv1(x) + output = self.fu(x) + + if self.enable_lfu: + n, c, h, w = x.shape + split_no = 2 + split_s = h // split_no + + + + + split_w = w // split_no + + + xs = torch.cat(torch.split( + x[:, :c // 4], split_s, dim=-2), dim=1).contiguous() + + # shape=(xs.shape[0], xs.shape[1], xs.shape[3], split_no*split_w) + # xs.resize_(shape) + + ww=torch.split(xs, split_w, dim=-1) + + + xs = torch.cat(ww, + dim=1).contiguous() + xs = self.lfu(xs) + xs = xs.repeat(1, 1, split_no, split_no).contiguous() + else: + xs = 0 + + output = self.conv2(x + output + xs) + + return output + + +class FFC(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, + ratio_gin=0.5, ratio_gout=0.5, stride=1, padding=0, + dilation=1, groups=1, bias=False, enable_lfu=True, + padding_type='reflect', gated=False, **spectral_kwargs): + super(FFC, self).__init__() + + assert stride == 1 or stride == 2, "Stride should be 1 or 2." + self.stride = stride + + in_cg = int(in_channels * ratio_gin) + in_cl = in_channels - in_cg + out_cg = int(out_channels * ratio_gout) + out_cl = out_channels - out_cg + #groups_g = 1 if groups == 1 else int(groups * ratio_gout) + #groups_l = 1 if groups == 1 else groups - groups_g + + self.ratio_gin = ratio_gin + self.ratio_gout = ratio_gout + self.global_in_num = in_cg + + module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + self.convl2l = module(in_cl, out_cl, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + self.convl2g = module(in_cl, out_cg, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + self.convg2l = module(in_cg, out_cl, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform + self.convg2g = module( + in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs) + + self.gated = gated + module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + self.gate = module(in_channels, 2, 1) + + def forward(self, x): + #x_l, x_g = x if type(x) is tuple else (x, 0) + + x_split=torch.split(x, x.shape[1]//2, dim=1) + + x_l=x_split[0] + x_g=x_split[1] + + + + out_xl, out_xg = 0, 0 + + if self.gated: + total_input_parts = [x_l] + if torch.is_tensor(x_g): + total_input_parts.append(x_g) + total_input = torch.cat(total_input_parts, dim=1) + + gates = torch.sigmoid(self.gate(total_input)) + g2l_gate, l2g_gate = gates.chunk(2, dim=1) + else: + g2l_gate, l2g_gate = 1, 1 + + if self.ratio_gout != 1: + out_xl = self.convl2l(x_l)+self.convg2l(x_g) + if self.ratio_gout != 0: + out_xg = self.convl2g(x_l) + self.convg2g(x_g) + + return out_xl, out_xg + + +class FFC_BN_ACT(nn.Module): + + def __init__(self, in_channels, out_channels, + kernel_size, ratio_gin=0.5, ratio_gout=0.5, + stride=1, padding=0, dilation=1, groups=1, bias=False, + norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity, + padding_type='reflect', + enable_lfu=True, **kwargs): + super(FFC_BN_ACT, self).__init__() + self.ffc = FFC(in_channels, out_channels, kernel_size, + ratio_gin, ratio_gout, stride, padding, dilation, + groups, bias, enable_lfu, padding_type=padding_type, **kwargs) + lnorm = nn.Identity if ratio_gout == 1 else norm_layer + gnorm = nn.Identity if ratio_gout == 0 else norm_layer + global_channels = int(out_channels * ratio_gout) + self.bn_l = lnorm(out_channels - global_channels) + self.bn_g = gnorm(global_channels) + + lact = nn.Identity if ratio_gout == 1 else activation_layer + gact = nn.Identity if ratio_gout == 0 else activation_layer + self.act_l = lact(inplace=True) + self.act_g = gact(inplace=True) + + def forward(self, x): + x_l, x_g = self.ffc(x) + x_l = self.act_l(self.bn_l(x_l)) + x_g = self.act_g(self.bn_g(x_g)) + return x_l, x_g + + +class FFCResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1, + spatial_transform_kwargs=None, inline=False, **conv_kwargs): + super().__init__() + + self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + if spatial_transform_kwargs is not None: + self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs) + self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs) + self.inline = inline + + def forward(self, x_l, x_g): + # if self.inline: + # x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:] + # else: + # x_l, x_g = x if type(x) is tuple else (x, 0) + + + # x_split=torch.split(x, x.shape[1]//2, dim=1) + + # x_l=x_split[0] + # x_g=x_split[1] + + + id_l, id_g = x_l, x_g + + x_l, x_g = self.conv1(torch.cat([x_l, x_g],dim=1)) + x_l, x_g = self.conv2(torch.cat([x_l, x_g],dim=1)) + + x_l, x_g = id_l + x_l, id_g + x_g + # if self.inline: + # out = torch.cat(out, dim=1) + return x_l, x_g + + +class ConcatTupleLayer(nn.Module): + def forward(self, x): + assert isinstance(x, tuple) + x_l, x_g = x + assert torch.is_tensor(x_l) or torch.is_tensor(x_g) + if not torch.is_tensor(x_g): + return x_l + return torch.cat(x, dim=1) + + +class FFCResNetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect', activation_layer=nn.ReLU, + up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), + init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={}, + spatial_transform_layers=None, spatial_transform_kwargs={}, + add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + model = [nn.ReflectionPad2d(3), + FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer, + activation_layer=activation_layer, **init_conv_kwargs)] + + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + if i == n_downsampling - 1: + cur_conv_kwargs = dict(downsample_conv_kwargs) + cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0) + else: + cur_conv_kwargs = downsample_conv_kwargs + model += [FFC_BN_ACT(min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, stride=2, padding=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + **cur_conv_kwargs)] + + mult = 2 ** n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + ### resnet blocks + for i in range(n_blocks): + cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer, + norm_layer=norm_layer, **resnet_conv_kwargs) + if spatial_transform_layers is not None and i in spatial_transform_layers: + cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs) + model += [cur_resblock] + + model += [ConcatTupleLayer()] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, stride=2, padding=1, output_padding=1), + up_norm_layer(min(max_features, int(ngf * mult / 2))), + up_activation] + + if out_ffc: + model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer, + norm_layer=norm_layer, inline=True, **out_ffc_kwargs)] + + model += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +class FFCNLayerDiscriminator(BaseDiscriminator): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, max_features=512, + init_conv_kwargs={}, conv_kwargs={}): + super().__init__() + self.n_layers = n_layers + + def _act_ctor(inplace=True): + return nn.LeakyReLU(negative_slope=0.2, inplace=inplace) + + kw = 3 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[FFC_BN_ACT(input_nc, ndf, kernel_size=kw, padding=padw, norm_layer=norm_layer, + activation_layer=_act_ctor, **init_conv_kwargs)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, max_features) + + cur_model = [ + FFC_BN_ACT(nf_prev, nf, + kernel_size=kw, stride=2, padding=padw, + norm_layer=norm_layer, + activation_layer=_act_ctor, + **conv_kwargs) + ] + sequence.append(cur_model) + + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [ + FFC_BN_ACT(nf_prev, nf, + kernel_size=kw, stride=1, padding=padw, + norm_layer=norm_layer, + activation_layer=lambda *args, **kwargs: nn.LeakyReLU(*args, negative_slope=0.2, **kwargs), + **conv_kwargs), + ConcatTupleLayer() + ] + sequence.append(cur_model) + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + + def get_all_activations(self, x): + res = [x] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + + def forward(self, x): + act = self.get_all_activations(x) + feats = [] + for out in act[:-1]: + if isinstance(out, tuple): + if torch.is_tensor(out[1]): + out = torch.cat(out, dim=1) + else: + out = out[0] + feats.append(out) + return act[-1], feats diff --git a/DH-AISP/2/saicinpainting/training/modules/ffc0.py b/DH-AISP/2/saicinpainting/training/modules/ffc0.py new file mode 100644 index 0000000000000000000000000000000000000000..e8c934dec73e682ce10548b6ce3fab70b8a963cb --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/ffc0.py @@ -0,0 +1,462 @@ +# Fast Fourier Convolution NeurIPS 2020 +# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py +# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from saicinpainting.training.modules.base import get_activation, BaseDiscriminator +from saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper +from saicinpainting.training.modules.squeeze_excitation import SELayer +from saicinpainting.utils import get_shape + + +class FFCSE_block(nn.Module): + + def __init__(self, channels, ratio_g): + super(FFCSE_block, self).__init__() + in_cg = int(channels * ratio_g) + in_cl = channels - in_cg + r = 16 + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv1 = nn.Conv2d(channels, channels // r, + kernel_size=1, bias=True) + self.relu1 = nn.ReLU(inplace=True) + self.conv_a2l = None if in_cl == 0 else nn.Conv2d( + channels // r, in_cl, kernel_size=1, bias=True) + self.conv_a2g = None if in_cg == 0 else nn.Conv2d( + channels // r, in_cg, kernel_size=1, bias=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = x if type(x) is tuple else (x, 0) + id_l, id_g = x + + x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1) + x = self.avgpool(x) + x = self.relu1(self.conv1(x)) + + x_l = 0 if self.conv_a2l is None else id_l * \ + self.sigmoid(self.conv_a2l(x)) + x_g = 0 if self.conv_a2g is None else id_g * \ + self.sigmoid(self.conv_a2g(x)) + return x_l, x_g + + +class FourierUnit(nn.Module): + + def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', + spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'): + # bn_layer not used + super(FourierUnit, self).__init__() + self.groups = groups + + self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) + self.bn = torch.nn.BatchNorm2d(out_channels * 2) + self.relu = torch.nn.ReLU(inplace=True) + + # squeeze and excitation block + self.use_se = use_se + if use_se: + if se_kwargs is None: + se_kwargs = {} + self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False) + + r_size = x.size() + # (batch, c, h, w/2+1, 2) + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view((batch, -1,) + ffted.size()[3:]) + + if self.spectral_pos_encoding: + height, width = ffted.shape[-2:] + coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) + coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) + ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) + + if self.use_se: + ffted = self.se(ffted) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(self.bn(ffted)) + + ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( + 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + if self.spatial_scale_factor is not None: + output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) + + return output + + +class SpectralTransform(nn.Module): + + def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.enable_lfu = enable_lfu + if stride == 2: + self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) + else: + self.downsample = nn.Identity() + + self.stride = stride + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels // + 2, kernel_size=1, groups=groups, bias=False), + nn.BatchNorm2d(out_channels // 2), + nn.ReLU(inplace=True) + ) + self.fu = FourierUnit( + out_channels // 2, out_channels // 2, groups, **fu_kwargs) + if self.enable_lfu: + self.lfu = FourierUnit( + out_channels // 2, out_channels // 2, groups) + self.conv2 = torch.nn.Conv2d( + out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) + + def forward(self, x): + + x = self.downsample(x) + x = self.conv1(x) + output = self.fu(x) + + if self.enable_lfu: + n, c, h, w = x.shape + split_no = 2 + split_s = h // split_no + + + + + split_w = w // split_no + + + xs = torch.cat(torch.split( + x[:, :c // 4], split_s, dim=-2), dim=1).contiguous() + + # shape=(xs.shape[0], xs.shape[1], xs.shape[3], split_no*split_w) + # xs.resize_(shape) + + ww=torch.split(xs, split_w, dim=-1) + + + xs = torch.cat(ww, + dim=1).contiguous() + xs = self.lfu(xs) + xs = xs.repeat(1, 1, split_no, split_no).contiguous() + else: + xs = 0 + + output = self.conv2(x + output + xs) + + return output + + +class FFC(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, + ratio_gin=0.5, ratio_gout=0.5, stride=1, padding=0, + dilation=1, groups=1, bias=False, enable_lfu=True, + padding_type='reflect', gated=False, **spectral_kwargs): + super(FFC, self).__init__() + + assert stride == 1 or stride == 2, "Stride should be 1 or 2." + self.stride = stride + + in_cg = int(in_channels * ratio_gin) + in_cl = in_channels - in_cg + out_cg = int(out_channels * ratio_gout) + out_cl = out_channels - out_cg + #groups_g = 1 if groups == 1 else int(groups * ratio_gout) + #groups_l = 1 if groups == 1 else groups - groups_g + + self.ratio_gin = ratio_gin + self.ratio_gout = ratio_gout + self.global_in_num = in_cg + + module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + self.convl2l = module(in_cl, out_cl, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + self.convl2g = module(in_cl, out_cg, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + self.convg2l = module(in_cg, out_cl, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform + self.convg2g = module( + in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs) + + self.gated = gated + module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + self.gate = module(in_channels, 2, 1) + + def forward(self, x): + #x_l, x_g = x if type(x) is tuple else (x, 0) + + x_split=torch.split(x, x.shape[1]//2, dim=1) + + x_l=x_split[0] + x_g=x_split[1] + + + + out_xl, out_xg = 0, 0 + + if self.gated: + total_input_parts = [x_l] + if torch.is_tensor(x_g): + total_input_parts.append(x_g) + total_input = torch.cat(total_input_parts, dim=1) + + gates = torch.sigmoid(self.gate(total_input)) + g2l_gate, l2g_gate = gates.chunk(2, dim=1) + else: + g2l_gate, l2g_gate = 1, 1 + + if self.ratio_gout != 1: + out_xl = self.convl2l(x_l) + if self.ratio_gout != 0: + out_xg = self.convl2g(x_l) + self.convg2g(x_g) + + return out_xl, out_xg + + +class FFC_BN_ACT(nn.Module): + + def __init__(self, in_channels, out_channels, + kernel_size, ratio_gin=0.5, ratio_gout=0.5, + stride=1, padding=0, dilation=1, groups=1, bias=False, + norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity, + padding_type='reflect', + enable_lfu=True, **kwargs): + super(FFC_BN_ACT, self).__init__() + self.ffc = FFC(in_channels, out_channels, kernel_size, + ratio_gin, ratio_gout, stride, padding, dilation, + groups, bias, enable_lfu, padding_type=padding_type, **kwargs) + lnorm = nn.Identity if ratio_gout == 1 else norm_layer + gnorm = nn.Identity if ratio_gout == 0 else norm_layer + global_channels = int(out_channels * ratio_gout) + self.bn_l = lnorm(out_channels - global_channels) + self.bn_g = gnorm(global_channels) + + lact = nn.Identity if ratio_gout == 1 else activation_layer + gact = nn.Identity if ratio_gout == 0 else activation_layer + self.act_l = lact(inplace=True) + self.act_g = gact(inplace=True) + + def forward(self, x): + x_l, x_g = self.ffc(x) + x_l = self.act_l(self.bn_l(x_l)) + x_g = self.act_g(self.bn_g(x_g)) + return x_l, x_g + + +class FFCResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1, + spatial_transform_kwargs=None, inline=False, **conv_kwargs): + super().__init__() + + self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + if spatial_transform_kwargs is not None: + self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs) + self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs) + self.inline = inline + + def forward(self, x_l, x_g): + # if self.inline: + # x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:] + # else: + # x_l, x_g = x if type(x) is tuple else (x, 0) + + + # x_split=torch.split(x, x.shape[1]//2, dim=1) + + # x_l=x_split[0] + # x_g=x_split[1] + + + id_l, id_g = x_l, x_g + + x_l, x_g = self.conv1(torch.cat([x_l, x_g],dim=1)) + x_l, x_g = self.conv2(torch.cat([x_l, x_g],dim=1)) + + x_l, x_g = id_l + x_l, id_g + x_g + # if self.inline: + # out = torch.cat(out, dim=1) + return x_l, x_g + + +class ConcatTupleLayer(nn.Module): + def forward(self, x): + assert isinstance(x, tuple) + x_l, x_g = x + assert torch.is_tensor(x_l) or torch.is_tensor(x_g) + if not torch.is_tensor(x_g): + return x_l + return torch.cat(x, dim=1) + + +class FFCResNetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect', activation_layer=nn.ReLU, + up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), + init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={}, + spatial_transform_layers=None, spatial_transform_kwargs={}, + add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + model = [nn.ReflectionPad2d(3), + FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer, + activation_layer=activation_layer, **init_conv_kwargs)] + + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + if i == n_downsampling - 1: + cur_conv_kwargs = dict(downsample_conv_kwargs) + cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0) + else: + cur_conv_kwargs = downsample_conv_kwargs + model += [FFC_BN_ACT(min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, stride=2, padding=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + **cur_conv_kwargs)] + + mult = 2 ** n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + ### resnet blocks + for i in range(n_blocks): + cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer, + norm_layer=norm_layer, **resnet_conv_kwargs) + if spatial_transform_layers is not None and i in spatial_transform_layers: + cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs) + model += [cur_resblock] + + model += [ConcatTupleLayer()] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, stride=2, padding=1, output_padding=1), + up_norm_layer(min(max_features, int(ngf * mult / 2))), + up_activation] + + if out_ffc: + model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer, + norm_layer=norm_layer, inline=True, **out_ffc_kwargs)] + + model += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +class FFCNLayerDiscriminator(BaseDiscriminator): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, max_features=512, + init_conv_kwargs={}, conv_kwargs={}): + super().__init__() + self.n_layers = n_layers + + def _act_ctor(inplace=True): + return nn.LeakyReLU(negative_slope=0.2, inplace=inplace) + + kw = 3 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[FFC_BN_ACT(input_nc, ndf, kernel_size=kw, padding=padw, norm_layer=norm_layer, + activation_layer=_act_ctor, **init_conv_kwargs)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, max_features) + + cur_model = [ + FFC_BN_ACT(nf_prev, nf, + kernel_size=kw, stride=2, padding=padw, + norm_layer=norm_layer, + activation_layer=_act_ctor, + **conv_kwargs) + ] + sequence.append(cur_model) + + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [ + FFC_BN_ACT(nf_prev, nf, + kernel_size=kw, stride=1, padding=padw, + norm_layer=norm_layer, + activation_layer=lambda *args, **kwargs: nn.LeakyReLU(*args, negative_slope=0.2, **kwargs), + **conv_kwargs), + ConcatTupleLayer() + ] + sequence.append(cur_model) + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + + def get_all_activations(self, x): + res = [x] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + + def forward(self, x): + act = self.get_all_activations(x) + feats = [] + for out in act[:-1]: + if isinstance(out, tuple): + if torch.is_tensor(out[1]): + out = torch.cat(out, dim=1) + else: + out = out[0] + feats.append(out) + return act[-1], feats diff --git a/DH-AISP/2/saicinpainting/training/modules/multidilated_conv.py b/DH-AISP/2/saicinpainting/training/modules/multidilated_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..d267ee2aa5eb84b6a9291d0eaaff322c6c2802d0 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/multidilated_conv.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import random +from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv + +class MultidilatedConv(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True, + shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs): + super().__init__() + convs = [] + self.equal_dim = equal_dim + assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode + if comb_mode in ('cat_out', 'cat_both'): + self.cat_out = True + if equal_dim: + assert out_dim % dilation_num == 0 + out_dims = [out_dim // dilation_num] * dilation_num + self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], []) + else: + out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] + out_dims.append(out_dim - sum(out_dims)) + index = [] + starts = [0] + out_dims[:-1] + lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)] + for i in range(out_dims[-1]): + for j in range(dilation_num): + index += list(range(starts[j], starts[j] + lengths[j])) + starts[j] += lengths[j] + self.index = index + assert(len(index) == out_dim) + self.out_dims = out_dims + else: + self.cat_out = False + self.out_dims = [out_dim] * dilation_num + + if comb_mode in ('cat_in', 'cat_both'): + if equal_dim: + assert in_dim % dilation_num == 0 + in_dims = [in_dim // dilation_num] * dilation_num + else: + in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] + in_dims.append(in_dim - sum(in_dims)) + self.in_dims = in_dims + self.cat_in = True + else: + self.cat_in = False + self.in_dims = [in_dim] * dilation_num + + conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d + dilation = min_dilation + for i in range(dilation_num): + if isinstance(padding, int): + cur_padding = padding * dilation + else: + cur_padding = padding[i] + convs.append(conv_type( + self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs + )) + if i > 0 and shared_weights: + convs[-1].weight = convs[0].weight + convs[-1].bias = convs[0].bias + dilation *= 2 + self.convs = nn.ModuleList(convs) + + self.shuffle_in_channels = shuffle_in_channels + if self.shuffle_in_channels: + # shuffle list as shuffling of tensors is nondeterministic + in_channels_permute = list(range(in_dim)) + random.shuffle(in_channels_permute) + # save as buffer so it is saved and loaded with checkpoint + self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute)) + + def forward(self, x): + if self.shuffle_in_channels: + x = x[:, self.in_channels_permute] + + outs = [] + if self.cat_in: + if self.equal_dim: + x = x.chunk(len(self.convs), dim=1) + else: + new_x = [] + start = 0 + for dim in self.in_dims: + new_x.append(x[:, start:start+dim]) + start += dim + x = new_x + for i, conv in enumerate(self.convs): + if self.cat_in: + input = x[i] + else: + input = x + outs.append(conv(input)) + if self.cat_out: + out = torch.cat(outs, dim=1)[:, self.index] + else: + out = sum(outs) + return out diff --git a/DH-AISP/2/saicinpainting/training/modules/multiscale.py b/DH-AISP/2/saicinpainting/training/modules/multiscale.py new file mode 100644 index 0000000000000000000000000000000000000000..65f0a54925593e9da8106bfc6d65a4098ce001d7 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/multiscale.py @@ -0,0 +1,244 @@ +from typing import List, Tuple, Union, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from saicinpainting.training.modules.base import get_conv_block_ctor, get_activation +from saicinpainting.training.modules.pix2pixhd import ResnetBlock + + +class ResNetHead(nn.Module): + def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)): + assert (n_blocks >= 0) + super(ResNetHead, self).__init__() + + conv_layer = get_conv_block_ctor(conv_kind) + + model = [nn.ReflectionPad2d(3), + conv_layer(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + activation] + + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf * mult * 2), + activation] + + mult = 2 ** n_downsampling + + ### resnet blocks + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, + conv_kind=conv_kind)] + + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +class ResNetTail(nn.Module): + def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect', conv_kind='default', activation=nn.ReLU(True), + up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0, + add_in_proj=None): + assert (n_blocks >= 0) + super(ResNetTail, self).__init__() + + mult = 2 ** n_downsampling + + model = [] + + if add_in_proj is not None: + model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1)) + + ### resnet blocks + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, + conv_kind=conv_kind)] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, + output_padding=1), + up_norm_layer(int(ngf * mult / 2)), + up_activation] + self.model = nn.Sequential(*model) + + out_layers = [] + for _ in range(out_extra_layers_n): + out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0), + up_norm_layer(ngf), + up_activation] + out_layers += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + + if add_out_act: + out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act)) + + self.out_proj = nn.Sequential(*out_layers) + + def forward(self, input, return_last_act=False): + features = self.model(input) + out = self.out_proj(features) + if return_last_act: + return out, features + else: + return out + + +class MultiscaleResNet(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3, + norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True), + up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0, + out_cumulative=False, return_only_hr=False): + super().__init__() + + self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling, + n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type, + conv_kind=conv_kind, activation=activation) + for i in range(n_scales)]) + tail_in_feats = ngf * (2 ** n_downsampling) + ngf + self.tails = nn.ModuleList([ResNetTail(output_nc, + ngf=ngf, n_downsampling=n_downsampling, + n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type, + conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer, + up_activation=up_activation, add_out_act=add_out_act, + out_extra_layers_n=out_extra_layers_n, + add_in_proj=None if (i == n_scales - 1) else tail_in_feats) + for i in range(n_scales)]) + + self.out_cumulative = out_cumulative + self.return_only_hr = return_only_hr + + @property + def num_scales(self): + return len(self.heads) + + def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \ + -> Union[torch.Tensor, List[torch.Tensor]]: + """ + :param ms_inputs: List of inputs of different resolutions from HR to LR + :param smallest_scales_num: int or None, number of smallest scales to take at input + :return: Depending on return_only_hr: + True: Only the most HR output + False: List of outputs of different resolutions from HR to LR + """ + if smallest_scales_num is None: + assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num) + smallest_scales_num = len(self.heads) + else: + assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num) + + cur_heads = self.heads[-smallest_scales_num:] + ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)] + + all_outputs = [] + prev_tail_features = None + for i in range(len(ms_features)): + scale_i = -i - 1 + + cur_tail_input = ms_features[-i - 1] + if prev_tail_features is not None: + if prev_tail_features.shape != cur_tail_input.shape: + prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:], + mode='bilinear', align_corners=False) + cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1) + + cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True) + + prev_tail_features = cur_tail_feats + all_outputs.append(cur_out) + + if self.out_cumulative: + all_outputs_cum = [all_outputs[0]] + for i in range(1, len(ms_features)): + cur_out = all_outputs[i] + cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:], + mode='bilinear', align_corners=False) + all_outputs_cum.append(cur_out_cum) + all_outputs = all_outputs_cum + + if self.return_only_hr: + return all_outputs[-1] + else: + return all_outputs[::-1] + + +class MultiscaleDiscriminatorSimple(nn.Module): + def __init__(self, ms_impl): + super().__init__() + self.ms_impl = nn.ModuleList(ms_impl) + + @property + def num_scales(self): + return len(self.ms_impl) + + def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \ + -> List[Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + :param ms_inputs: List of inputs of different resolutions from HR to LR + :param smallest_scales_num: int or None, number of smallest scales to take at input + :return: List of pairs (prediction, features) for different resolutions from HR to LR + """ + if smallest_scales_num is None: + assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num) + smallest_scales_num = len(self.heads) + else: + assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \ + (len(self.ms_impl), len(ms_inputs), smallest_scales_num) + + return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)] + + +class SingleToMultiScaleInputMixin: + def forward(self, x: torch.Tensor) -> List: + orig_height, orig_width = x.shape[2:] + factors = [2 ** i for i in range(self.num_scales)] + ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False) + for f in factors] + return super().forward(ms_inputs) + + +class GeneratorMultiToSingleOutputMixin: + def forward(self, x): + return super().forward(x)[0] + + +class DiscriminatorMultiToSingleOutputMixin: + def forward(self, x): + out_feat_tuples = super().forward(x) + return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist] + + +class DiscriminatorMultiToSingleOutputStackedMixin: + def __init__(self, *args, return_feats_only_levels=None, **kwargs): + super().__init__(*args, **kwargs) + self.return_feats_only_levels = return_feats_only_levels + + def forward(self, x): + out_feat_tuples = super().forward(x) + outs = [out for out, _ in out_feat_tuples] + scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:], + mode='bilinear', align_corners=False) + for cur_out in outs[1:]] + out = torch.cat(scaled_outs, dim=1) + if self.return_feats_only_levels is not None: + feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels] + else: + feat_lists = [flist for _, flist in out_feat_tuples] + feats = [f for flist in feat_lists for f in flist] + return out, feats + + +class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple): + pass + + +class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet): + pass diff --git a/DH-AISP/2/saicinpainting/training/modules/pix2pixhd.py b/DH-AISP/2/saicinpainting/training/modules/pix2pixhd.py new file mode 100644 index 0000000000000000000000000000000000000000..08c6afd777a88cd232592acbbf0ef25db8d43217 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/pix2pixhd.py @@ -0,0 +1,669 @@ +# original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py +import collections +from functools import partial +import functools +import logging +from collections import defaultdict + +import numpy as np +import torch.nn as nn + +from saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation +from saicinpainting.training.modules.ffc import FFCResnetBlock +from saicinpainting.training.modules.multidilated_conv import MultidilatedConv + +class DotDict(defaultdict): + # https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary + """dot.notation access to dictionary attributes""" + __getattr__ = defaultdict.get + __setattr__ = defaultdict.__setitem__ + __delattr__ = defaultdict.__delitem__ + +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default', + dilation=1, in_dim=None, groups=1, second_dilation=None): + super(ResnetBlock, self).__init__() + self.in_dim = in_dim + self.dim = dim + if second_dilation is None: + second_dilation = dilation + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout, + conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups, + second_dilation=second_dilation) + + if self.in_dim is not None: + self.input_conv = nn.Conv2d(in_dim, dim, 1) + + self.out_channnels = dim + + def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default', + dilation=1, in_dim=None, groups=1, second_dilation=1): + conv_layer = get_conv_block_ctor(conv_kind) + + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(dilation)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(dilation)] + elif padding_type == 'zero': + p = dilation + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + if in_dim is None: + in_dim = dim + + conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation), + norm_layer(dim), + activation] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(second_dilation)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(second_dilation)] + elif padding_type == 'zero': + p = second_dilation + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + x_before = x + if self.in_dim is not None: + x = self.input_conv(x) + out = x + self.conv_block(x_before) + return out + +class ResnetBlock5x5(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default', + dilation=1, in_dim=None, groups=1, second_dilation=None): + super(ResnetBlock5x5, self).__init__() + self.in_dim = in_dim + self.dim = dim + if second_dilation is None: + second_dilation = dilation + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout, + conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups, + second_dilation=second_dilation) + + if self.in_dim is not None: + self.input_conv = nn.Conv2d(in_dim, dim, 1) + + self.out_channnels = dim + + def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default', + dilation=1, in_dim=None, groups=1, second_dilation=1): + conv_layer = get_conv_block_ctor(conv_kind) + + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(dilation * 2)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(dilation * 2)] + elif padding_type == 'zero': + p = dilation * 2 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + if in_dim is None: + in_dim = dim + + conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation), + norm_layer(dim), + activation] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(second_dilation * 2)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(second_dilation * 2)] + elif padding_type == 'zero': + p = second_dilation * 2 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + x_before = x + if self.in_dim is not None: + x = self.input_conv(x) + out = x + self.conv_block(x_before) + return out + + +class MultidilatedResnetBlock(nn.Module): + def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False): + super().__init__() + self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout) + + def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1): + conv_block = [] + conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type), + norm_layer(dim), + activation] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +class MultiDilatedGlobalGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, + n_blocks=3, norm_layer=nn.BatchNorm2d, + padding_type='reflect', conv_kind='default', + deconv_kind='convtranspose', activation=nn.ReLU(True), + up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True), + add_out_act=True, max_features=1024, multidilation_kwargs={}, + ffc_positions=None, ffc_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + conv_layer = get_conv_block_ctor(conv_kind) + resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs) + norm_layer = get_norm_layer(norm_layer) + if affine is not None: + norm_layer = partial(norm_layer, affine=affine) + up_norm_layer = get_norm_layer(up_norm_layer) + if affine is not None: + up_norm_layer = partial(up_norm_layer, affine=affine) + + model = [nn.ReflectionPad2d(3), + conv_layer(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + activation] + + identity = Identity() + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + + model += [conv_layer(min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, stride=2, padding=1), + norm_layer(min(max_features, ngf * mult * 2)), + activation] + + mult = 2 ** n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + ### resnet blocks + for i in range(n_blocks): + if ffc_positions is not None and i in ffc_positions: + model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU, + inline=True, **ffc_kwargs)] + model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type, + conv_layer=resnet_conv_layer, activation=activation, + norm_layer=norm_layer)] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features) + model += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + +class ConfigGlobalGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, + n_blocks=3, norm_layer=nn.BatchNorm2d, + padding_type='reflect', conv_kind='default', + deconv_kind='convtranspose', activation=nn.ReLU(True), + up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True), + add_out_act=True, max_features=1024, + manual_block_spec=[], + resnet_block_kind='multidilatedresnetblock', + resnet_conv_kind='multidilated', + resnet_dilation=1, + multidilation_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + conv_layer = get_conv_block_ctor(conv_kind) + resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs) + norm_layer = get_norm_layer(norm_layer) + if affine is not None: + norm_layer = partial(norm_layer, affine=affine) + up_norm_layer = get_norm_layer(up_norm_layer) + if affine is not None: + up_norm_layer = partial(up_norm_layer, affine=affine) + + model = [nn.ReflectionPad2d(3), + conv_layer(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + activation] + + identity = Identity() + + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + model += [conv_layer(min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, stride=2, padding=1), + norm_layer(min(max_features, ngf * mult * 2)), + activation] + + mult = 2 ** n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + if len(manual_block_spec) == 0: + manual_block_spec = [ + DotDict(lambda : None, { + 'n_blocks': n_blocks, + 'use_default': True}) + ] + + ### resnet blocks + for block_spec in manual_block_spec: + def make_and_add_blocks(model, block_spec): + block_spec = DotDict(lambda : None, block_spec) + if not block_spec.use_default: + resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs) + resnet_conv_kind = block_spec.resnet_conv_kind + resnet_block_kind = block_spec.resnet_block_kind + if block_spec.resnet_dilation is not None: + resnet_dilation = block_spec.resnet_dilation + for i in range(block_spec.n_blocks): + if resnet_block_kind == "multidilatedresnetblock": + model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type, + conv_layer=resnet_conv_layer, activation=activation, + norm_layer=norm_layer)] + if resnet_block_kind == "resnetblock": + model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, + conv_kind=resnet_conv_kind)] + if resnet_block_kind == "resnetblock5x5": + model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, + conv_kind=resnet_conv_kind)] + if resnet_block_kind == "resnetblockdwdil": + model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, + conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)] + make_and_add_blocks(model, block_spec) + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features) + model += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs): + blocks = [] + for i in range(dilated_blocks_n): + if dilation_block_kind == 'simple': + blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1))) + elif dilation_block_kind == 'multi': + blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs)) + else: + raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"') + return blocks + + +class GlobalGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect', conv_kind='default', activation=nn.ReLU(True), + up_norm_layer=nn.BatchNorm2d, affine=None, + up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0, + dilated_blocks_n_middle=0, + add_out_act=True, + max_features=1024, is_resblock_depthwise=False, + ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None, + dilation_block_kind='simple', multidilation_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + conv_layer = get_conv_block_ctor(conv_kind) + norm_layer = get_norm_layer(norm_layer) + if affine is not None: + norm_layer = partial(norm_layer, affine=affine) + up_norm_layer = get_norm_layer(up_norm_layer) + if affine is not None: + up_norm_layer = partial(up_norm_layer, affine=affine) + + if ffc_positions is not None: + ffc_positions = collections.Counter(ffc_positions) + + model = [nn.ReflectionPad2d(3), + conv_layer(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + activation] + + identity = Identity() + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + + model += [conv_layer(min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, stride=2, padding=1), + norm_layer(min(max_features, ngf * mult * 2)), + activation] + + mult = 2 ** n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type, + activation=activation, norm_layer=norm_layer) + if dilation_block_kind == 'simple': + dilated_block_kwargs['conv_kind'] = conv_kind + elif dilation_block_kind == 'multi': + dilated_block_kwargs['conv_layer'] = functools.partial( + get_conv_block_ctor('multidilated'), **multidilation_kwargs) + + # dilated blocks at the start of the bottleneck sausage + if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0: + model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs) + + # resnet blocks + for i in range(n_blocks): + # dilated blocks at the middle of the bottleneck sausage + if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0: + model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs) + + if ffc_positions is not None and i in ffc_positions: + for _ in range(ffc_positions[i]): # same position can occur more than once + model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU, + inline=True, **ffc_kwargs)] + + if is_resblock_depthwise: + resblock_groups = feats_num_bottleneck + else: + resblock_groups = 1 + + model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation, + norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups, + dilation=dilation, second_dilation=second_dilation)] + + + # dilated blocks at the end of the bottleneck sausage + if dilated_blocks_n is not None and dilated_blocks_n > 0: + model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs) + + # upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, stride=2, padding=1, output_padding=1), + up_norm_layer(min(max_features, int(ngf * mult / 2))), + up_activation] + model += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +class GlobalGeneratorGated(GlobalGenerator): + def __init__(self, *args, **kwargs): + real_kwargs=dict( + conv_kind='gated_bn_relu', + activation=nn.Identity(), + norm_layer=nn.Identity + ) + real_kwargs.update(kwargs) + super().__init__(*args, **real_kwargs) + + +class GlobalGeneratorFromSuperChannels(nn.Module): + def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True): + super().__init__() + self.n_downsampling = n_downsampling + norm_layer = get_norm_layer(norm_layer) + if type(norm_layer) == functools.partial: + use_bias = (norm_layer.func == nn.InstanceNorm2d) + else: + use_bias = (norm_layer == nn.InstanceNorm2d) + + channels = self.convert_super_channels(super_channels) + self.channels = channels + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias), + norm_layer(channels[0]), + nn.ReLU(True)] + + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(channels[1+i]), + nn.ReLU(True)] + + mult = 2 ** n_downsampling + + n_blocks1 = n_blocks // 3 + n_blocks2 = n_blocks1 + n_blocks3 = n_blocks - n_blocks1 - n_blocks2 + + for i in range(n_blocks1): + c = n_downsampling + dim = channels[c] + model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)] + + for i in range(n_blocks2): + c = n_downsampling+1 + dim = channels[c] + kwargs = {} + if i == 0: + kwargs = {"in_dim": channels[c-1]} + model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)] + + for i in range(n_blocks3): + c = n_downsampling+2 + dim = channels[c] + kwargs = {} + if i == 0: + kwargs = {"in_dim": channels[c-1]} + model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(channels[n_downsampling+3+i], + channels[n_downsampling+3+i+1], + kernel_size=3, stride=2, + padding=1, output_padding=1, + bias=use_bias), + norm_layer(channels[n_downsampling+3+i+1]), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)] + + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def convert_super_channels(self, super_channels): + n_downsampling = self.n_downsampling + result = [] + cnt = 0 + + if n_downsampling == 2: + N1 = 10 + elif n_downsampling == 3: + N1 = 13 + else: + raise NotImplementedError + + for i in range(0, N1): + if i in [1,4,7,10]: + channel = super_channels[cnt] * (2 ** cnt) + config = {'channel': channel} + result.append(channel) + logging.info(f"Downsample channels {result[-1]}") + cnt += 1 + + for i in range(3): + for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)): + if len(super_channels) == 6: + channel = super_channels[3] * 4 + else: + channel = super_channels[i + 3] * 4 + config = {'channel': channel} + if counter == 0: + result.append(channel) + logging.info(f"Bottleneck channels {result[-1]}") + cnt = 2 + + for i in range(N1+9, N1+21): + if i in [22, 25,28]: + cnt -= 1 + if len(super_channels) == 6: + channel = super_channels[5 - cnt] * (2 ** cnt) + else: + channel = super_channels[7 - cnt] * (2 ** cnt) + result.append(int(channel)) + logging.info(f"Upsample channels {result[-1]}") + return result + + def forward(self, input): + return self.model(input) + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(BaseDiscriminator): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,): + super().__init__() + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + + def get_all_activations(self, x): + res = [x] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + + def forward(self, x): + act = self.get_all_activations(x) + return act[-1], act[:-1] + + +class MultidilatedNLayerDiscriminator(BaseDiscriminator): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}): + super().__init__() + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + + def get_all_activations(self, x): + res = [x] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + + def forward(self, x): + act = self.get_all_activations(x) + return act[-1], act[:-1] + + +class NLayerDiscriminatorAsGen(NLayerDiscriminator): + def forward(self, x): + return super().forward(x)[0] diff --git a/DH-AISP/2/saicinpainting/training/modules/spatial_transform.py b/DH-AISP/2/saicinpainting/training/modules/spatial_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..2de024ba08c549605a08b64d096f1f0db7b7722a --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/spatial_transform.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from kornia.geometry.transform import rotate + + +class LearnableSpatialTransformWrapper(nn.Module): + def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True): + super().__init__() + self.impl = impl + self.angle = torch.rand(1) * angle_init_range + if train_angle: + self.angle = nn.Parameter(self.angle, requires_grad=True) + self.pad_coef = pad_coef + + def forward(self, x): + if torch.is_tensor(x): + return self.inverse_transform(self.impl(self.transform(x)), x) + elif isinstance(x, tuple): + x_trans = tuple(self.transform(elem) for elem in x) + y_trans = self.impl(x_trans) + return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)) + else: + raise ValueError(f'Unexpected input type {type(x)}') + + def transform(self, x): + height, width = x.shape[2:] + pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) + x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') + x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) + return x_padded_rotated + + def inverse_transform(self, y_padded_rotated, orig_x): + height, width = orig_x.shape[2:] + pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) + + y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) + y_height, y_width = y_padded.shape[2:] + y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w] + return y + + +if __name__ == '__main__': + layer = LearnableSpatialTransformWrapper(nn.Identity()) + x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float() + y = layer(x) + assert x.shape == y.shape + assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1]) + print('all ok') diff --git a/DH-AISP/2/saicinpainting/training/modules/squeeze_excitation.py b/DH-AISP/2/saicinpainting/training/modules/squeeze_excitation.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d902bb30c071acbc0fa919a134c80fed86bd6c --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/modules/squeeze_excitation.py @@ -0,0 +1,20 @@ +import torch.nn as nn + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + res = x * y.expand_as(x) + return res diff --git a/DH-AISP/2/saicinpainting/training/trainers/__init__.py b/DH-AISP/2/saicinpainting/training/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c59241f553efe4e2dd6b198e2e5656a2b1488857 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/trainers/__init__.py @@ -0,0 +1,30 @@ +import logging +import torch +from saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule + + +def get_training_model_class(kind): + if kind == 'default': + return DefaultInpaintingTrainingModule + + raise ValueError(f'Unknown trainer module {kind}') + + +def make_training_model(config): + kind = config.training_model.kind + kwargs = dict(config.training_model) + kwargs.pop('kind') + kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp' + + logging.info(f'Make training model {kind}') + + cls = get_training_model_class(kind) + return cls(config, **kwargs) + + +def load_checkpoint(train_config, path, map_location='cuda', strict=True): + model: torch.nn.Module = make_training_model(train_config) + state = torch.load(path, map_location=map_location) + model.load_state_dict(state['state_dict'], strict=strict) + model.on_load_checkpoint(state) + return model diff --git a/DH-AISP/2/saicinpainting/training/trainers/base.py b/DH-AISP/2/saicinpainting/training/trainers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b1c66fc96e7edfba7b1ee193272f92b5db7438 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/trainers/base.py @@ -0,0 +1,291 @@ +import copy +import logging +from typing import Dict, Tuple + +import pandas as pd +import pytorch_lightning as ptl +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DistributedSampler + +from saicinpainting.evaluation import make_evaluator +from saicinpainting.training.data.datasets import make_default_train_dataloader, make_default_val_dataloader +from saicinpainting.training.losses.adversarial import make_discrim_loss +from saicinpainting.training.losses.perceptual import PerceptualLoss, ResNetPL +from saicinpainting.training.modules import make_generator, make_discriminator +from saicinpainting.training.visualizers import make_visualizer +from saicinpainting.utils import add_prefix_to_keys, average_dicts, set_requires_grad, flatten_dict, \ + get_has_ddp_rank + +LOGGER = logging.getLogger(__name__) + + +def make_optimizer(parameters, kind='adamw', **kwargs): + if kind == 'adam': + optimizer_class = torch.optim.Adam + elif kind == 'adamw': + optimizer_class = torch.optim.AdamW + else: + raise ValueError(f'Unknown optimizer kind {kind}') + return optimizer_class(parameters, **kwargs) + + +def update_running_average(result: nn.Module, new_iterate_model: nn.Module, decay=0.999): + with torch.no_grad(): + res_params = dict(result.named_parameters()) + new_params = dict(new_iterate_model.named_parameters()) + + for k in res_params.keys(): + res_params[k].data.mul_(decay).add_(new_params[k].data, alpha=1 - decay) + + +def make_multiscale_noise(base_tensor, scales=6, scale_mode='bilinear'): + batch_size, _, height, width = base_tensor.shape + cur_height, cur_width = height, width + result = [] + align_corners = False if scale_mode in ('bilinear', 'bicubic') else None + for _ in range(scales): + cur_sample = torch.randn(batch_size, 1, cur_height, cur_width, device=base_tensor.device) + cur_sample_scaled = F.interpolate(cur_sample, size=(height, width), mode=scale_mode, align_corners=align_corners) + result.append(cur_sample_scaled) + cur_height //= 2 + cur_width //= 2 + return torch.cat(result, dim=1) + + +class BaseInpaintingTrainingModule(ptl.LightningModule): + def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100, + average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000, + average_generator_period=10, store_discr_outputs_for_vis=False, + **kwargs): + super().__init__(*args, **kwargs) + LOGGER.info('BaseInpaintingTrainingModule init called') + + self.config = config + + self.generator = make_generator(config, **self.config.generator) + self.use_ddp = use_ddp + + if not get_has_ddp_rank(): + LOGGER.info(f'Generator\n{self.generator}') + + if not predict_only: + self.save_hyperparameters(self.config) + self.discriminator = make_discriminator(**self.config.discriminator) + self.adversarial_loss = make_discrim_loss(**self.config.losses.adversarial) + self.visualizer = make_visualizer(**self.config.visualizer) + self.val_evaluator = make_evaluator(**self.config.evaluator) + self.test_evaluator = make_evaluator(**self.config.evaluator) + + if not get_has_ddp_rank(): + LOGGER.info(f'Discriminator\n{self.discriminator}') + + extra_val = self.config.data.get('extra_val', ()) + if extra_val: + self.extra_val_titles = list(extra_val) + self.extra_evaluators = nn.ModuleDict({k: make_evaluator(**self.config.evaluator) + for k in extra_val}) + else: + self.extra_evaluators = {} + + self.average_generator = average_generator + self.generator_avg_beta = generator_avg_beta + self.average_generator_start_step = average_generator_start_step + self.average_generator_period = average_generator_period + self.generator_average = None + self.last_generator_averaging_step = -1 + self.store_discr_outputs_for_vis = store_discr_outputs_for_vis + + if self.config.losses.get("l1", {"weight_known": 0})['weight_known'] > 0: + self.loss_l1 = nn.L1Loss(reduction='none') + + if self.config.losses.get("mse", {"weight": 0})['weight'] > 0: + self.loss_mse = nn.MSELoss(reduction='none') + + if self.config.losses.perceptual.weight > 0: + self.loss_pl = PerceptualLoss() + + if self.config.losses.get("resnet_pl", {"weight": 0})['weight'] > 0: + self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl) + else: + self.loss_resnet_pl = None + + self.visualize_each_iters = visualize_each_iters + LOGGER.info('BaseInpaintingTrainingModule init done') + + def configure_optimizers(self): + discriminator_params = list(self.discriminator.parameters()) + return [ + dict(optimizer=make_optimizer(self.generator.parameters(), **self.config.optimizers.generator)), + dict(optimizer=make_optimizer(discriminator_params, **self.config.optimizers.discriminator)), + ] + + def train_dataloader(self): + kwargs = dict(self.config.data.train) + if self.use_ddp: + kwargs['ddp_kwargs'] = dict(num_replicas=self.trainer.num_nodes * self.trainer.num_processes, + rank=self.trainer.global_rank, + shuffle=True) + dataloader = make_default_train_dataloader(**self.config.data.train) + return dataloader + + def val_dataloader(self): + res = [make_default_val_dataloader(**self.config.data.val)] + + if self.config.data.visual_test is not None: + res = res + [make_default_val_dataloader(**self.config.data.visual_test)] + else: + res = res + res + + extra_val = self.config.data.get('extra_val', ()) + if extra_val: + res += [make_default_val_dataloader(**extra_val[k]) for k in self.extra_val_titles] + + return res + + def training_step(self, batch, batch_idx, optimizer_idx=None): + self._is_training_step = True + return self._do_step(batch, batch_idx, mode='train', optimizer_idx=optimizer_idx) + + def validation_step(self, batch, batch_idx, dataloader_idx): + extra_val_key = None + if dataloader_idx == 0: + mode = 'val' + elif dataloader_idx == 1: + mode = 'test' + else: + mode = 'extra_val' + extra_val_key = self.extra_val_titles[dataloader_idx - 2] + self._is_training_step = False + return self._do_step(batch, batch_idx, mode=mode, extra_val_key=extra_val_key) + + def training_step_end(self, batch_parts_outputs): + if self.training and self.average_generator \ + and self.global_step >= self.average_generator_start_step \ + and self.global_step >= self.last_generator_averaging_step + self.average_generator_period: + if self.generator_average is None: + self.generator_average = copy.deepcopy(self.generator) + else: + update_running_average(self.generator_average, self.generator, decay=self.generator_avg_beta) + self.last_generator_averaging_step = self.global_step + + full_loss = (batch_parts_outputs['loss'].mean() + if torch.is_tensor(batch_parts_outputs['loss']) # loss is not tensor when no discriminator used + else torch.tensor(batch_parts_outputs['loss']).float().requires_grad_(True)) + log_info = {k: v.mean() for k, v in batch_parts_outputs['log_info'].items()} + self.log_dict(log_info, on_step=True, on_epoch=False) + return full_loss + + def validation_epoch_end(self, outputs): + outputs = [step_out for out_group in outputs for step_out in out_group] + averaged_logs = average_dicts(step_out['log_info'] for step_out in outputs) + self.log_dict({k: v.mean() for k, v in averaged_logs.items()}) + + pd.set_option('display.max_columns', 500) + pd.set_option('display.width', 1000) + + # standard validation + val_evaluator_states = [s['val_evaluator_state'] for s in outputs if 'val_evaluator_state' in s] + val_evaluator_res = self.val_evaluator.evaluation_end(states=val_evaluator_states) + val_evaluator_res_df = pd.DataFrame(val_evaluator_res).stack(1).unstack(0) + val_evaluator_res_df.dropna(axis=1, how='all', inplace=True) + LOGGER.info(f'Validation metrics after epoch #{self.current_epoch}, ' + f'total {self.global_step} iterations:\n{val_evaluator_res_df}') + + for k, v in flatten_dict(val_evaluator_res).items(): + self.log(f'val_{k}', v) + + # standard visual test + test_evaluator_states = [s['test_evaluator_state'] for s in outputs + if 'test_evaluator_state' in s] + test_evaluator_res = self.test_evaluator.evaluation_end(states=test_evaluator_states) + test_evaluator_res_df = pd.DataFrame(test_evaluator_res).stack(1).unstack(0) + test_evaluator_res_df.dropna(axis=1, how='all', inplace=True) + LOGGER.info(f'Test metrics after epoch #{self.current_epoch}, ' + f'total {self.global_step} iterations:\n{test_evaluator_res_df}') + + for k, v in flatten_dict(test_evaluator_res).items(): + self.log(f'test_{k}', v) + + # extra validations + if self.extra_evaluators: + for cur_eval_title, cur_evaluator in self.extra_evaluators.items(): + cur_state_key = f'extra_val_{cur_eval_title}_evaluator_state' + cur_states = [s[cur_state_key] for s in outputs if cur_state_key in s] + cur_evaluator_res = cur_evaluator.evaluation_end(states=cur_states) + cur_evaluator_res_df = pd.DataFrame(cur_evaluator_res).stack(1).unstack(0) + cur_evaluator_res_df.dropna(axis=1, how='all', inplace=True) + LOGGER.info(f'Extra val {cur_eval_title} metrics after epoch #{self.current_epoch}, ' + f'total {self.global_step} iterations:\n{cur_evaluator_res_df}') + for k, v in flatten_dict(cur_evaluator_res).items(): + self.log(f'extra_val_{cur_eval_title}_{k}', v) + + def _do_step(self, batch, batch_idx, mode='train', optimizer_idx=None, extra_val_key=None): + if optimizer_idx == 0: # step for generator + set_requires_grad(self.generator, True) + set_requires_grad(self.discriminator, False) + elif optimizer_idx == 1: # step for discriminator + set_requires_grad(self.generator, False) + set_requires_grad(self.discriminator, True) + + batch = self(batch) + + total_loss = 0 + metrics = {} + + if optimizer_idx is None or optimizer_idx == 0: # step for generator + total_loss, metrics = self.generator_loss(batch) + + elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator + if self.config.losses.adversarial.weight > 0: + total_loss, metrics = self.discriminator_loss(batch) + + if self.get_ddp_rank() in (None, 0) and (batch_idx % self.visualize_each_iters == 0 or mode == 'test'): + if self.config.losses.adversarial.weight > 0: + if self.store_discr_outputs_for_vis: + with torch.no_grad(): + self.store_discr_outputs(batch) + vis_suffix = f'_{mode}' + if mode == 'extra_val': + vis_suffix += f'_{extra_val_key}' + self.visualizer(self.current_epoch, batch_idx, batch, suffix=vis_suffix) + + metrics_prefix = f'{mode}_' + if mode == 'extra_val': + metrics_prefix += f'{extra_val_key}_' + result = dict(loss=total_loss, log_info=add_prefix_to_keys(metrics, metrics_prefix)) + if mode == 'val': + result['val_evaluator_state'] = self.val_evaluator.process_batch(batch) + elif mode == 'test': + result['test_evaluator_state'] = self.test_evaluator.process_batch(batch) + elif mode == 'extra_val': + result[f'extra_val_{extra_val_key}_evaluator_state'] = self.extra_evaluators[extra_val_key].process_batch(batch) + + return result + + def get_current_generator(self, no_average=False): + if not no_average and not self.training and self.average_generator and self.generator_average is not None: + return self.generator_average + return self.generator + + def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys""" + raise NotImplementedError() + + def generator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError() + + def discriminator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError() + + def store_discr_outputs(self, batch): + out_size = batch['image'].shape[2:] + discr_real_out, _ = self.discriminator(batch['image']) + discr_fake_out, _ = self.discriminator(batch['predicted_image']) + batch['discr_output_real'] = F.interpolate(discr_real_out, size=out_size, mode='nearest') + batch['discr_output_fake'] = F.interpolate(discr_fake_out, size=out_size, mode='nearest') + batch['discr_output_diff'] = batch['discr_output_real'] - batch['discr_output_fake'] + + def get_ddp_rank(self): + return self.trainer.global_rank if (self.trainer.num_nodes * self.trainer.num_processes) > 1 else None diff --git a/DH-AISP/2/saicinpainting/training/trainers/default.py b/DH-AISP/2/saicinpainting/training/trainers/default.py new file mode 100644 index 0000000000000000000000000000000000000000..86c7f0fab42924bfc93a031e851117634c70f593 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/trainers/default.py @@ -0,0 +1,175 @@ +import logging + +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf + +from saicinpainting.training.data.datasets import make_constant_area_crop_params +from saicinpainting.training.losses.distance_weighting import make_mask_distance_weighter +from saicinpainting.training.losses.feature_matching import feature_matching_loss, masked_l1_loss +from saicinpainting.training.modules.fake_fakes import FakeFakesGenerator +from saicinpainting.training.trainers.base import BaseInpaintingTrainingModule, make_multiscale_noise +from saicinpainting.utils import add_prefix_to_keys, get_ramp + +LOGGER = logging.getLogger(__name__) + + +def make_constant_area_crop_batch(batch, **kwargs): + crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2], + img_width=batch['image'].shape[3], + **kwargs) + batch['image'] = batch['image'][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width] + batch['mask'] = batch['mask'][:, :, crop_y: crop_y + crop_height, crop_x: crop_x + crop_width] + return batch + + +class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule): + def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image', + add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None, + distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False, + fake_fakes_proba=0, fake_fakes_generator_kwargs=None, + **kwargs): + super().__init__(*args, **kwargs) + self.concat_mask = concat_mask + self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None + self.image_to_discriminator = image_to_discriminator + self.add_noise_kwargs = add_noise_kwargs + self.noise_fill_hole = noise_fill_hole + self.const_area_crop_kwargs = const_area_crop_kwargs + self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \ + if distance_weighter_kwargs is not None else None + self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr + + self.fake_fakes_proba = fake_fakes_proba + if self.fake_fakes_proba > 1e-3: + self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {})) + + def forward(self, batch): + if self.training and self.rescale_size_getter is not None: + cur_size = self.rescale_size_getter(self.global_step) + batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False) + batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest') + + if self.training and self.const_area_crop_kwargs is not None: + batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs) + + img = batch['image'] + mask = batch['mask'] + + masked_img = img * (1 - mask) + + if self.add_noise_kwargs is not None: + noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs) + if self.noise_fill_hole: + masked_img = masked_img + mask * noise[:, :masked_img.shape[1]] + masked_img = torch.cat([masked_img, noise], dim=1) + + if self.concat_mask: + masked_img = torch.cat([masked_img, mask], dim=1) + + batch['predicted_image'] = self.generator(masked_img) + batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image'] + + if self.fake_fakes_proba > 1e-3: + if self.training and torch.rand(1).item() < self.fake_fakes_proba: + batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask) + batch['use_fake_fakes'] = True + else: + batch['fake_fakes'] = torch.zeros_like(img) + batch['fake_fakes_masks'] = torch.zeros_like(mask) + batch['use_fake_fakes'] = False + + batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \ + if self.refine_mask_for_losses is not None and self.training \ + else mask + + return batch + + def generator_loss(self, batch): + img = batch['image'] + predicted_img = batch[self.image_to_discriminator] + original_mask = batch['mask'] + supervised_mask = batch['mask_for_losses'] + + # L1 + l1_value = masked_l1_loss(predicted_img, img, supervised_mask, + self.config.losses.l1.weight_known, + self.config.losses.l1.weight_missing) + + total_loss = l1_value + metrics = dict(gen_l1=l1_value) + + # vgg-based perceptual loss + if self.config.losses.perceptual.weight > 0: + pl_value = self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight + total_loss = total_loss + pl_value + metrics['gen_pl'] = pl_value + + # discriminator + # adversarial_loss calls backward by itself + mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask + self.adversarial_loss.pre_generator_step(real_batch=img, fake_batch=predicted_img, + generator=self.generator, discriminator=self.discriminator) + discr_real_pred, discr_real_features = self.discriminator(img) + discr_fake_pred, discr_fake_features = self.discriminator(predicted_img) + adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(real_batch=img, + fake_batch=predicted_img, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_pred, + mask=mask_for_discr) + total_loss = total_loss + adv_gen_loss + metrics['gen_adv'] = adv_gen_loss + metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) + + # feature matching + if self.config.losses.feature_matching.weight > 0: + need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get('pass_mask', False) + mask_for_fm = supervised_mask if need_mask_in_fm else None + fm_value = feature_matching_loss(discr_fake_features, discr_real_features, + mask=mask_for_fm) * self.config.losses.feature_matching.weight + total_loss = total_loss + fm_value + metrics['gen_fm'] = fm_value + + if self.loss_resnet_pl is not None: + resnet_pl_value = self.loss_resnet_pl(predicted_img, img) + total_loss = total_loss + resnet_pl_value + metrics['gen_resnet_pl'] = resnet_pl_value + + return total_loss, metrics + + def discriminator_loss(self, batch): + total_loss = 0 + metrics = {} + + predicted_img = batch[self.image_to_discriminator].detach() + self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=predicted_img, + generator=self.generator, discriminator=self.discriminator) + discr_real_pred, discr_real_features = self.discriminator(batch['image']) + discr_fake_pred, discr_fake_features = self.discriminator(predicted_img) + adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(real_batch=batch['image'], + fake_batch=predicted_img, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_pred, + mask=batch['mask']) + total_loss = total_loss + adv_discr_loss + metrics['discr_adv'] = adv_discr_loss + metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) + + + if batch.get('use_fake_fakes', False): + fake_fakes = batch['fake_fakes'] + self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=fake_fakes, + generator=self.generator, discriminator=self.discriminator) + discr_fake_fakes_pred, _ = self.discriminator(fake_fakes) + fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss( + real_batch=batch['image'], + fake_batch=fake_fakes, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_fakes_pred, + mask=batch['mask'] + ) + total_loss = total_loss + fake_fakes_adv_discr_loss + metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss + metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_')) + + return total_loss, metrics diff --git a/DH-AISP/2/saicinpainting/training/visualizers/__init__.py b/DH-AISP/2/saicinpainting/training/visualizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4770d1f15a6790ab9606c7b9881f798c8e2d9545 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/visualizers/__init__.py @@ -0,0 +1,15 @@ +import logging + +from saicinpainting.training.visualizers.directory import DirectoryVisualizer +from saicinpainting.training.visualizers.noop import NoopVisualizer + + +def make_visualizer(kind, **kwargs): + logging.info(f'Make visualizer {kind}') + + if kind == 'directory': + return DirectoryVisualizer(**kwargs) + if kind == 'noop': + return NoopVisualizer() + + raise ValueError(f'Unknown visualizer kind {kind}') diff --git a/DH-AISP/2/saicinpainting/training/visualizers/base.py b/DH-AISP/2/saicinpainting/training/visualizers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..675f01682ddf5e31b6cc341735378c6f3b242e49 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/visualizers/base.py @@ -0,0 +1,73 @@ +import abc +from typing import Dict, List + +import numpy as np +import torch +from skimage import color +from skimage.segmentation import mark_boundaries + +from . import colors + +COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation + + +class BaseVisualizer: + @abc.abstractmethod + def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): + """ + Take a batch, make an image from it and visualize + """ + raise NotImplementedError() + + +def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str], + last_without_mask=True, rescale_keys=None, mask_only_first=None, + black_mask=False) -> np.ndarray: + mask = images_dict['mask'] > 0.5 + result = [] + for i, k in enumerate(keys): + img = images_dict[k] + img = np.transpose(img, (1, 2, 0)) + + if rescale_keys is not None and k in rescale_keys: + img = img - img.min() + img /= img.max() + 1e-5 + if len(img.shape) == 2: + img = np.expand_dims(img, 2) + + if img.shape[2] == 1: + img = np.repeat(img, 3, axis=2) + elif (img.shape[2] > 3): + img_classes = img.argmax(2) + img = color.label2rgb(img_classes, colors=COLORS) + + if mask_only_first: + need_mark_boundaries = i == 0 + else: + need_mark_boundaries = i < len(keys) - 1 or not last_without_mask + + if need_mark_boundaries: + if black_mask: + img = img * (1 - mask[0][..., None]) + img = mark_boundaries(img, + mask[0], + color=(1., 0., 0.), + outline_color=(1., 1., 1.), + mode='thick') + result.append(img) + return np.concatenate(result, axis=1) + + +def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10, + last_without_mask=True, rescale_keys=None) -> np.ndarray: + batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items() + if k in keys or k == 'mask'} + + batch_size = next(iter(batch.values())).shape[0] + items_to_vis = min(batch_size, max_items) + result = [] + for i in range(items_to_vis): + cur_dct = {k: tens[i] for k, tens in batch.items()} + result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask, + rescale_keys=rescale_keys)) + return np.concatenate(result, axis=0) diff --git a/DH-AISP/2/saicinpainting/training/visualizers/colors.py b/DH-AISP/2/saicinpainting/training/visualizers/colors.py new file mode 100644 index 0000000000000000000000000000000000000000..9e9e39182c58cb06a1c5e97a7e6c497cc3388ebe --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/visualizers/colors.py @@ -0,0 +1,76 @@ +import random +import colorsys + +import numpy as np +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap + + +def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False): + # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib + """ + Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks + :param nlabels: Number of labels (size of colormap) + :param type: 'bright' for strong colors, 'soft' for pastel colors + :param first_color_black: Option to use first color as black, True or False + :param last_color_black: Option to use last color as black, True or False + :param verbose: Prints the number of labels and shows the colormap. True or False + :return: colormap for matplotlib + """ + if type not in ('bright', 'soft'): + print ('Please choose "bright" or "soft" for type') + return + + if verbose: + print('Number of labels: ' + str(nlabels)) + + # Generate color map for bright colors, based on hsv + if type == 'bright': + randHSVcolors = [(np.random.uniform(low=0.0, high=1), + np.random.uniform(low=0.2, high=1), + np.random.uniform(low=0.9, high=1)) for i in range(nlabels)] + + # Convert HSV list to RGB + randRGBcolors = [] + for HSVcolor in randHSVcolors: + randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])) + + if first_color_black: + randRGBcolors[0] = [0, 0, 0] + + if last_color_black: + randRGBcolors[-1] = [0, 0, 0] + + random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) + + # Generate soft pastel colors, by limiting the RGB spectrum + if type == 'soft': + low = 0.6 + high = 0.95 + randRGBcolors = [(np.random.uniform(low=low, high=high), + np.random.uniform(low=low, high=high), + np.random.uniform(low=low, high=high)) for i in range(nlabels)] + + if first_color_black: + randRGBcolors[0] = [0, 0, 0] + + if last_color_black: + randRGBcolors[-1] = [0, 0, 0] + random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) + + # Display colorbar + if verbose: + from matplotlib import colors, colorbar + from matplotlib import pyplot as plt + fig, ax = plt.subplots(1, 1, figsize=(15, 0.5)) + + bounds = np.linspace(0, nlabels, nlabels + 1) + norm = colors.BoundaryNorm(bounds, nlabels) + + cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None, + boundaries=bounds, format='%1i', orientation=u'horizontal') + + return randRGBcolors, random_colormap + diff --git a/DH-AISP/2/saicinpainting/training/visualizers/directory.py b/DH-AISP/2/saicinpainting/training/visualizers/directory.py new file mode 100644 index 0000000000000000000000000000000000000000..bc42e00500c7a5b70b2cef83b03e45b5bb471ff8 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/visualizers/directory.py @@ -0,0 +1,36 @@ +import os + +import cv2 +import numpy as np + +from saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch +from saicinpainting.utils import check_and_warn_input_range + + +class DirectoryVisualizer(BaseVisualizer): + DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ') + + def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10, + last_without_mask=True, rescale_keys=None): + self.outdir = outdir + os.makedirs(self.outdir, exist_ok=True) + self.key_order = key_order + self.max_items_in_batch = max_items_in_batch + self.last_without_mask = last_without_mask + self.rescale_keys = rescale_keys + + def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): + check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image') + vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch, + last_without_mask=self.last_without_mask, + rescale_keys=self.rescale_keys) + + vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8') + + curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}') + os.makedirs(curoutdir, exist_ok=True) + rank_suffix = f'_r{rank}' if rank is not None else '' + out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg') + + vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_fname, vis_img) diff --git a/DH-AISP/2/saicinpainting/training/visualizers/noop.py b/DH-AISP/2/saicinpainting/training/visualizers/noop.py new file mode 100644 index 0000000000000000000000000000000000000000..4175089a54a8484d51e6c879c1a99c4e4d961d15 --- /dev/null +++ b/DH-AISP/2/saicinpainting/training/visualizers/noop.py @@ -0,0 +1,9 @@ +from saicinpainting.training.visualizers.base import BaseVisualizer + + +class NoopVisualizer(BaseVisualizer): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): + pass diff --git a/DH-AISP/2/saicinpainting/utils.py b/DH-AISP/2/saicinpainting/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d67ed8bc793dd5113224fa322adb88f3ed9b22 --- /dev/null +++ b/DH-AISP/2/saicinpainting/utils.py @@ -0,0 +1,177 @@ +import bisect +import functools +import logging +import numbers +import os +import signal +import sys +import traceback +import warnings + +import torch +from pytorch_lightning import seed_everything + +LOGGER = logging.getLogger(__name__) + +import platform +if platform.system() != 'Linux': + signal.SIGUSR1 = 1 + +def check_and_warn_input_range(tensor, min_value, max_value, name): + actual_min = tensor.min() + actual_max = tensor.max() + if actual_min < min_value or actual_max > max_value: + warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}") + + +def sum_dict_with_prefix(target, cur_dict, prefix, default=0): + for k, v in cur_dict.items(): + target_key = prefix + k + target[target_key] = target.get(target_key, default) + v + + +def average_dicts(dict_list): + result = {} + norm = 1e-3 + for dct in dict_list: + sum_dict_with_prefix(result, dct, '') + norm += 1 + for k in list(result): + result[k] /= norm + return result + + +def add_prefix_to_keys(dct, prefix): + return {prefix + k: v for k, v in dct.items()} + + +def set_requires_grad(module, value): + for param in module.parameters(): + param.requires_grad = value + + +def flatten_dict(dct): + result = {} + for k, v in dct.items(): + if isinstance(k, tuple): + k = '_'.join(k) + if isinstance(v, dict): + for sub_k, sub_v in flatten_dict(v).items(): + result[f'{k}_{sub_k}'] = sub_v + else: + result[k] = v + return result + + +class LinearRamp: + def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): + self.start_value = start_value + self.end_value = end_value + self.start_iter = start_iter + self.end_iter = end_iter + + def __call__(self, i): + if i < self.start_iter: + return self.start_value + if i >= self.end_iter: + return self.end_value + part = (i - self.start_iter) / (self.end_iter - self.start_iter) + return self.start_value * (1 - part) + self.end_value * part + + +class LadderRamp: + def __init__(self, start_iters, values): + self.start_iters = start_iters + self.values = values + assert len(values) == len(start_iters) + 1, (len(values), len(start_iters)) + + def __call__(self, i): + segment_i = bisect.bisect_right(self.start_iters, i) + return self.values[segment_i] + + +def get_ramp(kind='ladder', **kwargs): + if kind == 'linear': + return LinearRamp(**kwargs) + if kind == 'ladder': + return LadderRamp(**kwargs) + raise ValueError(f'Unexpected ramp kind: {kind}') + + +def print_traceback_handler(sig, frame): + LOGGER.warning(f'Received signal {sig}') + bt = ''.join(traceback.format_stack()) + LOGGER.warning(f'Requested stack trace:\n{bt}') + + +def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler): + LOGGER.warning(f'Setting signal {sig} handler {handler}') + signal.signal(sig, handler) + + +def handle_deterministic_config(config): + seed = dict(config).get('seed', None) + if seed is None: + return False + + seed_everything(seed) + return True + + +def get_shape(t): + if torch.is_tensor(t): + return tuple(t.shape) + elif isinstance(t, dict): + return {n: get_shape(q) for n, q in t.items()} + elif isinstance(t, (list, tuple)): + return [get_shape(q) for q in t] + elif isinstance(t, numbers.Number): + return type(t) + else: + raise ValueError('unexpected type {}'.format(type(t))) + + +def get_has_ddp_rank(): + master_port = os.environ.get('MASTER_PORT', None) + node_rank = os.environ.get('NODE_RANK', None) + local_rank = os.environ.get('LOCAL_RANK', None) + world_size = os.environ.get('WORLD_SIZE', None) + has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None + return has_rank + + +def handle_ddp_subprocess(): + def main_decorator(main_func): + @functools.wraps(main_func) + def new_main(*args, **kwargs): + # Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE + parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) + has_parent = parent_cwd is not None + has_rank = get_has_ddp_rank() + assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' + + if has_parent: + # we are in the worker + sys.argv.extend([ + f'hydra.run.dir={parent_cwd}', + # 'hydra/hydra_logging=disabled', + # 'hydra/job_logging=disabled' + ]) + # do nothing if this is a top-level process + # TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization + + main_func(*args, **kwargs) + return new_main + return main_decorator + + +def handle_ddp_parent_process(): + parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) + has_parent = parent_cwd is not None + has_rank = get_has_ddp_rank() + assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' + + if parent_cwd is None: + os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd() + + return has_parent diff --git a/DH-AISP/2/test.py b/DH-AISP/2/test.py new file mode 100644 index 0000000000000000000000000000000000000000..5ddd825791ec88aad21a3b539577f673350e7e3f --- /dev/null +++ b/DH-AISP/2/test.py @@ -0,0 +1,149 @@ +import torch +import argparse +import torch.nn as nn +from torch.utils.data import DataLoader +from torchvision.utils import save_image as imwrite +import os +import time +import re +from torchvision import transforms + +from test_dataset_for_testing import dehaze_test_dataset +from model_convnext2_hdr import fusion_net +import glob +import scipy.io +import torch.optim as optim +import cv2 +import matplotlib.image +from PIL import Image +import random +import math +import numpy as np +import sys +import json + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +#run python test_05_hdr.py ./data/ ./result/ ./daylight_isp_03/ 1 2 4 + + +input_dir2 = '../data/' +input_dir = '../1/mid/' + +result_dir = '../data/' +checkpoint_dir = './result_low_light_hdr/' + +# get train IDs +train_fns = glob.glob(input_dir + '*_1.png') +train_ids = [os.path.basename(train_fn) for train_fn in train_fns] + +if not os.path.exists(result_dir): + os.mkdir(result_dir) + +def json_read(fname, **kwargs): + with open(fname) as j: + data = json.load(j, **kwargs) + return data + +def fraction_from_json(json_object): + if 'Fraction' in json_object: + return Fraction(*json_object['Fraction']) + return json_object + +def fractions2floats(fractions): + floats = [] + for fraction in fractions: + floats.append(float(fraction.numerator) / fraction.denominator) + return floats + +def reprocessing(input): + output = np.zeros(input.shape) + + input_1 = input + + output[:,:,0] = input_1[:,:,0] * 1.9021 - input_1[:,:,1] * 1.1651 + input_1[:,:,2] * 0.2630 + output[:,:,1] = input_1[:,:,0] * (-0.3189) + input_1[:,:,1] * 1.5831 - input_1[:,:,2] * 0.2643 + output[:,:,2] = input_1[:,:,0] * (-0.0662) - input_1[:,:,1] * 0.9350 + input_1[:,:,2] * 2.0013 + + result = np.clip(output, 0, 255).astype(np.uint8) + + return output + +def reprocessing1(input): + output = np.zeros(input.shape) + + input_1 = input + + output[:,:,0] = input_1[:,:,0] * 1.521689 - input_1[:,:,1] * 0.673763 + input_1[:,:,2] * 0.152074 + output[:,:,1] = input_1[:,:,0] * (-0.145724) + input_1[:,:,1] * 1.266507 - input_1[:,:,2] * 0.120783 + output[:,:,2] = input_1[:,:,0] * (-0.0397583) - input_1[:,:,1] * 0.561249 + input_1[:,:,2] * 1.60100734 + + result = np.clip(output, 0, 255).astype(np.uint8) + + return output + +# --- Gpu device --- # +device = torch.device("cuda:0") + +# --- Define the network --- # + +model_g = fusion_net() + +model_g = nn.DataParallel(model_g) + +MyEnsembleNet = model_g.to(device) + +MyEnsembleNet.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'checkpoint_gen.pth'))) + +# --- Test --- # +with torch.no_grad(): + MyEnsembleNet.eval() + + for ind in range(len(train_ids)): + print(ind) + train_id = train_ids[ind] + in_path_in = input_dir + train_id[:-5] + in_path_in_js = input_dir2 + train_id[:-5] + metadata = json_read(in_path_in_js[:-1] + '.json', object_hook=fraction_from_json) + + noise_profile = float(metadata['noise_profile'][0]) + + pic_in1 = np.asarray(Image.open(in_path_in + '1.png'), np.float32) / 255. + pic_in2 = np.asarray(Image.open(in_path_in + '2.png'), np.float32) / 255. + pic_in3 = np.asarray(Image.open(in_path_in + '3.png'), np.float32) / 255. + + pic_in = np.concatenate([pic_in1, pic_in2, pic_in3],axis=2) + #pic_in = cv2.resize(pic_in, None, fx = 0.5, fy = 0.5, interpolation=cv2.INTER_CUBIC ) + + [h,w,c] = pic_in.shape + + pad_h = 32 - h % 32 + pad_w = 32 - w % 32 + + pic_in = np.expand_dims(np.pad(pic_in, ((0, pad_h), (0, pad_w),(0,0)), mode='reflect'),axis = 0) + + in_data = torch.from_numpy(pic_in).permute(0,3,1,2).to(device) + out_data = MyEnsembleNet(in_data) + out_datass = out_data.cpu().detach().numpy().transpose((0, 2, 3, 1)) + output = np.clip(out_datass[0,:,:,:], 0, 1) + + if noise_profile < 0.02: + output = reprocessing(output) + else: + output = reprocessing1(output) + + #cv2.imwrite(result_dir + train_id[:-6] + '.png', output[0:h,0:w,::-1] * 255) + cv2.imwrite(result_dir + train_id[:-6] + '.jpg', output[0:h,0:w,::-1] * 255, [cv2.IMWRITE_JPEG_QUALITY, 100]) + + + + + + + + + + + + + diff --git a/DH-AISP/2/test_dataset_for_testing.py b/DH-AISP/2/test_dataset_for_testing.py new file mode 100644 index 0000000000000000000000000000000000000000..99aad7368ccbb24bec65353f628d9189eb297dff --- /dev/null +++ b/DH-AISP/2/test_dataset_for_testing.py @@ -0,0 +1,55 @@ +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +import os + +class dehaze_test_dataset(Dataset): + def __init__(self, test_dir): + self.transform = transforms.Compose([transforms.ToTensor()]) + self.list_test=[] + for i in os.listdir(test_dir): + self.list_test.append(i) + self.root_hazy = os.path.join(test_dir) + self.file_len = len(self.list_test) + def __getitem__(self, index, is_train=True): + hazy = Image.open(self.root_hazy + self.list_test[index]).convert('RGB') + #print(hazy) + hazy = self.transform(hazy) + #print(hazy.shape) + if hazy.shape[1]hazy.shape[2]: + hazy_up_left=hazy[:,0:2432, 0:1600] + hazy_up_middle=hazy[:, 0:2432, 1200:2800] + hazy_up_right=hazy[:,0:2432, 2400:] + + hazy_middle_left=hazy[:,1800:4232, 0:1600] + hazy_middle_middle=hazy[:, 1800:4232, 1200:2800] + hazy_middle_right=hazy[:,1800:4232, 2400:] + + hazy_down_left=hazy[:,3568:6000, 0:1600] + hazy_down_middle=hazy[:, 3568:6000, 1200:2800] + hazy_down_right=hazy[:,3568:6000, 2400:] + + + + + name=self.list_test[index] + + return hazy_up_left, hazy_up_middle, hazy_up_right, hazy_middle_left, hazy_middle_middle, hazy_middle_right, hazy_down_left, hazy_down_middle, hazy_down_right, name + def __len__(self): + return self.file_len diff --git a/DH-AISP/2/utils_test.py b/DH-AISP/2/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a7517bae363619460514c9f82c286642a36442 --- /dev/null +++ b/DH-AISP/2/utils_test.py @@ -0,0 +1,56 @@ +import torch +import torch.nn.functional as F +from math import log10 +import cv2 +import numpy as np +import torchvision +from skimage.metrics import structural_similarity as ssim +def to_psnr(frame_out, gt): + mse = F.mse_loss(frame_out, gt, reduction='none') + mse_split = torch.split(mse, 1, dim=0) + mse_list = [torch.mean(torch.squeeze(mse_split[ind])).item() for ind in range(len(mse_split))] + rmse_list = np.sqrt(mse_list) ## + intensity_max = 1.0 + psnr_list = [10.0 * log10(intensity_max / mse) for mse in mse_list] + return psnr_list, rmse_list + +def to_ssim_skimage(dehaze, gt): + dehaze_list = torch.split(dehaze, 1, dim=0) + gt_list = torch.split(gt, 1, dim=0) + + dehaze_list_np = [dehaze_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))] + gt_list_np = [gt_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))] + ssim_list = [ssim(dehaze_list_np[ind], gt_list_np[ind], data_range=1, multichannel=True) for ind in range(len(dehaze_list))] + + return ssim_list + + +def predict(gridnet, test_data_loader): + + psnr_list = [] + for batch_idx, (frame1, frame2, frame3) in enumerate(test_data_loader): + with torch.no_grad(): + frame1 = frame1.to(torch.device('cuda')) + frame3 = frame3.to(torch.device('cuda')) + gt = frame2.to(torch.device('cuda')) + # print(frame1) + + frame_out = gridnet(frame1, frame3) + # print(frame_out) + frame_debug = torch.cat((frame1, frame_out, gt, frame3), dim =0) + filepath = "./image" + str(batch_idx) + '.png' + torchvision.utils.save_image(frame_debug, filepath) + # print(frame_out) + # img = np.asarray(frame_out.cpu()).astype(float) + + # cv2.imwrite(filepath , img) + + + + # --- Calculate the average PSNR --- # + a, b = to_psnr(frame_out, gt) + psnr_list.extend(a) + avr_psnr = sum(psnr_list) / len(psnr_list) + return avr_psnr + + diff --git a/DH-AISP/2/validation.py b/DH-AISP/2/validation.py new file mode 100644 index 0000000000000000000000000000000000000000..4574691c7d55b2887194108d42b08f53aaadd656 --- /dev/null +++ b/DH-AISP/2/validation.py @@ -0,0 +1,194 @@ +import torch +import argparse +import torch.nn as nn +from torch.utils.data import DataLoader +from torchvision.utils import save_image as imwrite +import os +import time +import re +from torchvision import transforms + +from test_dataset_for_testing import dehaze_test_dataset +from model_convnext import fusion_net + + +parser = argparse.ArgumentParser(description='DWT-FFC') +parser.add_argument('--valid_dir', type=str, default='./NTIRE2023_Valid_Hazy/') #please check the path for hazy images +parser.add_argument('--valid_result', type=str, default='valid_result') +parser.add_argument('-test_batch_size', help='Set the testing batch size', default=1, type=int) +args = parser.parse_args() +output_dir =args.valid_result +if not os.path.exists(output_dir + '/'): + os.makedirs(output_dir + '/') +valid_dir = args.valid_dir +test_batch_size = args.test_batch_size + +test_dataset = dehaze_test_dataset(valid_dir) +test_loader = DataLoader(dataset=test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0) + +# --- Gpu device --- # +device = torch.device("cuda:0") + +# --- Define the network --- # +MyEnsembleNet= fusion_net() + + +MyEnsembleNet = MyEnsembleNet.to(device) + +MyEnsembleNet.load_state_dict(torch.load(os.path.join('./weights/', 'validation_best.pkl'))) + +# --- Test --- # +with torch.no_grad(): + MyEnsembleNet.eval() + start_time = time.time() + for batch_idx, (hazy_up_left, hazy_up_middle, hazy_up_right, hazy_middle_left, hazy_middle_middle, hazy_middle_right, hazy_down_left, hazy_down_middle, hazy_down_right,name) in enumerate(test_loader): + #print(hazy_up_left.shape) + + hazy_up_left = hazy_up_left.to(device) + hazy_up_middle =hazy_up_middle.to(device) + hazy_up_right = hazy_up_right.to(device) + hazy_middle_left =hazy_middle_left.to(device) + hazy_middle_middle = hazy_middle_middle.to(device) + hazy_middle_right =hazy_middle_right.to(device) + + hazy_down_left = hazy_down_left.to(device) + hazy_down_middle =hazy_down_middle.to(device) + hazy_down_right = hazy_down_right.to(device) + + frame_out_up_left = MyEnsembleNet(hazy_up_left) + frame_out_middle_left = MyEnsembleNet(hazy_middle_left) + frame_out_down_left = MyEnsembleNet(hazy_down_left) + + frame_out_up_middle = MyEnsembleNet(hazy_up_middle) + frame_out_middle_middle = MyEnsembleNet(hazy_middle_middle) + frame_out_down_middle = MyEnsembleNet(hazy_down_middle) + + frame_out_up_right = MyEnsembleNet(hazy_up_right) + frame_out_middle_right = MyEnsembleNet(hazy_middle_right) + frame_out_down_right = MyEnsembleNet(hazy_down_right) + + frame_out_up_left=frame_out_up_left.to(device) + frame_out_middle_left =frame_out_middle_left .to(device) + frame_out_down_left=frame_out_down_left.to(device) + frame_out_up_middle=frame_out_up_middle.to(device) + frame_out_middle_middle=frame_out_middle_middle.to(device) + frame_out_down_middle=frame_out_down_middle.to(device) + frame_out_up_right=frame_out_up_right.to(device) + frame_out_middle_right=frame_out_middle_right.to(device) + frame_out_down_right=frame_out_down_right.to(device) + + + if frame_out_up_left.shape[2]==1600: + frame_out_up_left_middle=(frame_out_up_left[:,:,:,1800:2432]+frame_out_up_middle[:,:,:,0:632])/2 + frame_out_up_middle_right=(frame_out_up_middle[:,:,:,1768:2432]+frame_out_up_right[:,:,:,0:664])/2 + + frame_out_middle_left_middle=(frame_out_middle_left[:,:,:,1800:2432]+frame_out_middle_middle[:,:,:,0:632])/2 + frame_out_middle_middle_right=(frame_out_middle_middle[:,:,:,1768:2432]+frame_out_middle_right[:,:,:,0:664])/2 + + frame_out_down_left_middle=(frame_out_down_left[:,:,:,1800:2432]+frame_out_down_middle[:,:,:,0:632])/2 + frame_out_down_middle_right=(frame_out_down_middle[:,:,:,1768:2432]+frame_out_down_right[:,:,:,0:664])/2 + + + + frame_out_left_up_middle=(frame_out_up_left[:,:,1200:1600,0:1800]+frame_out_middle_left[:,:,0:400,0:1800])/2 + frame_out_left_middle_down=(frame_out_middle_left[:,:,1200:1600,0:1800]+frame_out_down_left[:,:,0:400,0:1800])/2 + + frame_out_left = (torch.cat([frame_out_up_left[:, :, 0:1200, 0:1800].permute(0, 2, 3, 1),frame_out_left_up_middle.permute(0, 2, 3, 1), frame_out_middle_left[:, :, 400:1200, 0:1800].permute(0, 2, 3, 1), frame_out_left_middle_down.permute(0, 2, 3, 1), frame_out_down_left[:, :, 400:, 0:1800].permute(0, 2, 3, 1)],1)) + + + frame_out_leftmiddle_up_middle=(frame_out_up_left_middle[:,:,1200:1600,:]+frame_out_middle_left_middle[:,:,0:400,:])/2 + frame_out_leftmiddle_middle_down=(frame_out_middle_left_middle[:,:,1200:1600,:]+frame_out_down_left_middle[:,:,0:400,:])/2 + + + frame_out_leftmiddle = (torch.cat([frame_out_up_left_middle[:, :, 0:1200, :].permute(0, 2, 3, 1),frame_out_leftmiddle_up_middle.permute(0, 2, 3, 1), frame_out_middle_left_middle[:, :, 400:1200, :].permute(0, 2, 3, 1), frame_out_leftmiddle_middle_down.permute(0, 2, 3, 1), frame_out_down_left_middle[:, :, 400:, :].permute(0, 2, 3, 1)],1)) + + + frame_out_middle_up_middle=(frame_out_up_middle[:,:,1200:1600,632:1768]+frame_out_middle_middle[:,:,0:400,632:1768])/2 + frame_out_middle_middle_down=(frame_out_middle_middle[:,:,1200:1600,632:1768]+frame_out_down_middle[:,:,0:400,632:1768])/2 + + frame_out_middle = (torch.cat([frame_out_up_middle[:, :, 0:1200, 632:1768].permute(0, 2, 3, 1),frame_out_middle_up_middle.permute(0, 2, 3, 1), frame_out_middle_middle[:, :, 400:1200, 632:1768].permute(0, 2, 3, 1), frame_out_middle_middle_down.permute(0, 2, 3, 1), frame_out_down_middle[:, :, 400:, 632:1768].permute(0, 2, 3, 1)],1)) + + frame_out_middleright_up_middle=(frame_out_up_middle_right[:,:,1200:1600,:]+frame_out_middle_middle_right[:,:,0:400,:])/2 + frame_out_middleright_middle_down=(frame_out_middle_middle_right[:,:,1200:1600,:]+frame_out_down_middle_right[:,:,0:400,:])/2 + + frame_out_middleright = (torch.cat([frame_out_up_middle_right[:, :, 0:1200, :].permute(0, 2, 3, 1),frame_out_middleright_up_middle.permute(0, 2, 3, 1), frame_out_middle_middle_right[:, :, 400:1200, :].permute(0, 2, 3, 1), frame_out_middleright_middle_down.permute(0, 2, 3, 1), frame_out_down_middle_right[:, :, 400:, :].permute(0, 2, 3, 1)],1)) + + + + frame_out_right_up_middle=(frame_out_up_right[:,:,1200:1600,664:]+frame_out_middle_right[:,:,0:400,664:])/2 + frame_out_right_middle_down=(frame_out_middle_right[:,:,1200:1600,664:]+frame_out_down_right[:,:,0:400,664:])/2 + + frame_out_right = (torch.cat([frame_out_up_right[:, :, 0:1200, 664:].permute(0, 2, 3, 1),frame_out_right_up_middle.permute(0, 2, 3, 1), frame_out_middle_right[:, :, 400:1200, 664:].permute(0, 2, 3, 1), frame_out_right_middle_down.permute(0, 2, 3, 1), frame_out_down_right[:, :, 400:, 664:].permute(0, 2, 3, 1)],1)) + + + + + if frame_out_up_left.shape[2]==2432: + frame_out_up_left_middle=(frame_out_up_left[:,:,:,1200:1600]+frame_out_up_middle[:,:,:,0:400])/2 + frame_out_up_middle_right=(frame_out_up_middle[:,:,:,1200:1600]+frame_out_up_right[:,:,:,0:400])/2 + + frame_out_middle_left_middle=(frame_out_middle_left[:,:,:,1200:1600]+frame_out_middle_middle[:,:,:,0:400])/2 + frame_out_middle_middle_right=(frame_out_middle_middle[:,:,:,1200:1600]+frame_out_middle_right[:,:,:,0:400])/2 + + frame_out_down_left_middle=(frame_out_down_left[:,:,:,1200:1600]+frame_out_down_middle[:,:,:,0:400])/2 + frame_out_down_middle_right=(frame_out_down_middle[:,:,:,1200:1600]+frame_out_down_right[:,:,:,0:400])/2 + + + frame_out_left_up_middle=(frame_out_up_left[:,:,1800:2432,0:1200]+frame_out_middle_left[:,:,0:632,0:1200])/2 + frame_out_left_middle_down=(frame_out_middle_left[:,:,1768:2432,0:1200]+frame_out_down_left[:,:,0:664,0:1200])/2 + + frame_out_left = (torch.cat([frame_out_up_left[:, :, 0:1800, 0:1200].permute(0, 2, 3, 1),frame_out_left_up_middle.permute(0, 2, 3, 1), frame_out_middle_left[:, :, 632:1768, 0:1200].permute(0, 2, 3, 1), frame_out_left_middle_down.permute(0, 2, 3, 1), frame_out_down_left[:, :, 664:, 0:1200].permute(0, 2, 3, 1)],1)) + + + frame_out_leftmiddle_up_middle=(frame_out_up_left_middle[:,:,1800:2432,:]+frame_out_middle_left_middle[:,:,0:632,:])/2 + frame_out_leftmiddle_middle_down=(frame_out_middle_left_middle[:,:,1768:2432,:]+frame_out_down_left_middle[:,:,0:664,:])/2 + + + frame_out_leftmiddle = (torch.cat([frame_out_up_left_middle[:, :, 0:1800, :].permute(0, 2, 3, 1),frame_out_leftmiddle_up_middle.permute(0, 2, 3, 1), frame_out_middle_left_middle[:, :, 632:1768, :].permute(0, 2, 3, 1), frame_out_leftmiddle_middle_down.permute(0, 2, 3, 1), frame_out_down_left_middle[:, :, 664:, :].permute(0, 2, 3, 1)],1)) + + + frame_out_middle_up_middle=(frame_out_up_middle[:,:,1800:2432,400:1200]+frame_out_middle_middle[:,:,0:632,400:1200])/2 + frame_out_middle_middle_down=(frame_out_middle_middle[:,:,1768:2432,400:1200]+frame_out_down_middle[:,:,0:664,400:1200])/2 + + frame_out_middle = (torch.cat([frame_out_up_middle[:, :, 0:1800, 400:1200].permute(0, 2, 3, 1),frame_out_middle_up_middle.permute(0, 2, 3, 1), frame_out_middle_middle[:, :, 632:1768, 400:1200].permute(0, 2, 3, 1), frame_out_middle_middle_down.permute(0, 2, 3, 1), frame_out_down_middle[:, :, 664:, 400:1200].permute(0, 2, 3, 1)],1)) + + + + frame_out_middleright_up_middle=(frame_out_up_middle_right[:,:,1800:2432,:]+frame_out_middle_middle_right[:,:,0:632,:])/2 + frame_out_middleright_middle_down=(frame_out_middle_middle_right[:,:,1768:2432,:]+frame_out_down_middle_right[:,:,0:664,:])/2 + + frame_out_middleright = (torch.cat([frame_out_up_middle_right[:, :, 0:1800, :].permute(0, 2, 3, 1),frame_out_middleright_up_middle.permute(0, 2, 3, 1), frame_out_middle_middle_right[:, :, 632:1768, :].permute(0, 2, 3, 1), frame_out_middleright_middle_down.permute(0, 2, 3, 1), frame_out_down_middle_right[:, :, 664:, :].permute(0, 2, 3, 1)],1)) + + + frame_out_right_up_middle=(frame_out_up_right[:,:,1800:2432,400:]+frame_out_middle_right[:,:,0:632,400:])/2 + frame_out_right_middle_down=(frame_out_middle_right[:,:,1768:2432,400:]+frame_out_down_right[:,:,0:664,400:])/2 + + frame_out_right = (torch.cat([frame_out_up_right[:, :, 0:1800, 400:].permute(0, 2, 3, 1),frame_out_right_up_middle.permute(0, 2, 3, 1), frame_out_middle_right[:, :, 632:1768, 400:].permute(0, 2, 3, 1), frame_out_right_middle_down.permute(0, 2, 3, 1), frame_out_down_right[:, :, 664:, 400:].permute(0, 2, 3, 1)],1)) + + + + frame_out=torch.cat([frame_out_left, frame_out_leftmiddle, frame_out_middle, frame_out_middleright, frame_out_right],2).permute(0, 3, 1, 2) + + frame_out=frame_out.to(device) + + fourth_channel=torch.ones([frame_out.shape[0],1,frame_out.shape[2],frame_out.shape[3]],device='cuda:0') + frame_out_rgba=torch.cat([frame_out,fourth_channel],1) + #print(frame_out_rgba.shape) + + name= re.findall("\d+",str(name)) + imwrite(frame_out_rgba, output_dir + '/' +str(name[0])+'.png', range=(0, 1)) + +test_time = time.time() - start_time +print(test_time) + + + + + + + + + + + + diff --git a/DH-AISP/2/vgg_loss.py b/DH-AISP/2/vgg_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..02d144ac77b4d83e7501da6589931e14e468abcb --- /dev/null +++ b/DH-AISP/2/vgg_loss.py @@ -0,0 +1,217 @@ +import torch +import torch.nn as nn +from torchvision.models import vgg19, vgg16 +import os.path as osp + + +#-------------------------------------------------------------------------------------------------------------------# +#-------------------------------------------------------------------------------------------------------------------# +# **** VGG损失模块 +class VGG_LOSS(nn.Module): + def __init__(self, model_type='vgg19', layer_names=('conv_1_1', 'conv_2_1'), loss_type='l1'): + super(VGG_LOSS, self).__init__() + # **** 加载vgg模型 + mdir = osp.dirname(osp.realpath(__file__)) + if model_type== 'vgg16': + vgg_model = vgg16(pretrained=False) + pre_trained = torch.load('../vgg16-397923af.pth') + vgg_model.cuda() + vgg_model.load_state_dict(pre_trained) + elif model_type== 'vgg19': + vgg_model = vgg19(pretrained=False) + pre_trained = torch.load('../vgg19-dcbb9e9d.pth') + vgg_model.cuda() + vgg_model.load_state_dict(pre_trained) + + # **** 层名称及层编号 + self.layer_names = get_layer_name_id(model_type, layer_names) + self.layer_ids = inverse_dict(self.layer_names) + self.lid_list = list(self.layer_names.values()) + self.lname_input = 'input' if ('input' in layer_names) else None + + # **** 截断模型 + lid_max = max(self.lid_list) + self.network = vgg_model.features[:lid_max + 1] + + # **** 输入图像正则化层 + self.mean_shift = MeanShift() + + # **** vgg特征损失函数 + loss_fun = nn.L1Loss() + if loss_type == 'l1': + loss_fun = nn.L1Loss() + elif loss_type == 'l2': + loss_fun = nn.MSELoss() + else: + pass + self.loss_fun = loss_fun + + # **** 固定参数 + self.set_not_requires_grad() + return + + def forward(self, img_gt, img_infer, img_range=(-1.0, 1.0)): + ''' + 计算vgg损失 + ''' + feas_gt = self.get_feas(img_gt, img_range) + feas_infer = self.get_feas(img_infer, img_range) + + loss_total = 0 + for lname, gt in feas_gt.items(): + infer = feas_infer[lname] + loss_tmp = self.loss_fun(gt, infer) + loss_total = loss_total + loss_tmp + return loss_total + + def get_feas(self, xx, in_range): + ''' + 获取中间特征 + ''' + # **** 调整输入 + xx = reset_range(xx, in_range) + xx = self.mean_shift(xx) + + # **** 获取中间特征 + out_feas = dict() + if self.lname_input is not None: + inname = self.lname_input + out_feas[inname] = xx.clone() + for lid, layer in enumerate(self.network): + xx = layer(xx) + if lid in self.lid_list: + layer_name = self.layer_ids[lid] + out_feas[layer_name] = xx.clone() + return out_feas + + def set_not_requires_grad(self): + for para in self.parameters(): + para.requires_grad = False + self.eval() + return + +def reset_range(indata, in_range): + ''' + 将数据范围调整为0~1 + ''' + minv, maxv = in_range + midv = 1.0 / (maxv - minv) + return (indata - minv) * midv + +def get_layer_name_id(vgg_type, lnames): + ''' + 根据层名称获取层编号 + ''' + out_dict = dict() + layer_id_dict = vgg_all_layers(vgg_type) + for lname in lnames: + lid = layer_id_dict[lname] + out_dict[lname] = lid + return out_dict + +def vgg_all_layers(vgg_type): + ''' + 获取vgg中间层名称及层号 + ''' + vgg_layer_vgg19 = { + 'conv_1_1': 0, 'conv_1_2': 2, 'pool_1': 4, + 'conv_2_1': 5, 'conv_2_2': 7, 'pool_2': 9, + 'conv_3_1': 10, 'conv_3_2': 12, 'conv_3_3': 14, 'conv_3_4': 16, 'pool_3': 18, + 'conv_4_1': 19, 'conv_4_2': 21, 'conv_4_3': 23, 'conv_4_4': 25, 'pool_4': 27, + 'conv_5_1': 28, 'conv_5_2': 30, 'conv_5_3': 32, 'conv_5_4': 34, 'pool_5': 36 + } + vgg_layer_vgg16 = { + 'conv_1_1': 0, 'conv_1_2': 2, 'pool_1': 4, + 'conv_2_1': 5, 'conv_2_2': 7, 'pool_2': 9, + 'conv_3_1': 10, 'conv_3_2': 12, 'conv_3_3': 14, 'pool_3': 16, + 'conv_4_1': 17, 'conv_4_2': 19, 'conv_4_3': 21, 'pool_4': 23, + 'conv_5_1': 24, 'conv_5_2': 26, 'conv_5_3': 28, 'pool_5': 30 + } + + if vgg_type=='vgg16': + vgg_layer_dict = vgg_layer_vgg16 + elif vgg_type=='vgg19': + vgg_layer_dict = vgg_layer_vgg19 + else: + raise ValueError('Vgg network type should be either vgg16 or vgg19.') + + vgg_fea_dict = {} + for lname, lindex in vgg_layer_dict.items(): + vgg_fea_dict[lname] = lindex + if 'conv' in lname: + lname_relu = lname.replace('conv', 'relu') + lindex_relu = lindex + 1 + vgg_fea_dict[lname_relu] = lindex_relu + return vgg_fea_dict + +def inverse_dict(inp_dict): + ''' + 交换字典的键值及键 + ''' + out_dict = dict() + for key, val in inp_dict.items(): + out_dict[val] = key + return out_dict + +class MeanShift(nn.Conv2d): + ''' + 固定参数卷积层,用于将普通RGB图像(范围0~1)转换为VGG输入格式 + ''' + def __init__(self, rgb_mean=(0.485, 0.456, 0.406), rgb_std=(0.229, 0.224, 0.225)): + super(MeanShift, self).__init__(in_channels=3, out_channels=3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = torch.Tensor(rgb_mean) / std + return + + + +#-------------------------------------------------------------------------------------------------------------------# +#-------------------------------------------------------------------------------------------------------------------# +if __name__ == '__main__': + import torchvision.utils as tv_utils + import cv2 + import numpy as np + import os + import os.path as osp + + + # **** 参数 + inp = r'D:\tmp\test\baboon.png' + + img_in = cv2.imread(inp) + img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2RGB).transpose(2,0,1) + # img_in = (np.float32(img_in) - 127.5) / 127.5 + img_in = (np.float32(img_in) - 0.0) / 255.0 + img_in_t = torch.from_numpy(img_in).unsqueeze(0) + + # **** 测试 + layer_names = ('conv_1_1', 'conv_2_1', 'conv_3_1', 'conv_4_1') + vgg_test = VGG_LOSS(layer_names=layer_names) + + # # **** 检查参数是否已冻结 + # for lname, paras in vgg_test.named_parameters(): + # print(lname, paras.requires_grad) + + # **** 获取特征 + feas = vgg_test.get_feas(img_in_t, in_range=(0.0, 1.0)) + + # **** 保存特征为图像 + def vgg_fea2img(vgg_fea): + mid_feas = vgg_fea.data + mid_feas = torch.transpose(mid_feas, 0, 1) + fea_nrow = round((mid_feas.shape[0]) ** 0.5) + fea_grid = tv_utils.make_grid(mid_feas, nrow=fea_nrow, normalize=True, scale_each=True) + fea_grid = fea_grid.cpu().float().numpy().transpose((1, 2, 0)) + fea_grid = (fea_grid * 255.0).round().clip(0, 255).astype(np.uint8) + return fea_grid + + out_dir = osp.splitext(inp)[0] + if not osp.exists(out_dir): + os.mkdir(out_dir) + bind = 0 + for layer_name, mid_feas in feas.items(): + out_fea = vgg_fea2img(mid_feas[bind:(bind + 1)]) + # out_fea = cv2.applyColorMap(out_fea[:, :, 0], cv2.COLORMAP_JET) + out_path = osp.join(out_dir, '{}_{}.png'.format(bind, layer_name)) + cv2.imwrite(out_path, out_fea) \ No newline at end of file diff --git a/DH-AISP/Dockerfile b/DH-AISP/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f6b96f0a4edbe9736bca887a699809a1c77c1e67 --- /dev/null +++ b/DH-AISP/Dockerfile @@ -0,0 +1,10 @@ +FROM tensorflow/tensorflow:2.4.3-gpu + +RUN apt-get update && apt-get install -y \ + libsm6 libxext6 libxrender-dev + +COPY requirements.txt . +RUN python -m pip install --no-cache -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html + +COPY . /nightimaging +WORKDIR /nightimaging \ No newline at end of file diff --git a/DH-AISP/requirements.txt b/DH-AISP/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..833d80b29509ae2fd95a375ef8597f631e517999 --- /dev/null +++ b/DH-AISP/requirements.txt @@ -0,0 +1,59 @@ +absl-py==0.13.0 +asn1crypto==0.24.0 +astunparse==1.6.3 +cachetools==4.2.2 +certifi==2021.5.30 +charset-normalizer==2.0.4 +cryptography==2.1.4 +dataclasses==0.8 +flatbuffers==1.12 +future==0.18.3 +gast==0.3.3 +google-auth==1.34.0 +google-auth-oauthlib==0.4.5 +google-pasta==0.2.0 +grpcio==1.32.0 +h5py==2.10.0 +idna==2.6 +importlib-metadata==4.6.3 +Keras-Preprocessing==1.1.2 +keyring==10.6.0 +keyrings.alt==3.0 +Markdown==3.3.4 +numpy==1.19.5 +oauthlib==3.1.1 +opencv-contrib-python==4.2.0.34 +opencv-python==4.2.0.34 +opt-einsum==3.3.0 +Pillow==8.4.0 +pip==20.2.4 +protobuf==3.17.3 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycrypto==2.6.1 +pygobject==3.26.1 +python-apt==1.6.5+ubuntu0.6 +pyxdg==0.25 +requests==2.26.0 +requests-oauthlib==1.3.0 +rsa==4.7.2 +scipy==1.5.4 +SecretStorage==2.3.1 +setuptools==57.4.0 +six==1.15.0 +termcolor==1.1.0 +tf-slim==1.1.0 +torch==1.7.0+cu110 +torchvision==0.8.0 +typing-extensions==3.7.4.3 +urllib3==1.26.6 +Werkzeug==2.0.1 +wheel==0.37.0 +wrapt==1.12.1 +zipp==3.5.0 +timm==0.9.12 +kornia==0.5.8 +pytorch_lightning==1.5.0 +six==1.16.0 +scipy==1.7.3 +matplotlib==3.5.3 \ No newline at end of file diff --git a/DH-AISP/run.sh b/DH-AISP/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..82c9250de33ecb35f5e2493ff13205f2112e4856 --- /dev/null +++ b/DH-AISP/run.sh @@ -0,0 +1,6 @@ + +cd ./1 +python tensorflow2to1_3_unet_bining3_7.py + +cd ../2 +python test.py