File size: 1,243 Bytes
033bd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from utils import misc
from albumentations import Resize


class Implicit2DGenerator(object):
    def __init__(self, opt, mode):
        if mode == 'Train':
            sidelength = opt.INR_input_size
        elif mode == 'Val':
            sidelength = opt.input_size
        else:
            raise NotImplementedError

        self.mode = mode

        self.size = sidelength

        if isinstance(sidelength, int):
            sidelength = (sidelength, sidelength)

        self.mgrid = misc.get_mgrid(sidelength)

        self.transform = Resize(self.size, self.size)

    def generator(self, torch_transforms, composite_image, real_image, mask):
        composite_image = torch_transforms(self.transform(image=composite_image)['image'])
        real_image = torch_transforms(self.transform(image=real_image)['image'])

        fg_INR_RGB = composite_image.permute(1, 2, 0).contiguous().view(-1, 3)
        fg_transfer_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
        bg_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)

        fg_INR_coordinates = self.mgrid
        bg_INR_coordinates = self.mgrid

        return fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB