INR-Harmon / model /base /ih_model.py
WindVChen's picture
Update
033bd8b
raw
history blame
4.3 kB
import torch
import torchvision
import torch.nn as nn
from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder
from .ops import ScaleLayer
class IHModelWithBackbone(nn.Module):
def __init__(
self,
model, backbone,
downsize_backbone_input=False,
mask_fusion='sum',
backbone_conv1_channels=64, opt=None
):
super(IHModelWithBackbone, self).__init__()
self.downsize_backbone_input = downsize_backbone_input
self.mask_fusion = mask_fusion
self.backbone = backbone
self.model = model
self.opt = opt
self.mask_conv = nn.Sequential(
nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True),
ScaleLayer(init_value=0.1, lr_mult=1)
)
def forward(self, image, mask, coord=None, start_proportion=None):
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0])
backbone_mask = torch.cat(
(torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]),
1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
else:
backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image)
backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask),
1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
backbone_mask_features = self.mask_conv(backbone_mask[:, :1])
backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features)
output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion)
return output
class DeepImageHarmonization(nn.Module):
def __init__(
self,
depth,
norm_layer=nn.BatchNorm2d, batchnorm_from=0,
attend_from=-1,
image_fusion=False,
ch=64, max_channels=512,
backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None
):
super(DeepImageHarmonization, self).__init__()
self.depth = depth
self.encoder = ConvEncoder(
depth, ch,
norm_layer, batchnorm_from, max_channels,
backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode
)
self.opt = opt
if opt.INRDecode:
"See Table 2 in the paper to test with different INR decoders' structures."
self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from)
else:
"Baseline: https://github.com/SamsungLabs/image_harmonization"
self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion)
def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None):
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]),
torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
else:
x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image),
torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
intermediates = self.encoder(x, backbone_features)
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion)
else:
output = self.decoder(intermediates, image, mask)
return output