# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging from collections import defaultdict from dataclasses import field, dataclass from typing import Any, Dict, List, Optional, Tuple, Union, Callable from util.embedding import TimeStepEmbedding, PoseEmbedding import torch import torch.nn as nn from hydra.utils import instantiate logger = logging.getLogger(__name__) class Denoiser(nn.Module): def __init__( self, TRANSFORMER: Dict, target_dim: int = 9, # TODO: reduce fl dim from 2 to 1 pivot_cam_onehot: bool = True, z_dim: int = 384, mlp_hidden_dim: bool = 128, ): super().__init__() self.pivot_cam_onehot = pivot_cam_onehot self.target_dim = target_dim self.time_embed = TimeStepEmbedding() self.pose_embed = PoseEmbedding(target_dim=self.target_dim) first_dim = ( self.time_embed.out_dim + self.pose_embed.out_dim + z_dim + int(self.pivot_cam_onehot) ) d_model = TRANSFORMER.d_model self._first = nn.Linear(first_dim, d_model) # slightly different from the paper that # we use 2 encoder layers and 6 decoder layers # here we use a transformer with 8 encoder layers # call TransformerEncoderWrapper() to build a encoder-only transformer self._trunk = instantiate(TRANSFORMER, _recursive_=False) # TODO: change the implementation of MLP to a more mature one self._last = MLP( d_model, [mlp_hidden_dim, self.target_dim], norm_layer=nn.LayerNorm, ) def forward( self, x: torch.Tensor, # B x N x dim t: torch.Tensor, # B z: torch.Tensor, # B x N x dim_z ): B, N, _ = x.shape t_emb = self.time_embed(t) # expand t from B x C to B x N x C t_emb = t_emb.view(B, 1, t_emb.shape[-1]).expand(-1, N, -1) x_emb = self.pose_embed(x) if self.pivot_cam_onehot: # add the one hot vector identifying the first camera as pivot cam_pivot_id = torch.zeros_like(z[..., :1]) cam_pivot_id[:, 0, ...] = 1.0 z = torch.cat([z, cam_pivot_id], dim=-1) feed_feats = torch.cat([x_emb, t_emb, z], dim=-1) input_ = self._first(feed_feats) feats_ = self._trunk(input_) output = self._last(feats_) return output def TransformerEncoderWrapper( d_model: int, nhead: int, num_encoder_layers: int, dim_feedforward: int = 2048, dropout: float = 0.1, norm_first: bool = True, batch_first: bool = True, ): encoder_layer = torch.nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=batch_first, norm_first=norm_first, ) _trunk = torch.nn.TransformerEncoder(encoder_layer, num_encoder_layers) return _trunk class MLP(torch.nn.Sequential): """This block implements the multi-layer perceptron (MLP) module. Args: in_channels (int): Number of channels of the input hidden_channels (List[int]): List of the hidden channel dimensions norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` bias (bool): Whether to use bias in the linear layer. Default ``True`` dropout (float): The probability for the dropout layer. Default: 0.0 """ def __init__( self, in_channels: int, hidden_channels: List[int], norm_layer: Optional[Callable[..., torch.nn.Module]] = None, activation_layer: Optional[ Callable[..., torch.nn.Module] ] = torch.nn.ReLU, inplace: Optional[bool] = True, bias: bool = True, norm_first: bool = False, dropout: float = 0.0, ): # The addition of `norm_layer` is inspired from # the implementation of TorchMultimodal: # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py params = {} if inplace is None else {"inplace": inplace} layers = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: if norm_first and norm_layer is not None: layers.append(norm_layer(in_dim)) layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) if not norm_first and norm_layer is not None: layers.append(norm_layer(hidden_dim)) layers.append(activation_layer(**params)) if dropout > 0: layers.append(torch.nn.Dropout(dropout, **params)) in_dim = hidden_dim if norm_first and norm_layer is not None: layers.append(norm_layer(in_dim)) layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) if dropout > 0: layers.append(torch.nn.Dropout(dropout, **params)) super().__init__(*layers)