INR-Harmon / model /build_model.py
WindVChen's picture
Update
033bd8b
raw
history blame
997 Bytes
import torch.nn as nn
from .backbone import build_backbone
class build_model(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.backbone = build_backbone('baseline', opt)
def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None):
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
"""
For HR Training, due to the designed RSC strategy in Section 3.4 in the paper,
here we need to pass in the coordinates of the cropped regions.
"""
extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion)
else:
extracted_features = self.backbone(composite_image, mask)
if self.opt.INRDecode:
return extracted_features
return None, None, extracted_features