INR-Harmon / model /base /basic_blocks.py
WindVChen's picture
Update
033bd8b
raw
history blame
17.3 kB
import torch
from torch import nn as nn
import numpy as np
def hyper_weight_init(m, in_features_main_net, activation):
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
m.weight.data = m.weight.data / 1.e2
if hasattr(m, 'bias'):
with torch.no_grad():
if activation == 'sine':
m.bias.uniform_(-np.sqrt(6 / in_features_main_net) / 30, np.sqrt(6 / in_features_main_net) / 30)
elif activation == 'leakyrelu_pe':
m.bias.uniform_(-np.sqrt(6 / in_features_main_net), np.sqrt(6 / in_features_main_net))
else:
raise NotImplementedError
class ConvBlock(nn.Module):
def __init__(
self,
in_channels, out_channels,
kernel_size=4, stride=2, padding=1,
norm_layer=nn.BatchNorm2d, activation=nn.ELU,
bias=True,
):
super(ConvBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
activation(),
)
def forward(self, x):
return self.block(x)
class MaxPoolDownSize(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, depth):
super(MaxPoolDownSize, self).__init__()
self.depth = depth
self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
self.convs = nn.ModuleList([
ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
for conv_i in range(depth)
])
self.pool2d = nn.MaxPool2d(kernel_size=2)
def forward(self, x):
outputs = []
output = self.reduce_conv(x)
for conv_i, conv in enumerate(self.convs):
output = output if conv_i == 0 else self.pool2d(output)
outputs.append(conv(output))
return outputs
class convParams(nn.Module):
def __init__(self, input_dim, INR_in_out, opt, hidden_mlp_num, hidden_dim=512, toRGB=False):
super(convParams, self).__init__()
self.INR_in_out = INR_in_out
self.cont_split_weight = []
self.cont_split_bias = []
self.hidden_mlp_num = hidden_mlp_num
self.param_factorize_dim = opt.param_factorize_dim
output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num, toRGB)
self.output_dim = output_dim
self.toRGB = toRGB
self.cont_extraction_net = nn.Sequential(
nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, output_dim, kernel_size=1, stride=1, padding=0, bias=True),
)
self.cont_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
self.basic_params = nn.ParameterList()
if opt.param_factorize_dim > 0:
for id in range(self.hidden_mlp_num + 1):
if id == 0:
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
else:
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, inp, outp)))
if toRGB:
self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, self.INR_in_out[1], 3)))
def forward(self, feat, outMore=False):
cont_params = self.cont_extraction_net(feat)
out_mlp = self.to_mlp(cont_params)
if outMore:
return out_mlp, cont_params
return out_mlp
def cal_params_num(self, INR_in_out, hidden_mlp_num, toRGB=False):
cont_params = 0
start = 0
if self.param_factorize_dim == -1:
cont_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
start = cont_params
for id in range(hidden_mlp_num):
cont_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
start = cont_params
if toRGB:
cont_params += INR_in_out[1] * 3 + 3
self.cont_split_weight.append([start, cont_params - 3])
self.cont_split_bias.append([cont_params - 3, cont_params])
elif self.param_factorize_dim > 0:
cont_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
INR_in_out[1]
self.cont_split_weight.append(
[start, start + INR_in_out[0] * self.param_factorize_dim, cont_params - INR_in_out[1]])
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
start = cont_params
for id in range(hidden_mlp_num):
cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
INR_in_out[1]
self.cont_split_weight.append(
[start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - INR_in_out[1]])
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
start = cont_params
if toRGB:
cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
self.cont_split_weight.append(
[start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - 3])
self.cont_split_bias.append([cont_params - 3, cont_params])
return cont_params
def to_mlp(self, params):
all_weight_bias = []
if self.param_factorize_dim == -1:
for id in range(self.hidden_mlp_num + 1):
if id == 0:
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
else:
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
weight = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
inp, outp)
bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
all_weight_bias.append([weight, bias])
if self.toRGB:
inp, outp = self.INR_in_out[1], 3
weight = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
inp, outp)
bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
all_weight_bias.append([weight, bias])
return all_weight_bias
else:
for id in range(self.hidden_mlp_num + 1):
if id == 0:
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
else:
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
weight1 = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
inp, self.param_factorize_dim)
weight2 = params[:, self.cont_split_weight[id][1]:self.cont_split_weight[id][2], :, :]
weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
self.param_factorize_dim, outp)
bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
if self.toRGB:
inp, outp = self.INR_in_out[1], 3
weight1 = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
inp, self.param_factorize_dim)
weight2 = params[:, self.cont_split_weight[-1][1]:self.cont_split_weight[-1][2], :, :]
weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
self.param_factorize_dim, outp)
bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[-1], bias])
return all_weight_bias
class lineParams(nn.Module):
def __init__(self, input_dim, INR_in_out, input_resolution, opt, hidden_mlp_num, toRGB=False,
hidden_dim=512):
super(lineParams, self).__init__()
self.INR_in_out = INR_in_out
self.app_split_weight = []
self.app_split_bias = []
self.toRGB = toRGB
self.hidden_mlp_num = hidden_mlp_num
self.param_factorize_dim = opt.param_factorize_dim
output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num)
self.output_dim = output_dim
self.compress_layer = nn.Sequential(
nn.Linear(input_resolution, 64, bias=False),
nn.BatchNorm1d(input_dim),
nn.ReLU(inplace=True),
nn.Linear(64, 1, bias=True)
)
self.app_extraction_net = nn.Sequential(
nn.Linear(input_dim, hidden_dim, bias=False),
# nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim, bias=False),
# nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, output_dim, bias=True)
)
self.app_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
self.basic_params = nn.ParameterList()
if opt.param_factorize_dim > 0:
for id in range(self.hidden_mlp_num + 1):
if id == 0:
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
else:
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
self.basic_params.append(nn.Parameter(torch.randn(1, inp, outp)))
if toRGB:
self.basic_params.append(nn.Parameter(torch.randn(1, self.INR_in_out[1], 3)))
def forward(self, feat):
app_params = self.app_extraction_net(self.compress_layer(torch.flatten(feat, 2)).squeeze(-1))
out_mlp = self.to_mlp(app_params)
return out_mlp, app_params
def cal_params_num(self, INR_in_out, hidden_mlp_num):
app_params = 0
start = 0
if self.param_factorize_dim == -1:
app_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
self.app_split_weight.append([start, app_params - INR_in_out[1]])
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
start = app_params
for id in range(hidden_mlp_num):
app_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
self.app_split_weight.append([start, app_params - INR_in_out[1]])
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
start = app_params
if self.toRGB:
app_params += INR_in_out[1] * 3 + 3
self.app_split_weight.append([start, app_params - 3])
self.app_split_bias.append([app_params - 3, app_params])
elif self.param_factorize_dim > 0:
app_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
INR_in_out[1]
self.app_split_weight.append([start, start + INR_in_out[0] * self.param_factorize_dim,
app_params - INR_in_out[1]])
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
start = app_params
for id in range(hidden_mlp_num):
app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
INR_in_out[1]
self.app_split_weight.append(
[start, start + INR_in_out[1] * self.param_factorize_dim, app_params - INR_in_out[1]])
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
start = app_params
if self.toRGB:
app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
self.app_split_weight.append([start, start + INR_in_out[1] * self.param_factorize_dim,
app_params - 3])
self.app_split_bias.append([app_params - 3, app_params])
return app_params
def to_mlp(self, params):
all_weight_bias = []
if self.param_factorize_dim == -1:
for id in range(self.hidden_mlp_num + 1):
if id == 0:
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
else:
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
weight = weight.view(weight.shape[0], inp, outp)
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
bias = bias.view(bias.shape[0], 1, outp)
all_weight_bias.append([weight, bias])
if self.toRGB:
id = -1
inp, outp = self.INR_in_out[1], 3
weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
weight = weight.view(weight.shape[0], inp, outp)
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
bias = bias.view(bias.shape[0], 1, outp)
all_weight_bias.append([weight, bias])
return all_weight_bias
else:
for id in range(self.hidden_mlp_num + 1):
if id == 0:
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
else:
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
bias = bias.view(bias.shape[0], 1, outp)
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
if self.toRGB:
id = -1
inp, outp = self.INR_in_out[1], 3
weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
bias = bias.view(bias.shape[0], 1, outp)
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
return all_weight_bias