Spaces:
Running
Running
import torch | |
import cv2 | |
import numpy as np | |
import torchvision | |
import os | |
import random | |
from utils.misc import prepare_cooridinate_input, customRandomCrop | |
from datasets.build_INR_dataset import Implicit2DGenerator | |
import albumentations | |
from albumentations import Resize, RandomResizedCrop, HorizontalFlip | |
from torch.utils.data import DataLoader | |
class dataset_generator(torch.utils.data.Dataset): | |
def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'): | |
super().__init__() | |
self.opt = opt | |
self.root_path = opt.dataset_path | |
self.mode = mode | |
self.alb_transforms = alb_transforms | |
self.torch_transforms = torch_transforms | |
self.kp_t = area_keep_thresh | |
with open(dataset_txt, 'r') as f: | |
self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()] | |
self.INR_dataset = Implicit2DGenerator(opt, self.mode) | |
def __len__(self): | |
return len(self.dataset_samples) | |
def __getitem__(self, idx): | |
composite_image = self.dataset_samples[idx] | |
if self.opt.hr_train: | |
if self.opt.isFullRes: | |
"Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \ | |
"quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \ | |
"if `opt.isFullRes` is set to True." | |
composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori") | |
real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg' | |
mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png' | |
composite_image = cv2.imread(composite_image) | |
composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB) | |
real_image = cv2.imread(real_image) | |
real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB) | |
mask = cv2.imread(mask) | |
mask = mask[:, :, 0].astype(np.float32) / 255. | |
""" | |
If set `opt.hr_train` to True: | |
Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres), | |
the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size` | |
patch to feed in multiINR process. For inference, just resize it. | |
While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size. | |
BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5` | |
""" | |
if self.opt.hr_train: | |
if self.mode == 'Train' and self.opt.isFullRes: | |
if random.random() < 0.5: # LR mix training | |
mixTransform = albumentations.Compose( | |
[ | |
RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)), | |
HorizontalFlip()], | |
additional_targets={'real_image': 'image', 'object_mask': 'image'} | |
) | |
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) | |
origin_bg_ratio = 1 - origin_fg_ratio | |
"Ensure fg and bg not disappear after transformation" | |
valid_augmentation = False | |
transform_out = None | |
time = 0 | |
while not valid_augmentation: | |
time += 1 | |
# There are some extreme ratio pics, this code is to avoid being hindered by them. | |
if time == 20: | |
tmp_transform = albumentations.Compose( | |
[Resize(self.opt.base_size, self.opt.base_size)], | |
additional_targets={'real_image': 'image', | |
'object_mask': 'image'}) | |
transform_out = tmp_transform(image=composite_image, real_image=real_image, | |
object_mask=mask) | |
valid_augmentation = True | |
else: | |
transform_out = mixTransform(image=composite_image, real_image=real_image, | |
object_mask=mask) | |
valid_augmentation = check_augmented_sample(transform_out['object_mask'], | |
origin_fg_ratio, | |
origin_bg_ratio, | |
self.kp_t) | |
composite_image = transform_out['image'] | |
real_image = transform_out['real_image'] | |
mask = transform_out['object_mask'] | |
else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop. | |
if real_image.shape[0] < 256: | |
bottom_pad = 256 - real_image.shape[0] | |
else: | |
bottom_pad = (4 - real_image.shape[0] % 4) % 4 | |
if real_image.shape[1] < 256: | |
right_pad = 256 - real_image.shape[1] | |
else: | |
right_pad = (4 - real_image.shape[1] % 4) % 4 | |
composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad, | |
cv2.BORDER_REPLICATE) | |
real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) | |
mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) | |
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) | |
origin_bg_ratio = 1 - origin_fg_ratio | |
"Ensure fg and bg not disappear after transformation" | |
valid_augmentation = False | |
transform_out = None | |
time = 0 | |
if self.opt.hr_train: | |
if self.mode == 'Train': | |
if not self.opt.isFullRes: | |
if random.random() < 0.5: # LR mix training | |
mixTransform = albumentations.Compose( | |
[ | |
RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)), | |
HorizontalFlip()], | |
additional_targets={'real_image': 'image', 'object_mask': 'image'} | |
) | |
while not valid_augmentation: | |
time += 1 | |
# There are some extreme ratio pics, this code is to avoid being hindered by them. | |
if time == 20: | |
tmp_transform = albumentations.Compose( | |
[Resize(self.opt.base_size, self.opt.base_size)], | |
additional_targets={'real_image': 'image', | |
'object_mask': 'image'}) | |
transform_out = tmp_transform(image=composite_image, real_image=real_image, | |
object_mask=mask) | |
valid_augmentation = True | |
else: | |
transform_out = mixTransform(image=composite_image, real_image=real_image, | |
object_mask=mask) | |
valid_augmentation = check_augmented_sample(transform_out['object_mask'], | |
origin_fg_ratio, | |
origin_bg_ratio, | |
self.kp_t) | |
else: | |
while not valid_augmentation: | |
time += 1 | |
# There are some extreme ratio pics, this code is to avoid being hindered by them. | |
if time == 20: | |
tmp_transform = albumentations.Compose( | |
[Resize(self.opt.input_size, self.opt.input_size)], | |
additional_targets={'real_image': 'image', | |
'object_mask': 'image'}) | |
transform_out = tmp_transform(image=composite_image, real_image=real_image, | |
object_mask=mask) | |
valid_augmentation = True | |
else: | |
transform_out = self.alb_transforms(image=composite_image, real_image=real_image, | |
object_mask=mask) | |
valid_augmentation = check_augmented_sample(transform_out['object_mask'], | |
origin_fg_ratio, | |
origin_bg_ratio, | |
self.kp_t) | |
composite_image = transform_out['image'] | |
real_image = transform_out['real_image'] | |
mask = transform_out['object_mask'] | |
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) | |
full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0) | |
tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)], | |
additional_targets={'real_image': 'image', | |
'object_mask': 'image'}) | |
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) | |
compos_list = [self.torch_transforms(transform_out['image'])] | |
real_list = [self.torch_transforms(transform_out['real_image'])] | |
mask_list = [ | |
torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))] | |
coord_map_list = [] | |
valid_augmentation = False | |
while not valid_augmentation: | |
# RSC strategy. To crop different resolutions. | |
transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord], | |
self.opt.base_size, self.opt.base_size) | |
valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio) | |
compos_list.append(self.torch_transforms(transform_out[0])) | |
real_list.append(self.torch_transforms(transform_out[1])) | |
mask_list.append( | |
torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32))) | |
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) | |
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) | |
for n in range(2): | |
tmp_comp = cv2.resize(composite_image, ( | |
composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1))) | |
tmp_real = cv2.resize(real_image, | |
(real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1))) | |
tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1))) | |
tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0) | |
transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord], | |
self.opt.base_size // 2 ** (n + 1), | |
self.opt.base_size // 2 ** (n + 1), c_h, c_w) | |
compos_list.append(self.torch_transforms(transform_out[0])) | |
real_list.append(self.torch_transforms(transform_out[1])) | |
mask_list.append( | |
torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32))) | |
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) | |
out_comp = compos_list | |
out_real = real_list | |
out_mask = mask_list | |
out_coord = coord_map_list | |
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( | |
self.torch_transforms, transform_out[0], transform_out[1], mask) | |
return { | |
'file_path': self.dataset_samples[idx], | |
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], | |
'composite_image': out_comp, | |
'real_image': out_real, | |
'mask': out_mask, | |
'coordinate_map': out_coord, | |
'composite_image0': out_comp[0], | |
'real_image0': out_real[0], | |
'mask0': out_mask[0], | |
'coordinate_map0': out_coord[0], | |
'composite_image1': out_comp[1], | |
'real_image1': out_real[1], | |
'mask1': out_mask[1], | |
'coordinate_map1': out_coord[1], | |
'composite_image2': out_comp[2], | |
'real_image2': out_real[2], | |
'mask2': out_mask[2], | |
'coordinate_map2': out_coord[2], | |
'composite_image3': out_comp[3], | |
'real_image3': out_real[3], | |
'mask3': out_mask[3], | |
'coordinate_map3': out_coord[3], | |
'fg_INR_coordinates': fg_INR_coordinates, | |
'bg_INR_coordinates': bg_INR_coordinates, | |
'fg_INR_RGB': fg_INR_RGB, | |
'fg_transfer_INR_RGB': fg_transfer_INR_RGB, | |
'bg_INR_RGB': bg_INR_RGB | |
} | |
else: | |
if not self.opt.isFullRes: | |
tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)], | |
additional_targets={'real_image': 'image', | |
'object_mask': 'image'}) | |
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) | |
coordinate_map = prepare_cooridinate_input(transform_out['object_mask']) | |
"Generate INR dataset." | |
mask = (torchvision.transforms.ToTensor()( | |
transform_out['object_mask']).squeeze() > 100 / 255.).view(-1) | |
mask = np.bool_(mask.numpy()) | |
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( | |
self.torch_transforms, transform_out['image'], transform_out['real_image'], mask) | |
return { | |
'file_path': self.dataset_samples[idx], | |
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], | |
'composite_image': self.torch_transforms(transform_out['image']), | |
'real_image': self.torch_transforms(transform_out['real_image']), | |
'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32), | |
# Can automatically transfer to Tensor. | |
'coordinate_map': coordinate_map, | |
'fg_INR_coordinates': fg_INR_coordinates, | |
'bg_INR_coordinates': bg_INR_coordinates, | |
'fg_INR_RGB': fg_INR_RGB, | |
'fg_transfer_INR_RGB': fg_transfer_INR_RGB, | |
'bg_INR_RGB': bg_INR_RGB | |
} | |
else: | |
coordinate_map = prepare_cooridinate_input(mask) | |
"Generate INR dataset." | |
mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1) | |
mask_tmp = np.bool_(mask_tmp.numpy()) | |
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( | |
self.torch_transforms, composite_image, real_image, mask_tmp) | |
return { | |
'file_path': self.dataset_samples[idx], | |
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], | |
'composite_image': self.torch_transforms(composite_image), | |
'real_image': self.torch_transforms(real_image), | |
'mask': mask[np.newaxis, ...].astype(np.float32), | |
# Can automatically transfer to Tensor. | |
'coordinate_map': coordinate_map, | |
'fg_INR_coordinates': fg_INR_coordinates, | |
'bg_INR_coordinates': bg_INR_coordinates, | |
'fg_INR_RGB': fg_INR_RGB, | |
'fg_transfer_INR_RGB': fg_transfer_INR_RGB, | |
'bg_INR_RGB': bg_INR_RGB | |
} | |
while not valid_augmentation: | |
time += 1 | |
# There are some extreme ratio pics, this code is to avoid being hindered by them. | |
if time == 20: | |
tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)], | |
additional_targets={'real_image': 'image', | |
'object_mask': 'image'}) | |
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) | |
valid_augmentation = True | |
else: | |
transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask) | |
valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio, | |
origin_bg_ratio, | |
self.kp_t) | |
coordinate_map = prepare_cooridinate_input(transform_out['object_mask']) | |
"Generate INR dataset." | |
mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1) | |
mask = np.bool_(mask.numpy()) | |
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( | |
self.torch_transforms, transform_out['image'], transform_out['real_image'], mask) | |
return { | |
'file_path': self.dataset_samples[idx], | |
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], | |
'composite_image': self.torch_transforms(transform_out['image']), | |
'real_image': self.torch_transforms(transform_out['real_image']), | |
'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32), | |
# Can automatically transfer to Tensor. | |
'coordinate_map': coordinate_map, | |
'fg_INR_coordinates': fg_INR_coordinates, | |
'bg_INR_coordinates': bg_INR_coordinates, | |
'fg_INR_RGB': fg_INR_RGB, | |
'fg_transfer_INR_RGB': fg_transfer_INR_RGB, | |
'bg_INR_RGB': bg_INR_RGB | |
} | |
def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh): | |
current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) | |
current_bg_ratio = 1 - current_fg_ratio | |
if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh: | |
return False | |
return True | |
def check_hr_crop_sample(mask, origin_fg_ratio): | |
current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) | |
if current_fg_ratio < 0.8 * origin_fg_ratio: | |
return False | |
return True | |