Spaces:
Running
Running
import torch | |
import torchvision | |
import torch.nn as nn | |
from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder | |
from .ops import ScaleLayer | |
class IHModelWithBackbone(nn.Module): | |
def __init__( | |
self, | |
model, backbone, | |
downsize_backbone_input=False, | |
mask_fusion='sum', | |
backbone_conv1_channels=64, opt=None | |
): | |
super(IHModelWithBackbone, self).__init__() | |
self.downsize_backbone_input = downsize_backbone_input | |
self.mask_fusion = mask_fusion | |
self.backbone = backbone | |
self.model = model | |
self.opt = opt | |
self.mask_conv = nn.Sequential( | |
nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True), | |
ScaleLayer(init_value=0.1, lr_mult=1) | |
) | |
def forward(self, image, mask, coord=None, start_proportion=None): | |
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): | |
backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]) | |
backbone_mask = torch.cat( | |
(torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]), | |
1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1) | |
else: | |
backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image) | |
backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask), | |
1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1) | |
backbone_mask_features = self.mask_conv(backbone_mask[:, :1]) | |
backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features) | |
output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion) | |
return output | |
class DeepImageHarmonization(nn.Module): | |
def __init__( | |
self, | |
depth, | |
norm_layer=nn.BatchNorm2d, batchnorm_from=0, | |
attend_from=-1, | |
image_fusion=False, | |
ch=64, max_channels=512, | |
backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None | |
): | |
super(DeepImageHarmonization, self).__init__() | |
self.depth = depth | |
self.encoder = ConvEncoder( | |
depth, ch, | |
norm_layer, batchnorm_from, max_channels, | |
backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode | |
) | |
self.opt = opt | |
if opt.INRDecode: | |
"See Table 2 in the paper to test with different INR decoders' structures." | |
self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from) | |
else: | |
"Baseline: https://github.com/SamsungLabs/image_harmonization" | |
self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion) | |
def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None): | |
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): | |
x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]), | |
torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1) | |
else: | |
x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image), | |
torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1) | |
intermediates = self.encoder(x, backbone_features) | |
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): | |
output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion) | |
else: | |
output = self.decoder(intermediates, image, mask) | |
return output | |