import torch import torch.nn as nn import torch.nn.functional as F from utils.misc import normalize class build_lut_transform(nn.Module): def __init__(self, input_dim, lut_dim, input_resolution, opt): super().__init__() self.lut_dim = lut_dim self.opt = opt # self.compress_layer = nn.Linear(input_resolution, 1) self.transform_layers = nn.Sequential( nn.Linear(input_dim, 3 * lut_dim ** 3, bias=True), # nn.BatchNorm1d(3 * lut_dim ** 3, affine=False), nn.ReLU(inplace=True), nn.Linear(3 * lut_dim ** 3, 3 * lut_dim ** 3, bias=True), ) self.transform_layers[-1].apply(lambda m: hyper_weight_init(m)) def forward(self, composite_image, fg_appearance_features, bg_appearance_features): composite_image = normalize(composite_image, self.opt, 'inv') features = fg_appearance_features lut_params = self.transform_layers(features) fit_3DLUT = lut_params.view(lut_params.shape[0], 3, self.lut_dim, self.lut_dim, self.lut_dim) lut_transform_image = torch.stack( [TrilinearInterpolation(lut, image)[0] for lut, image in zip(fit_3DLUT, composite_image)], dim=0) return fit_3DLUT, normalize(lut_transform_image, self.opt) def TrilinearInterpolation(LUT, img): img = (img - 0.5) * 2. img = img.unsqueeze(0).permute(0, 2, 3, 1)[:, None].flip(-1) # Note that the coordinates in the grid_sample are inverse to LUT DHW, i.e., xyz is to WHD not DHW. LUT = LUT[None] # grid sample result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True) # drop added dimensions and permute back result = result[:, :, 0] return result def hyper_weight_init(m): if hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') m.weight.data = m.weight.data / 1.e2 if hasattr(m, 'bias'): with torch.no_grad(): m.bias.uniform_(0., 1.)