diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e53bdd888d3221c0b317e64532d9a40078391bc6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/demo.gif filter=lfs diff=lfs merge=lfs -text +assets/visualizations2.png filter=lfs diff=lfs merge=lfs -text +demo/demo_6k_composite.jpg filter=lfs diff=lfs merge=lfs -text +demo/demo_6k_real.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/assets/demo.gif b/assets/demo.gif new file mode 100644 index 0000000000000000000000000000000000000000..86b3d14c40233d09211296def0fa1730358ab6e3 --- /dev/null +++ b/assets/demo.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5f136d5335252050ca723e0360a767ebc5d94fd87d6d372221575769d6528a7 +size 1727946 diff --git a/assets/metrics.png b/assets/metrics.png new file mode 100644 index 0000000000000000000000000000000000000000..a582340ec3adeaf2238d29ac75dfe4379f10455e Binary files /dev/null and b/assets/metrics.png differ diff --git a/assets/network.png b/assets/network.png new file mode 100644 index 0000000000000000000000000000000000000000..33e0ce848c8c80259f21179b28cee41c160c4f91 Binary files /dev/null and b/assets/network.png differ diff --git a/assets/title_any_image.gif b/assets/title_any_image.gif new file mode 100644 index 0000000000000000000000000000000000000000..811cdca8c592f3e818c2e12ede9a416cb2ed9f0b Binary files /dev/null and b/assets/title_any_image.gif differ diff --git a/assets/title_harmon.gif b/assets/title_harmon.gif new file mode 100644 index 0000000000000000000000000000000000000000..dfd60802a20933714ee26ff6fefbbc996a120e9b Binary files /dev/null and b/assets/title_harmon.gif differ diff --git a/assets/title_you_want.gif b/assets/title_you_want.gif new file mode 100644 index 0000000000000000000000000000000000000000..26d5d9f036680a1cc2d9ab7c62959272f555c055 Binary files /dev/null and b/assets/title_you_want.gif differ diff --git a/assets/visualizations.png b/assets/visualizations.png new file mode 100644 index 0000000000000000000000000000000000000000..93fbd1362704fea79b06262da32d598e9b848f45 Binary files /dev/null and b/assets/visualizations.png differ diff --git a/assets/visualizations2.png b/assets/visualizations2.png new file mode 100644 index 0000000000000000000000000000000000000000..51a4c8467cce746fb33331d9e0ee901dd9a6e267 --- /dev/null +++ b/assets/visualizations2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fa5f4c202818ab94d6faf57055a323285e169a33ccfd59200bc93a8d597a4a4 +size 1673273 diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/datasets/__pycache__/__init__.cpython-38.pyc b/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4f5b0f1f3257f9cfd01dbb991a66e8fa81a79e2 Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/datasets/__pycache__/build_INR_dataset.cpython-38.pyc b/datasets/__pycache__/build_INR_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42f28c9910a89a72dd3a7ff22f738bb5046f1ece Binary files /dev/null and b/datasets/__pycache__/build_INR_dataset.cpython-38.pyc differ diff --git a/datasets/__pycache__/build_dataset.cpython-38.pyc b/datasets/__pycache__/build_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4a9eaa5efacd4a95dc96956517b05136080c17a Binary files /dev/null and b/datasets/__pycache__/build_dataset.cpython-38.pyc differ diff --git a/datasets/build_INR_dataset.py b/datasets/build_INR_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..141384f87bee9e4e4741dc87e6297046e05d9fe7 --- /dev/null +++ b/datasets/build_INR_dataset.py @@ -0,0 +1,36 @@ +from utils import misc +from albumentations import Resize + + +class Implicit2DGenerator(object): + def __init__(self, opt, mode): + if mode == 'Train': + sidelength = opt.INR_input_size + elif mode == 'Val': + sidelength = opt.input_size + else: + raise NotImplementedError + + self.mode = mode + + self.size = sidelength + + if isinstance(sidelength, int): + sidelength = (sidelength, sidelength) + + self.mgrid = misc.get_mgrid(sidelength) + + self.transform = Resize(self.size, self.size) + + def generator(self, torch_transforms, composite_image, real_image, mask): + composite_image = torch_transforms(self.transform(image=composite_image)['image']) + real_image = torch_transforms(self.transform(image=real_image)['image']) + + fg_INR_RGB = composite_image.permute(1, 2, 0).contiguous().view(-1, 3) + fg_transfer_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3) + bg_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3) + + fg_INR_coordinates = self.mgrid + bg_INR_coordinates = self.mgrid + + return fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB diff --git a/datasets/build_dataset.py b/datasets/build_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..50414948efae73ac05c1343235417a32260bf265 --- /dev/null +++ b/datasets/build_dataset.py @@ -0,0 +1,371 @@ +import torch +import cv2 +import numpy as np +import torchvision +import os +import random + +from utils.misc import prepare_cooridinate_input, customRandomCrop + +from datasets.build_INR_dataset import Implicit2DGenerator +import albumentations +from albumentations import Resize, RandomResizedCrop, HorizontalFlip +from torch.utils.data import DataLoader + + +class dataset_generator(torch.utils.data.Dataset): + def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'): + super().__init__() + + self.opt = opt + self.root_path = opt.dataset_path + self.mode = mode + + self.alb_transforms = alb_transforms + self.torch_transforms = torch_transforms + self.kp_t = area_keep_thresh + + with open(dataset_txt, 'r') as f: + self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()] + + self.INR_dataset = Implicit2DGenerator(opt, self.mode) + + def __len__(self): + return len(self.dataset_samples) + + def __getitem__(self, idx): + composite_image = self.dataset_samples[idx] + + if self.opt.hr_train: + if self.opt.isFullRes: + "Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \ + "quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \ + "if `opt.isFullRes` is set to True." + composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori") + + real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg' + mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png' + + composite_image = cv2.imread(composite_image) + composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB) + + real_image = cv2.imread(real_image) + real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB) + + mask = cv2.imread(mask) + mask = mask[:, :, 0].astype(np.float32) / 255. + + """ + If set `opt.hr_train` to True: + + Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres), + the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size` + patch to feed in multiINR process. For inference, just resize it. + + While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size. + + BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5` + """ + if self.opt.hr_train: + if self.mode == 'Train' and self.opt.isFullRes: + if random.random() < 0.5: # LR mix training + mixTransform = albumentations.Compose( + [ + RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)), + HorizontalFlip()], + additional_targets={'real_image': 'image', 'object_mask': 'image'} + ) + origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) + origin_bg_ratio = 1 - origin_fg_ratio + + "Ensure fg and bg not disappear after transformation" + valid_augmentation = False + transform_out = None + time = 0 + while not valid_augmentation: + time += 1 + # There are some extreme ratio pics, this code is to avoid being hindered by them. + if time == 20: + tmp_transform = albumentations.Compose( + [Resize(self.opt.base_size, self.opt.base_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = True + else: + transform_out = mixTransform(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = check_augmented_sample(transform_out['object_mask'], + origin_fg_ratio, + origin_bg_ratio, + self.kp_t) + composite_image = transform_out['image'] + real_image = transform_out['real_image'] + mask = transform_out['object_mask'] + else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop. + if real_image.shape[0] < 256: + bottom_pad = 256 - real_image.shape[0] + else: + bottom_pad = (4 - real_image.shape[0] % 4) % 4 + if real_image.shape[1] < 256: + right_pad = 256 - real_image.shape[1] + else: + right_pad = (4 - real_image.shape[1] % 4) % 4 + composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad, + cv2.BORDER_REPLICATE) + real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) + mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) + + origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) + origin_bg_ratio = 1 - origin_fg_ratio + + "Ensure fg and bg not disappear after transformation" + valid_augmentation = False + transform_out = None + time = 0 + + if self.opt.hr_train: + if self.mode == 'Train': + if not self.opt.isFullRes: + if random.random() < 0.5: # LR mix training + mixTransform = albumentations.Compose( + [ + RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)), + HorizontalFlip()], + additional_targets={'real_image': 'image', 'object_mask': 'image'} + ) + while not valid_augmentation: + time += 1 + # There are some extreme ratio pics, this code is to avoid being hindered by them. + if time == 20: + tmp_transform = albumentations.Compose( + [Resize(self.opt.base_size, self.opt.base_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = True + else: + transform_out = mixTransform(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = check_augmented_sample(transform_out['object_mask'], + origin_fg_ratio, + origin_bg_ratio, + self.kp_t) + else: + while not valid_augmentation: + time += 1 + # There are some extreme ratio pics, this code is to avoid being hindered by them. + if time == 20: + tmp_transform = albumentations.Compose( + [Resize(self.opt.input_size, self.opt.input_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = True + else: + transform_out = self.alb_transforms(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = check_augmented_sample(transform_out['object_mask'], + origin_fg_ratio, + origin_bg_ratio, + self.kp_t) + composite_image = transform_out['image'] + real_image = transform_out['real_image'] + mask = transform_out['object_mask'] + + origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) + + full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0) + + tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) + compos_list = [self.torch_transforms(transform_out['image'])] + real_list = [self.torch_transforms(transform_out['real_image'])] + mask_list = [ + torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))] + coord_map_list = [] + + valid_augmentation = False + while not valid_augmentation: + # RSC strategy. To crop different resolutions. + transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord], + self.opt.base_size, self.opt.base_size) + valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio) + + compos_list.append(self.torch_transforms(transform_out[0])) + real_list.append(self.torch_transforms(transform_out[1])) + mask_list.append( + torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32))) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) + for n in range(2): + tmp_comp = cv2.resize(composite_image, ( + composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1))) + tmp_real = cv2.resize(real_image, + (real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1))) + tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1))) + tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0) + + transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord], + self.opt.base_size // 2 ** (n + 1), + self.opt.base_size // 2 ** (n + 1), c_h, c_w) + compos_list.append(self.torch_transforms(transform_out[0])) + real_list.append(self.torch_transforms(transform_out[1])) + mask_list.append( + torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32))) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) + out_comp = compos_list + out_real = real_list + out_mask = mask_list + out_coord = coord_map_list + + fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( + self.torch_transforms, transform_out[0], transform_out[1], mask) + + return { + 'file_path': self.dataset_samples[idx], + 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], + 'composite_image': out_comp, + 'real_image': out_real, + 'mask': out_mask, + 'coordinate_map': out_coord, + 'composite_image0': out_comp[0], + 'real_image0': out_real[0], + 'mask0': out_mask[0], + 'coordinate_map0': out_coord[0], + 'composite_image1': out_comp[1], + 'real_image1': out_real[1], + 'mask1': out_mask[1], + 'coordinate_map1': out_coord[1], + 'composite_image2': out_comp[2], + 'real_image2': out_real[2], + 'mask2': out_mask[2], + 'coordinate_map2': out_coord[2], + 'composite_image3': out_comp[3], + 'real_image3': out_real[3], + 'mask3': out_mask[3], + 'coordinate_map3': out_coord[3], + 'fg_INR_coordinates': fg_INR_coordinates, + 'bg_INR_coordinates': bg_INR_coordinates, + 'fg_INR_RGB': fg_INR_RGB, + 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, + 'bg_INR_RGB': bg_INR_RGB + } + else: + if not self.opt.isFullRes: + tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) + + coordinate_map = prepare_cooridinate_input(transform_out['object_mask']) + + "Generate INR dataset." + mask = (torchvision.transforms.ToTensor()( + transform_out['object_mask']).squeeze() > 100 / 255.).view(-1) + mask = np.bool_(mask.numpy()) + + fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( + self.torch_transforms, transform_out['image'], transform_out['real_image'], mask) + + return { + 'file_path': self.dataset_samples[idx], + 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], + 'composite_image': self.torch_transforms(transform_out['image']), + 'real_image': self.torch_transforms(transform_out['real_image']), + 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32), + # Can automatically transfer to Tensor. + 'coordinate_map': coordinate_map, + 'fg_INR_coordinates': fg_INR_coordinates, + 'bg_INR_coordinates': bg_INR_coordinates, + 'fg_INR_RGB': fg_INR_RGB, + 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, + 'bg_INR_RGB': bg_INR_RGB + } + else: + coordinate_map = prepare_cooridinate_input(mask) + + "Generate INR dataset." + mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1) + mask_tmp = np.bool_(mask_tmp.numpy()) + + fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( + self.torch_transforms, composite_image, real_image, mask_tmp) + + return { + 'file_path': self.dataset_samples[idx], + 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], + 'composite_image': self.torch_transforms(composite_image), + 'real_image': self.torch_transforms(real_image), + 'mask': mask[np.newaxis, ...].astype(np.float32), + # Can automatically transfer to Tensor. + 'coordinate_map': coordinate_map, + 'fg_INR_coordinates': fg_INR_coordinates, + 'bg_INR_coordinates': bg_INR_coordinates, + 'fg_INR_RGB': fg_INR_RGB, + 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, + 'bg_INR_RGB': bg_INR_RGB + } + + while not valid_augmentation: + time += 1 + # There are some extreme ratio pics, this code is to avoid being hindered by them. + if time == 20: + tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) + valid_augmentation = True + else: + transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask) + valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio, + origin_bg_ratio, + self.kp_t) + + coordinate_map = prepare_cooridinate_input(transform_out['object_mask']) + + "Generate INR dataset." + mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1) + mask = np.bool_(mask.numpy()) + + fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( + self.torch_transforms, transform_out['image'], transform_out['real_image'], mask) + + return { + 'file_path': self.dataset_samples[idx], + 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], + 'composite_image': self.torch_transforms(transform_out['image']), + 'real_image': self.torch_transforms(transform_out['real_image']), + 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32), + # Can automatically transfer to Tensor. + 'coordinate_map': coordinate_map, + 'fg_INR_coordinates': fg_INR_coordinates, + 'bg_INR_coordinates': bg_INR_coordinates, + 'fg_INR_RGB': fg_INR_RGB, + 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, + 'bg_INR_RGB': bg_INR_RGB + } + + +def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh): + current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) + current_bg_ratio = 1 - current_fg_ratio + + if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh: + return False + + return True + + +def check_hr_crop_sample(mask, origin_fg_ratio): + current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) + + if current_fg_ratio < 0.8 * origin_fg_ratio: + return False + + return True diff --git a/demo/demo_2k_composite.jpg b/demo/demo_2k_composite.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b6ede304a1523cd1ad6919375cc6a32e642b8202 Binary files /dev/null and b/demo/demo_2k_composite.jpg differ diff --git a/demo/demo_2k_mask.jpg b/demo/demo_2k_mask.jpg new file mode 100644 index 0000000000000000000000000000000000000000..99815f78d7db6e84a54dab16b60c426e64562021 Binary files /dev/null and b/demo/demo_2k_mask.jpg differ diff --git a/demo/demo_2k_real.jpg b/demo/demo_2k_real.jpg new file mode 100644 index 0000000000000000000000000000000000000000..12237a6c5417a23f1452ffde80923d306a40a0bc Binary files /dev/null and b/demo/demo_2k_real.jpg differ diff --git a/demo/demo_6k_composite.jpg b/demo/demo_6k_composite.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6151dc8a077304e0e143e9776a5740a9bddc46b9 --- /dev/null +++ b/demo/demo_6k_composite.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:910f8a9787c7b2dd739c89a56f2cd64fa67be9a257ea17963b656f63e1ad2250 +size 5882481 diff --git a/demo/demo_6k_mask.jpg b/demo/demo_6k_mask.jpg new file mode 100644 index 0000000000000000000000000000000000000000..adb3ff591464ff42f3a72cf8450e9d31826305e4 Binary files /dev/null and b/demo/demo_6k_mask.jpg differ diff --git a/demo/demo_6k_real.jpg b/demo/demo_6k_real.jpg new file mode 100644 index 0000000000000000000000000000000000000000..809420164eff576bf1718eee4d43f05c14bca70f --- /dev/null +++ b/demo/demo_6k_real.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5dd69a2a79388378e43079a0ca0dbb7e3e9c86822526083d54f315a3f1a48647 +size 6096010 diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/__pycache__/__init__.cpython-38.pyc b/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc8ec4d248a4282af04444034f8ea6850e79b1f6 Binary files /dev/null and b/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/model/__pycache__/backbone.cpython-38.pyc b/model/__pycache__/backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..380c365df4bdc42fcf51afb93b013fbb626cc66d Binary files /dev/null and b/model/__pycache__/backbone.cpython-38.pyc differ diff --git a/model/__pycache__/build_model.cpython-38.pyc b/model/__pycache__/build_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a4cd6d612c0e0b17084504727c3bf67a404a570 Binary files /dev/null and b/model/__pycache__/build_model.cpython-38.pyc differ diff --git a/model/__pycache__/lut_transformation_net.cpython-38.pyc b/model/__pycache__/lut_transformation_net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dc0588f62d2dbfa0110ff66bed1f04120f8ebab Binary files /dev/null and b/model/__pycache__/lut_transformation_net.cpython-38.pyc differ diff --git a/model/backbone.py b/model/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..738e112d404148241126c2fda769d883752431dd --- /dev/null +++ b/model/backbone.py @@ -0,0 +1,79 @@ +import torch.nn as nn + +from .hrnetv2.hrnet_ocr import HighResolutionNet +from .hrnetv2.modifiers import LRMult +from .base.basic_blocks import MaxPoolDownSize +from .base.ih_model import IHModelWithBackbone, DeepImageHarmonization + + +def build_backbone(name, opt): + return eval(name)(opt) + + +class baseline(IHModelWithBackbone): + def __init__(self, opt, ocr=64): + base_config = {'model': DeepImageHarmonization, + 'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'opt': opt}} + + params = base_config['params'] + + backbone = HRNetV2(opt, ocr=ocr) + + params.update(dict( + backbone_from=2, + backbone_channels=backbone.output_channels, + backbone_mode='cat', + opt=opt + )) + base_model = base_config['model'](**params) + + super(baseline, self).__init__(base_model, backbone, False, 'sum', opt=opt) + + +class HRNetV2(nn.Module): + def __init__( + self, opt, + cat_outputs=True, + pyramid_channels=-1, pyramid_depth=4, + width=18, ocr=128, small=False, + lr_mult=0.1, pretained=True + ): + super(HRNetV2, self).__init__() + self.opt = opt + self.cat_outputs = cat_outputs + self.ocr_on = ocr > 0 and cat_outputs + self.pyramid_on = pyramid_channels > 0 and cat_outputs + + self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small, opt=opt) + self.hrnet.apply(LRMult(lr_mult)) + if self.ocr_on: + self.hrnet.ocr_distri_head.apply(LRMult(1.0)) + self.hrnet.ocr_gather_head.apply(LRMult(1.0)) + self.hrnet.conv3x3_ocr.apply(LRMult(1.0)) + + hrnet_cat_channels = [width * 2 ** i for i in range(4)] + if self.pyramid_on: + self.output_channels = [pyramid_channels] * 4 + elif self.ocr_on: + self.output_channels = [ocr * 2] + elif self.cat_outputs: + self.output_channels = [sum(hrnet_cat_channels)] + else: + self.output_channels = hrnet_cat_channels + + if self.pyramid_on: + downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels) + self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth) + + if pretained: + self.load_pretrained_weights( + r".\pretrained_models/hrnetv2_w18_imagenet_pretrained.pth") + + self.output_resolution = (opt.input_size // 8) ** 2 + + def forward(self, image, mask, mask_features=None): + outputs = list(self.hrnet(image, mask, mask_features)) + return outputs + + def load_pretrained_weights(self, pretrained_path): + self.hrnet.load_pretrained_weights(pretrained_path) diff --git a/model/base/__init__.py b/model/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/base/__pycache__/__init__.cpython-38.pyc b/model/base/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1edd3a427e7fb4dfb560338517d6b4454a832bc1 Binary files /dev/null and b/model/base/__pycache__/__init__.cpython-38.pyc differ diff --git a/model/base/__pycache__/basic_blocks.cpython-38.pyc b/model/base/__pycache__/basic_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34ceac858fe7708f5238dc3cbdf30aeba319516d Binary files /dev/null and b/model/base/__pycache__/basic_blocks.cpython-38.pyc differ diff --git a/model/base/__pycache__/conv_autoencoder.cpython-38.pyc b/model/base/__pycache__/conv_autoencoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..915125f790d83cad0df4cfad11242272c0ad282f Binary files /dev/null and b/model/base/__pycache__/conv_autoencoder.cpython-38.pyc differ diff --git a/model/base/__pycache__/ih_model.cpython-38.pyc b/model/base/__pycache__/ih_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b3f9dca654370a12523c4c85faa9d54a0308c51 Binary files /dev/null and b/model/base/__pycache__/ih_model.cpython-38.pyc differ diff --git a/model/base/__pycache__/ops.cpython-38.pyc b/model/base/__pycache__/ops.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcb34b7c91dd66fb460436b3493c76afb4d19b2b Binary files /dev/null and b/model/base/__pycache__/ops.cpython-38.pyc differ diff --git a/model/base/basic_blocks.py b/model/base/basic_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..cf62bec8deb6f23d23b94aaec27448adb4d03672 --- /dev/null +++ b/model/base/basic_blocks.py @@ -0,0 +1,366 @@ +import torch +from torch import nn as nn +import numpy as np + + +def hyper_weight_init(m, in_features_main_net, activation): + if hasattr(m, 'weight'): + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') + m.weight.data = m.weight.data / 1.e2 + + if hasattr(m, 'bias'): + with torch.no_grad(): + if activation == 'sine': + m.bias.uniform_(-np.sqrt(6 / in_features_main_net) / 30, np.sqrt(6 / in_features_main_net) / 30) + elif activation == 'leakyrelu_pe': + m.bias.uniform_(-np.sqrt(6 / in_features_main_net), np.sqrt(6 / in_features_main_net)) + else: + raise NotImplementedError + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels, out_channels, + kernel_size=4, stride=2, padding=1, + norm_layer=nn.BatchNorm2d, activation=nn.ELU, + bias=True, + ): + super(ConvBlock, self).__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), + norm_layer(out_channels) if norm_layer is not None else nn.Identity(), + activation(), + ) + + def forward(self, x): + return self.block(x) + + +class MaxPoolDownSize(nn.Module): + def __init__(self, in_channels, mid_channels, out_channels, depth): + super(MaxPoolDownSize, self).__init__() + self.depth = depth + self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0) + self.convs = nn.ModuleList([ + ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) + for conv_i in range(depth) + ]) + self.pool2d = nn.MaxPool2d(kernel_size=2) + + def forward(self, x): + outputs = [] + + output = self.reduce_conv(x) + + for conv_i, conv in enumerate(self.convs): + output = output if conv_i == 0 else self.pool2d(output) + outputs.append(conv(output)) + + return outputs + + +class convParams(nn.Module): + def __init__(self, input_dim, INR_in_out, opt, hidden_mlp_num, hidden_dim=512, toRGB=False): + super(convParams, self).__init__() + self.INR_in_out = INR_in_out + self.cont_split_weight = [] + self.cont_split_bias = [] + self.hidden_mlp_num = hidden_mlp_num + self.param_factorize_dim = opt.param_factorize_dim + output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num, toRGB) + self.output_dim = output_dim + self.toRGB = toRGB + self.cont_extraction_net = nn.Sequential( + nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=1, bias=False), + # nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), + # nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, output_dim, kernel_size=1, stride=1, padding=0, bias=True), + ) + + self.cont_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation)) + + self.basic_params = nn.ParameterList() + if opt.param_factorize_dim > 0: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, inp, outp))) + + if toRGB: + self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, self.INR_in_out[1], 3))) + + def forward(self, feat, outMore=False): + cont_params = self.cont_extraction_net(feat) + out_mlp = self.to_mlp(cont_params) + if outMore: + return out_mlp, cont_params + return out_mlp + + def cal_params_num(self, INR_in_out, hidden_mlp_num, toRGB=False): + cont_params = 0 + start = 0 + if self.param_factorize_dim == -1: + cont_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1] + self.cont_split_weight.append([start, cont_params - INR_in_out[1]]) + self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) + start = cont_params + + for id in range(hidden_mlp_num): + cont_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1] + self.cont_split_weight.append([start, cont_params - INR_in_out[1]]) + self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) + start = cont_params + + if toRGB: + cont_params += INR_in_out[1] * 3 + 3 + self.cont_split_weight.append([start, cont_params - 3]) + self.cont_split_bias.append([cont_params - 3, cont_params]) + + elif self.param_factorize_dim > 0: + cont_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ + INR_in_out[1] + self.cont_split_weight.append( + [start, start + INR_in_out[0] * self.param_factorize_dim, cont_params - INR_in_out[1]]) + self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) + start = cont_params + + for id in range(hidden_mlp_num): + cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ + INR_in_out[1] + self.cont_split_weight.append( + [start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - INR_in_out[1]]) + self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) + start = cont_params + + if toRGB: + cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3 + self.cont_split_weight.append( + [start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - 3]) + self.cont_split_bias.append([cont_params - 3, cont_params]) + + return cont_params + + def to_mlp(self, params): + all_weight_bias = [] + if self.param_factorize_dim == -1: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + weight = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :] + weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:], + inp, outp) + + bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :] + bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) + all_weight_bias.append([weight, bias]) + + if self.toRGB: + inp, outp = self.INR_in_out[1], 3 + weight = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :] + weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:], + inp, outp) + + bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :] + bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) + all_weight_bias.append([weight, bias]) + + return all_weight_bias + + else: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + weight1 = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :] + weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:], + inp, self.param_factorize_dim) + + weight2 = params[:, self.cont_split_weight[id][1]:self.cont_split_weight[id][2], :, :] + weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:], + self.param_factorize_dim, outp) + + bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :] + bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) + + all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias]) + + if self.toRGB: + inp, outp = self.INR_in_out[1], 3 + weight1 = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :] + weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:], + inp, self.param_factorize_dim) + + weight2 = params[:, self.cont_split_weight[-1][1]:self.cont_split_weight[-1][2], :, :] + weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:], + self.param_factorize_dim, outp) + + bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :] + bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) + + all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[-1], bias]) + + return all_weight_bias + + +class lineParams(nn.Module): + def __init__(self, input_dim, INR_in_out, input_resolution, opt, hidden_mlp_num, toRGB=False, + hidden_dim=512): + super(lineParams, self).__init__() + self.INR_in_out = INR_in_out + self.app_split_weight = [] + self.app_split_bias = [] + self.toRGB = toRGB + self.hidden_mlp_num = hidden_mlp_num + self.param_factorize_dim = opt.param_factorize_dim + output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num) + self.output_dim = output_dim + + self.compress_layer = nn.Sequential( + nn.Linear(input_resolution, 64, bias=False), + nn.BatchNorm1d(input_dim), + nn.ReLU(inplace=True), + nn.Linear(64, 1, bias=True) + ) + + self.app_extraction_net = nn.Sequential( + nn.Linear(input_dim, hidden_dim, bias=False), + # nn.BatchNorm1d(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim, bias=False), + # nn.BatchNorm1d(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, output_dim, bias=True) + ) + + self.app_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation)) + + self.basic_params = nn.ParameterList() + if opt.param_factorize_dim > 0: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + self.basic_params.append(nn.Parameter(torch.randn(1, inp, outp))) + if toRGB: + self.basic_params.append(nn.Parameter(torch.randn(1, self.INR_in_out[1], 3))) + + def forward(self, feat): + app_params = self.app_extraction_net(self.compress_layer(torch.flatten(feat, 2)).squeeze(-1)) + out_mlp = self.to_mlp(app_params) + return out_mlp, app_params + + def cal_params_num(self, INR_in_out, hidden_mlp_num): + app_params = 0 + start = 0 + if self.param_factorize_dim == -1: + app_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1] + self.app_split_weight.append([start, app_params - INR_in_out[1]]) + self.app_split_bias.append([app_params - INR_in_out[1], app_params]) + start = app_params + + for id in range(hidden_mlp_num): + app_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1] + self.app_split_weight.append([start, app_params - INR_in_out[1]]) + self.app_split_bias.append([app_params - INR_in_out[1], app_params]) + start = app_params + + if self.toRGB: + app_params += INR_in_out[1] * 3 + 3 + self.app_split_weight.append([start, app_params - 3]) + self.app_split_bias.append([app_params - 3, app_params]) + + elif self.param_factorize_dim > 0: + app_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ + INR_in_out[1] + self.app_split_weight.append([start, start + INR_in_out[0] * self.param_factorize_dim, + app_params - INR_in_out[1]]) + self.app_split_bias.append([app_params - INR_in_out[1], app_params]) + start = app_params + + for id in range(hidden_mlp_num): + app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ + INR_in_out[1] + self.app_split_weight.append( + [start, start + INR_in_out[1] * self.param_factorize_dim, app_params - INR_in_out[1]]) + self.app_split_bias.append([app_params - INR_in_out[1], app_params]) + start = app_params + + if self.toRGB: + app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3 + self.app_split_weight.append([start, start + INR_in_out[1] * self.param_factorize_dim, + app_params - 3]) + self.app_split_bias.append([app_params - 3, app_params]) + + return app_params + + def to_mlp(self, params): + all_weight_bias = [] + if self.param_factorize_dim == -1: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] + weight = weight.view(weight.shape[0], inp, outp) + + bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] + bias = bias.view(bias.shape[0], 1, outp) + + all_weight_bias.append([weight, bias]) + + if self.toRGB: + id = -1 + inp, outp = self.INR_in_out[1], 3 + weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] + weight = weight.view(weight.shape[0], inp, outp) + + bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] + bias = bias.view(bias.shape[0], 1, outp) + + all_weight_bias.append([weight, bias]) + + return all_weight_bias + + else: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] + weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim) + + weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]] + weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp) + + bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] + bias = bias.view(bias.shape[0], 1, outp) + + all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias]) + + if self.toRGB: + id = -1 + inp, outp = self.INR_in_out[1], 3 + weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] + weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim) + + weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]] + weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp) + + bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] + bias = bias.view(bias.shape[0], 1, outp) + + all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias]) + + return all_weight_bias diff --git a/model/base/conv_autoencoder.py b/model/base/conv_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..bc30b98941037b4d6dc30ee177f5872daa1327c1 --- /dev/null +++ b/model/base/conv_autoencoder.py @@ -0,0 +1,519 @@ +import torch +import torchvision +from torch import nn as nn +import torch.nn.functional as F +import numpy as np +import math + +from .basic_blocks import ConvBlock, lineParams, convParams +from .ops import MaskedChannelAttention, FeaturesConnector +from .ops import PosEncodingNeRF, INRGAN_embed, RandomFourier, CIPS_embed +from utils import misc +from utils.misc import lin2img +from ..lut_transformation_net import build_lut_transform + + +class Sine(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.sin(30 * input) + + +class Leaky_relu(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.leaky_relu(input, 0.01, inplace=True) + + +def select_activation(type): + if type == 'sine': + return Sine() + elif type == 'leakyrelu_pe': + return Leaky_relu() + else: + raise NotImplementedError + + +class ConvEncoder(nn.Module): + def __init__( + self, + depth, ch, + norm_layer, batchnorm_from, max_channels, + backbone_from, backbone_channels=None, backbone_mode='', INRDecode=False + ): + super(ConvEncoder, self).__init__() + self.depth = depth + self.INRDecode = INRDecode + self.backbone_from = backbone_from + backbone_channels = [] if backbone_channels is None else backbone_channels[::-1] + + in_channels = 4 + out_channels = ch + + self.block0 = ConvBlock(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None) + self.block1 = ConvBlock(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None) + self.blocks_channels = [out_channels, out_channels] + + self.blocks_connected = nn.ModuleDict() + self.connectors = nn.ModuleDict() + for block_i in range(2, depth): + if block_i % 2: + in_channels = out_channels + else: + in_channels, out_channels = out_channels, min(2 * out_channels, max_channels) + + if 0 <= backbone_from <= block_i and len(backbone_channels): + if INRDecode: + self.blocks_connected[f'block{block_i}_decode'] = ConvBlock( + in_channels, out_channels, + norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, + padding=int(block_i < depth - 1) + ) + self.blocks_channels += [out_channels] + stage_channels = backbone_channels.pop() + connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels) + self.connectors[f'connector{block_i}'] = connector + in_channels = connector.output_channels + + self.blocks_connected[f'block{block_i}'] = ConvBlock( + in_channels, out_channels, + norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, + padding=int(block_i < depth - 1) + ) + self.blocks_channels += [out_channels] + + def forward(self, x, backbone_features): + backbone_features = [] if backbone_features is None else backbone_features[::-1] + + outputs = [self.block0(x)] + outputs += [self.block1(outputs[-1])] + + for block_i in range(2, self.depth): + output = outputs[-1] + connector_name = f'connector{block_i}' + if connector_name in self.connectors: + if self.INRDecode: + block = self.blocks_connected[f'block{block_i}_decode'] + outputs += [block(output)] + + stage_features = backbone_features.pop() + connector = self.connectors[connector_name] + output = connector(output, stage_features) + block = self.blocks_connected[f'block{block_i}'] + outputs += [block(output)] + + return outputs[::-1] + + +class DeconvDecoder(nn.Module): + def __init__(self, depth, encoder_blocks_channels, norm_layer, attend_from=-1, image_fusion=False): + super(DeconvDecoder, self).__init__() + self.image_fusion = image_fusion + self.deconv_blocks = nn.ModuleList() + + in_channels = encoder_blocks_channels.pop() + out_channels = in_channels + for d in range(depth): + out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2 + self.deconv_blocks.append(SEDeconvBlock( + in_channels, out_channels, + norm_layer=norm_layer, + padding=0 if d == 0 else 1, + with_se=0 <= attend_from <= d + )) + in_channels = out_channels + + if self.image_fusion: + self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1) + self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1) + + def forward(self, encoder_outputs, image, mask=None): + output = encoder_outputs[0] + for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]): + output = block(output, mask) + output = output + skip_output + output = self.deconv_blocks[-1](output, mask) + + if self.image_fusion: + attention_map = torch.sigmoid(3.0 * self.conv_attention(output)) + output = attention_map * image + (1.0 - attention_map) * self.to_rgb(output) + else: + output = self.to_rgb(output) + + return output + + +class SEDeconvBlock(nn.Module): + def __init__( + self, + in_channels, out_channels, + kernel_size=4, stride=2, padding=1, + norm_layer=nn.BatchNorm2d, activation=nn.ELU, + with_se=False + ): + super(SEDeconvBlock, self).__init__() + self.with_se = with_se + self.block = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding), + norm_layer(out_channels) if norm_layer is not None else nn.Identity(), + activation(), + ) + if self.with_se: + self.se = MaskedChannelAttention(out_channels) + + def forward(self, x, mask=None): + out = self.block(x) + if self.with_se: + out = self.se(out, mask) + return out + + +class INRDecoder(nn.Module): + def __init__(self, depth, encoder_blocks_channels, norm_layer, opt, attend_from): + super(INRDecoder, self).__init__() + self.INR_encoding = None + if opt.embedding_type == "PosEncodingNeRF": + self.INR_encoding = PosEncodingNeRF(in_features=2, sidelength=opt.input_size) + elif opt.embedding_type == "RandomFourier": + self.INR_encoding = RandomFourier(std_scale=10, embedding_length=64, device=opt.device) + elif opt.embedding_type == "CIPS_embed": + self.INR_encoding = CIPS_embed(size=opt.base_size, embedding_length=32) + elif opt.embedding_type == "INRGAN_embed": + self.INR_encoding = INRGAN_embed(resolution=opt.INR_input_size) + else: + raise NotImplementedError + encoder_blocks_channels = encoder_blocks_channels[::-1] + max_hidden_mlp_num = attend_from + 1 + self.opt = opt + self.max_hidden_mlp_num = max_hidden_mlp_num + self.content_mlp_blocks = nn.ModuleDict() + for n in range(max_hidden_mlp_num): + if n != max_hidden_mlp_num - 1: + self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(), + [self.INR_encoding.out_dim + opt.INR_MLP_dim + ( + 4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim], + opt, n + 1) + else: + self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(), + [self.INR_encoding.out_dim + ( + 4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim], + opt, n + 1) + + self.deconv_blocks = nn.ModuleList() + + encoder_blocks_channels = encoder_blocks_channels[::-1] + in_channels = encoder_blocks_channels.pop() + out_channels = in_channels + for d in range(depth - attend_from): + out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2 + self.deconv_blocks.append(SEDeconvBlock( + in_channels, out_channels, + norm_layer=norm_layer, + padding=0 if d == 0 else 1, + with_se=False + )) + in_channels = out_channels + + self.appearance_mlps = lineParams(out_channels, [opt.INR_MLP_dim, opt.INR_MLP_dim], + (opt.base_size // (2 ** (max_hidden_mlp_num - 1))) ** 2, + opt, 2, toRGB=True) + + self.lut_transform = build_lut_transform(self.appearance_mlps.output_dim, opt.LUT_dim, + None, opt) + + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + + def forward(self, encoder_outputs, image=None, mask=None, coord_samples=None, start_proportion=None): + """For full resolution, do split.""" + if self.opt.hr_train and not (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, + 'split_resolution')) and self.opt.isFullRes: + return self.forward_fullResInference(encoder_outputs, image=image, mask=mask, coord_samples=coord_samples) + + encoder_outputs = encoder_outputs[::-1] + mlp_output = None + waitToRGB = [] + for n in range(self.max_hidden_mlp_num): + if not self.opt.hr_train: + coord = misc.get_mgrid(self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))) \ + .unsqueeze(0).repeat(encoder_outputs[0].shape[0], 1, 1).to(self.opt.device) + else: + if self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution'): + coord = coord_samples[self.max_hidden_mlp_num - n - 1].permute(0, 2, 3, 1).view( + encoder_outputs[0].shape[0], -1, 2) + else: + coord = misc.get_mgrid( + self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))).unsqueeze(0).repeat( + encoder_outputs[0].shape[0], 1, 1).to(self.opt.device) + + """Whether to leverage multiple input to INR decoder. See Section 3.4 in the paper.""" + if self.opt.isMoreINRInput: + if not self.opt.isFullRes or ( + self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): + res_h = res_w = np.sqrt(coord.shape[1]).astype(int) + else: + res_h = image.shape[-2] // (2 ** (self.max_hidden_mlp_num - n - 1)) + res_w = image.shape[-1] // (2 ** (self.max_hidden_mlp_num - n - 1)) + + res_image = torchvision.transforms.Resize([res_h, res_w])(image) + res_mask = torchvision.transforms.Resize([res_h, res_w])(mask) + coord = torch.cat([self.INR_encoding(coord), res_image.view(*res_image.shape[:2], -1).permute(0, 2, 1), + res_mask.view(*res_mask.shape[:2], -1).permute(0, 2, 1)], dim=-1) + else: + coord = self.INR_encoding(coord) + + """============ LRIP structure, see Section 3.3 ==============""" + + """Local MLPs.""" + if n == 0: + mlp_output = self.mlp_process(coord, self.INR_encoding.out_dim + (4 if self.opt.isMoreINRInput else 0), + self.opt, content_mlp=self.content_mlp_blocks[ + f"block{self.max_hidden_mlp_num - 1 - n}"]( + encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), start_proportion=start_proportion) + waitToRGB.append(mlp_output[1]) + else: + mlp_output = self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + ( + 4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0], + content_mlp=self.content_mlp_blocks[ + f"block{self.max_hidden_mlp_num - 1 - n}"]( + encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), + start_proportion=start_proportion) + waitToRGB.append(mlp_output[1]) + + encoder_outputs = encoder_outputs[::-1] + output = encoder_outputs[0] + for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]): + output = block(output) + output = output + skip_output + output = self.deconv_blocks[-1](output) + + """Global MLPs.""" + app_mlp, app_params = self.appearance_mlps(output) + harm_out = [] + for id in range(len(waitToRGB)): + output = self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=waitToRGB[id], + appearance_mlp=app_mlp) + harm_out.append(output[0]) + + """Optional 3D LUT prediction.""" + fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None) + + return harm_out, fit_lut3d, lut_transform_image + + def mlp_process(self, coorinates, INR_input_dim, opt, base_feat=None, content_mlp=None, appearance_mlp=None, + resolution=None, start_proportion=None): + + activation = select_activation(opt.activation) + + output = None + + if content_mlp is not None: + if base_feat is not None: + coorinates = torch.cat([coorinates, base_feat], dim=2) + coorinates = lin2img(coorinates, resolution) + + if hasattr(opt, 'split_resolution'): + """ + Here we crop the needed MLPs according to the region of the split input patches. + Note that this only support inferencing square images. + """ + for idx in range(len(content_mlp)): + content_mlp[idx][0] = content_mlp[idx][0][:, + (content_mlp[idx][0].shape[1] * start_proportion[0]).int():( + content_mlp[idx][0].shape[1] * start_proportion[2]).int(), + (content_mlp[idx][0].shape[2] * start_proportion[1]).int():( + content_mlp[idx][0].shape[2] * start_proportion[3]).int(), :, + :] + content_mlp[idx][1] = content_mlp[idx][1][:, + (content_mlp[idx][1].shape[1] * start_proportion[0]).int():( + content_mlp[idx][1].shape[1] * start_proportion[2]).int(), + (content_mlp[idx][1].shape[2] * start_proportion[1]).int():( + content_mlp[idx][1].shape[2] * start_proportion[3]).int(), + :, + :] + k_h = coorinates.shape[2] // content_mlp[0][0].shape[1] + k_w = coorinates.shape[3] // content_mlp[0][0].shape[1] + bs = coorinates.shape[0] + h_lr = w_lr = content_mlp[0][0].shape[1] + nci = INR_input_dim + + coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w) + coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view( + bs, h_lr, w_lr, int(k_h * k_w), nci) + + for id, layer in enumerate(content_mlp): + if id == 0: + output = torch.matmul(coorinates, layer[0]) + layer[1] + output = activation(output) + else: + output = torch.matmul(output, layer[0]) + layer[1] + output = activation(output) + + output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute( + 0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim) + + output_large = self.up(lin2img(output)) + + return output_large.view(bs, -1, opt.INR_MLP_dim), output + + k_h = coorinates.shape[2] // content_mlp[0][0].shape[1] + k_w = coorinates.shape[3] // content_mlp[0][0].shape[1] + bs = coorinates.shape[0] + h_lr = w_lr = content_mlp[0][0].shape[1] + nci = INR_input_dim + + """(evaluation or not HR training) and not fullres evaluation""" + if (not self.opt.hr_train or not (self.training or hasattr(self.opt, 'split_num'))) and not ( + not (self.training or hasattr(self.opt, 'split_num')) and self.opt.isFullRes and self.opt.hr_train): + coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w) + coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view( + bs, h_lr, w_lr, int(k_h * k_w), nci) + + for id, layer in enumerate(content_mlp): + if id == 0: + output = torch.matmul(coorinates, layer[0]) + layer[1] + output = activation(output) + else: + output = torch.matmul(output, layer[0]) + layer[1] + output = activation(output) + + output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute( + 0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim) + + output_large = self.up(lin2img(output)) + + return output_large.view(bs, -1, opt.INR_MLP_dim), output + else: + coorinates = coorinates.permute(0, 2, 3, 1) + for id, layer in enumerate(content_mlp): + weigt_shape = layer[0].shape + bias_shape = layer[1].shape + layer[0] = layer[0].view(*layer[0].shape[:-2], -1).permute(0, 3, 1, 2).contiguous() + layer[1] = layer[1].view(*layer[1].shape[:-2], -1).permute(0, 3, 1, 2).contiguous() + layer[0] = F.grid_sample(layer[0], coorinates[..., :2].flip(-1), mode='nearest' if True + else 'bilinear', padding_mode='border', align_corners=False) + layer[1] = F.grid_sample(layer[1], coorinates[..., :2].flip(-1), mode='nearest' if True + else 'bilinear', padding_mode='border', align_corners=False) + layer[0] = layer[0].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *weigt_shape[-2:]) + layer[1] = layer[1].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *bias_shape[-2:]) + + if id == 0: + output = torch.matmul(coorinates.unsqueeze(-2), layer[0]) + layer[1] + output = activation(output) + else: + output = torch.matmul(output, layer[0]) + layer[1] + output = activation(output) + + output = output.squeeze(-2).view(bs, -1, opt.INR_MLP_dim) + + output_large = self.up(lin2img(output, resolution)) + + return output_large.view(bs, -1, opt.INR_MLP_dim), output + + elif appearance_mlp is not None: + output = base_feat + genMask = None + for id, layer in enumerate(appearance_mlp): + if id != len(appearance_mlp) - 1: + output = torch.matmul(output, layer[0]) + layer[1] + output = activation(output) + else: + output = torch.matmul(output, layer[0]) + layer[1] # last layer + if opt.activation == 'leakyrelu_pe': + output = torch.tanh(output) + return lin2img(output, resolution), None + + def forward_fullResInference(self, encoder_outputs, image=None, mask=None, coord_samples=None): + encoder_outputs = encoder_outputs[::-1] + mlp_output = None + res_w = image.shape[-1] + res_h = image.shape[-2] + coord = misc.get_mgrid([image.shape[-2], image.shape[-1]]).unsqueeze(0).repeat( + encoder_outputs[0].shape[0], 1, 1).to(self.opt.device) + + if self.opt.isMoreINRInput: + coord = torch.cat( + [self.INR_encoding(coord, (res_h, res_w)), image.view(*image.shape[:2], -1).permute(0, 2, 1), + mask.view(*mask.shape[:2], -1).permute(0, 2, 1)], dim=-1) + else: + coord = self.INR_encoding(coord, (res_h, res_w)) + + total = coord.clone() + + interval = 10 + all_intervals = math.ceil(res_h / interval) + divisible = True + if res_h / interval != res_h // interval: + divisible = False + + for n in range(self.max_hidden_mlp_num): + accum_mlp_output = [] + for line in range(all_intervals): + if not divisible and line == all_intervals - 1: + coord = total[:, line * interval * res_w:, :] + else: + coord = total[:, line * interval * res_w: (line + 1) * interval * res_w, :] + if n == 0: + accum_mlp_output.append(self.mlp_process(coord, + self.INR_encoding.out_dim + ( + 4 if self.opt.isMoreINRInput else 0), + self.opt, content_mlp=self.content_mlp_blocks[ + f"block{self.max_hidden_mlp_num - 1 - n}"]( + encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else + encoder_outputs[self.max_hidden_mlp_num - 1 - n]), + resolution=(interval, + res_w) if divisible or line != all_intervals - 1 else ( + res_h - interval * (all_intervals - 1), res_w))[1]) + + else: + accum_mlp_output.append(self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + ( + 4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0][:, + line * interval * res_w: ( + line + 1) * interval * res_w, + :] + if divisible or line != all_intervals - 1 else mlp_output[0][:, line * interval * res_w:, :], + content_mlp=self.content_mlp_blocks[ + f"block{self.max_hidden_mlp_num - 1 - n}"]( + encoder_outputs.pop( + self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else + encoder_outputs[self.max_hidden_mlp_num - 1 - n]), + resolution=(interval, + res_w) if divisible or line != all_intervals - 1 else ( + res_h - interval * (all_intervals - 1), res_w))[1]) + + accum_mlp_output = torch.cat(accum_mlp_output, dim=1) + mlp_output = [accum_mlp_output, accum_mlp_output] + + encoder_outputs = encoder_outputs[::-1] + output = encoder_outputs[0] + for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]): + output = block(output) + output = output + skip_output + output = self.deconv_blocks[-1](output) + + app_mlp, app_params = self.appearance_mlps(output) + harm_out = [] + + accum_mlp_output = [] + for line in range(all_intervals): + if not divisible and line == all_intervals - 1: + base = mlp_output[1][:, line * interval * res_w:, :] + else: + base = mlp_output[1][:, line * interval * res_w: (line + 1) * interval * res_w, :] + + accum_mlp_output.append(self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=base, + appearance_mlp=app_mlp, + resolution=( + interval, + res_w) if divisible or line != all_intervals - 1 else ( + res_h - interval * (all_intervals - 1), res_w))[0]) + + accum_mlp_output = torch.cat(accum_mlp_output, dim=2) + harm_out.append(accum_mlp_output) + + fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None) + + return harm_out, fit_lut3d, lut_transform_image diff --git a/model/base/ih_model.py b/model/base/ih_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4dc531e41d99169dc35113e2c4bfcdb0aa5e67 --- /dev/null +++ b/model/base/ih_model.py @@ -0,0 +1,88 @@ +import torch +import torchvision +import torch.nn as nn + +from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder + +from .ops import ScaleLayer + + +class IHModelWithBackbone(nn.Module): + def __init__( + self, + model, backbone, + downsize_backbone_input=False, + mask_fusion='sum', + backbone_conv1_channels=64, opt=None + ): + super(IHModelWithBackbone, self).__init__() + self.downsize_backbone_input = downsize_backbone_input + self.mask_fusion = mask_fusion + + self.backbone = backbone + self.model = model + self.opt = opt + + self.mask_conv = nn.Sequential( + nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True), + ScaleLayer(init_value=0.1, lr_mult=1) + ) + + def forward(self, image, mask, coord=None, start_proportion=None): + if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): + backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]) + backbone_mask = torch.cat( + (torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]), + 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1) + else: + backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image) + backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask), + 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1) + + backbone_mask_features = self.mask_conv(backbone_mask[:, :1]) + backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features) + + output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion) + return output + + +class DeepImageHarmonization(nn.Module): + def __init__( + self, + depth, + norm_layer=nn.BatchNorm2d, batchnorm_from=0, + attend_from=-1, + image_fusion=False, + ch=64, max_channels=512, + backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None + ): + super(DeepImageHarmonization, self).__init__() + self.depth = depth + self.encoder = ConvEncoder( + depth, ch, + norm_layer, batchnorm_from, max_channels, + backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode + ) + self.opt = opt + if opt.INRDecode: + "See Table 2 in the paper to test with different INR decoders' structures." + self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from) + else: + "Baseline: https://github.com/SamsungLabs/image_harmonization" + self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion) + + def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None): + if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): + x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]), + torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1) + else: + x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image), + torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1) + + intermediates = self.encoder(x, backbone_features) + + if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): + output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion) + else: + output = self.decoder(intermediates, image, mask) + return output diff --git a/model/base/ops.py b/model/base/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fd2a027c79b9995f8be59cf412c96ee05e8b5050 --- /dev/null +++ b/model/base/ops.py @@ -0,0 +1,397 @@ +import torch +from torch import nn as nn +import numpy as np +import math +import torch.nn.functional as F + + +class SimpleInputFusion(nn.Module): + def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d): + super(SimpleInputFusion, self).__init__() + + self.fusion_conv = nn.Sequential( + nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1), + nn.LeakyReLU(negative_slope=0.2), + norm_layer(ch), + nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1), + ) + + def forward(self, image, additional_input): + return self.fusion_conv(torch.cat((image, additional_input), dim=1)) + + +class MaskedChannelAttention(nn.Module): + def __init__(self, in_channels, *args, **kwargs): + super(MaskedChannelAttention, self).__init__() + self.global_max_pool = MaskedGlobalMaxPool2d() + self.global_avg_pool = FastGlobalAvgPool2d() + + intermediate_channels_count = max(in_channels // 16, 8) + self.attention_transform = nn.Sequential( + nn.Linear(3 * in_channels, intermediate_channels_count), + nn.ReLU(inplace=True), + nn.Linear(intermediate_channels_count, in_channels), + nn.Sigmoid(), + ) + + def forward(self, x, mask): + if mask.shape[2:] != x.shape[:2]: + mask = nn.functional.interpolate( + mask, size=x.size()[-2:], + mode='bilinear', align_corners=True + ) + pooled_x = torch.cat([ + self.global_max_pool(x, mask), + self.global_avg_pool(x) + ], dim=1) + channel_attention_weights = self.attention_transform(pooled_x)[..., None, None] + + return channel_attention_weights * x + + +class MaskedGlobalMaxPool2d(nn.Module): + def __init__(self): + super().__init__() + self.global_max_pool = FastGlobalMaxPool2d() + + def forward(self, x, mask): + return torch.cat(( + self.global_max_pool(x * mask), + self.global_max_pool(x * (1.0 - mask)) + ), dim=1) + + +class FastGlobalAvgPool2d(nn.Module): + def __init__(self): + super(FastGlobalAvgPool2d, self).__init__() + + def forward(self, x): + in_size = x.size() + return x.view((in_size[0], in_size[1], -1)).mean(dim=2) + + +class FastGlobalMaxPool2d(nn.Module): + def __init__(self): + super(FastGlobalMaxPool2d, self).__init__() + + def forward(self, x): + in_size = x.size() + return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0] + + +class ScaleLayer(nn.Module): + def __init__(self, init_value=1.0, lr_mult=1): + super().__init__() + self.lr_mult = lr_mult + self.scale = nn.Parameter( + torch.full((1,), init_value / lr_mult, dtype=torch.float32) + ) + + def forward(self, x): + scale = torch.abs(self.scale * self.lr_mult) + return x * scale + + +class FeaturesConnector(nn.Module): + def __init__(self, mode, in_channels, feature_channels, out_channels): + super(FeaturesConnector, self).__init__() + self.mode = mode if feature_channels else '' + + if self.mode == 'catc': + self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1) + elif self.mode == 'sum': + self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1) + + self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels + + def forward(self, x, features): + if self.mode == 'cat': + return torch.cat((x, features), 1) + if self.mode == 'catc': + return self.reduce_conv(torch.cat((x, features), 1)) + if self.mode == 'sum': + return self.reduce_conv(features) + x + return x + + def extra_repr(self): + return self.mode + + +class PosEncodingNeRF(nn.Module): + def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True): + super().__init__() + + self.in_features = in_features + + if self.in_features == 3: + self.num_frequencies = 10 + elif self.in_features == 2: + assert sidelength is not None + if isinstance(sidelength, int): + sidelength = (sidelength, sidelength) + self.num_frequencies = 4 + if use_nyquist: + self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1])) + elif self.in_features == 1: + assert fn_samples is not None + self.num_frequencies = 4 + if use_nyquist: + self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples) + + self.out_dim = in_features + 2 * in_features * self.num_frequencies + + def get_num_frequencies_nyquist(self, samples): + nyquist_rate = 1 / (2 * (2 * 1 / samples)) + return int(math.floor(math.log(nyquist_rate, 2))) + + def forward(self, coords): + coords = coords.view(coords.shape[0], -1, self.in_features) + + coords_pos_enc = coords + for i in range(self.num_frequencies): + for j in range(self.in_features): + c = coords[..., j] + + sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1) + cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1) + + coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1) + + return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim) + + +class RandomFourier(nn.Module): + def __init__(self, std_scale, embedding_length, device): + super().__init__() + + self.embed = torch.normal(0, 1, (2, embedding_length)) * std_scale + self.embed = self.embed.to(device) + + self.out_dim = embedding_length * 2 + 2 + + def forward(self, coords): + coords_pos_enc = torch.cat([torch.sin(torch.matmul(2 * np.pi * coords, self.embed)), + torch.cos(torch.matmul(2 * np.pi * coords, self.embed))], dim=-1) + + return torch.cat([coords, coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)], dim=-1) + + +class CIPS_embed(nn.Module): + def __init__(self, size, embedding_length): + super().__init__() + self.fourier_embed = ConstantInput(size, embedding_length) + self.predict_embed = Predict_embed(embedding_length) + self.out_dim = embedding_length * 2 + 2 + + def forward(self, coord, res=None): + x = self.predict_embed(coord) + y = self.fourier_embed(x, coord, res) + + return torch.cat([coord, x, y], dim=-1) + + +class Predict_embed(nn.Module): + def __init__(self, embedding_length): + super(Predict_embed, self).__init__() + self.ffm = nn.Linear(2, embedding_length, bias=True) + nn.init.uniform_(self.ffm.weight, -np.sqrt(9 / 2), np.sqrt(9 / 2)) + + def forward(self, x): + x = self.ffm(x) + x = torch.sin(x) + return x + + +class ConstantInput(nn.Module): + def __init__(self, size, channel): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, size ** 2, channel)) + + def forward(self, input, coord, resolution=None): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1) + + if coord.shape[1] != self.input.shape[1]: + x = out.permute(0, 2, 1).contiguous().view(batch, self.input.shape[-1], + int(self.input.shape[1] ** 0.5), int(self.input.shape[1] ** 0.5)) + + if resolution is None: + grid = coord.view(coord.shape[0], int(coord.shape[1] ** 0.5), int(coord.shape[1] ** 0.5), coord.shape[-1]) + else: + grid = coord.view(coord.shape[0], *resolution, coord.shape[-1]) + + out = F.grid_sample(x, grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=True) + + out = out.permute(0, 2, 3, 1).contiguous().view(batch, -1, self.input.shape[-1]) + + return out + + +class INRGAN_embed(nn.Module): + def __init__(self, resolution: int, w_dim=None): + super().__init__() + + self.resolution = resolution + self.res_cfg = {"log_emb_size": 32, + "random_emb_size": 32, + "const_emb_size": 64, + "use_cosine": True} + self.log_emb_size = self.res_cfg.get('log_emb_size', 0) + self.random_emb_size = self.res_cfg.get('random_emb_size', 0) + self.shared_emb_size = self.res_cfg.get('shared_emb_size', 0) + self.predictable_emb_size = self.res_cfg.get('predictable_emb_size', 0) + self.const_emb_size = self.res_cfg.get('const_emb_size', 0) + self.fourier_scale = self.res_cfg.get('fourier_scale', np.sqrt(10)) + self.use_cosine = self.res_cfg.get('use_cosine', False) + + if self.log_emb_size > 0: + self.register_buffer('log_basis', generate_logarithmic_basis( + resolution, self.log_emb_size, use_diagonal=self.res_cfg.get('use_diagonal', False))) + + if self.random_emb_size > 0: + self.register_buffer('random_basis', self.sample_w_matrix((2, self.random_emb_size), self.fourier_scale)) + + if self.shared_emb_size > 0: + self.shared_basis = nn.Parameter(self.sample_w_matrix((2, self.shared_emb_size), self.fourier_scale)) + + if self.predictable_emb_size > 0: + self.W_size = self.predictable_emb_size * self.cfg.coord_dim + self.b_size = self.predictable_emb_size + self.affine = nn.Linear(w_dim, self.W_size + self.b_size) + + if self.const_emb_size > 0: + self.const_embs = nn.Parameter(torch.randn(1, resolution ** 2, self.const_emb_size)) + + self.out_dim = self.get_total_dim() + 2 + + def sample_w_matrix(self, shape, scale: float): + return torch.randn(shape) * scale + + def get_total_dim(self) -> int: + total_dim = 0 + if self.log_emb_size > 0: + total_dim += self.log_basis.shape[0] * (2 if self.use_cosine else 1) + total_dim += self.random_emb_size * (2 if self.use_cosine else 1) + total_dim += self.shared_emb_size * (2 if self.use_cosine else 1) + total_dim += self.predictable_emb_size * (2 if self.use_cosine else 1) + total_dim += self.const_emb_size + + return total_dim + + def forward(self, raw_coords, w=None): + batch_size, img_size, in_channels = raw_coords.shape + + raw_embs = [] + + if self.log_emb_size > 0: + log_bases = self.log_basis.unsqueeze(0).repeat(batch_size, 1, 1).permute(0, 2, 1) + raw_log_embs = torch.matmul(raw_coords, log_bases) + raw_embs.append(raw_log_embs) + + if self.random_emb_size > 0: + random_bases = self.random_basis.unsqueeze(0).repeat(batch_size, 1, 1) + raw_random_embs = torch.matmul(raw_coords, random_bases) + raw_embs.append(raw_random_embs) + + if self.shared_emb_size > 0: + shared_bases = self.shared_basis.unsqueeze(0).repeat(batch_size, 1, 1) + raw_shared_embs = torch.matmul(raw_coords, shared_bases) + raw_embs.append(raw_shared_embs) + + if self.predictable_emb_size > 0: + mod = self.affine(w) + W = self.fourier_scale * mod[:, :self.W_size] + W = W.view(batch_size, self.cfg.coord_dim, self.predictable_emb_size) + bias = mod[:, self.W_size:].view(batch_size, 1, self.predictable_emb_size) + raw_predictable_embs = (torch.matmul(raw_coords, W) + bias) + raw_embs.append(raw_predictable_embs) + + if len(raw_embs) > 0: + raw_embs = torch.cat(raw_embs, dim=-1) + raw_embs = raw_embs.contiguous() + out = raw_embs.sin() + + if self.use_cosine: + out = torch.cat([out, raw_embs.cos()], dim=-1) + + if self.const_emb_size > 0: + const_embs = self.const_embs.repeat([batch_size, 1, 1]) + const_embs = const_embs + out = torch.cat([out, const_embs], dim=-1) + + return torch.cat([raw_coords, out], dim=-1) + + +def generate_logarithmic_basis( + resolution, + max_num_feats, + remove_lowest_freq: bool = False, + use_diagonal: bool = True): + """ + Generates a directional logarithmic basis with the following directions: + - horizontal + - vertical + - main diagonal + - anti-diagonal + """ + max_num_feats_per_direction = np.ceil(np.log2(resolution)).astype(int) + bases = [ + generate_horizontal_basis(max_num_feats_per_direction), + generate_vertical_basis(max_num_feats_per_direction), + ] + + if use_diagonal: + bases.extend([ + generate_diag_main_basis(max_num_feats_per_direction), + generate_anti_diag_basis(max_num_feats_per_direction), + ]) + + if remove_lowest_freq: + bases = [b[1:] for b in bases] + + # If we do not fit into `max_num_feats`, then trying to remove the features in the order: + # 1) anti-diagonal 2) main-diagonal + # while (max_num_feats_per_direction * len(bases) > max_num_feats) and (len(bases) > 2): + # bases = bases[:-1] + + basis = torch.cat(bases, dim=0) + + # If we still do not fit, then let's remove each second feature, + # then each third, each forth and so on + # We cannot drop the whole horizontal or vertical direction since otherwise + # model won't be able to locate the position + # (unless the previously computed embeddings encode the position) + # while basis.shape[0] > max_num_feats: + # num_exceeding_feats = basis.shape[0] - max_num_feats + # basis = basis[::2] + + assert basis.shape[0] <= max_num_feats, \ + f"num_coord_feats > max_num_fixed_coord_feats: {basis.shape, max_num_feats}." + + return basis + + +def generate_horizontal_basis(num_feats: int): + return generate_wavefront_basis(num_feats, [0.0, 1.0], 4.0) + + +def generate_vertical_basis(num_feats: int): + return generate_wavefront_basis(num_feats, [1.0, 0.0], 4.0) + + +def generate_diag_main_basis(num_feats: int): + return generate_wavefront_basis(num_feats, [-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2)) + + +def generate_anti_diag_basis(num_feats: int): + return generate_wavefront_basis(num_feats, [1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2)) + + +def generate_wavefront_basis(num_feats: int, basis_block, period_length: float): + period_coef = 2.0 * np.pi / period_length + basis = torch.tensor([basis_block]).repeat(num_feats, 1) # [num_feats, 2] + powers = torch.tensor([2]).repeat(num_feats).pow(torch.arange(num_feats)).unsqueeze(1) # [num_feats, 1] + result = basis * powers * period_coef # [num_feats, 2] + + return result.float() \ No newline at end of file diff --git a/model/build_model.py b/model/build_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9d91c125aa7caca6467d2a6af6c8973112019681 --- /dev/null +++ b/model/build_model.py @@ -0,0 +1,24 @@ +import torch.nn as nn +from .backbone import build_backbone + + +class build_model(nn.Module): + def __init__(self, opt): + super().__init__() + + self.opt = opt + self.backbone = build_backbone('baseline', opt) + + def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None): + if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): + """ + For HR Training, due to the designed RSC strategy in Section 3.4 in the paper, + here we need to pass in the coordinates of the cropped regions. + """ + extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion) + else: + extracted_features = self.backbone(composite_image, mask) + + if self.opt.INRDecode: + return extracted_features + return None, None, extracted_features \ No newline at end of file diff --git a/model/hrnetv2/__init__.py b/model/hrnetv2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/hrnetv2/__pycache__/__init__.cpython-38.pyc b/model/hrnetv2/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..607ef591d5b7a70b96a93ce0417f4c13acf69489 Binary files /dev/null and b/model/hrnetv2/__pycache__/__init__.cpython-38.pyc differ diff --git a/model/hrnetv2/__pycache__/hrnet_ocr.cpython-38.pyc b/model/hrnetv2/__pycache__/hrnet_ocr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05deb9f6bbfeca67b6f6b4f0cc21a41b1e17c8a3 Binary files /dev/null and b/model/hrnetv2/__pycache__/hrnet_ocr.cpython-38.pyc differ diff --git a/model/hrnetv2/__pycache__/modifiers.cpython-38.pyc b/model/hrnetv2/__pycache__/modifiers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e752745b3c6dac211359f74769a212e47e20128 Binary files /dev/null and b/model/hrnetv2/__pycache__/modifiers.cpython-38.pyc differ diff --git a/model/hrnetv2/__pycache__/ocr.cpython-38.pyc b/model/hrnetv2/__pycache__/ocr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb0e0898e3bf3545322ef5b235b4ec09cb699c54 Binary files /dev/null and b/model/hrnetv2/__pycache__/ocr.cpython-38.pyc differ diff --git a/model/hrnetv2/__pycache__/resnetv1b.cpython-38.pyc b/model/hrnetv2/__pycache__/resnetv1b.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac205af65ea9caba6a0821d4e3686d8a0b6be290 Binary files /dev/null and b/model/hrnetv2/__pycache__/resnetv1b.cpython-38.pyc differ diff --git a/model/hrnetv2/hrnet_ocr.py b/model/hrnetv2/hrnet_ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7e98a03be0f9c8a44b5c171d4e64704af77d82 --- /dev/null +++ b/model/hrnetv2/hrnet_ocr.py @@ -0,0 +1,400 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch._utils + +from .ocr import SpatialOCR_Module, SpatialGather_Module +from .resnetv1b import BasicBlockV1b, BottleneckV1b + +relu_inplace = True + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method,multi_scale_output=True, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + self.norm_layer = norm_layer + self.align_corners = align_corners + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=relu_inplace) + + def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(num_channels[branch_index] * block.expansion), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, + downsample=downsample, norm_layer=self.norm_layer)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], + norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(in_channels=num_inchannels[j], + out_channels=num_inchannels[i], + kernel_size=1, + bias=False), + self.norm_layer(num_inchannels[i]))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3), + nn.ReLU(inplace=relu_inplace))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode='bilinear', align_corners=self.align_corners) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HighResolutionNet(nn.Module): + def __init__(self, width, num_classes, ocr_width=256, small=False, + norm_layer=nn.BatchNorm2d, align_corners=True, opt=None): + super(HighResolutionNet, self).__init__() + self.opt = opt + self.norm_layer = norm_layer + self.width = width + self.ocr_width = ocr_width + self.ocr_on = ocr_width > 0 + self.align_corners = align_corners + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(64) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = norm_layer(64) + self.relu = nn.ReLU(inplace=relu_inplace) + + num_blocks = 2 if small else 4 + + stage1_num_channels = 64 + self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks) + stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels + + self.stage2_num_branches = 2 + num_channels = [width, 2 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_inchannels) + self.stage2, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches, + num_blocks=2 * [num_blocks], num_channels=num_channels) + + self.stage3_num_branches = 3 + num_channels = [width, 2 * width, 4 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage3, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, + num_modules=3 if small else 4, num_branches=self.stage3_num_branches, + num_blocks=3 * [num_blocks], num_channels=num_channels) + + self.stage4_num_branches = 4 + num_channels = [width, 2 * width, 4 * width, 8 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage4, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3, + num_branches=self.stage4_num_branches, + num_blocks=4 * [num_blocks], num_channels=num_channels) + + if self.ocr_on: + last_inp_channels = np.int(np.sum(pre_stage_channels)) + ocr_mid_channels = 2 * ocr_width + ocr_key_channels = ocr_width + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(last_inp_channels, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(ocr_mid_channels), + nn.ReLU(inplace=relu_inplace), + ) + self.ocr_gather_head = SpatialGather_Module(num_classes) + + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + norm_layer=norm_layer, + align_corners=align_corners, opt=opt) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + self.norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=relu_inplace))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(outchannels), + nn.ReLU(inplace=relu_inplace))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, + downsample=downsample, norm_layer=self.norm_layer)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_stage(self, block, num_inchannels, + num_modules, num_branches, num_blocks, num_channels, + fuse_method='SUM', + multi_scale_output=True): + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + norm_layer=self.norm_layer, + align_corners=self.align_corners) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x, mask=None, additional_features=None): + hrnet_feats = self.compute_hrnet_feats(x, additional_features) + if not self.ocr_on: + return hrnet_feats, + + ocr_feats = self.conv3x3_ocr(hrnet_feats) + mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True) + context = self.ocr_gather_head(ocr_feats, mask) + ocr_feats = self.ocr_distri_head(ocr_feats, context) + return ocr_feats, + + def compute_hrnet_feats(self, x, additional_features, return_list=False): + x = self.compute_pre_stage_features(x, additional_features) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_num_branches): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_num_branches): + if self.transition2[i] is not None: + if i < self.stage2_num_branches: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_num_branches): + if self.transition3[i] is not None: + if i < self.stage3_num_branches: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + if return_list: + return x + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate(x[1], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x2 = F.interpolate(x[2], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x3 = F.interpolate(x[3], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + + return torch.cat([x[0], x1, x2, x3], 1) + + def compute_pre_stage_features(self, x, additional_features): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if additional_features is not None: + x = x + additional_features + x = self.conv2(x) + x = self.bn2(x) + return self.relu(x) + + def load_pretrained_weights(self, pretrained_path=''): + model_dict = self.state_dict() + + if not os.path.exists(pretrained_path): + print(f'\nFile "{pretrained_path}" does not exist.') + print('You need to specify the correct path to the pre-trained weights.\n' + 'You can download the weights for HRNet from the repository:\n' + 'https://github.com/HRNet/HRNet-Image-Classification') + exit(1) + pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'}) + pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in + pretrained_dict.items()} + params_count = len(pretrained_dict) + + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + + print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet') + + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) diff --git a/model/hrnetv2/modifiers.py b/model/hrnetv2/modifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..046221838069e90ae201b9169db159cc69c13244 --- /dev/null +++ b/model/hrnetv2/modifiers.py @@ -0,0 +1,11 @@ + + +class LRMult(object): + def __init__(self, lr_mult=1.): + self.lr_mult = lr_mult + + def __call__(self, m): + if getattr(m, 'weight', None) is not None: + m.weight.lr_mult = self.lr_mult + if getattr(m, 'bias', None) is not None: + m.bias.lr_mult = self.lr_mult diff --git a/model/hrnetv2/ocr.py b/model/hrnetv2/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..d9cfbb8eec51f1e5532b9e8b3e35c6a4e0757cff --- /dev/null +++ b/model/hrnetv2/ocr.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F + + +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, cls_num=0, scale=1): + super(SpatialGather_Module, self).__init__() + self.cls_num = cls_num + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw + ocr_context = torch.matmul(probs, feats) \ + .permute(0, 2, 1).unsqueeze(3).contiguous() # batch x k x c + return ocr_context + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + """ + + def __init__(self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + norm_layer=nn.BatchNorm2d, + align_corners=True, opt=None): + super(SpatialOCR_Module, self).__init__() + self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale, + norm_layer, align_corners) + _in_channels = 2 * in_channels + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), + nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)), + nn.Dropout2d(dropout) + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + return output + + +class ObjectAttentionBlock2D(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + bn_type : specify the bn type + Return: + N X C X H X W + ''' + + def __init__(self, + in_channels, + key_channels, + scale=1, + norm_layer=nn.BatchNorm2d, + align_corners=True): + super(ObjectAttentionBlock2D, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.align_corners = align_corners + + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_object = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_down = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_up = nn.Sequential( + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)) + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels ** -.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), + mode='bilinear', align_corners=self.align_corners) + + return context diff --git a/model/hrnetv2/resnetv1b.py b/model/hrnetv2/resnetv1b.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad24cef5bde19f2627cfd3f755636f37cfb39ac --- /dev/null +++ b/model/hrnetv2/resnetv1b.py @@ -0,0 +1,276 @@ +import torch +import torch.nn as nn +GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' + + +class BasicBlockV1b(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BasicBlockV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn1 = norm_layer(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, + padding=previous_dilation, dilation=previous_dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class BottleneckV1b(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BottleneckV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(planes) + + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class ResNetV1b(nn.Module): + """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5. + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockV1, BottleneckV1. + layers : list of int + Numbers of layers in each block + classes : int, default 1000 + Number of classification classes. + dilated : bool, default False + Applying dilation strategy to pretrained ResNet yielding a stride-8 model, + typically used in Semantic Segmentation. + norm_layer : object + Normalization layer used (default: :class:`nn.BatchNorm2d`) + deep_stem : bool, default False + Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + final_drop : float, default 0.0 + Dropout ratio before the final classification layer. + + Reference: + - He, Kaiming, et al. "Deep residual learning for image recognition." + Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. + + - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." + """ + def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32, + avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d): + self.inplanes = stem_width*2 if deep_stem else 64 + super(ResNetV1b, self).__init__() + if not deep_stem: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False) + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(True) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down, + norm_layer=norm_layer) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down, + norm_layer=norm_layer) + if dilated: + self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, + avg_down=avg_down, norm_layer=norm_layer) + else: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.drop = None + if final_drop > 0.0: + self.drop = nn.Dropout(final_drop) + self.fc = nn.Linear(512 * block.expansion, classes) + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, + avg_down=False, norm_layer=nn.BatchNorm2d): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = [] + if avg_down: + if dilation == 1: + downsample.append( + nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False) + ) + else: + downsample.append( + nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False) + ) + downsample.extend([ + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=1, bias=False), + norm_layer(planes * block.expansion) + ]) + downsample = nn.Sequential(*downsample) + else: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + norm_layer(planes * block.expansion) + ) + + layers = [] + if dilation in (1, 2): + layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + elif dilation == 4: + layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + else: + raise RuntimeError("=> unknown dilation size: {}".format(dilation)) + + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation, + previous_dilation=dilation, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + if self.drop is not None: + x = self.drop(x) + x = self.fc(x) + + return x + + +def _safe_state_dict_filtering(orig_dict, model_dict_keys): + filtered_orig_dict = {} + for k, v in orig_dict.items(): + if k in model_dict_keys: + filtered_orig_dict[k] = v + else: + print(f"[ERROR] Failed to load <{k}> in backbone") + return filtered_orig_dict + + +def resnet34_v1b(pretrained=False, **kwargs): + model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet50_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet101_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet152_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model diff --git a/model/lut_transformation_net.py b/model/lut_transformation_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f119bc17e4d9ff3ac346bf85553061ec852177af --- /dev/null +++ b/model/lut_transformation_net.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.misc import normalize + + +class build_lut_transform(nn.Module): + + def __init__(self, input_dim, lut_dim, input_resolution, opt): + super().__init__() + + self.lut_dim = lut_dim + self.opt = opt + + # self.compress_layer = nn.Linear(input_resolution, 1) + + self.transform_layers = nn.Sequential( + nn.Linear(input_dim, 3 * lut_dim ** 3, bias=True), + # nn.BatchNorm1d(3 * lut_dim ** 3, affine=False), + nn.ReLU(inplace=True), + nn.Linear(3 * lut_dim ** 3, 3 * lut_dim ** 3, bias=True), + ) + self.transform_layers[-1].apply(lambda m: hyper_weight_init(m)) + + def forward(self, composite_image, fg_appearance_features, bg_appearance_features): + composite_image = normalize(composite_image, self.opt, 'inv') + + features = fg_appearance_features + + lut_params = self.transform_layers(features) + + fit_3DLUT = lut_params.view(lut_params.shape[0], 3, self.lut_dim, self.lut_dim, self.lut_dim) + + lut_transform_image = torch.stack( + [TrilinearInterpolation(lut, image)[0] for lut, image in zip(fit_3DLUT, composite_image)], dim=0) + + return fit_3DLUT, normalize(lut_transform_image, self.opt) + + +def TrilinearInterpolation(LUT, img): + img = (img - 0.5) * 2. + + img = img.unsqueeze(0).permute(0, 2, 3, 1)[:, None].flip(-1) + + # Note that the coordinates in the grid_sample are inverse to LUT DHW, i.e., xyz is to WHD not DHW. + LUT = LUT[None] + + # grid sample + result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True) + + # drop added dimensions and permute back + result = result[:, :, 0] + + return result + + +def hyper_weight_init(m): + if hasattr(m, 'weight'): + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') + m.weight.data = m.weight.data / 1.e2 + + if hasattr(m, 'bias'): + with torch.no_grad(): + m.bias.uniform_(0., 1.) diff --git a/pretrained_models/Resolution_1024_HAdobe5K.pth b/pretrained_models/Resolution_1024_HAdobe5K.pth new file mode 100644 index 0000000000000000000000000000000000000000..be7712c1c3f781748eff3e713e314dbc101c503b --- /dev/null +++ b/pretrained_models/Resolution_1024_HAdobe5K.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4917e99cc20c2530b6d248d530368929c1784113d20365085b96bbb10860a2f8 +size 477235439 diff --git a/pretrained_models/Resolution_2048_HAdobe5K.pth b/pretrained_models/Resolution_2048_HAdobe5K.pth new file mode 100644 index 0000000000000000000000000000000000000000..ef4d0b31b61d544bd97be5cb5d97608c34497473 --- /dev/null +++ b/pretrained_models/Resolution_2048_HAdobe5K.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa3d076b5cbf653f17fbf02b95f45b95a0d38e6cb53eaeed71cac1fb22af6f69 +size 477235439 diff --git a/pretrained_models/Resolution_256_iHarmony4.pth b/pretrained_models/Resolution_256_iHarmony4.pth new file mode 100644 index 0000000000000000000000000000000000000000..4d364bd0939ce0ba3d393e6f3461aff8e31b2646 --- /dev/null +++ b/pretrained_models/Resolution_256_iHarmony4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70a7df7a5b8ba502b69d8dba3b9f47cb99e95f5ebb64b289f9f0454613d6f5b6 +size 477528743 diff --git a/pretrained_models/Resolution_RAW_HAdobe5K.pth b/pretrained_models/Resolution_RAW_HAdobe5K.pth new file mode 100644 index 0000000000000000000000000000000000000000..ebecc42d63627ca3dd76d91a50fbc99ce2eb66d9 --- /dev/null +++ b/pretrained_models/Resolution_RAW_HAdobe5K.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1829182a1a03e9bb5116ac166a9debfb6fdb97c8547790cc0d94bfe313e0c80 +size 953285076 diff --git a/pretrained_models/Resolution_RAW_iHarmony4.pth b/pretrained_models/Resolution_RAW_iHarmony4.pth new file mode 100644 index 0000000000000000000000000000000000000000..164d6467d05514225f26561b499b551aefc1ef40 --- /dev/null +++ b/pretrained_models/Resolution_RAW_iHarmony4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5475d7a2a77260f8c9503b02e46a7b456bd1a7a7d7b0fc3c85ef534b473eef8 +size 477235439 diff --git a/pretrained_models/hrnetv2_w18_imagenet_pretrained.pth b/pretrained_models/hrnetv2_w18_imagenet_pretrained.pth new file mode 100644 index 0000000000000000000000000000000000000000..5d6c001ff1443b030fdd11e72d60b9722b8f027d --- /dev/null +++ b/pretrained_models/hrnetv2_w18_imagenet_pretrained.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00eb200687c1ed1fc1042767e5b965052f4c7338823ba4d36a790979a078b36b +size 85758673 diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/constructDataset.py b/tools/constructDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..10e9ff67bc531f959835bc1db84a28390b3eb42f --- /dev/null +++ b/tools/constructDataset.py @@ -0,0 +1,33 @@ +import os +import shutil + +# root = r"G:\Datasets\Images Harmonization\LR_real_composite_images_99_DIH\all" +# +# all_path = os.listdir(os.path.join(root, "image")) +# +# with open(os.path.join(root, "dataset.txt"), mode='w') as f: +# for im in all_path: +# f.write(os.path.join('image', im) + "\n") +# +# print("Done!") + + +# Re-order dataset +with open(r"G:\Datasets\Images Harmonization\iHarmony4\IHD_test.txt", mode="r") as f: + names = f.readlines() + +for id in range(len(names)): + names[id] = names[id].strip().split("/")[-1].split(".")[0] + +root = r"G:\ComputerPrograms\Image_Harmonization\Supervised_Harmonization\logs\HINet_2048×2048_HAdobe5k\figs\-1" +os.makedirs(os.path.join(root, "reorder"), exist_ok=True) +allFiles = os.listdir(root) +for id, file in enumerate(allFiles): + if "pred_harmonized_image" in file: + name = names[int(file.split("_")[0])] + "_" + file + shutil.copy(os.path.join(root, file), os.path.join(root, "reorder", name)) + shutil.copy(os.path.join(root, file).replace("pred_harmonized_image", "mask"), os.path.join(root, "reorder", name).replace("pred_harmonized_image", "mask")) + shutil.copy(os.path.join(root, file).replace("pred_harmonized_image", "real"), os.path.join(root, "reorder", name).replace("pred_harmonized_image", "real")) + shutil.copy(os.path.join(root, file).replace("pred_harmonized_image", "composite"), os.path.join(root, "reorder", name).replace("pred_harmonized_image", "composite")) + +print("Done!") diff --git a/tools/resize_Adobe.py b/tools/resize_Adobe.py new file mode 100644 index 0000000000000000000000000000000000000000..ce971ab2d2af260c567140443bd1f1cddb539747 --- /dev/null +++ b/tools/resize_Adobe.py @@ -0,0 +1,45 @@ +import cv2 +import shutil +from tqdm import tqdm +from pathlib import Path + +max_size = 1024 +input_dataset_path = r'.\iHarmony4\HAdobe5k' +output_path = f'{input_dataset_path}_resized{max_size}' + +input_dataset_path = Path(input_dataset_path) +output_path = Path(output_path) + +assert not output_path.exists() + +output_path.mkdir() +for subfolder in ['composite_images', 'masks', 'real_images']: + (output_path / subfolder).mkdir() + +for annotation_path in input_dataset_path.glob('*.txt'): + shutil.copy(annotation_path, output_path / annotation_path.name) + +images_list = sorted(input_dataset_path.rglob('*.jpg')) +images_list.extend(sorted(input_dataset_path.rglob('*.png'))) + +for x in tqdm(images_list): + image = cv2.imread(str(x), cv2.IMREAD_UNCHANGED) + new_path = output_path / x.relative_to(input_dataset_path) + + if max(image.shape[:2]) <= max_size: + shutil.copy(x, new_path) + continue + + new_width = max_size + new_height = max_size + scale = max_size / max(image.shape[:2]) + if image.shape[0] > image.shape[1]: + new_width = int(round(scale * image.shape[1])) + else: + new_height = int(round(scale * image.shape[0])) + + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4) + if x.suffix == '.jpg': + cv2.imwrite(str(new_path), image, [cv2.IMWRITE_JPEG_QUALITY, 90]) + else: + cv2.imwrite(str(new_path), image) \ No newline at end of file diff --git a/tools/resize_HR_Adobe.py b/tools/resize_HR_Adobe.py new file mode 100644 index 0000000000000000000000000000000000000000..d38bbdbf8a55b7ad5a88e73fc4c4c6ed7d874941 --- /dev/null +++ b/tools/resize_HR_Adobe.py @@ -0,0 +1,36 @@ +import cv2 +import shutil +from tqdm import tqdm +from pathlib import Path + +max_size = 2048 +input_dataset_path = r'.\iHarmony4\HAdobe5k' +output_path = f'{input_dataset_path}_resized{max_size}_{max_size}' + +input_dataset_path = Path(input_dataset_path) +output_path = Path(output_path) + +assert not output_path.exists() + +output_path.mkdir() +for subfolder in ['composite_images', 'masks', 'real_images']: + (output_path / subfolder).mkdir() + +for annotation_path in input_dataset_path.glob('*.txt'): + shutil.copy(annotation_path, output_path / annotation_path.name) + +images_list = sorted(input_dataset_path.rglob('*.jpg')) +images_list.extend(sorted(input_dataset_path.rglob('*.png'))) + +for x in tqdm(images_list): + image = cv2.imread(str(x), cv2.IMREAD_UNCHANGED) + new_path = output_path / x.relative_to(input_dataset_path) + + new_width = max_size + new_height = max_size + + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4) + if x.suffix == '.jpg': + cv2.imwrite(str(new_path), image, [cv2.IMWRITE_JPEG_QUALITY, 90]) + else: + cv2.imwrite(str(new_path), image) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0968dccae62e30c6aae070f2b7f3e7dfb64f4272 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/build_loss.cpython-38.pyc b/utils/__pycache__/build_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0940e0a8e373c00cff3f96423393c35114e024cf Binary files /dev/null and b/utils/__pycache__/build_loss.cpython-38.pyc differ diff --git a/utils/__pycache__/metrics.cpython-38.pyc b/utils/__pycache__/metrics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07daba99766fbef209343cd00425243fd97e284d Binary files /dev/null and b/utils/__pycache__/metrics.cpython-38.pyc differ diff --git a/utils/__pycache__/misc.cpython-38.pyc b/utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fd30eef70d0705fcd63703b20e0dee151f18422 Binary files /dev/null and b/utils/__pycache__/misc.cpython-38.pyc differ diff --git a/utils/build_loss.py b/utils/build_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..01ebe4bba88be6a3b611f69809a9c9960aefd9ae --- /dev/null +++ b/utils/build_loss.py @@ -0,0 +1,76 @@ +import torch + + +def loss_generator(ignore: list = None): + loss_fn = {'mse': mse, + 'lut_mse': lut_mse, + 'masked_mse': masked_mse, + 'sample_weighted_mse': sample_weighted_mse, + 'regularize_LUT': regularize_LUT, + 'MaskWeightedMSE': MaskWeightedMSE} + + if ignore: + for fn in ignore: + ignore.pop(fn) + + return loss_fn + + +def mse(pred, gt): + return torch.mean((pred - gt) ** 2) + + +def masked_mse(pred, gt, mask): + delimin = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda() + # total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))])) + out = torch.sum((mask > 100 / 255.) * (pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))])) + out = out / delimin + return torch.mean(out) + + +def sample_weighted_mse(pred, gt, mask): + multi_factor = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda() + multi_factor = multi_factor / (multi_factor.sum()) + # total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))])) + out = torch.mean((pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))])) + out = out * multi_factor + return torch.sum(out) + + +def regularize_LUT(lut): + st = lut[lut < 0.] + reg_st = (st ** 2).mean() if min(st.shape) != 0 else 0 + + lt = lut[lut > 1.] + reg_lt = ((lt - 1.) ** 2).mean() if min(lt.shape) != 0 else 0 + + return reg_lt + reg_st + + +def lut_mse(feat, lut_batch): + loss = 0 + for id in range(feat.shape[0] // lut_batch): + for i in feat[id * lut_batch: id * lut_batch + lut_batch]: + for j in feat[id * lut_batch: id * lut_batch + lut_batch]: + loss += mse(i, j) + + return loss / lut_batch + + +def MaskWeightedMSE(pred, label, mask): + label = label.view(pred.size()) + reduce_dims = get_dims_with_exclusion(label.dim(), 0) + + loss = (pred - label) ** 2 + delimeter = pred.size(1) * torch.clamp_min(torch.sum(mask, dim=reduce_dims), 100) + loss = torch.sum(loss, dim=reduce_dims) / delimeter + + return torch.mean(loss) + + +def get_dims_with_exclusion(dim, exclude=None): + dims = list(range(dim)) + if exclude is not None: + dims.remove(exclude) + + return dims \ No newline at end of file diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..4c0cad855fd24bd9fbad0803f1d77101e667215c --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,58 @@ +import skimage + +import torch +import numpy as np +from pytorch_msssim import ssim +import math + + +def calc_metrics(harmonized, real, mask_batch): + n, c, h, w = harmonized.shape + + mse = [] + fmse = [] + psnr = [] + ssim = [] + for id in range(n): + # fg = (mask_batch[id]).view(-1) + # fg_pixels = int(torch.sum(fg).cpu().numpy()) + # total_pixels = h * w + # + # pred = torch.clamp(harmonized[id] * 255, 0, 255) + # gt = torch.clamp(real[id] * 255, 0, 255) + # + # pred = pred.permute(1, 2, 0).cpu().numpy() + # gt = gt.permute(1, 2, 0).cpu().numpy() + # mask = mask_batch[id].permute(1, 2, 0).cpu().numpy() + # + # mse.append(skimage.metrics.mean_squared_error(pred, gt)) + # fmse.append(skimage.metrics.mean_squared_error(pred * mask, gt * mask) * total_pixels / fg_pixels) + # psnr.append(skimage.metrics.peak_signal_noise_ratio(pred, gt, data_range=pred.max() - pred.min())) + # ssim.append(skimage.metrics.structural_similarity(pred, gt, multichannel=True)) + mse.append(MSE(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) + fmse.append(fMSE(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) + psnr.append(PSNR(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) + ssim.append(SSIM(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) + + return mse, fmse, psnr, ssim + + +def SSIM(pred, target_image, mask): + pred = pred * mask + (target_image) * (1 - mask) + return ssim(pred.unsqueeze(0), target_image.unsqueeze(0)) + + +def MSE(pred, target_image, mask): + return (mask * (pred - target_image) ** 2).mean().item() + + +def fMSE(pred, target_image, mask): + diff = mask * ((pred - target_image) ** 2) + return (diff.sum() / (diff.size(0) * mask.sum() + 1e-6)).item() + + +def PSNR(pred, target_image, mask): + mse = (mask * (pred - target_image) ** 2).mean().item() + squared_max = target_image.max().item() ** 2 + + return 10 * math.log10(squared_max / (mse + 1e-6)) \ No newline at end of file diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..b1419b5af7ace016621c0e3e8934c6397aef2a20 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,380 @@ +import re +from pathlib import Path +import glob +import logging +import numpy as np +import torch +import cv2 +import os +import math +from adamp import AdamP +import random +import torch.nn as nn + +_logger = None + + +def increment_path(path): + # Increment path, i.e. runs/exp1 --> runs/exp{sep}1, runs/exp{sep}2 etc. + res = re.search("\d+", path) + if res is None: + print("Set initial exp number!") + exit(1) + + if not Path(path).exists(): + return str(path) + else: + path = path[:res.start()] + dirs = glob.glob(f"{path}*") # similar paths + matches = [re.search(rf"%s(\d+)" % Path(path).stem, d) for d in dirs] + i = [int(m.groups()[0]) for m in matches if m] # indices + n = max(i) + 1 # increment number + return f"{path}{n}" # update path + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, fmt=':f'): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def create_logger(log_file, level=logging.INFO): + global _logger + _logger = logging.getLogger() + formatter = logging.Formatter( + '[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s') + fh = logging.FileHandler(log_file) + fh.setFormatter(formatter) + sh = logging.StreamHandler() + sh.setFormatter(formatter) + _logger.setLevel(level) + _logger.addHandler(fh) + _logger.addHandler(sh) + + return _logger + + +def get_mgrid(sidelen, dim=2): + '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.''' + if isinstance(sidelen, int): + sidelen = dim * (sidelen,) + + if dim == 2: + pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) + pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) + pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) + elif dim == 3: + pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) + pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) + pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) + pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) + else: + raise NotImplementedError('Not implemented for dim=%d' % dim) + + pixel_coords -= 0.5 + pixel_coords *= 2. + pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) + return pixel_coords + + +def lin2img(tensor, image_resolution=None): + batch_size, num_samples, channels = tensor.shape + if image_resolution is None: + width = np.sqrt(num_samples).astype(int) + height = width + else: + if isinstance(image_resolution, int): + image_resolution = (image_resolution, image_resolution) + height = image_resolution[0] + width = image_resolution[1] + + return tensor.permute(0, 2, 1).contiguous().view(batch_size, channels, height, width) + + +def normalize(x, opt, mode='normal'): + device = x.device + mean = torch.tensor(np.array(opt.transform_mean), dtype=x.dtype)[np.newaxis, :, np.newaxis, np.newaxis].to(device) + var = torch.tensor(np.array(opt.transform_var), dtype=x.dtype)[np.newaxis, :, np.newaxis, np.newaxis].to(device) + if mode == 'normal': + return (x - mean) / var + elif mode == 'inv': + return x * var + mean + + +def prepare_cooridinate_input(mask, dim=2): + '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.''' + if mask.shape[0] == mask.shape[1]: + sidelen = mask.shape[0] + else: + sidelen = mask.shape[:2] + + if isinstance(sidelen, int): + sidelen = dim * (sidelen,) + + if dim == 2: + pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) + pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) + pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) + elif dim == 3: + pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) + pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) + pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) + pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) + else: + raise NotImplementedError('Not implemented for dim=%d' % dim) + + pixel_coords -= 0.5 + pixel_coords *= 2. + return pixel_coords.squeeze(0).transpose(2, 0, 1) + + +def visualize(real, composite, mask, pred_fg, pred_harmonized, lut_transform_image, opt, epoch, + show=False, wandb=True, isAll=False, step=None): + save_path = os.path.join(opt.save_path, "figs", str(epoch)) + os.makedirs(save_path, exist_ok=True) + + if isAll: + final_index = 1 + + """ + Uncomment the following code if you want to save all the results, otherwise will only save the first image + of each batch + """ + # final_index = len(real) + else: + final_index = 1 + + for id in range(final_index): + if show: + cv2.imshow("pred_fg", normalize(pred_fg, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) + cv2.imshow("real", normalize(real, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) + cv2.imshow("lut_transform", normalize(lut_transform_image, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) + cv2.imshow("composite", normalize(composite, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) + cv2.imshow("mask", mask[id].permute(1, 2, 0).cpu().numpy()) + cv2.imshow("pred_harmonized_image", + normalize(pred_harmonized, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) + cv2.waitKey() + + if not opt.INRDecode: + real_tmp = cv2.cvtColor( + normalize(real, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( + np.uint8), + cv2.COLOR_RGB2BGR) + composite_tmp = cv2.cvtColor( + normalize(composite, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( + np.uint8), cv2.COLOR_RGB2BGR) + mask_tmp = mask[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype(np.uint8) + lut_transform_image_tmp = cv2.cvtColor( + normalize(lut_transform_image, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_( + 0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) + else: + pred_fg_tmp = cv2.cvtColor( + normalize(pred_fg, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( + np.uint8), cv2.COLOR_RGB2BGR) + real_tmp = cv2.cvtColor( + normalize(real, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( + np.uint8), + cv2.COLOR_RGB2BGR) + composite_tmp = cv2.cvtColor( + normalize(composite, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( + np.uint8), cv2.COLOR_RGB2BGR) + lut_transform_image_tmp = cv2.cvtColor( + normalize(lut_transform_image, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_( + 0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) + mask_tmp = mask[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype(np.uint8) + pred_harmonized_tmp = cv2.cvtColor( + normalize(pred_harmonized, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_( + 0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) + + if isAll: + cv2.imwrite(os.path.join(save_path, f"{step}_{id}_composite.jpg"), composite_tmp) + cv2.imwrite(os.path.join(save_path, f"{step}_{id}_real.jpg"), real_tmp) + if opt.INRDecode: + cv2.imwrite(os.path.join(save_path, f"{step}_{id}_pred_harmonized_image.jpg"), pred_harmonized_tmp) + cv2.imwrite(os.path.join(save_path, f"{step}_{id}_lut_transform_image.jpg"), lut_transform_image_tmp) + cv2.imwrite(os.path.join(save_path, f"{step}_{id}_mask.jpg"), mask_tmp) + else: + if not opt.INRDecode: + cv2.imwrite(os.path.join(save_path, f"real_{step}_{id}.jpg"), real_tmp) + cv2.imwrite(os.path.join(save_path, f"composite_{step}_{id}.jpg"), composite_tmp) + cv2.imwrite(os.path.join(save_path, f"mask_{step}_{id}.jpg"), mask_tmp) + cv2.imwrite(os.path.join(save_path, f"lut_transform_image_{step}_{id}.jpg"), lut_transform_image_tmp) + else: + cv2.imwrite(os.path.join(save_path, f"pred_fg_{step}_{id}.jpg"), pred_fg_tmp) + cv2.imwrite(os.path.join(save_path, f"real_{step}_{id}.jpg"), real_tmp) + cv2.imwrite(os.path.join(save_path, f"composite_{step}_{id}.jpg"), composite_tmp) + cv2.imwrite(os.path.join(save_path, f"mask_{step}_{id}.jpg"), mask_tmp) + cv2.imwrite(os.path.join(save_path, f"pred_harmonized_image_{step}_{id}.jpg"), pred_harmonized_tmp) + cv2.imwrite(os.path.join(save_path, f"lut_transform_image_{step}_{id}.jpg"), lut_transform_image_tmp) + + "Only upload images of the first batch of the first epoch to save storage." + if wandb and id == 0 and step == 0: + import wandb + real_tmp = wandb.Image(real_tmp, caption=epoch) + composite_tmp = wandb.Image(composite_tmp, caption=epoch) + if opt.INRDecode: + pred_fg_tmp = wandb.Image(pred_fg_tmp, caption=epoch) + pred_harmonized_tmp = wandb.Image(pred_harmonized_tmp, caption=epoch) + lut_transform_image_tmp = wandb.Image(lut_transform_image_tmp, caption=epoch) + mask_tmp = wandb.Image(mask_tmp, caption=epoch) + if not opt.INRDecode: + wandb.log( + {"pic/real": real_tmp, "pic/composite": composite_tmp, + "pic/mask": mask_tmp, + "pic/lut_trans": lut_transform_image_tmp, + "pic/epoch": epoch}) + else: + wandb.log( + {"pic/pred_fg": pred_fg_tmp, "pic/real": real_tmp, "pic/composite": composite_tmp, + "pic/mask": mask_tmp, + "pic/lut_trans": lut_transform_image_tmp, + "pic/pred_harmonized": pred_harmonized_tmp, + "pic/epoch": epoch}) + wandb.log({}) + + +def get_optimizer(model, opt_name, opt_kwargs): + params = [] + base_lr = opt_kwargs['lr'] + for name, param in model.named_parameters(): + param_group = {'params': [param]} + if not param.requires_grad: + params.append(param_group) + continue + + if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0): + # print(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.') + param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult + + params.append(param_group) + + optimizer = { + 'sgd': torch.optim.SGD, + 'adam': torch.optim.Adam, + 'adamw': torch.optim.AdamW, + 'adamp': AdamP + }[opt_name.lower()](params, **opt_kwargs) + + return optimizer + + +def improved_efficient_matmul(a, c, index, batch=256): + """ + Reduce the unneed memory cost, but the speed is very slow. + :param a: N * I * J + :param b: N * J * K + :return: N * I * K + """ + "The first can only support when a is not requires_grad_, and have high speed. While the second one supports " + "whatever situations, but speed is quite slow. More Details in " + "https://discuss.pytorch.org/t/many-weird-phenomena-about-torch-matmul-operation/158208" + + # out = torch.cat( + # [torch.matmul(a[i * batch:i * batch + batch, :, :], c[index[i * batch:i * batch + batch], :, :]) for i in + # range(a.shape[0] // batch)], dim=0) + + batch = 1 + out = torch.cat( + [torch.matmul(a[i * batch:i * batch + batch, :, :], c[index[i * batch], :, :]) for i in + range(a.shape[0] // batch)], dim=0) + + return out + + +class LRMult(object): + def __init__(self, lr_mult=1.): + self.lr_mult = lr_mult + + def __call__(self, m): + if getattr(m, 'weight', None) is not None: + m.weight.lr_mult = self.lr_mult + if getattr(m, 'bias', None) is not None: + m.bias.lr_mult = self.lr_mult + + +def customRandomCrop(objects, crop_height, crop_width, h_start=None, w_start=None): + if h_start is None: + h_start = random.random() + if w_start is None: + w_start = random.random() + if isinstance(objects, list): + out = [] + for obj in objects: + out.append(random_crop(obj, crop_height, crop_width, h_start, w_start)) + + else: + out = random_crop(objects, crop_height, crop_width, h_start, w_start) + + return out, h_start, w_start + + +def get_random_crop_coords(height: int, width: int, crop_height: int, crop_width: int, h_start: float, + w_start: float): + y1 = int((height - crop_height) * h_start) + y2 = y1 + crop_height + x1 = int((width - crop_width) * w_start) + x2 = x1 + crop_width + return x1, y1, x2, y2 + + +def random_crop(img: np.ndarray, crop_height: int, crop_width: int, h_start: float, w_start: float): + height, width = img.shape[:2] + if height < crop_height or width < crop_width: + raise ValueError( + "Requested crop size ({crop_height}, {crop_width}) is " + "larger than the image size ({height}, {width})".format( + crop_height=crop_height, crop_width=crop_width, height=height, width=width + ) + ) + x1, y1, x2, y2 = get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start) + img = img[y1:y2, x1:x2] + return img + + +class PadToDivisor: + def __init__(self, divisor): + super().__init__() + self.divisor = divisor + + def transform(self, images): + + self._pads = (*self._get_dim_padding(images[0].shape[-1]), *self._get_dim_padding(images[0].shape[-2])) + self.pad_operation = nn.ZeroPad2d(padding=self._pads) + + out = [] + for im in images: + out.append(self.pad_operation(im)) + + return out + + def inv_transform(self, image): + assert self._pads is not None,\ + 'Something went wrong, inv_transform(...) should be called after transform(...)' + return self._remove_padding(image) + + def _get_dim_padding(self, dim_size): + pad = (self.divisor - dim_size % self.divisor) % self.divisor + pad_upper = pad // 2 + pad_lower = pad - pad_upper + + return pad_upper, pad_lower + + def _remove_padding(self, tensors): + tensor_h, tensor_w = tensors[0].shape[-2:] + out = [] + for t in tensors: + out.append(t[..., self._pads[2]:tensor_h - self._pads[3], self._pads[0]:tensor_w - self._pads[1]]) + return out