|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Different model implementation plus a general port for all the models.""" |
|
from typing import Any, Callable |
|
from flax import linen as nn |
|
from jax import random |
|
import jax.numpy as jnp |
|
|
|
from nerf import model_utils |
|
from nerf import utils |
|
|
|
|
|
def get_model(key, example_batch, args): |
|
"""A helper function that wraps around a 'model zoo'.""" |
|
model_dict = {"nerf": construct_nerf} |
|
return model_dict[args.model](key, example_batch, args) |
|
|
|
|
|
class NerfModel(nn.Module): |
|
"""Nerf NN Model with both coarse and fine MLPs.""" |
|
num_coarse_samples: int |
|
num_fine_samples: int |
|
use_viewdirs: bool |
|
near: float |
|
far: float |
|
noise_std: float |
|
net_depth: int |
|
net_width: int |
|
net_depth_condition: int |
|
net_width_condition: int |
|
net_activation: Callable[..., Any] |
|
skip_layer: int |
|
num_rgb_channels: int |
|
num_sigma_channels: int |
|
white_bkgd: bool |
|
min_deg_point: int |
|
max_deg_point: int |
|
deg_view: int |
|
lindisp: bool |
|
rgb_activation: Callable[..., Any] |
|
sigma_activation: Callable[..., Any] |
|
legacy_posenc_order: bool |
|
|
|
@nn.compact |
|
def __call__(self, rng_0, rng_1, rays, randomized, rgb_only = False): |
|
"""Nerf Model. |
|
|
|
Args: |
|
rng_0: jnp.ndarray, random number generator for coarse model sampling. |
|
rng_1: jnp.ndarray, random number generator for fine model sampling. |
|
rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs. |
|
randomized: bool, use randomized stratified sampling. |
|
|
|
Returns: |
|
ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)] |
|
""" |
|
|
|
key, rng_0 = random.split(rng_0) |
|
dtype = rays[0].dtype |
|
|
|
z_vals, samples = model_utils.sample_along_rays( |
|
key, |
|
rays.origins, |
|
rays.directions, |
|
self.num_coarse_samples, |
|
self.near, |
|
self.far, |
|
randomized, |
|
self.lindisp, |
|
) |
|
|
|
samples_enc = model_utils.posenc( |
|
samples, |
|
self.min_deg_point, |
|
self.max_deg_point, |
|
self.legacy_posenc_order, |
|
) |
|
|
|
|
|
coarse_mlp = model_utils.MLP( |
|
net_depth=self.net_depth, |
|
net_width=self.net_width, |
|
net_depth_condition=self.net_depth_condition, |
|
net_width_condition=self.net_width_condition, |
|
net_activation=self.net_activation, |
|
skip_layer=self.skip_layer, |
|
num_rgb_channels=self.num_rgb_channels, |
|
num_sigma_channels=self.num_sigma_channels) |
|
|
|
|
|
if self.use_viewdirs: |
|
viewdirs_enc = model_utils.posenc( |
|
rays.viewdirs, |
|
0, |
|
self.deg_view, |
|
self.legacy_posenc_order, |
|
) |
|
raw_rgb, raw_sigma = coarse_mlp(samples_enc, viewdirs_enc) |
|
else: |
|
viewdirs_enc = None |
|
raw_rgb, raw_sigma = coarse_mlp(samples_enc) |
|
|
|
key, rng_0 = random.split(rng_0) |
|
raw_sigma = model_utils.add_gaussian_noise( |
|
key, |
|
raw_sigma, |
|
self.noise_std, |
|
randomized, |
|
) |
|
rgb = self.rgb_activation(raw_rgb) |
|
sigma = self.sigma_activation(raw_sigma) |
|
|
|
comp_rgb, disp, acc, weights = model_utils.volumetric_rendering( |
|
rgb, |
|
sigma, |
|
z_vals, |
|
rays.directions, |
|
white_bkgd=self.white_bkgd, |
|
) |
|
|
|
ret = [ |
|
(comp_rgb, disp, acc), |
|
] |
|
|
|
if self.num_fine_samples > 0: |
|
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) |
|
key, rng_1 = random.split(rng_1) |
|
|
|
z_vals, samples = model_utils.sample_pdf( |
|
key, |
|
z_vals_mid, |
|
weights[..., 1:-1], |
|
rays.origins, |
|
rays.directions, |
|
z_vals, |
|
self.num_fine_samples, |
|
randomized, |
|
) |
|
samples_enc = model_utils.posenc( |
|
samples, |
|
self.min_deg_point, |
|
self.max_deg_point, |
|
self.legacy_posenc_order, |
|
) |
|
|
|
|
|
fine_mlp = model_utils.MLP( |
|
net_depth=self.net_depth, |
|
net_width=self.net_width, |
|
net_depth_condition=self.net_depth_condition, |
|
net_width_condition=self.net_width_condition, |
|
net_activation=self.net_activation, |
|
skip_layer=self.skip_layer, |
|
num_rgb_channels=self.num_rgb_channels, |
|
num_sigma_channels=self.num_sigma_channels) |
|
|
|
if self.use_viewdirs: |
|
raw_rgb, raw_sigma = fine_mlp(samples_enc, viewdirs_enc) |
|
else: |
|
raw_rgb, raw_sigma = fine_mlp(samples_enc) |
|
key, rng_1 = random.split(rng_1) |
|
raw_sigma = model_utils.add_gaussian_noise( |
|
key, |
|
raw_sigma, |
|
self.noise_std, |
|
randomized, |
|
) |
|
rgb = self.rgb_activation(raw_rgb) |
|
sigma = self.sigma_activation(raw_sigma) |
|
|
|
comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering( |
|
rgb, |
|
sigma, |
|
z_vals, |
|
rays.directions, |
|
white_bkgd=self.white_bkgd, |
|
) |
|
ret.append((comp_rgb, disp, acc)) |
|
if rgb_only: |
|
return [ret[0][0], ret[1][0]] |
|
return ret |
|
|
|
def construct_nerf(key, example_batch, args): |
|
"""Construct a Neural Radiance Field. |
|
|
|
Args: |
|
key: jnp.ndarray. Random number generator. |
|
example_batch: dict, an example of a batch of data. |
|
args: FLAGS class. Hyperparameters of nerf. |
|
|
|
Returns: |
|
model: nn.Model. Nerf model with parameters. |
|
state: flax.Module.state. Nerf model state for stateful parameters. |
|
""" |
|
net_activation = getattr(nn, str(args.net_activation)) |
|
rgb_activation = getattr(nn, str(args.rgb_activation)) |
|
sigma_activation = getattr(nn, str(args.sigma_activation)) |
|
|
|
|
|
|
|
x = jnp.exp(jnp.linspace(-90, 90, 1024)) |
|
x = jnp.concatenate([-x[::-1], x], 0) |
|
|
|
rgb = rgb_activation(x) |
|
if jnp.any(rgb < 0) or jnp.any(rgb > 1): |
|
raise NotImplementedError( |
|
"Choice of rgb_activation `{}` produces colors outside of [0, 1]" |
|
.format(args.rgb_activation)) |
|
|
|
sigma = sigma_activation(x) |
|
if jnp.any(sigma < 0): |
|
raise NotImplementedError( |
|
"Choice of sigma_activation `{}` produces negative densities".format( |
|
args.sigma_activation)) |
|
|
|
model = NerfModel( |
|
min_deg_point=args.min_deg_point, |
|
max_deg_point=args.max_deg_point, |
|
deg_view=args.deg_view, |
|
num_coarse_samples=args.num_coarse_samples, |
|
num_fine_samples=args.num_fine_samples, |
|
use_viewdirs=args.use_viewdirs, |
|
near=args.near, |
|
far=args.far, |
|
noise_std=args.noise_std, |
|
white_bkgd=args.white_bkgd, |
|
net_depth=args.net_depth, |
|
net_width=args.net_width, |
|
net_depth_condition=args.net_depth_condition, |
|
net_width_condition=args.net_width_condition, |
|
skip_layer=args.skip_layer, |
|
num_rgb_channels=args.num_rgb_channels, |
|
num_sigma_channels=args.num_sigma_channels, |
|
lindisp=args.lindisp, |
|
net_activation=net_activation, |
|
rgb_activation=rgb_activation, |
|
sigma_activation=sigma_activation, |
|
legacy_posenc_order=args.legacy_posenc_order) |
|
rays = example_batch["rays"] |
|
key1, key2, key3 = random.split(key, num=3) |
|
|
|
init_variables = model.init( |
|
key1, |
|
rng_0=key2, |
|
rng_1=key3, |
|
rays=utils.namedtuple_map(lambda x: x[0], rays), |
|
randomized=args.randomized) |
|
|
|
return model, init_variables |
|
|