INR-Harmon / datasets /build_dataset.py
WindVChen's picture
Update
033bd8b
raw
history blame
21 kB
import torch
import cv2
import numpy as np
import torchvision
import os
import random
from utils.misc import prepare_cooridinate_input, customRandomCrop
from datasets.build_INR_dataset import Implicit2DGenerator
import albumentations
from albumentations import Resize, RandomResizedCrop, HorizontalFlip
from torch.utils.data import DataLoader
class dataset_generator(torch.utils.data.Dataset):
def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'):
super().__init__()
self.opt = opt
self.root_path = opt.dataset_path
self.mode = mode
self.alb_transforms = alb_transforms
self.torch_transforms = torch_transforms
self.kp_t = area_keep_thresh
with open(dataset_txt, 'r') as f:
self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()]
self.INR_dataset = Implicit2DGenerator(opt, self.mode)
def __len__(self):
return len(self.dataset_samples)
def __getitem__(self, idx):
composite_image = self.dataset_samples[idx]
if self.opt.hr_train:
if self.opt.isFullRes:
"Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \
"quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \
"if `opt.isFullRes` is set to True."
composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori")
real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg'
mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png'
composite_image = cv2.imread(composite_image)
composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
real_image = cv2.imread(real_image)
real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask)
mask = mask[:, :, 0].astype(np.float32) / 255.
"""
If set `opt.hr_train` to True:
Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres),
the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size`
patch to feed in multiINR process. For inference, just resize it.
While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size.
BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5`
"""
if self.opt.hr_train:
if self.mode == 'Train' and self.opt.isFullRes:
if random.random() < 0.5: # LR mix training
mixTransform = albumentations.Compose(
[
RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
HorizontalFlip()],
additional_targets={'real_image': 'image', 'object_mask': 'image'}
)
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
origin_bg_ratio = 1 - origin_fg_ratio
"Ensure fg and bg not disappear after transformation"
valid_augmentation = False
transform_out = None
time = 0
while not valid_augmentation:
time += 1
# There are some extreme ratio pics, this code is to avoid being hindered by them.
if time == 20:
tmp_transform = albumentations.Compose(
[Resize(self.opt.base_size, self.opt.base_size)],
additional_targets={'real_image': 'image',
'object_mask': 'image'})
transform_out = tmp_transform(image=composite_image, real_image=real_image,
object_mask=mask)
valid_augmentation = True
else:
transform_out = mixTransform(image=composite_image, real_image=real_image,
object_mask=mask)
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
origin_fg_ratio,
origin_bg_ratio,
self.kp_t)
composite_image = transform_out['image']
real_image = transform_out['real_image']
mask = transform_out['object_mask']
else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop.
if real_image.shape[0] < 256:
bottom_pad = 256 - real_image.shape[0]
else:
bottom_pad = (4 - real_image.shape[0] % 4) % 4
if real_image.shape[1] < 256:
right_pad = 256 - real_image.shape[1]
else:
right_pad = (4 - real_image.shape[1] % 4) % 4
composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad,
cv2.BORDER_REPLICATE)
real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
origin_bg_ratio = 1 - origin_fg_ratio
"Ensure fg and bg not disappear after transformation"
valid_augmentation = False
transform_out = None
time = 0
if self.opt.hr_train:
if self.mode == 'Train':
if not self.opt.isFullRes:
if random.random() < 0.5: # LR mix training
mixTransform = albumentations.Compose(
[
RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
HorizontalFlip()],
additional_targets={'real_image': 'image', 'object_mask': 'image'}
)
while not valid_augmentation:
time += 1
# There are some extreme ratio pics, this code is to avoid being hindered by them.
if time == 20:
tmp_transform = albumentations.Compose(
[Resize(self.opt.base_size, self.opt.base_size)],
additional_targets={'real_image': 'image',
'object_mask': 'image'})
transform_out = tmp_transform(image=composite_image, real_image=real_image,
object_mask=mask)
valid_augmentation = True
else:
transform_out = mixTransform(image=composite_image, real_image=real_image,
object_mask=mask)
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
origin_fg_ratio,
origin_bg_ratio,
self.kp_t)
else:
while not valid_augmentation:
time += 1
# There are some extreme ratio pics, this code is to avoid being hindered by them.
if time == 20:
tmp_transform = albumentations.Compose(
[Resize(self.opt.input_size, self.opt.input_size)],
additional_targets={'real_image': 'image',
'object_mask': 'image'})
transform_out = tmp_transform(image=composite_image, real_image=real_image,
object_mask=mask)
valid_augmentation = True
else:
transform_out = self.alb_transforms(image=composite_image, real_image=real_image,
object_mask=mask)
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
origin_fg_ratio,
origin_bg_ratio,
self.kp_t)
composite_image = transform_out['image']
real_image = transform_out['real_image']
mask = transform_out['object_mask']
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)
tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
additional_targets={'real_image': 'image',
'object_mask': 'image'})
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
compos_list = [self.torch_transforms(transform_out['image'])]
real_list = [self.torch_transforms(transform_out['real_image'])]
mask_list = [
torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
coord_map_list = []
valid_augmentation = False
while not valid_augmentation:
# RSC strategy. To crop different resolutions.
transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord],
self.opt.base_size, self.opt.base_size)
valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio)
compos_list.append(self.torch_transforms(transform_out[0]))
real_list.append(self.torch_transforms(transform_out[1]))
mask_list.append(
torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
for n in range(2):
tmp_comp = cv2.resize(composite_image, (
composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
tmp_real = cv2.resize(real_image,
(real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1)))
tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)
transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord],
self.opt.base_size // 2 ** (n + 1),
self.opt.base_size // 2 ** (n + 1), c_h, c_w)
compos_list.append(self.torch_transforms(transform_out[0]))
real_list.append(self.torch_transforms(transform_out[1]))
mask_list.append(
torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
out_comp = compos_list
out_real = real_list
out_mask = mask_list
out_coord = coord_map_list
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
self.torch_transforms, transform_out[0], transform_out[1], mask)
return {
'file_path': self.dataset_samples[idx],
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
'composite_image': out_comp,
'real_image': out_real,
'mask': out_mask,
'coordinate_map': out_coord,
'composite_image0': out_comp[0],
'real_image0': out_real[0],
'mask0': out_mask[0],
'coordinate_map0': out_coord[0],
'composite_image1': out_comp[1],
'real_image1': out_real[1],
'mask1': out_mask[1],
'coordinate_map1': out_coord[1],
'composite_image2': out_comp[2],
'real_image2': out_real[2],
'mask2': out_mask[2],
'coordinate_map2': out_coord[2],
'composite_image3': out_comp[3],
'real_image3': out_real[3],
'mask3': out_mask[3],
'coordinate_map3': out_coord[3],
'fg_INR_coordinates': fg_INR_coordinates,
'bg_INR_coordinates': bg_INR_coordinates,
'fg_INR_RGB': fg_INR_RGB,
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
'bg_INR_RGB': bg_INR_RGB
}
else:
if not self.opt.isFullRes:
tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
additional_targets={'real_image': 'image',
'object_mask': 'image'})
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
"Generate INR dataset."
mask = (torchvision.transforms.ToTensor()(
transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
mask = np.bool_(mask.numpy())
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
return {
'file_path': self.dataset_samples[idx],
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
'composite_image': self.torch_transforms(transform_out['image']),
'real_image': self.torch_transforms(transform_out['real_image']),
'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
# Can automatically transfer to Tensor.
'coordinate_map': coordinate_map,
'fg_INR_coordinates': fg_INR_coordinates,
'bg_INR_coordinates': bg_INR_coordinates,
'fg_INR_RGB': fg_INR_RGB,
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
'bg_INR_RGB': bg_INR_RGB
}
else:
coordinate_map = prepare_cooridinate_input(mask)
"Generate INR dataset."
mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1)
mask_tmp = np.bool_(mask_tmp.numpy())
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
self.torch_transforms, composite_image, real_image, mask_tmp)
return {
'file_path': self.dataset_samples[idx],
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
'composite_image': self.torch_transforms(composite_image),
'real_image': self.torch_transforms(real_image),
'mask': mask[np.newaxis, ...].astype(np.float32),
# Can automatically transfer to Tensor.
'coordinate_map': coordinate_map,
'fg_INR_coordinates': fg_INR_coordinates,
'bg_INR_coordinates': bg_INR_coordinates,
'fg_INR_RGB': fg_INR_RGB,
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
'bg_INR_RGB': bg_INR_RGB
}
while not valid_augmentation:
time += 1
# There are some extreme ratio pics, this code is to avoid being hindered by them.
if time == 20:
tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
additional_targets={'real_image': 'image',
'object_mask': 'image'})
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
valid_augmentation = True
else:
transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask)
valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio,
origin_bg_ratio,
self.kp_t)
coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
"Generate INR dataset."
mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
mask = np.bool_(mask.numpy())
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
return {
'file_path': self.dataset_samples[idx],
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
'composite_image': self.torch_transforms(transform_out['image']),
'real_image': self.torch_transforms(transform_out['real_image']),
'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
# Can automatically transfer to Tensor.
'coordinate_map': coordinate_map,
'fg_INR_coordinates': fg_INR_coordinates,
'bg_INR_coordinates': bg_INR_coordinates,
'fg_INR_RGB': fg_INR_RGB,
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
'bg_INR_RGB': bg_INR_RGB
}
def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh):
current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
current_bg_ratio = 1 - current_fg_ratio
if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh:
return False
return True
def check_hr_crop_sample(mask, origin_fg_ratio):
current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
if current_fg_ratio < 0.8 * origin_fg_ratio:
return False
return True