import torch.nn as nn from .backbone import build_backbone class build_model(nn.Module): def __init__(self, opt): super().__init__() self.opt = opt self.backbone = build_backbone('baseline', opt) def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None): if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): """ For HR Training, due to the designed RSC strategy in Section 3.4 in the paper, here we need to pass in the coordinates of the cropped regions. """ extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion) else: extracted_features = self.backbone(composite_image, mask) if self.opt.INRDecode: return extracted_features return None, None, extracted_features