Spaces:
Running
Running
File size: 997 Bytes
033bd8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch.nn as nn
from .backbone import build_backbone
class build_model(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.backbone = build_backbone('baseline', opt)
def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None):
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
"""
For HR Training, due to the designed RSC strategy in Section 3.4 in the paper,
here we need to pass in the coordinates of the cropped regions.
"""
extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion)
else:
extracted_features = self.backbone(composite_image, mask)
if self.opt.INRDecode:
return extracted_features
return None, None, extracted_features |