import copy from typing import List, Optional import torch class AdaptiveLayerNorm1D(torch.nn.Module): def __init__(self, data_dim: int, norm_cond_dim: int): super().__init__() if data_dim <= 0: raise ValueError(f"data_dim must be positive, but got {data_dim}") if norm_cond_dim <= 0: raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}") self.norm = torch.nn.LayerNorm( data_dim ) # TODO: Check if elementwise_affine=True is correct self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim) torch.nn.init.zeros_(self.linear.weight) torch.nn.init.zeros_(self.linear.bias) def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: # x: (batch, ..., data_dim) # t: (batch, norm_cond_dim) # return: (batch, data_dim) x = self.norm(x) alpha, beta = self.linear(t).chunk(2, dim=-1) # Add singleton dimensions to alpha and beta if x.dim() > 2: alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1]) beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1]) return x * (1 + alpha) + beta class SequentialCond(torch.nn.Sequential): def forward(self, input, *args, **kwargs): for module in self: if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)): # print(f'Passing on args to {module}', [a.shape for a in args]) input = module(input, *args, **kwargs) else: # print(f'Skipping passing args to {module}', [a.shape for a in args]) input = module(input) return input def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1): if norm == "batch": return torch.nn.BatchNorm1d(dim) elif norm == "layer": return torch.nn.LayerNorm(dim) elif norm == "ada": assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}" return AdaptiveLayerNorm1D(dim, norm_cond_dim) elif norm is None: return torch.nn.Identity() else: raise ValueError(f"Unknown norm: {norm}") def linear_norm_activ_dropout( input_dim: int, output_dim: int, activation: torch.nn.Module = torch.nn.ReLU(), bias: bool = True, norm: Optional[str] = "layer", # Options: ada/batch/layer dropout: float = 0.0, norm_cond_dim: int = -1, ) -> SequentialCond: layers = [] layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias)) if norm is not None: layers.append(normalization_layer(norm, output_dim, norm_cond_dim)) layers.append(copy.deepcopy(activation)) if dropout > 0.0: layers.append(torch.nn.Dropout(dropout)) return SequentialCond(*layers) def create_simple_mlp( input_dim: int, hidden_dims: List[int], output_dim: int, activation: torch.nn.Module = torch.nn.ReLU(), bias: bool = True, norm: Optional[str] = "layer", # Options: ada/batch/layer dropout: float = 0.0, norm_cond_dim: int = -1, ) -> SequentialCond: layers = [] prev_dim = input_dim for hidden_dim in hidden_dims: layers.extend( linear_norm_activ_dropout( prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim ) ) prev_dim = hidden_dim layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias)) return SequentialCond(*layers) class ResidualMLPBlock(torch.nn.Module): def __init__( self, input_dim: int, hidden_dim: int, num_hidden_layers: int, output_dim: int, activation: torch.nn.Module = torch.nn.ReLU(), bias: bool = True, norm: Optional[str] = "layer", # Options: ada/batch/layer dropout: float = 0.0, norm_cond_dim: int = -1, ): super().__init__() if not (input_dim == output_dim == hidden_dim): raise NotImplementedError( f"input_dim {input_dim} != output_dim {output_dim} is not implemented" ) layers = [] prev_dim = input_dim for i in range(num_hidden_layers): layers.append( linear_norm_activ_dropout( prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim ) ) prev_dim = hidden_dim self.model = SequentialCond(*layers) self.skip = torch.nn.Identity() def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return x + self.model(x, *args, **kwargs) class ResidualMLP(torch.nn.Module): def __init__( self, input_dim: int, hidden_dim: int, num_hidden_layers: int, output_dim: int, activation: torch.nn.Module = torch.nn.ReLU(), bias: bool = True, norm: Optional[str] = "layer", # Options: ada/batch/layer dropout: float = 0.0, num_blocks: int = 1, norm_cond_dim: int = -1, ): super().__init__() self.input_dim = input_dim self.model = SequentialCond( linear_norm_activ_dropout( input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim ), *[ ResidualMLPBlock( hidden_dim, hidden_dim, num_hidden_layers, hidden_dim, activation, bias, norm, dropout, norm_cond_dim, ) for _ in range(num_blocks) ], torch.nn.Linear(hidden_dim, output_dim, bias=bias), ) def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return self.model(x, *args, **kwargs) class FrequencyEmbedder(torch.nn.Module): def __init__(self, num_frequencies, max_freq_log2): super().__init__() frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies) self.register_buffer("frequencies", frequencies) def forward(self, x): # x should be of size (N,) or (N, D) N = x.size(0) if x.dim() == 1: # (N,) x = x.unsqueeze(1) # (N, D) where D=1 x_unsqueezed = x.unsqueeze(-1) # (N, D, 1) scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies) s = torch.sin(scaled) c = torch.cos(scaled) embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view( N, -1 ) # (N, D * 2 * num_frequencies + D) return embedded