from utils import misc from albumentations import Resize class Implicit2DGenerator(object): def __init__(self, opt, mode): if mode == 'Train': sidelength = opt.INR_input_size elif mode == 'Val': sidelength = opt.input_size else: raise NotImplementedError self.mode = mode self.size = sidelength if isinstance(sidelength, int): sidelength = (sidelength, sidelength) self.mgrid = misc.get_mgrid(sidelength) self.transform = Resize(self.size, self.size) def generator(self, torch_transforms, composite_image, real_image, mask): composite_image = torch_transforms(self.transform(image=composite_image)['image']) real_image = torch_transforms(self.transform(image=real_image)['image']) fg_INR_RGB = composite_image.permute(1, 2, 0).contiguous().view(-1, 3) fg_transfer_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3) bg_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3) fg_INR_coordinates = self.mgrid bg_INR_coordinates = self.mgrid return fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB