import torch import cv2 import numpy as np import torchvision import os import random from utils.misc import prepare_cooridinate_input, customRandomCrop from datasets.build_INR_dataset import Implicit2DGenerator import albumentations from albumentations import Resize, RandomResizedCrop, HorizontalFlip from torch.utils.data import DataLoader class dataset_generator(torch.utils.data.Dataset): def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'): super().__init__() self.opt = opt self.root_path = opt.dataset_path self.mode = mode self.alb_transforms = alb_transforms self.torch_transforms = torch_transforms self.kp_t = area_keep_thresh with open(dataset_txt, 'r') as f: self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()] self.INR_dataset = Implicit2DGenerator(opt, self.mode) def __len__(self): return len(self.dataset_samples) def __getitem__(self, idx): composite_image = self.dataset_samples[idx] if self.opt.hr_train: if self.opt.isFullRes: "Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \ "quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \ "if `opt.isFullRes` is set to True." composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori") real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg' mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png' composite_image = cv2.imread(composite_image) composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB) real_image = cv2.imread(real_image) real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB) mask = cv2.imread(mask) mask = mask[:, :, 0].astype(np.float32) / 255. """ If set `opt.hr_train` to True: Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres), the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size` patch to feed in multiINR process. For inference, just resize it. While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size. BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5` """ if self.opt.hr_train: if self.mode == 'Train' and self.opt.isFullRes: if random.random() < 0.5: # LR mix training mixTransform = albumentations.Compose( [ RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)), HorizontalFlip()], additional_targets={'real_image': 'image', 'object_mask': 'image'} ) origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) origin_bg_ratio = 1 - origin_fg_ratio "Ensure fg and bg not disappear after transformation" valid_augmentation = False transform_out = None time = 0 while not valid_augmentation: time += 1 # There are some extreme ratio pics, this code is to avoid being hindered by them. if time == 20: tmp_transform = albumentations.Compose( [Resize(self.opt.base_size, self.opt.base_size)], additional_targets={'real_image': 'image', 'object_mask': 'image'}) transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) valid_augmentation = True else: transform_out = mixTransform(image=composite_image, real_image=real_image, object_mask=mask) valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio, origin_bg_ratio, self.kp_t) composite_image = transform_out['image'] real_image = transform_out['real_image'] mask = transform_out['object_mask'] else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop. if real_image.shape[0] < 256: bottom_pad = 256 - real_image.shape[0] else: bottom_pad = (4 - real_image.shape[0] % 4) % 4 if real_image.shape[1] < 256: right_pad = 256 - real_image.shape[1] else: right_pad = (4 - real_image.shape[1] % 4) % 4 composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) origin_bg_ratio = 1 - origin_fg_ratio "Ensure fg and bg not disappear after transformation" valid_augmentation = False transform_out = None time = 0 if self.opt.hr_train: if self.mode == 'Train': if not self.opt.isFullRes: if random.random() < 0.5: # LR mix training mixTransform = albumentations.Compose( [ RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)), HorizontalFlip()], additional_targets={'real_image': 'image', 'object_mask': 'image'} ) while not valid_augmentation: time += 1 # There are some extreme ratio pics, this code is to avoid being hindered by them. if time == 20: tmp_transform = albumentations.Compose( [Resize(self.opt.base_size, self.opt.base_size)], additional_targets={'real_image': 'image', 'object_mask': 'image'}) transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) valid_augmentation = True else: transform_out = mixTransform(image=composite_image, real_image=real_image, object_mask=mask) valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio, origin_bg_ratio, self.kp_t) else: while not valid_augmentation: time += 1 # There are some extreme ratio pics, this code is to avoid being hindered by them. if time == 20: tmp_transform = albumentations.Compose( [Resize(self.opt.input_size, self.opt.input_size)], additional_targets={'real_image': 'image', 'object_mask': 'image'}) transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) valid_augmentation = True else: transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask) valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio, origin_bg_ratio, self.kp_t) composite_image = transform_out['image'] real_image = transform_out['real_image'] mask = transform_out['object_mask'] origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0) tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)], additional_targets={'real_image': 'image', 'object_mask': 'image'}) transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) compos_list = [self.torch_transforms(transform_out['image'])] real_list = [self.torch_transforms(transform_out['real_image'])] mask_list = [ torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))] coord_map_list = [] valid_augmentation = False while not valid_augmentation: # RSC strategy. To crop different resolutions. transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord], self.opt.base_size, self.opt.base_size) valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio) compos_list.append(self.torch_transforms(transform_out[0])) real_list.append(self.torch_transforms(transform_out[1])) mask_list.append( torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32))) coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) for n in range(2): tmp_comp = cv2.resize(composite_image, ( composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1))) tmp_real = cv2.resize(real_image, (real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1))) tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1))) tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0) transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord], self.opt.base_size // 2 ** (n + 1), self.opt.base_size // 2 ** (n + 1), c_h, c_w) compos_list.append(self.torch_transforms(transform_out[0])) real_list.append(self.torch_transforms(transform_out[1])) mask_list.append( torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32))) coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) out_comp = compos_list out_real = real_list out_mask = mask_list out_coord = coord_map_list fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( self.torch_transforms, transform_out[0], transform_out[1], mask) return { 'file_path': self.dataset_samples[idx], 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], 'composite_image': out_comp, 'real_image': out_real, 'mask': out_mask, 'coordinate_map': out_coord, 'composite_image0': out_comp[0], 'real_image0': out_real[0], 'mask0': out_mask[0], 'coordinate_map0': out_coord[0], 'composite_image1': out_comp[1], 'real_image1': out_real[1], 'mask1': out_mask[1], 'coordinate_map1': out_coord[1], 'composite_image2': out_comp[2], 'real_image2': out_real[2], 'mask2': out_mask[2], 'coordinate_map2': out_coord[2], 'composite_image3': out_comp[3], 'real_image3': out_real[3], 'mask3': out_mask[3], 'coordinate_map3': out_coord[3], 'fg_INR_coordinates': fg_INR_coordinates, 'bg_INR_coordinates': bg_INR_coordinates, 'fg_INR_RGB': fg_INR_RGB, 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, 'bg_INR_RGB': bg_INR_RGB } else: if not self.opt.isFullRes: tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)], additional_targets={'real_image': 'image', 'object_mask': 'image'}) transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) coordinate_map = prepare_cooridinate_input(transform_out['object_mask']) "Generate INR dataset." mask = (torchvision.transforms.ToTensor()( transform_out['object_mask']).squeeze() > 100 / 255.).view(-1) mask = np.bool_(mask.numpy()) fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( self.torch_transforms, transform_out['image'], transform_out['real_image'], mask) return { 'file_path': self.dataset_samples[idx], 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], 'composite_image': self.torch_transforms(transform_out['image']), 'real_image': self.torch_transforms(transform_out['real_image']), 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32), # Can automatically transfer to Tensor. 'coordinate_map': coordinate_map, 'fg_INR_coordinates': fg_INR_coordinates, 'bg_INR_coordinates': bg_INR_coordinates, 'fg_INR_RGB': fg_INR_RGB, 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, 'bg_INR_RGB': bg_INR_RGB } else: coordinate_map = prepare_cooridinate_input(mask) "Generate INR dataset." mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1) mask_tmp = np.bool_(mask_tmp.numpy()) fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( self.torch_transforms, composite_image, real_image, mask_tmp) return { 'file_path': self.dataset_samples[idx], 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], 'composite_image': self.torch_transforms(composite_image), 'real_image': self.torch_transforms(real_image), 'mask': mask[np.newaxis, ...].astype(np.float32), # Can automatically transfer to Tensor. 'coordinate_map': coordinate_map, 'fg_INR_coordinates': fg_INR_coordinates, 'bg_INR_coordinates': bg_INR_coordinates, 'fg_INR_RGB': fg_INR_RGB, 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, 'bg_INR_RGB': bg_INR_RGB } while not valid_augmentation: time += 1 # There are some extreme ratio pics, this code is to avoid being hindered by them. if time == 20: tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)], additional_targets={'real_image': 'image', 'object_mask': 'image'}) transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) valid_augmentation = True else: transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask) valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio, origin_bg_ratio, self.kp_t) coordinate_map = prepare_cooridinate_input(transform_out['object_mask']) "Generate INR dataset." mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1) mask = np.bool_(mask.numpy()) fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( self.torch_transforms, transform_out['image'], transform_out['real_image'], mask) return { 'file_path': self.dataset_samples[idx], 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], 'composite_image': self.torch_transforms(transform_out['image']), 'real_image': self.torch_transforms(transform_out['real_image']), 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32), # Can automatically transfer to Tensor. 'coordinate_map': coordinate_map, 'fg_INR_coordinates': fg_INR_coordinates, 'bg_INR_coordinates': bg_INR_coordinates, 'fg_INR_RGB': fg_INR_RGB, 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, 'bg_INR_RGB': bg_INR_RGB } def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh): current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) current_bg_ratio = 1 - current_fg_ratio if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh: return False return True def check_hr_crop_sample(mask, origin_fg_ratio): current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) if current_fg_ratio < 0.8 * origin_fg_ratio: return False return True