import torch from torch import nn as nn import numpy as np def hyper_weight_init(m, in_features_main_net, activation): if hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') m.weight.data = m.weight.data / 1.e2 if hasattr(m, 'bias'): with torch.no_grad(): if activation == 'sine': m.bias.uniform_(-np.sqrt(6 / in_features_main_net) / 30, np.sqrt(6 / in_features_main_net) / 30) elif activation == 'leakyrelu_pe': m.bias.uniform_(-np.sqrt(6 / in_features_main_net), np.sqrt(6 / in_features_main_net)) else: raise NotImplementedError class ConvBlock(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, norm_layer=nn.BatchNorm2d, activation=nn.ELU, bias=True, ): super(ConvBlock, self).__init__() self.block = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), norm_layer(out_channels) if norm_layer is not None else nn.Identity(), activation(), ) def forward(self, x): return self.block(x) class MaxPoolDownSize(nn.Module): def __init__(self, in_channels, mid_channels, out_channels, depth): super(MaxPoolDownSize, self).__init__() self.depth = depth self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0) self.convs = nn.ModuleList([ ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) for conv_i in range(depth) ]) self.pool2d = nn.MaxPool2d(kernel_size=2) def forward(self, x): outputs = [] output = self.reduce_conv(x) for conv_i, conv in enumerate(self.convs): output = output if conv_i == 0 else self.pool2d(output) outputs.append(conv(output)) return outputs class convParams(nn.Module): def __init__(self, input_dim, INR_in_out, opt, hidden_mlp_num, hidden_dim=512, toRGB=False): super(convParams, self).__init__() self.INR_in_out = INR_in_out self.cont_split_weight = [] self.cont_split_bias = [] self.hidden_mlp_num = hidden_mlp_num self.param_factorize_dim = opt.param_factorize_dim output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num, toRGB) self.output_dim = output_dim self.toRGB = toRGB self.cont_extraction_net = nn.Sequential( nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=1, bias=False), # nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, output_dim, kernel_size=1, stride=1, padding=0, bias=True), ) self.cont_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation)) self.basic_params = nn.ParameterList() if opt.param_factorize_dim > 0: for id in range(self.hidden_mlp_num + 1): if id == 0: inp, outp = self.INR_in_out[0], self.INR_in_out[1] else: inp, outp = self.INR_in_out[1], self.INR_in_out[1] self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, inp, outp))) if toRGB: self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, self.INR_in_out[1], 3))) def forward(self, feat, outMore=False): cont_params = self.cont_extraction_net(feat) out_mlp = self.to_mlp(cont_params) if outMore: return out_mlp, cont_params return out_mlp def cal_params_num(self, INR_in_out, hidden_mlp_num, toRGB=False): cont_params = 0 start = 0 if self.param_factorize_dim == -1: cont_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1] self.cont_split_weight.append([start, cont_params - INR_in_out[1]]) self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) start = cont_params for id in range(hidden_mlp_num): cont_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1] self.cont_split_weight.append([start, cont_params - INR_in_out[1]]) self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) start = cont_params if toRGB: cont_params += INR_in_out[1] * 3 + 3 self.cont_split_weight.append([start, cont_params - 3]) self.cont_split_bias.append([cont_params - 3, cont_params]) elif self.param_factorize_dim > 0: cont_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ INR_in_out[1] self.cont_split_weight.append( [start, start + INR_in_out[0] * self.param_factorize_dim, cont_params - INR_in_out[1]]) self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) start = cont_params for id in range(hidden_mlp_num): cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ INR_in_out[1] self.cont_split_weight.append( [start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - INR_in_out[1]]) self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) start = cont_params if toRGB: cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3 self.cont_split_weight.append( [start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - 3]) self.cont_split_bias.append([cont_params - 3, cont_params]) return cont_params def to_mlp(self, params): all_weight_bias = [] if self.param_factorize_dim == -1: for id in range(self.hidden_mlp_num + 1): if id == 0: inp, outp = self.INR_in_out[0], self.INR_in_out[1] else: inp, outp = self.INR_in_out[1], self.INR_in_out[1] weight = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :] weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:], inp, outp) bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :] bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) all_weight_bias.append([weight, bias]) if self.toRGB: inp, outp = self.INR_in_out[1], 3 weight = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :] weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:], inp, outp) bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :] bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) all_weight_bias.append([weight, bias]) return all_weight_bias else: for id in range(self.hidden_mlp_num + 1): if id == 0: inp, outp = self.INR_in_out[0], self.INR_in_out[1] else: inp, outp = self.INR_in_out[1], self.INR_in_out[1] weight1 = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :] weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:], inp, self.param_factorize_dim) weight2 = params[:, self.cont_split_weight[id][1]:self.cont_split_weight[id][2], :, :] weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:], self.param_factorize_dim, outp) bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :] bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias]) if self.toRGB: inp, outp = self.INR_in_out[1], 3 weight1 = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :] weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:], inp, self.param_factorize_dim) weight2 = params[:, self.cont_split_weight[-1][1]:self.cont_split_weight[-1][2], :, :] weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:], self.param_factorize_dim, outp) bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :] bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[-1], bias]) return all_weight_bias class lineParams(nn.Module): def __init__(self, input_dim, INR_in_out, input_resolution, opt, hidden_mlp_num, toRGB=False, hidden_dim=512): super(lineParams, self).__init__() self.INR_in_out = INR_in_out self.app_split_weight = [] self.app_split_bias = [] self.toRGB = toRGB self.hidden_mlp_num = hidden_mlp_num self.param_factorize_dim = opt.param_factorize_dim output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num) self.output_dim = output_dim self.compress_layer = nn.Sequential( nn.Linear(input_resolution, 64, bias=False), nn.BatchNorm1d(input_dim), nn.ReLU(inplace=True), nn.Linear(64, 1, bias=True) ) self.app_extraction_net = nn.Sequential( nn.Linear(input_dim, hidden_dim, bias=False), # nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim, bias=False), # nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, output_dim, bias=True) ) self.app_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation)) self.basic_params = nn.ParameterList() if opt.param_factorize_dim > 0: for id in range(self.hidden_mlp_num + 1): if id == 0: inp, outp = self.INR_in_out[0], self.INR_in_out[1] else: inp, outp = self.INR_in_out[1], self.INR_in_out[1] self.basic_params.append(nn.Parameter(torch.randn(1, inp, outp))) if toRGB: self.basic_params.append(nn.Parameter(torch.randn(1, self.INR_in_out[1], 3))) def forward(self, feat): app_params = self.app_extraction_net(self.compress_layer(torch.flatten(feat, 2)).squeeze(-1)) out_mlp = self.to_mlp(app_params) return out_mlp, app_params def cal_params_num(self, INR_in_out, hidden_mlp_num): app_params = 0 start = 0 if self.param_factorize_dim == -1: app_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1] self.app_split_weight.append([start, app_params - INR_in_out[1]]) self.app_split_bias.append([app_params - INR_in_out[1], app_params]) start = app_params for id in range(hidden_mlp_num): app_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1] self.app_split_weight.append([start, app_params - INR_in_out[1]]) self.app_split_bias.append([app_params - INR_in_out[1], app_params]) start = app_params if self.toRGB: app_params += INR_in_out[1] * 3 + 3 self.app_split_weight.append([start, app_params - 3]) self.app_split_bias.append([app_params - 3, app_params]) elif self.param_factorize_dim > 0: app_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ INR_in_out[1] self.app_split_weight.append([start, start + INR_in_out[0] * self.param_factorize_dim, app_params - INR_in_out[1]]) self.app_split_bias.append([app_params - INR_in_out[1], app_params]) start = app_params for id in range(hidden_mlp_num): app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ INR_in_out[1] self.app_split_weight.append( [start, start + INR_in_out[1] * self.param_factorize_dim, app_params - INR_in_out[1]]) self.app_split_bias.append([app_params - INR_in_out[1], app_params]) start = app_params if self.toRGB: app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3 self.app_split_weight.append([start, start + INR_in_out[1] * self.param_factorize_dim, app_params - 3]) self.app_split_bias.append([app_params - 3, app_params]) return app_params def to_mlp(self, params): all_weight_bias = [] if self.param_factorize_dim == -1: for id in range(self.hidden_mlp_num + 1): if id == 0: inp, outp = self.INR_in_out[0], self.INR_in_out[1] else: inp, outp = self.INR_in_out[1], self.INR_in_out[1] weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] weight = weight.view(weight.shape[0], inp, outp) bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] bias = bias.view(bias.shape[0], 1, outp) all_weight_bias.append([weight, bias]) if self.toRGB: id = -1 inp, outp = self.INR_in_out[1], 3 weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] weight = weight.view(weight.shape[0], inp, outp) bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] bias = bias.view(bias.shape[0], 1, outp) all_weight_bias.append([weight, bias]) return all_weight_bias else: for id in range(self.hidden_mlp_num + 1): if id == 0: inp, outp = self.INR_in_out[0], self.INR_in_out[1] else: inp, outp = self.INR_in_out[1], self.INR_in_out[1] weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim) weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]] weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp) bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] bias = bias.view(bias.shape[0], 1, outp) all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias]) if self.toRGB: id = -1 inp, outp = self.INR_in_out[1], 3 weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim) weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]] weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp) bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] bias = bias.view(bias.shape[0], 1, outp) all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias]) return all_weight_bias