Spaces:
Configuration error
Configuration error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
# code taken from https://github.com/ykasten/layered-neural-atlases | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def positionalEncoding_vec(in_tensor, b): | |
proj = torch.einsum("ij, k -> ijk", in_tensor, b) # shape (batch, in_tensor.size(1), freqNum) | |
mapped_coords = torch.cat((torch.sin(proj), torch.cos(proj)), dim=1) # shape (batch, 2*in_tensor.size(1), freqNum) | |
output = mapped_coords.transpose(2, 1).contiguous().view(mapped_coords.size(0), -1) | |
return output | |
class IMLP(nn.Module): | |
def __init__( | |
self, | |
input_dim, | |
output_dim, | |
hidden_dim=256, | |
use_positional=True, | |
positional_dim=10, | |
skip_layers=[4, 6], | |
num_layers=8, # includes the output layer | |
verbose=True, | |
use_tanh=True, | |
apply_softmax=False, | |
): | |
super(IMLP, self).__init__() | |
self.verbose = verbose | |
self.use_tanh = use_tanh | |
self.apply_softmax = apply_softmax | |
if apply_softmax: | |
self.softmax = nn.Softmax() | |
if use_positional: | |
encoding_dimensions = 2 * input_dim * positional_dim | |
self.b = torch.tensor([(2 ** j) * np.pi for j in range(positional_dim)], requires_grad=False) | |
else: | |
encoding_dimensions = input_dim | |
self.hidden = nn.ModuleList() | |
for i in range(num_layers): | |
if i == 0: | |
input_dims = encoding_dimensions | |
elif i in skip_layers: | |
input_dims = hidden_dim + encoding_dimensions | |
else: | |
input_dims = hidden_dim | |
if i == num_layers - 1: | |
# last layer | |
self.hidden.append(nn.Linear(input_dims, output_dim, bias=True)) | |
else: | |
self.hidden.append(nn.Linear(input_dims, hidden_dim, bias=True)) | |
self.skip_layers = skip_layers | |
self.num_layers = num_layers | |
self.positional_dim = positional_dim | |
self.use_positional = use_positional | |
if self.verbose: | |
print(f"Model has {count_parameters(self)} params") | |
def forward(self, x): | |
if self.use_positional: | |
if self.b.device != x.device: | |
self.b = self.b.to(x.device) | |
pos = positionalEncoding_vec(x, self.b) | |
x = pos | |
input = x.detach().clone() | |
for i, layer in enumerate(self.hidden): | |
if i > 0: | |
x = F.relu(x) | |
if i in self.skip_layers: | |
x = torch.cat((x, input), 1) | |
x = layer(x) | |
if self.use_tanh: | |
x = torch.tanh(x) | |
if self.apply_softmax: | |
x = self.softmax(x) | |
return x | |