sseung0703's picture
update
e8c4ed3
raw
history blame
9.81 kB
# coding=utf-8
# Copyright 2021 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Different model implementation plus a general port for all the models."""
from typing import Any, Callable
from flax import linen as nn
from jax import random
import jax.numpy as jnp
from nerf import model_utils
from nerf import utils
def get_model(key, example_batch, args):
"""A helper function that wraps around a 'model zoo'."""
model_dict = {"nerf": construct_nerf}
return model_dict[args.model](key, example_batch, args)
class NerfModel(nn.Module):
"""Nerf NN Model with both coarse and fine MLPs."""
num_coarse_samples: int # The number of samples for the coarse nerf.
num_fine_samples: int # The number of samples for the fine nerf.
use_viewdirs: bool # If True, use viewdirs as an input.
near: float # The distance to the near plane
far: float # The distance to the far plane
noise_std: float # The std dev of noise added to raw sigma.
net_depth: int # The depth of the first part of MLP.
net_width: int # The width of the first part of MLP.
net_depth_condition: int # The depth of the second part of MLP.
net_width_condition: int # The width of the second part of MLP.
net_activation: Callable[..., Any] # MLP activation
skip_layer: int # How often to add skip connections.
num_rgb_channels: int # The number of RGB channels.
num_sigma_channels: int # The number of density channels.
white_bkgd: bool # If True, use a white background.
min_deg_point: int # The minimum degree of positional encoding for positions.
max_deg_point: int # The maximum degree of positional encoding for positions.
deg_view: int # The degree of positional encoding for viewdirs.
lindisp: bool # If True, sample linearly in disparity rather than in depth.
rgb_activation: Callable[..., Any] # Output RGB activation.
sigma_activation: Callable[..., Any] # Output sigma activation.
legacy_posenc_order: bool # Keep the same ordering as the original tf code.
@nn.compact
def __call__(self, rng_0, rng_1, rays, randomized, rgb_only = False):
"""Nerf Model.
Args:
rng_0: jnp.ndarray, random number generator for coarse model sampling.
rng_1: jnp.ndarray, random number generator for fine model sampling.
rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
randomized: bool, use randomized stratified sampling.
Returns:
ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)]
"""
# Stratified sampling along rays
key, rng_0 = random.split(rng_0)
dtype = rays[0].dtype
z_vals, samples = model_utils.sample_along_rays(
key,
rays.origins,
rays.directions,
self.num_coarse_samples,
self.near,
self.far,
randomized,
self.lindisp,
)
samples_enc = model_utils.posenc(
samples,
self.min_deg_point,
self.max_deg_point,
self.legacy_posenc_order,
)
# Construct the "coarse" MLP.
coarse_mlp = model_utils.MLP(
net_depth=self.net_depth,
net_width=self.net_width,
net_depth_condition=self.net_depth_condition,
net_width_condition=self.net_width_condition,
net_activation=self.net_activation,
skip_layer=self.skip_layer,
num_rgb_channels=self.num_rgb_channels,
num_sigma_channels=self.num_sigma_channels)
# Point attribute predictions
if self.use_viewdirs:
viewdirs_enc = model_utils.posenc(
rays.viewdirs,
0,
self.deg_view,
self.legacy_posenc_order,
)
raw_rgb, raw_sigma = coarse_mlp(samples_enc, viewdirs_enc)
else:
viewdirs_enc = None
raw_rgb, raw_sigma = coarse_mlp(samples_enc)
# Add noises to regularize the density predictions if needed
key, rng_0 = random.split(rng_0)
raw_sigma = model_utils.add_gaussian_noise(
key,
raw_sigma,
self.noise_std,
randomized,
)
rgb = self.rgb_activation(raw_rgb)
sigma = self.sigma_activation(raw_sigma)
# Volumetric rendering.
comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
rgb,
sigma,
z_vals,
rays.directions,
white_bkgd=self.white_bkgd,
)
ret = [
(comp_rgb, disp, acc),
]
if self.num_fine_samples > 0 and not(rgb_only):
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
key, rng_1 = random.split(rng_1)
z_vals, samples = model_utils.sample_pdf(
key,
z_vals_mid,
weights[..., 1:-1],
rays.origins,
rays.directions,
z_vals,
self.num_fine_samples,
randomized,
)
samples_enc = model_utils.posenc(
samples,
self.min_deg_point,
self.max_deg_point,
self.legacy_posenc_order,
)
# Construct the "fine" MLP.
fine_mlp = model_utils.MLP(
net_depth=self.net_depth,
net_width=self.net_width,
net_depth_condition=self.net_depth_condition,
net_width_condition=self.net_width_condition,
net_activation=self.net_activation,
skip_layer=self.skip_layer,
num_rgb_channels=self.num_rgb_channels,
num_sigma_channels=self.num_sigma_channels)
if self.use_viewdirs:
raw_rgb, raw_sigma = fine_mlp(samples_enc, viewdirs_enc)
else:
raw_rgb, raw_sigma = fine_mlp(samples_enc)
key, rng_1 = random.split(rng_1)
raw_sigma = model_utils.add_gaussian_noise(
key,
raw_sigma,
self.noise_std,
randomized,
)
rgb = self.rgb_activation(raw_rgb)
sigma = self.sigma_activation(raw_sigma)
comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering(
rgb,
sigma,
z_vals,
rays.directions,
white_bkgd=self.white_bkgd,
)
ret.append((comp_rgb, disp, acc))
if rgb_only:
#return [ret[0][0], ret[1][0]]
return [None, ret[0][0]]
return ret
def construct_nerf(key, example_batch, args):
"""Construct a Neural Radiance Field.
Args:
key: jnp.ndarray. Random number generator.
example_batch: dict, an example of a batch of data.
args: FLAGS class. Hyperparameters of nerf.
Returns:
model: nn.Model. Nerf model with parameters.
state: flax.Module.state. Nerf model state for stateful parameters.
"""
net_activation = getattr(nn, str(args.net_activation))
rgb_activation = getattr(nn, str(args.rgb_activation))
sigma_activation = getattr(nn, str(args.sigma_activation))
# Assert that rgb_activation always produces outputs in [0, 1], and
# sigma_activation always produce non-negative outputs.
x = jnp.exp(jnp.linspace(-90, 90, 1024))
x = jnp.concatenate([-x[::-1], x], 0)
rgb = rgb_activation(x)
if jnp.any(rgb < 0) or jnp.any(rgb > 1):
raise NotImplementedError(
"Choice of rgb_activation `{}` produces colors outside of [0, 1]"
.format(args.rgb_activation))
sigma = sigma_activation(x)
if jnp.any(sigma < 0):
raise NotImplementedError(
"Choice of sigma_activation `{}` produces negative densities".format(
args.sigma_activation))
model = NerfModel(
min_deg_point=args.min_deg_point,
max_deg_point=args.max_deg_point,
deg_view=args.deg_view,
num_coarse_samples=args.num_coarse_samples,
num_fine_samples=args.num_fine_samples,
use_viewdirs=args.use_viewdirs,
near=args.near,
far=args.far,
noise_std=args.noise_std,
white_bkgd=args.white_bkgd,
net_depth=args.net_depth,
net_width=args.net_width,
net_depth_condition=args.net_depth_condition,
net_width_condition=args.net_width_condition,
skip_layer=args.skip_layer,
num_rgb_channels=args.num_rgb_channels,
num_sigma_channels=args.num_sigma_channels,
lindisp=args.lindisp,
net_activation=net_activation,
rgb_activation=rgb_activation,
sigma_activation=sigma_activation,
legacy_posenc_order=args.legacy_posenc_order)
rays = example_batch["rays"]
key1, key2, key3 = random.split(key, num=3)
init_variables = model.init(
key1,
rng_0=key2,
rng_1=key3,
rays=utils.namedtuple_map(lambda x: x[0], rays),
randomized=args.randomized)
return model, init_variables