File size: 997 Bytes
033bd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn as nn
from .backbone import build_backbone


class build_model(nn.Module):
    def __init__(self, opt):
        super().__init__()

        self.opt = opt
        self.backbone = build_backbone('baseline', opt)

    def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None):
        if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
            """
                For HR Training, due to the designed RSC strategy in Section 3.4 in the paper, 
                here we need to pass in the coordinates of the cropped regions.
            """
            extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion)
        else:
            extracted_features = self.backbone(composite_image, mask)

        if self.opt.INRDecode:
            return extracted_features
        return None, None, extracted_features