PoseDiffusion_MVP / util /embedding.py
hugoycj
Initial commit
3d3e4e9
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import math
from pytorch3d.renderer import HarmonicEmbedding
class TimeStepEmbedding(nn.Module):
# learned from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py
def __init__(self, dim=256, max_period=10000):
super().__init__()
self.dim = dim
self.max_period = max_period
self.linear = nn.Sequential(
nn.Linear(dim, dim // 2),
nn.SiLU(),
nn.Linear(dim // 2, dim // 2),
)
self.out_dim = dim // 2
def _compute_freqs(self, half):
freqs = torch.exp(
-math.log(self.max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
)
return freqs
def forward(self, timesteps):
half = self.dim // 2
freqs = self._compute_freqs(half).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if self.dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
output = self.linear(embedding)
return output
class PoseEmbedding(nn.Module):
def __init__(self, target_dim, n_harmonic_functions=10, append_input=True):
super().__init__()
self._emb_pose = HarmonicEmbedding(
n_harmonic_functions=n_harmonic_functions, append_input=append_input
)
self.out_dim = self._emb_pose.get_output_dim(target_dim)
def forward(self, pose_encoding):
e_pose_encoding = self._emb_pose(pose_encoding)
return e_pose_encoding