Spaces:
Running
Running
import torch | |
from torch import nn as nn | |
import numpy as np | |
import math | |
import torch.nn.functional as F | |
class SimpleInputFusion(nn.Module): | |
def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d): | |
super(SimpleInputFusion, self).__init__() | |
self.fusion_conv = nn.Sequential( | |
nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1), | |
nn.LeakyReLU(negative_slope=0.2), | |
norm_layer(ch), | |
nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1), | |
) | |
def forward(self, image, additional_input): | |
return self.fusion_conv(torch.cat((image, additional_input), dim=1)) | |
class MaskedChannelAttention(nn.Module): | |
def __init__(self, in_channels, *args, **kwargs): | |
super(MaskedChannelAttention, self).__init__() | |
self.global_max_pool = MaskedGlobalMaxPool2d() | |
self.global_avg_pool = FastGlobalAvgPool2d() | |
intermediate_channels_count = max(in_channels // 16, 8) | |
self.attention_transform = nn.Sequential( | |
nn.Linear(3 * in_channels, intermediate_channels_count), | |
nn.ReLU(inplace=True), | |
nn.Linear(intermediate_channels_count, in_channels), | |
nn.Sigmoid(), | |
) | |
def forward(self, x, mask): | |
if mask.shape[2:] != x.shape[:2]: | |
mask = nn.functional.interpolate( | |
mask, size=x.size()[-2:], | |
mode='bilinear', align_corners=True | |
) | |
pooled_x = torch.cat([ | |
self.global_max_pool(x, mask), | |
self.global_avg_pool(x) | |
], dim=1) | |
channel_attention_weights = self.attention_transform(pooled_x)[..., None, None] | |
return channel_attention_weights * x | |
class MaskedGlobalMaxPool2d(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.global_max_pool = FastGlobalMaxPool2d() | |
def forward(self, x, mask): | |
return torch.cat(( | |
self.global_max_pool(x * mask), | |
self.global_max_pool(x * (1.0 - mask)) | |
), dim=1) | |
class FastGlobalAvgPool2d(nn.Module): | |
def __init__(self): | |
super(FastGlobalAvgPool2d, self).__init__() | |
def forward(self, x): | |
in_size = x.size() | |
return x.view((in_size[0], in_size[1], -1)).mean(dim=2) | |
class FastGlobalMaxPool2d(nn.Module): | |
def __init__(self): | |
super(FastGlobalMaxPool2d, self).__init__() | |
def forward(self, x): | |
in_size = x.size() | |
return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0] | |
class ScaleLayer(nn.Module): | |
def __init__(self, init_value=1.0, lr_mult=1): | |
super().__init__() | |
self.lr_mult = lr_mult | |
self.scale = nn.Parameter( | |
torch.full((1,), init_value / lr_mult, dtype=torch.float32) | |
) | |
def forward(self, x): | |
scale = torch.abs(self.scale * self.lr_mult) | |
return x * scale | |
class FeaturesConnector(nn.Module): | |
def __init__(self, mode, in_channels, feature_channels, out_channels): | |
super(FeaturesConnector, self).__init__() | |
self.mode = mode if feature_channels else '' | |
if self.mode == 'catc': | |
self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1) | |
elif self.mode == 'sum': | |
self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1) | |
self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels | |
def forward(self, x, features): | |
if self.mode == 'cat': | |
return torch.cat((x, features), 1) | |
if self.mode == 'catc': | |
return self.reduce_conv(torch.cat((x, features), 1)) | |
if self.mode == 'sum': | |
return self.reduce_conv(features) + x | |
return x | |
def extra_repr(self): | |
return self.mode | |
class PosEncodingNeRF(nn.Module): | |
def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True): | |
super().__init__() | |
self.in_features = in_features | |
if self.in_features == 3: | |
self.num_frequencies = 10 | |
elif self.in_features == 2: | |
assert sidelength is not None | |
if isinstance(sidelength, int): | |
sidelength = (sidelength, sidelength) | |
self.num_frequencies = 4 | |
if use_nyquist: | |
self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1])) | |
elif self.in_features == 1: | |
assert fn_samples is not None | |
self.num_frequencies = 4 | |
if use_nyquist: | |
self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples) | |
self.out_dim = in_features + 2 * in_features * self.num_frequencies | |
def get_num_frequencies_nyquist(self, samples): | |
nyquist_rate = 1 / (2 * (2 * 1 / samples)) | |
return int(math.floor(math.log(nyquist_rate, 2))) | |
def forward(self, coords): | |
coords = coords.view(coords.shape[0], -1, self.in_features) | |
coords_pos_enc = coords | |
for i in range(self.num_frequencies): | |
for j in range(self.in_features): | |
c = coords[..., j] | |
sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1) | |
cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1) | |
coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1) | |
return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim) | |
class RandomFourier(nn.Module): | |
def __init__(self, std_scale, embedding_length, device): | |
super().__init__() | |
self.embed = torch.normal(0, 1, (2, embedding_length)) * std_scale | |
self.embed = self.embed.to(device) | |
self.out_dim = embedding_length * 2 + 2 | |
def forward(self, coords): | |
coords_pos_enc = torch.cat([torch.sin(torch.matmul(2 * np.pi * coords, self.embed)), | |
torch.cos(torch.matmul(2 * np.pi * coords, self.embed))], dim=-1) | |
return torch.cat([coords, coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)], dim=-1) | |
class CIPS_embed(nn.Module): | |
def __init__(self, size, embedding_length): | |
super().__init__() | |
self.fourier_embed = ConstantInput(size, embedding_length) | |
self.predict_embed = Predict_embed(embedding_length) | |
self.out_dim = embedding_length * 2 + 2 | |
def forward(self, coord, res=None): | |
x = self.predict_embed(coord) | |
y = self.fourier_embed(x, coord, res) | |
return torch.cat([coord, x, y], dim=-1) | |
class Predict_embed(nn.Module): | |
def __init__(self, embedding_length): | |
super(Predict_embed, self).__init__() | |
self.ffm = nn.Linear(2, embedding_length, bias=True) | |
nn.init.uniform_(self.ffm.weight, -np.sqrt(9 / 2), np.sqrt(9 / 2)) | |
def forward(self, x): | |
x = self.ffm(x) | |
x = torch.sin(x) | |
return x | |
class ConstantInput(nn.Module): | |
def __init__(self, size, channel): | |
super().__init__() | |
self.input = nn.Parameter(torch.randn(1, size ** 2, channel)) | |
def forward(self, input, coord, resolution=None): | |
batch = input.shape[0] | |
out = self.input.repeat(batch, 1, 1) | |
if coord.shape[1] != self.input.shape[1]: | |
x = out.permute(0, 2, 1).contiguous().view(batch, self.input.shape[-1], | |
int(self.input.shape[1] ** 0.5), int(self.input.shape[1] ** 0.5)) | |
if resolution is None: | |
grid = coord.view(coord.shape[0], int(coord.shape[1] ** 0.5), int(coord.shape[1] ** 0.5), coord.shape[-1]) | |
else: | |
grid = coord.view(coord.shape[0], *resolution, coord.shape[-1]) | |
out = F.grid_sample(x, grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=True) | |
out = out.permute(0, 2, 3, 1).contiguous().view(batch, -1, self.input.shape[-1]) | |
return out | |
class INRGAN_embed(nn.Module): | |
def __init__(self, resolution: int, w_dim=None): | |
super().__init__() | |
self.resolution = resolution | |
self.res_cfg = {"log_emb_size": 32, | |
"random_emb_size": 32, | |
"const_emb_size": 64, | |
"use_cosine": True} | |
self.log_emb_size = self.res_cfg.get('log_emb_size', 0) | |
self.random_emb_size = self.res_cfg.get('random_emb_size', 0) | |
self.shared_emb_size = self.res_cfg.get('shared_emb_size', 0) | |
self.predictable_emb_size = self.res_cfg.get('predictable_emb_size', 0) | |
self.const_emb_size = self.res_cfg.get('const_emb_size', 0) | |
self.fourier_scale = self.res_cfg.get('fourier_scale', np.sqrt(10)) | |
self.use_cosine = self.res_cfg.get('use_cosine', False) | |
if self.log_emb_size > 0: | |
self.register_buffer('log_basis', generate_logarithmic_basis( | |
resolution, self.log_emb_size, use_diagonal=self.res_cfg.get('use_diagonal', False))) | |
if self.random_emb_size > 0: | |
self.register_buffer('random_basis', self.sample_w_matrix((2, self.random_emb_size), self.fourier_scale)) | |
if self.shared_emb_size > 0: | |
self.shared_basis = nn.Parameter(self.sample_w_matrix((2, self.shared_emb_size), self.fourier_scale)) | |
if self.predictable_emb_size > 0: | |
self.W_size = self.predictable_emb_size * self.cfg.coord_dim | |
self.b_size = self.predictable_emb_size | |
self.affine = nn.Linear(w_dim, self.W_size + self.b_size) | |
if self.const_emb_size > 0: | |
self.const_embs = nn.Parameter(torch.randn(1, resolution ** 2, self.const_emb_size)) | |
self.out_dim = self.get_total_dim() + 2 | |
def sample_w_matrix(self, shape, scale: float): | |
return torch.randn(shape) * scale | |
def get_total_dim(self) -> int: | |
total_dim = 0 | |
if self.log_emb_size > 0: | |
total_dim += self.log_basis.shape[0] * (2 if self.use_cosine else 1) | |
total_dim += self.random_emb_size * (2 if self.use_cosine else 1) | |
total_dim += self.shared_emb_size * (2 if self.use_cosine else 1) | |
total_dim += self.predictable_emb_size * (2 if self.use_cosine else 1) | |
total_dim += self.const_emb_size | |
return total_dim | |
def forward(self, raw_coords, w=None): | |
batch_size, img_size, in_channels = raw_coords.shape | |
raw_embs = [] | |
if self.log_emb_size > 0: | |
log_bases = self.log_basis.unsqueeze(0).repeat(batch_size, 1, 1).permute(0, 2, 1) | |
raw_log_embs = torch.matmul(raw_coords, log_bases) | |
raw_embs.append(raw_log_embs) | |
if self.random_emb_size > 0: | |
random_bases = self.random_basis.unsqueeze(0).repeat(batch_size, 1, 1) | |
raw_random_embs = torch.matmul(raw_coords, random_bases) | |
raw_embs.append(raw_random_embs) | |
if self.shared_emb_size > 0: | |
shared_bases = self.shared_basis.unsqueeze(0).repeat(batch_size, 1, 1) | |
raw_shared_embs = torch.matmul(raw_coords, shared_bases) | |
raw_embs.append(raw_shared_embs) | |
if self.predictable_emb_size > 0: | |
mod = self.affine(w) | |
W = self.fourier_scale * mod[:, :self.W_size] | |
W = W.view(batch_size, self.cfg.coord_dim, self.predictable_emb_size) | |
bias = mod[:, self.W_size:].view(batch_size, 1, self.predictable_emb_size) | |
raw_predictable_embs = (torch.matmul(raw_coords, W) + bias) | |
raw_embs.append(raw_predictable_embs) | |
if len(raw_embs) > 0: | |
raw_embs = torch.cat(raw_embs, dim=-1) | |
raw_embs = raw_embs.contiguous() | |
out = raw_embs.sin() | |
if self.use_cosine: | |
out = torch.cat([out, raw_embs.cos()], dim=-1) | |
if self.const_emb_size > 0: | |
const_embs = self.const_embs.repeat([batch_size, 1, 1]) | |
const_embs = const_embs | |
out = torch.cat([out, const_embs], dim=-1) | |
return torch.cat([raw_coords, out], dim=-1) | |
def generate_logarithmic_basis( | |
resolution, | |
max_num_feats, | |
remove_lowest_freq: bool = False, | |
use_diagonal: bool = True): | |
""" | |
Generates a directional logarithmic basis with the following directions: | |
- horizontal | |
- vertical | |
- main diagonal | |
- anti-diagonal | |
""" | |
max_num_feats_per_direction = np.ceil(np.log2(resolution)).astype(int) | |
bases = [ | |
generate_horizontal_basis(max_num_feats_per_direction), | |
generate_vertical_basis(max_num_feats_per_direction), | |
] | |
if use_diagonal: | |
bases.extend([ | |
generate_diag_main_basis(max_num_feats_per_direction), | |
generate_anti_diag_basis(max_num_feats_per_direction), | |
]) | |
if remove_lowest_freq: | |
bases = [b[1:] for b in bases] | |
# If we do not fit into `max_num_feats`, then trying to remove the features in the order: | |
# 1) anti-diagonal 2) main-diagonal | |
# while (max_num_feats_per_direction * len(bases) > max_num_feats) and (len(bases) > 2): | |
# bases = bases[:-1] | |
basis = torch.cat(bases, dim=0) | |
# If we still do not fit, then let's remove each second feature, | |
# then each third, each forth and so on | |
# We cannot drop the whole horizontal or vertical direction since otherwise | |
# model won't be able to locate the position | |
# (unless the previously computed embeddings encode the position) | |
# while basis.shape[0] > max_num_feats: | |
# num_exceeding_feats = basis.shape[0] - max_num_feats | |
# basis = basis[::2] | |
assert basis.shape[0] <= max_num_feats, \ | |
f"num_coord_feats > max_num_fixed_coord_feats: {basis.shape, max_num_feats}." | |
return basis | |
def generate_horizontal_basis(num_feats: int): | |
return generate_wavefront_basis(num_feats, [0.0, 1.0], 4.0) | |
def generate_vertical_basis(num_feats: int): | |
return generate_wavefront_basis(num_feats, [1.0, 0.0], 4.0) | |
def generate_diag_main_basis(num_feats: int): | |
return generate_wavefront_basis(num_feats, [-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2)) | |
def generate_anti_diag_basis(num_feats: int): | |
return generate_wavefront_basis(num_feats, [1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2)) | |
def generate_wavefront_basis(num_feats: int, basis_block, period_length: float): | |
period_coef = 2.0 * np.pi / period_length | |
basis = torch.tensor([basis_block]).repeat(num_feats, 1) # [num_feats, 2] | |
powers = torch.tensor([2]).repeat(num_feats).pow(torch.arange(num_feats)).unsqueeze(1) # [num_feats, 1] | |
result = basis * powers * period_coef # [num_feats, 2] | |
return result.float() |