from pathlib import Path import einops import numpy as np import torch import torch.nn as nn import torch.nn.init as init class PoseNet(nn.Module): """a tiny conv network for introducing pose sequence as the condition """ def __init__(self, noise_latent_channels=320, *args, **kwargs): super().__init__(*args, **kwargs) # multiple convolution layers self.conv_layers = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1), nn.SiLU(), nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1), nn.SiLU(), nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1), nn.SiLU(), nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1), nn.SiLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1), nn.SiLU(), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1), nn.SiLU(), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1), nn.SiLU(), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), nn.SiLU() ) # Final projection layer self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1) # Initialize layers self._initialize_weights() self.scale = nn.Parameter(torch.ones(1) * 2) def _initialize_weights(self): """Initialize weights with He. initialization and zero out the biases """ for m in self.conv_layers: if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n)) if m.bias is not None: init.zeros_(m.bias) init.zeros_(self.final_proj.weight) if self.final_proj.bias is not None: init.zeros_(self.final_proj.bias) def forward(self, x): if x.ndim == 5: x = einops.rearrange(x, "b f c h w -> (b f) c h w") x = self.conv_layers(x) x = self.final_proj(x) return x * self.scale @classmethod def from_pretrained(cls, pretrained_model_path): """load pretrained pose-net weights """ if not Path(pretrained_model_path).exists(): print(f"There is no model file in {pretrained_model_path}") print(f"loaded PoseNet's pretrained weights from {pretrained_model_path}.") state_dict = torch.load(pretrained_model_path, map_location="cpu") model = PoseNet(noise_latent_channels=320) model.load_state_dict(state_dict, strict=True) return model