import torch import torchvision from torch import nn as nn import torch.nn.functional as F import numpy as np import math from .basic_blocks import ConvBlock, lineParams, convParams from .ops import MaskedChannelAttention, FeaturesConnector from .ops import PosEncodingNeRF, INRGAN_embed, RandomFourier, CIPS_embed from utils import misc from utils.misc import lin2img from ..lut_transformation_net import build_lut_transform class Sine(nn.Module): def __init__(self): super().__init__() def forward(self, input): return torch.sin(30 * input) class Leaky_relu(nn.Module): def __init__(self): super().__init__() def forward(self, input): return torch.nn.functional.leaky_relu(input, 0.01, inplace=True) def select_activation(type): if type == 'sine': return Sine() elif type == 'leakyrelu_pe': return Leaky_relu() else: raise NotImplementedError class ConvEncoder(nn.Module): def __init__( self, depth, ch, norm_layer, batchnorm_from, max_channels, backbone_from, backbone_channels=None, backbone_mode='', INRDecode=False ): super(ConvEncoder, self).__init__() self.depth = depth self.INRDecode = INRDecode self.backbone_from = backbone_from backbone_channels = [] if backbone_channels is None else backbone_channels[::-1] in_channels = 4 out_channels = ch self.block0 = ConvBlock(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None) self.block1 = ConvBlock(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None) self.blocks_channels = [out_channels, out_channels] self.blocks_connected = nn.ModuleDict() self.connectors = nn.ModuleDict() for block_i in range(2, depth): if block_i % 2: in_channels = out_channels else: in_channels, out_channels = out_channels, min(2 * out_channels, max_channels) if 0 <= backbone_from <= block_i and len(backbone_channels): if INRDecode: self.blocks_connected[f'block{block_i}_decode'] = ConvBlock( in_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, padding=int(block_i < depth - 1) ) self.blocks_channels += [out_channels] stage_channels = backbone_channels.pop() connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels) self.connectors[f'connector{block_i}'] = connector in_channels = connector.output_channels self.blocks_connected[f'block{block_i}'] = ConvBlock( in_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, padding=int(block_i < depth - 1) ) self.blocks_channels += [out_channels] def forward(self, x, backbone_features): backbone_features = [] if backbone_features is None else backbone_features[::-1] outputs = [self.block0(x)] outputs += [self.block1(outputs[-1])] for block_i in range(2, self.depth): output = outputs[-1] connector_name = f'connector{block_i}' if connector_name in self.connectors: if self.INRDecode: block = self.blocks_connected[f'block{block_i}_decode'] outputs += [block(output)] stage_features = backbone_features.pop() connector = self.connectors[connector_name] output = connector(output, stage_features) block = self.blocks_connected[f'block{block_i}'] outputs += [block(output)] return outputs[::-1] class DeconvDecoder(nn.Module): def __init__(self, depth, encoder_blocks_channels, norm_layer, attend_from=-1, image_fusion=False): super(DeconvDecoder, self).__init__() self.image_fusion = image_fusion self.deconv_blocks = nn.ModuleList() in_channels = encoder_blocks_channels.pop() out_channels = in_channels for d in range(depth): out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2 self.deconv_blocks.append(SEDeconvBlock( in_channels, out_channels, norm_layer=norm_layer, padding=0 if d == 0 else 1, with_se=0 <= attend_from <= d )) in_channels = out_channels if self.image_fusion: self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1) self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1) def forward(self, encoder_outputs, image, mask=None): output = encoder_outputs[0] for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]): output = block(output, mask) output = output + skip_output output = self.deconv_blocks[-1](output, mask) if self.image_fusion: attention_map = torch.sigmoid(3.0 * self.conv_attention(output)) output = attention_map * image + (1.0 - attention_map) * self.to_rgb(output) else: output = self.to_rgb(output) return output class SEDeconvBlock(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, norm_layer=nn.BatchNorm2d, activation=nn.ELU, with_se=False ): super(SEDeconvBlock, self).__init__() self.with_se = with_se self.block = nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding), norm_layer(out_channels) if norm_layer is not None else nn.Identity(), activation(), ) if self.with_se: self.se = MaskedChannelAttention(out_channels) def forward(self, x, mask=None): out = self.block(x) if self.with_se: out = self.se(out, mask) return out class INRDecoder(nn.Module): def __init__(self, depth, encoder_blocks_channels, norm_layer, opt, attend_from): super(INRDecoder, self).__init__() self.INR_encoding = None if opt.embedding_type == "PosEncodingNeRF": self.INR_encoding = PosEncodingNeRF(in_features=2, sidelength=opt.input_size) elif opt.embedding_type == "RandomFourier": self.INR_encoding = RandomFourier(std_scale=10, embedding_length=64, device=opt.device) elif opt.embedding_type == "CIPS_embed": self.INR_encoding = CIPS_embed(size=opt.base_size, embedding_length=32) elif opt.embedding_type == "INRGAN_embed": self.INR_encoding = INRGAN_embed(resolution=opt.INR_input_size) else: raise NotImplementedError encoder_blocks_channels = encoder_blocks_channels[::-1] max_hidden_mlp_num = attend_from + 1 self.opt = opt self.max_hidden_mlp_num = max_hidden_mlp_num self.content_mlp_blocks = nn.ModuleDict() for n in range(max_hidden_mlp_num): if n != max_hidden_mlp_num - 1: self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(), [self.INR_encoding.out_dim + opt.INR_MLP_dim + ( 4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim], opt, n + 1) else: self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(), [self.INR_encoding.out_dim + ( 4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim], opt, n + 1) self.deconv_blocks = nn.ModuleList() encoder_blocks_channels = encoder_blocks_channels[::-1] in_channels = encoder_blocks_channels.pop() out_channels = in_channels for d in range(depth - attend_from): out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2 self.deconv_blocks.append(SEDeconvBlock( in_channels, out_channels, norm_layer=norm_layer, padding=0 if d == 0 else 1, with_se=False )) in_channels = out_channels self.appearance_mlps = lineParams(out_channels, [opt.INR_MLP_dim, opt.INR_MLP_dim], (opt.base_size // (2 ** (max_hidden_mlp_num - 1))) ** 2, opt, 2, toRGB=True) self.lut_transform = build_lut_transform(self.appearance_mlps.output_dim, opt.LUT_dim, None, opt) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) def forward(self, encoder_outputs, image=None, mask=None, coord_samples=None, start_proportion=None): """For full resolution, do split.""" if self.opt.hr_train and not (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')) and self.opt.isFullRes: return self.forward_fullResInference(encoder_outputs, image=image, mask=mask, coord_samples=coord_samples) encoder_outputs = encoder_outputs[::-1] mlp_output = None waitToRGB = [] for n in range(self.max_hidden_mlp_num): if not self.opt.hr_train: coord = misc.get_mgrid(self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))) \ .unsqueeze(0).repeat(encoder_outputs[0].shape[0], 1, 1).to(self.opt.device) else: if self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution'): coord = coord_samples[self.max_hidden_mlp_num - n - 1].permute(0, 2, 3, 1).view( encoder_outputs[0].shape[0], -1, 2) else: coord = misc.get_mgrid( self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))).unsqueeze(0).repeat( encoder_outputs[0].shape[0], 1, 1).to(self.opt.device) """Whether to leverage multiple input to INR decoder. See Section 3.4 in the paper.""" if self.opt.isMoreINRInput: if not self.opt.isFullRes or ( self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): res_h = res_w = np.sqrt(coord.shape[1]).astype(int) else: res_h = image.shape[-2] // (2 ** (self.max_hidden_mlp_num - n - 1)) res_w = image.shape[-1] // (2 ** (self.max_hidden_mlp_num - n - 1)) res_image = torchvision.transforms.Resize([res_h, res_w])(image) res_mask = torchvision.transforms.Resize([res_h, res_w])(mask) coord = torch.cat([self.INR_encoding(coord), res_image.view(*res_image.shape[:2], -1).permute(0, 2, 1), res_mask.view(*res_mask.shape[:2], -1).permute(0, 2, 1)], dim=-1) else: coord = self.INR_encoding(coord) """============ LRIP structure, see Section 3.3 ==============""" """Local MLPs.""" if n == 0: mlp_output = self.mlp_process(coord, self.INR_encoding.out_dim + (4 if self.opt.isMoreINRInput else 0), self.opt, content_mlp=self.content_mlp_blocks[ f"block{self.max_hidden_mlp_num - 1 - n}"]( encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), start_proportion=start_proportion) waitToRGB.append(mlp_output[1]) else: mlp_output = self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + ( 4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0], content_mlp=self.content_mlp_blocks[ f"block{self.max_hidden_mlp_num - 1 - n}"]( encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), start_proportion=start_proportion) waitToRGB.append(mlp_output[1]) encoder_outputs = encoder_outputs[::-1] output = encoder_outputs[0] for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]): output = block(output) output = output + skip_output output = self.deconv_blocks[-1](output) """Global MLPs.""" app_mlp, app_params = self.appearance_mlps(output) harm_out = [] for id in range(len(waitToRGB)): output = self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=waitToRGB[id], appearance_mlp=app_mlp) harm_out.append(output[0]) """Optional 3D LUT prediction.""" fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None) return harm_out, fit_lut3d, lut_transform_image def mlp_process(self, coorinates, INR_input_dim, opt, base_feat=None, content_mlp=None, appearance_mlp=None, resolution=None, start_proportion=None): activation = select_activation(opt.activation) output = None if content_mlp is not None: if base_feat is not None: coorinates = torch.cat([coorinates, base_feat], dim=2) coorinates = lin2img(coorinates, resolution) if hasattr(opt, 'split_resolution'): """ Here we crop the needed MLPs according to the region of the split input patches. Note that this only support inferencing square images. """ for idx in range(len(content_mlp)): content_mlp[idx][0] = content_mlp[idx][0][:, (content_mlp[idx][0].shape[1] * start_proportion[0]).int():( content_mlp[idx][0].shape[1] * start_proportion[2]).int(), (content_mlp[idx][0].shape[2] * start_proportion[1]).int():( content_mlp[idx][0].shape[2] * start_proportion[3]).int(), :, :] content_mlp[idx][1] = content_mlp[idx][1][:, (content_mlp[idx][1].shape[1] * start_proportion[0]).int():( content_mlp[idx][1].shape[1] * start_proportion[2]).int(), (content_mlp[idx][1].shape[2] * start_proportion[1]).int():( content_mlp[idx][1].shape[2] * start_proportion[3]).int(), :, :] k_h = coorinates.shape[2] // content_mlp[0][0].shape[1] k_w = coorinates.shape[3] // content_mlp[0][0].shape[1] bs = coorinates.shape[0] h_lr = w_lr = content_mlp[0][0].shape[1] nci = INR_input_dim coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w) coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view( bs, h_lr, w_lr, int(k_h * k_w), nci) for id, layer in enumerate(content_mlp): if id == 0: output = torch.matmul(coorinates, layer[0]) + layer[1] output = activation(output) else: output = torch.matmul(output, layer[0]) + layer[1] output = activation(output) output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute( 0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim) output_large = self.up(lin2img(output)) return output_large.view(bs, -1, opt.INR_MLP_dim), output k_h = coorinates.shape[2] // content_mlp[0][0].shape[1] k_w = coorinates.shape[3] // content_mlp[0][0].shape[1] bs = coorinates.shape[0] h_lr = w_lr = content_mlp[0][0].shape[1] nci = INR_input_dim """(evaluation or not HR training) and not fullres evaluation""" if (not self.opt.hr_train or not (self.training or hasattr(self.opt, 'split_num'))) and not ( not (self.training or hasattr(self.opt, 'split_num')) and self.opt.isFullRes and self.opt.hr_train): coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w) coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view( bs, h_lr, w_lr, int(k_h * k_w), nci) for id, layer in enumerate(content_mlp): if id == 0: output = torch.matmul(coorinates, layer[0]) + layer[1] output = activation(output) else: output = torch.matmul(output, layer[0]) + layer[1] output = activation(output) output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute( 0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim) output_large = self.up(lin2img(output)) return output_large.view(bs, -1, opt.INR_MLP_dim), output else: coorinates = coorinates.permute(0, 2, 3, 1) for id, layer in enumerate(content_mlp): weigt_shape = layer[0].shape bias_shape = layer[1].shape layer[0] = layer[0].view(*layer[0].shape[:-2], -1).permute(0, 3, 1, 2).contiguous() layer[1] = layer[1].view(*layer[1].shape[:-2], -1).permute(0, 3, 1, 2).contiguous() layer[0] = F.grid_sample(layer[0], coorinates[..., :2].flip(-1), mode='nearest' if True else 'bilinear', padding_mode='border', align_corners=False) layer[1] = F.grid_sample(layer[1], coorinates[..., :2].flip(-1), mode='nearest' if True else 'bilinear', padding_mode='border', align_corners=False) layer[0] = layer[0].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *weigt_shape[-2:]) layer[1] = layer[1].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *bias_shape[-2:]) if id == 0: output = torch.matmul(coorinates.unsqueeze(-2), layer[0]) + layer[1] output = activation(output) else: output = torch.matmul(output, layer[0]) + layer[1] output = activation(output) output = output.squeeze(-2).view(bs, -1, opt.INR_MLP_dim) output_large = self.up(lin2img(output, resolution)) return output_large.view(bs, -1, opt.INR_MLP_dim), output elif appearance_mlp is not None: output = base_feat genMask = None for id, layer in enumerate(appearance_mlp): if id != len(appearance_mlp) - 1: output = torch.matmul(output, layer[0]) + layer[1] output = activation(output) else: output = torch.matmul(output, layer[0]) + layer[1] # last layer if opt.activation == 'leakyrelu_pe': output = torch.tanh(output) return lin2img(output, resolution), None def forward_fullResInference(self, encoder_outputs, image=None, mask=None, coord_samples=None): encoder_outputs = encoder_outputs[::-1] mlp_output = None res_w = image.shape[-1] res_h = image.shape[-2] coord = misc.get_mgrid([image.shape[-2], image.shape[-1]]).unsqueeze(0).repeat( encoder_outputs[0].shape[0], 1, 1).to(self.opt.device) if self.opt.isMoreINRInput: coord = torch.cat( [self.INR_encoding(coord, (res_h, res_w)), image.view(*image.shape[:2], -1).permute(0, 2, 1), mask.view(*mask.shape[:2], -1).permute(0, 2, 1)], dim=-1) else: coord = self.INR_encoding(coord, (res_h, res_w)) total = coord.clone() interval = 10 all_intervals = math.ceil(res_h / interval) divisible = True if res_h / interval != res_h // interval: divisible = False for n in range(self.max_hidden_mlp_num): accum_mlp_output = [] for line in range(all_intervals): if not divisible and line == all_intervals - 1: coord = total[:, line * interval * res_w:, :] else: coord = total[:, line * interval * res_w: (line + 1) * interval * res_w, :] if n == 0: accum_mlp_output.append(self.mlp_process(coord, self.INR_encoding.out_dim + ( 4 if self.opt.isMoreINRInput else 0), self.opt, content_mlp=self.content_mlp_blocks[ f"block{self.max_hidden_mlp_num - 1 - n}"]( encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else encoder_outputs[self.max_hidden_mlp_num - 1 - n]), resolution=(interval, res_w) if divisible or line != all_intervals - 1 else ( res_h - interval * (all_intervals - 1), res_w))[1]) else: accum_mlp_output.append(self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + ( 4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0][:, line * interval * res_w: ( line + 1) * interval * res_w, :] if divisible or line != all_intervals - 1 else mlp_output[0][:, line * interval * res_w:, :], content_mlp=self.content_mlp_blocks[ f"block{self.max_hidden_mlp_num - 1 - n}"]( encoder_outputs.pop( self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else encoder_outputs[self.max_hidden_mlp_num - 1 - n]), resolution=(interval, res_w) if divisible or line != all_intervals - 1 else ( res_h - interval * (all_intervals - 1), res_w))[1]) accum_mlp_output = torch.cat(accum_mlp_output, dim=1) mlp_output = [accum_mlp_output, accum_mlp_output] encoder_outputs = encoder_outputs[::-1] output = encoder_outputs[0] for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]): output = block(output) output = output + skip_output output = self.deconv_blocks[-1](output) app_mlp, app_params = self.appearance_mlps(output) harm_out = [] accum_mlp_output = [] for line in range(all_intervals): if not divisible and line == all_intervals - 1: base = mlp_output[1][:, line * interval * res_w:, :] else: base = mlp_output[1][:, line * interval * res_w: (line + 1) * interval * res_w, :] accum_mlp_output.append(self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=base, appearance_mlp=app_mlp, resolution=( interval, res_w) if divisible or line != all_intervals - 1 else ( res_h - interval * (all_intervals - 1), res_w))[0]) accum_mlp_output = torch.cat(accum_mlp_output, dim=2) harm_out.append(accum_mlp_output) fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None) return harm_out, fit_lut3d, lut_transform_image