import torch from torch import nn as nn import numpy as np import math import torch.nn.functional as F class SimpleInputFusion(nn.Module): def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d): super(SimpleInputFusion, self).__init__() self.fusion_conv = nn.Sequential( nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1), nn.LeakyReLU(negative_slope=0.2), norm_layer(ch), nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1), ) def forward(self, image, additional_input): return self.fusion_conv(torch.cat((image, additional_input), dim=1)) class MaskedChannelAttention(nn.Module): def __init__(self, in_channels, *args, **kwargs): super(MaskedChannelAttention, self).__init__() self.global_max_pool = MaskedGlobalMaxPool2d() self.global_avg_pool = FastGlobalAvgPool2d() intermediate_channels_count = max(in_channels // 16, 8) self.attention_transform = nn.Sequential( nn.Linear(3 * in_channels, intermediate_channels_count), nn.ReLU(inplace=True), nn.Linear(intermediate_channels_count, in_channels), nn.Sigmoid(), ) def forward(self, x, mask): if mask.shape[2:] != x.shape[:2]: mask = nn.functional.interpolate( mask, size=x.size()[-2:], mode='bilinear', align_corners=True ) pooled_x = torch.cat([ self.global_max_pool(x, mask), self.global_avg_pool(x) ], dim=1) channel_attention_weights = self.attention_transform(pooled_x)[..., None, None] return channel_attention_weights * x class MaskedGlobalMaxPool2d(nn.Module): def __init__(self): super().__init__() self.global_max_pool = FastGlobalMaxPool2d() def forward(self, x, mask): return torch.cat(( self.global_max_pool(x * mask), self.global_max_pool(x * (1.0 - mask)) ), dim=1) class FastGlobalAvgPool2d(nn.Module): def __init__(self): super(FastGlobalAvgPool2d, self).__init__() def forward(self, x): in_size = x.size() return x.view((in_size[0], in_size[1], -1)).mean(dim=2) class FastGlobalMaxPool2d(nn.Module): def __init__(self): super(FastGlobalMaxPool2d, self).__init__() def forward(self, x): in_size = x.size() return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0] class ScaleLayer(nn.Module): def __init__(self, init_value=1.0, lr_mult=1): super().__init__() self.lr_mult = lr_mult self.scale = nn.Parameter( torch.full((1,), init_value / lr_mult, dtype=torch.float32) ) def forward(self, x): scale = torch.abs(self.scale * self.lr_mult) return x * scale class FeaturesConnector(nn.Module): def __init__(self, mode, in_channels, feature_channels, out_channels): super(FeaturesConnector, self).__init__() self.mode = mode if feature_channels else '' if self.mode == 'catc': self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1) elif self.mode == 'sum': self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1) self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels def forward(self, x, features): if self.mode == 'cat': return torch.cat((x, features), 1) if self.mode == 'catc': return self.reduce_conv(torch.cat((x, features), 1)) if self.mode == 'sum': return self.reduce_conv(features) + x return x def extra_repr(self): return self.mode class PosEncodingNeRF(nn.Module): def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True): super().__init__() self.in_features = in_features if self.in_features == 3: self.num_frequencies = 10 elif self.in_features == 2: assert sidelength is not None if isinstance(sidelength, int): sidelength = (sidelength, sidelength) self.num_frequencies = 4 if use_nyquist: self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1])) elif self.in_features == 1: assert fn_samples is not None self.num_frequencies = 4 if use_nyquist: self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples) self.out_dim = in_features + 2 * in_features * self.num_frequencies def get_num_frequencies_nyquist(self, samples): nyquist_rate = 1 / (2 * (2 * 1 / samples)) return int(math.floor(math.log(nyquist_rate, 2))) def forward(self, coords): coords = coords.view(coords.shape[0], -1, self.in_features) coords_pos_enc = coords for i in range(self.num_frequencies): for j in range(self.in_features): c = coords[..., j] sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1) cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1) coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1) return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim) class RandomFourier(nn.Module): def __init__(self, std_scale, embedding_length, device): super().__init__() self.embed = torch.normal(0, 1, (2, embedding_length)) * std_scale self.embed = self.embed.to(device) self.out_dim = embedding_length * 2 + 2 def forward(self, coords): coords_pos_enc = torch.cat([torch.sin(torch.matmul(2 * np.pi * coords, self.embed)), torch.cos(torch.matmul(2 * np.pi * coords, self.embed))], dim=-1) return torch.cat([coords, coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)], dim=-1) class CIPS_embed(nn.Module): def __init__(self, size, embedding_length): super().__init__() self.fourier_embed = ConstantInput(size, embedding_length) self.predict_embed = Predict_embed(embedding_length) self.out_dim = embedding_length * 2 + 2 def forward(self, coord, res=None): x = self.predict_embed(coord) y = self.fourier_embed(x, coord, res) return torch.cat([coord, x, y], dim=-1) class Predict_embed(nn.Module): def __init__(self, embedding_length): super(Predict_embed, self).__init__() self.ffm = nn.Linear(2, embedding_length, bias=True) nn.init.uniform_(self.ffm.weight, -np.sqrt(9 / 2), np.sqrt(9 / 2)) def forward(self, x): x = self.ffm(x) x = torch.sin(x) return x class ConstantInput(nn.Module): def __init__(self, size, channel): super().__init__() self.input = nn.Parameter(torch.randn(1, size ** 2, channel)) def forward(self, input, coord, resolution=None): batch = input.shape[0] out = self.input.repeat(batch, 1, 1) if coord.shape[1] != self.input.shape[1]: x = out.permute(0, 2, 1).contiguous().view(batch, self.input.shape[-1], int(self.input.shape[1] ** 0.5), int(self.input.shape[1] ** 0.5)) if resolution is None: grid = coord.view(coord.shape[0], int(coord.shape[1] ** 0.5), int(coord.shape[1] ** 0.5), coord.shape[-1]) else: grid = coord.view(coord.shape[0], *resolution, coord.shape[-1]) out = F.grid_sample(x, grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=True) out = out.permute(0, 2, 3, 1).contiguous().view(batch, -1, self.input.shape[-1]) return out class INRGAN_embed(nn.Module): def __init__(self, resolution: int, w_dim=None): super().__init__() self.resolution = resolution self.res_cfg = {"log_emb_size": 32, "random_emb_size": 32, "const_emb_size": 64, "use_cosine": True} self.log_emb_size = self.res_cfg.get('log_emb_size', 0) self.random_emb_size = self.res_cfg.get('random_emb_size', 0) self.shared_emb_size = self.res_cfg.get('shared_emb_size', 0) self.predictable_emb_size = self.res_cfg.get('predictable_emb_size', 0) self.const_emb_size = self.res_cfg.get('const_emb_size', 0) self.fourier_scale = self.res_cfg.get('fourier_scale', np.sqrt(10)) self.use_cosine = self.res_cfg.get('use_cosine', False) if self.log_emb_size > 0: self.register_buffer('log_basis', generate_logarithmic_basis( resolution, self.log_emb_size, use_diagonal=self.res_cfg.get('use_diagonal', False))) if self.random_emb_size > 0: self.register_buffer('random_basis', self.sample_w_matrix((2, self.random_emb_size), self.fourier_scale)) if self.shared_emb_size > 0: self.shared_basis = nn.Parameter(self.sample_w_matrix((2, self.shared_emb_size), self.fourier_scale)) if self.predictable_emb_size > 0: self.W_size = self.predictable_emb_size * self.cfg.coord_dim self.b_size = self.predictable_emb_size self.affine = nn.Linear(w_dim, self.W_size + self.b_size) if self.const_emb_size > 0: self.const_embs = nn.Parameter(torch.randn(1, resolution ** 2, self.const_emb_size)) self.out_dim = self.get_total_dim() + 2 def sample_w_matrix(self, shape, scale: float): return torch.randn(shape) * scale def get_total_dim(self) -> int: total_dim = 0 if self.log_emb_size > 0: total_dim += self.log_basis.shape[0] * (2 if self.use_cosine else 1) total_dim += self.random_emb_size * (2 if self.use_cosine else 1) total_dim += self.shared_emb_size * (2 if self.use_cosine else 1) total_dim += self.predictable_emb_size * (2 if self.use_cosine else 1) total_dim += self.const_emb_size return total_dim def forward(self, raw_coords, w=None): batch_size, img_size, in_channels = raw_coords.shape raw_embs = [] if self.log_emb_size > 0: log_bases = self.log_basis.unsqueeze(0).repeat(batch_size, 1, 1).permute(0, 2, 1) raw_log_embs = torch.matmul(raw_coords, log_bases) raw_embs.append(raw_log_embs) if self.random_emb_size > 0: random_bases = self.random_basis.unsqueeze(0).repeat(batch_size, 1, 1) raw_random_embs = torch.matmul(raw_coords, random_bases) raw_embs.append(raw_random_embs) if self.shared_emb_size > 0: shared_bases = self.shared_basis.unsqueeze(0).repeat(batch_size, 1, 1) raw_shared_embs = torch.matmul(raw_coords, shared_bases) raw_embs.append(raw_shared_embs) if self.predictable_emb_size > 0: mod = self.affine(w) W = self.fourier_scale * mod[:, :self.W_size] W = W.view(batch_size, self.cfg.coord_dim, self.predictable_emb_size) bias = mod[:, self.W_size:].view(batch_size, 1, self.predictable_emb_size) raw_predictable_embs = (torch.matmul(raw_coords, W) + bias) raw_embs.append(raw_predictable_embs) if len(raw_embs) > 0: raw_embs = torch.cat(raw_embs, dim=-1) raw_embs = raw_embs.contiguous() out = raw_embs.sin() if self.use_cosine: out = torch.cat([out, raw_embs.cos()], dim=-1) if self.const_emb_size > 0: const_embs = self.const_embs.repeat([batch_size, 1, 1]) const_embs = const_embs out = torch.cat([out, const_embs], dim=-1) return torch.cat([raw_coords, out], dim=-1) def generate_logarithmic_basis( resolution, max_num_feats, remove_lowest_freq: bool = False, use_diagonal: bool = True): """ Generates a directional logarithmic basis with the following directions: - horizontal - vertical - main diagonal - anti-diagonal """ max_num_feats_per_direction = np.ceil(np.log2(resolution)).astype(int) bases = [ generate_horizontal_basis(max_num_feats_per_direction), generate_vertical_basis(max_num_feats_per_direction), ] if use_diagonal: bases.extend([ generate_diag_main_basis(max_num_feats_per_direction), generate_anti_diag_basis(max_num_feats_per_direction), ]) if remove_lowest_freq: bases = [b[1:] for b in bases] # If we do not fit into `max_num_feats`, then trying to remove the features in the order: # 1) anti-diagonal 2) main-diagonal # while (max_num_feats_per_direction * len(bases) > max_num_feats) and (len(bases) > 2): # bases = bases[:-1] basis = torch.cat(bases, dim=0) # If we still do not fit, then let's remove each second feature, # then each third, each forth and so on # We cannot drop the whole horizontal or vertical direction since otherwise # model won't be able to locate the position # (unless the previously computed embeddings encode the position) # while basis.shape[0] > max_num_feats: # num_exceeding_feats = basis.shape[0] - max_num_feats # basis = basis[::2] assert basis.shape[0] <= max_num_feats, \ f"num_coord_feats > max_num_fixed_coord_feats: {basis.shape, max_num_feats}." return basis def generate_horizontal_basis(num_feats: int): return generate_wavefront_basis(num_feats, [0.0, 1.0], 4.0) def generate_vertical_basis(num_feats: int): return generate_wavefront_basis(num_feats, [1.0, 0.0], 4.0) def generate_diag_main_basis(num_feats: int): return generate_wavefront_basis(num_feats, [-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2)) def generate_anti_diag_basis(num_feats: int): return generate_wavefront_basis(num_feats, [1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2)) def generate_wavefront_basis(num_feats: int, basis_block, period_length: float): period_coef = 2.0 * np.pi / period_length basis = torch.tensor([basis_block]).repeat(num_feats, 1) # [num_feats, 2] powers = torch.tensor([2]).repeat(num_feats).pow(torch.arange(num_feats)).unsqueeze(1) # [num_feats, 1] result = basis * powers * period_coef # [num_feats, 2] return result.float()