Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,786 Bytes
8866a87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Optional
import torch
class HarmonicEmbedding(torch.nn.Module):
def __init__(
self, n_harmonic_functions: int = 6, omega_0: float = 1.0, logspace: bool = True, append_input: bool = True
) -> None:
"""
The harmonic embedding layer supports the classical
Nerf positional encoding described in
`NeRF <https://arxiv.org/abs/2003.08934>`_
and the integrated position encoding in
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
During the inference you can provide the extra argument `diag_cov`.
If `diag_cov is None`, it converts
rays parametrized with a `ray_bundle` to 3D points by
extending each ray according to the corresponding length.
Then it converts each feature
(i.e. vector along the last dimension) in `x`
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]::
[
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i], # only present if append_input is True.
]
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `diag_cov is not None`, it approximates
conical frustums following a ray bundle as gaussians,
defined by x, the means of the gaussians and diag_cov,
the diagonal covariances.
Then it converts each gaussian
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]::
[
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
...
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
...
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
x[..., i], # only present if append_input is True.
]
where N equals `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
powers of 2:
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
If `logspace==False`, frequencies are linearly spaced between
`1.0` and `2**(n_harmonic_functions-1)`:
`f_1, ..., f_N = torch.linspace(
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
)`
Note that `x` is also premultiplied by the base frequency `omega_0`
before evaluating the harmonic functions.
Args:
n_harmonic_functions: int, number of harmonic
features
omega_0: float, base frequency
logspace: bool, Whether to space the frequencies in
logspace or linear space
append_input: bool, whether to concat the original
input to the harmonic embedding. If true the
output is of the form (embed.sin(), embed.cos(), x)
"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32)
else:
frequencies = torch.linspace(
1.0, 2.0 ** (n_harmonic_functions - 1), n_harmonic_functions, dtype=torch.float32
)
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
self.register_buffer("_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False)
self.append_input = append_input
def forward(self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
"""
Args:
x: tensor of shape [..., dim]
diag_cov: An optional tensor of shape `(..., dim)`
representing the diagonal covariance matrices of our Gaussians, joined with x
as means of the Gaussians.
Returns:
embedding: a harmonic embedding of `x` of shape
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
"""
# [..., dim, n_harmonic_functions]
embed = x[..., None] * self._frequencies
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
# Use the trig identity cos(x) = sin(x + pi/2)
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
embed = embed.sin()
if diag_cov is not None:
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
exp_var = torch.exp(-0.5 * x_var)
# [..., 2, dim, n_harmonic_functions]
embed = embed * exp_var[..., None, :, :]
embed = embed.reshape(*x.shape[:-1], -1)
if self.append_input:
return torch.cat([embed, x], dim=-1)
return embed
@staticmethod
def get_output_dim_static(input_dims: int, n_harmonic_functions: int, append_input: bool) -> int:
"""
Utility to help predict the shape of the output of `forward`.
Args:
input_dims: length of the last dimension of the input tensor
n_harmonic_functions: number of embedding frequencies
append_input: whether or not to concat the original
input to the harmonic embedding
Returns:
int: the length of the last dimension of the output tensor
"""
return input_dims * (2 * n_harmonic_functions + int(append_input))
def get_output_dim(self, input_dims: int = 3) -> int:
"""
Same as above. The default for input_dims is 3 for 3D applications
which use harmonic embedding for positional encoding,
so the input might be xyz.
"""
return self.get_output_dim_static(input_dims, len(self._frequencies), self.append_input)
|