|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Different datasets implementation plus a general port for all the datasets.""" |
|
INTERNAL = False |
|
import json |
|
import os, time |
|
from os import path |
|
import queue |
|
import threading |
|
|
|
if not INTERNAL: |
|
import cv2 |
|
import jax |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from nerf import utils |
|
from nerf import clip_utils |
|
|
|
def get_dataset(split, args, clip_model = None): |
|
return dataset_dict[args.dataset](split, args, clip_model) |
|
|
|
|
|
def convert_to_ndc(origins, directions, focal, w, h, near=1.): |
|
"""Convert a set of rays to NDC coordinates.""" |
|
|
|
t = -(near + origins[..., 2]) / directions[..., 2] |
|
origins = origins + t[..., None] * directions |
|
|
|
dx, dy, dz = tuple(np.moveaxis(directions, -1, 0)) |
|
ox, oy, oz = tuple(np.moveaxis(origins, -1, 0)) |
|
|
|
|
|
o0 = -((2 * focal) / w) * (ox / oz) |
|
o1 = -((2 * focal) / h) * (oy / oz) |
|
o2 = 1 + 2 * near / oz |
|
|
|
d0 = -((2 * focal) / w) * (dx / dz - ox / oz) |
|
d1 = -((2 * focal) / h) * (dy / dz - oy / oz) |
|
d2 = -2 * near / oz |
|
|
|
origins = np.stack([o0, o1, o2], -1) |
|
directions = np.stack([d0, d1, d2], -1) |
|
return origins, directions |
|
|
|
|
|
class Dataset(threading.Thread): |
|
"""Dataset Base Class.""" |
|
|
|
def __init__(self, split, flags, clip_model): |
|
super(Dataset, self).__init__() |
|
self.queue = queue.Queue(3) |
|
self.daemon = True |
|
self.use_pixel_centers = flags.use_pixel_centers |
|
self.split = split |
|
|
|
if split == "train": |
|
self._train_init(flags, clip_model) |
|
elif split == "test": |
|
self._test_init(flags) |
|
else: |
|
raise ValueError( |
|
"the split argument should be either \"train\" or \"test\", set" |
|
"to {} here.".format(split)) |
|
self.batch_size = flags.batch_size // jax.process_count() |
|
self.batching = flags.batching |
|
self.render_path = flags.render_path |
|
self.far = flags.far |
|
self.near = flags.near |
|
self.max_steps = flags.max_steps |
|
self.start() |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
"""Get the next training batch or test example. |
|
|
|
Returns: |
|
batch: dict, has "pixels" and "rays". |
|
""" |
|
x = self.queue.get() |
|
if self.split == "train": |
|
return utils.shard(x) |
|
else: |
|
return utils.to_device(x) |
|
|
|
def peek(self): |
|
"""Peek at the next training batch or test example without dequeuing it. |
|
|
|
Returns: |
|
batch: dict, has "pixels" and "rays". |
|
""" |
|
x = self.queue.queue[0].copy() |
|
if self.split == "train": |
|
return utils.shard(x) |
|
else: |
|
return utils.to_device(x) |
|
|
|
def run(self): |
|
if self.split == "train": |
|
next_func = self._next_train |
|
else: |
|
next_func = self._next_test |
|
while True: |
|
self.queue.put(next_func()) |
|
|
|
@property |
|
def size(self): |
|
return self.n_examples |
|
|
|
def _train_init(self, flags, clip_model): |
|
"""Initialize training.""" |
|
self._load_renderings(flags, clip_model) |
|
self._generate_rays() |
|
|
|
if flags.batching == "all_images": |
|
|
|
self.images = self.images.reshape([-1, 3]) |
|
self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]), |
|
self.rays) |
|
elif flags.batching == "single_image": |
|
self.images = self.images.reshape([-1, self.resolution, 3]) |
|
self.rays = utils.namedtuple_map( |
|
lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays) |
|
else: |
|
raise NotImplementedError( |
|
f"{flags.batching} batching strategy is not implemented.") |
|
|
|
def _test_init(self, flags): |
|
self._load_renderings(flags, clip_model = None) |
|
self._generate_rays() |
|
self.it = 0 |
|
|
|
def _next_train(self): |
|
"""Sample next training batch.""" |
|
|
|
if self.batching == "all_images": |
|
ray_indices = np.random.randint(0, self.rays[0].shape[0], |
|
(self.batch_size,)) |
|
batch_pixels = self.images[ray_indices] |
|
batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays) |
|
raise NotImplementedError("image_index not implemented for batching=all_images") |
|
|
|
elif self.batching == "single_image": |
|
image_index = np.random.randint(0, self.n_examples, ()) |
|
ray_indices = np.random.randint(0, self.rays[0][0].shape[0], |
|
(self.batch_size,)) |
|
batch_pixels = self.images[image_index][ray_indices] |
|
batch_rays = utils.namedtuple_map(lambda r: r[image_index][ray_indices], |
|
self.rays) |
|
else: |
|
raise NotImplementedError( |
|
f"{self.batching} batching strategy is not implemented.") |
|
return {"pixels": batch_pixels, "rays": batch_rays, "image_index": image_index} |
|
|
|
def _next_test(self): |
|
"""Sample next test example.""" |
|
idx = self.it |
|
self.it = (self.it + 1) % self.n_examples |
|
|
|
if self.render_path: |
|
return {"rays": utils.namedtuple_map(lambda r: r[idx], self.render_rays)} |
|
else: |
|
return {"pixels": self.images[idx], |
|
"rays": utils.namedtuple_map(lambda r: r[idx], self.rays), |
|
"image_index": idx} |
|
|
|
|
|
def _generate_rays(self): |
|
"""Generating rays for all images.""" |
|
pixel_center = 0.5 if self.use_pixel_centers else 0.0 |
|
x, y = np.meshgrid( |
|
np.arange(self.w, dtype=np.float32) + pixel_center, |
|
np.arange(self.h, dtype=np.float32) + pixel_center, |
|
indexing="xy") |
|
camera_dirs = np.stack([(x - self.w * 0.5) / self.focal, |
|
-(y - self.h * 0.5) / self.focal, -np.ones_like(x)], |
|
axis=-1) |
|
directions = ((camera_dirs[None, ..., None, :] * |
|
self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1)) |
|
origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1], |
|
directions.shape) |
|
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True) |
|
self.rays = utils.Rays( |
|
origins=origins, directions=directions, viewdirs=viewdirs) |
|
|
|
def camtoworld_matrix_to_rays(self, camtoworld, downsample = 1): |
|
""" render one instance of rays given a camera to world matrix (4, 4) """ |
|
pixel_center = 0.5 if self.use_pixel_centers else 0.0 |
|
|
|
x, y = np.meshgrid( |
|
np.arange(self.w, step = downsample, dtype=np.float32) + pixel_center, |
|
np.arange(self.h, step = downsample, dtype=np.float32) + pixel_center, |
|
indexing="xy") |
|
camera_dirs = np.stack([(x - self.w * 0.5) / self.focal, |
|
-(y - self.h * 0.5) / self.focal, -np.ones_like(x)], |
|
axis=-1) |
|
directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1) |
|
origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape) |
|
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True) |
|
return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs) |
|
|
|
class Blender(Dataset): |
|
"""Blender Dataset.""" |
|
|
|
def _load_renderings(self, flags, clip_model = None): |
|
"""Load images from disk.""" |
|
if flags.render_path: |
|
raise ValueError("render_path cannot be used for the blender dataset.") |
|
cams, images, meta = self.load_files(flags.data_dir, self.split, flags.factor, flags.few_shot) |
|
|
|
self.images = np.stack(images, axis=0) |
|
if flags.white_bkgd: |
|
self.images = (self.images[..., :3] * self.images[..., -1:] + |
|
(1. - self.images[..., -1:])) |
|
else: |
|
self.images = self.images[..., :3] |
|
self.h, self.w = self.images.shape[1:3] |
|
self.resolution = self.h * self.w |
|
self.camtoworlds = np.stack(cams, axis=0) |
|
camera_angle_x = float(meta["camera_angle_x"]) |
|
self.focal = .5 * self.w / np.tan(.5 * camera_angle_x) |
|
self.n_examples = self.images.shape[0] |
|
self.dtype = flags.clip_output_dtype |
|
|
|
if flags.use_semantic_loss and clip_model is not None: |
|
embs = [] |
|
for img in self.images: |
|
img = np.expand_dims(np.transpose(img,[2,0,1]), 0) |
|
emb = clip_model.get_image_features(pixel_values = clip_utils.preprocess_for_CLIP(img)) |
|
embs.append( emb/np.linalg.norm(emb) ) |
|
self.embeddings = np.concatenate(embs, 0) |
|
|
|
self.image_idx = np.arange(self.images.shape[0]) |
|
np.random.shuffle(self.image_idx) |
|
self.image_idx = self.image_idx.tolist() |
|
|
|
@staticmethod |
|
def load_files(data_dir, split, factor, few_shot): |
|
with utils.open_file(path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp: |
|
meta = json.load(fp) |
|
images = [] |
|
cams = [] |
|
|
|
frames = np.arange(len(meta["frames"])) |
|
if few_shot > 0 and split == 'train': |
|
|
|
|
|
frames = frames[:few_shot] |
|
|
|
|
|
|
|
|
|
for i in frames: |
|
frame = meta["frames"][i] |
|
fname = os.path.join(data_dir, frame["file_path"] + ".png") |
|
with utils.open_file(fname, "rb") as imgin: |
|
image = np.array(Image.open(imgin)).astype(np.float32) / 255. |
|
if factor == 2: |
|
[halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]] |
|
image = cv2.resize(image, (halfres_w, halfres_h), |
|
interpolation=cv2.INTER_AREA) |
|
elif factor == 4: |
|
[halfres_h, halfres_w] = [hw // 4 for hw in image.shape[:2]] |
|
image = cv2.resize(image, (halfres_w, halfres_h), |
|
interpolation=cv2.INTER_AREA) |
|
elif factor > 0: |
|
raise ValueError("Blender dataset only supports factor=0 or 2 or 4, {} " |
|
"set.".format(factor)) |
|
cams.append(np.array(frame["transform_matrix"], dtype=np.float32)) |
|
images.append(image) |
|
|
|
print(f'No. of samples: {len(frames)}') |
|
return cams, images, meta |
|
|
|
def _next_train(self): |
|
batch_dict = super(Blender, self)._next_train() |
|
if self.batching == "single_image": |
|
image_index = batch_dict.pop("image_index") |
|
else: |
|
raise NotImplementedError |
|
return batch_dict |
|
|
|
def get_clip_data(self): |
|
if len(self.image_idx) == 0: |
|
self.image_idx = np.arange(self.images.shape[0]) |
|
np.random.shuffle(self.image_idx) |
|
self.image_idx = self.image_idx.tolist() |
|
image_index = self.image_idx.pop() |
|
|
|
batch_dict = {} |
|
batch_dict["embedding"] = self.embeddings[image_index] |
|
|
|
src_seed = int(time.time()) |
|
src_rng = jax.random.PRNGKey(src_seed) |
|
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far))) |
|
|
|
cx = np.random.randint(320, 480) |
|
cy = np.random.randint(320, 480) |
|
d = 140 |
|
|
|
random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 1) |
|
random_rays = jax.tree_map(lambda x: x[cy-d:cy+d:4,cx-d:cx+d:4], random_rays) |
|
|
|
w = random_rays[0].shape[0] - random_rays[0].shape[0]%jax.local_device_count() |
|
random_rays = jax.tree_map(lambda x: x[:w,:w].reshape(-1,3), random_rays) |
|
batch_dict["random_rays"] = utils.shard(random_rays) |
|
if self.dtype == 'float16': |
|
batch_dict = jax.tree_map(lambda x: x.astype(np.float16), batch_dict) |
|
return batch_dict |
|
|
|
class LLFF(Dataset): |
|
"""LLFF Dataset.""" |
|
|
|
def _load_renderings(self, flags): |
|
"""Load images from disk.""" |
|
|
|
imgdir_suffix = "" |
|
if flags.factor > 0: |
|
imgdir_suffix = "_{}".format(flags.factor) |
|
factor = flags.factor |
|
else: |
|
factor = 1 |
|
imgdir = path.join(flags.data_dir, "images" + imgdir_suffix) |
|
if not utils.file_exists(imgdir): |
|
raise ValueError("Image folder {} doesn't exist.".format(imgdir)) |
|
imgfiles = [ |
|
path.join(imgdir, f) |
|
for f in sorted(utils.listdir(imgdir)) |
|
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") |
|
] |
|
images = [] |
|
for imgfile in imgfiles: |
|
with utils.open_file(imgfile, "rb") as imgin: |
|
image = np.array(Image.open(imgin), dtype=np.float32) / 255. |
|
images.append(image) |
|
images = np.stack(images, axis=-1) |
|
|
|
|
|
with utils.open_file(path.join(flags.data_dir, "poses_bounds.npy"), |
|
"rb") as fp: |
|
poses_arr = np.load(fp) |
|
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) |
|
bds = poses_arr[:, -2:].transpose([1, 0]) |
|
if poses.shape[-1] != images.shape[-1]: |
|
raise RuntimeError("Mismatch between imgs {} and poses {}".format( |
|
images.shape[-1], poses.shape[-1])) |
|
|
|
|
|
poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) |
|
poses[2, 4, :] = poses[2, 4, :] * 1. / factor |
|
|
|
|
|
poses = np.concatenate( |
|
[poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) |
|
poses = np.moveaxis(poses, -1, 0).astype(np.float32) |
|
images = np.moveaxis(images, -1, 0) |
|
bds = np.moveaxis(bds, -1, 0).astype(np.float32) |
|
|
|
|
|
scale = 1. / (bds.min() * .75) |
|
poses[:, :3, 3] *= scale |
|
bds *= scale |
|
|
|
|
|
poses = self._recenter_poses(poses) |
|
|
|
|
|
if flags.spherify: |
|
poses = self._generate_spherical_poses(poses, bds) |
|
self.spherify = True |
|
else: |
|
self.spherify = False |
|
if not flags.spherify and self.split == "test": |
|
self._generate_spiral_poses(poses, bds) |
|
|
|
|
|
i_test = np.arange(images.shape[0])[::flags.llffhold] |
|
i_train = np.array( |
|
[i for i in np.arange(int(images.shape[0])) if i not in i_test]) |
|
if self.split == "train": |
|
indices = i_train |
|
else: |
|
indices = i_test |
|
images = images[indices] |
|
poses = poses[indices] |
|
|
|
self.images = images |
|
self.camtoworlds = poses[:, :3, :4] |
|
self.focal = poses[0, -1, -1] |
|
self.h, self.w = images.shape[1:3] |
|
self.resolution = self.h * self.w |
|
if flags.render_path: |
|
self.n_examples = self.render_poses.shape[0] |
|
else: |
|
self.n_examples = images.shape[0] |
|
|
|
def _generate_rays(self): |
|
"""Generate normalized device coordinate rays for llff.""" |
|
if self.split == "test": |
|
n_render_poses = self.render_poses.shape[0] |
|
self.camtoworlds = np.concatenate([self.render_poses, self.camtoworlds], |
|
axis=0) |
|
|
|
super()._generate_rays() |
|
|
|
if not self.spherify: |
|
ndc_origins, ndc_directions = convert_to_ndc(self.rays.origins, |
|
self.rays.directions, |
|
self.focal, self.w, self.h) |
|
self.rays = utils.Rays( |
|
origins=ndc_origins, |
|
directions=ndc_directions, |
|
viewdirs=self.rays.viewdirs) |
|
|
|
|
|
if self.split == "test": |
|
self.camtoworlds = self.camtoworlds[n_render_poses:] |
|
split = [np.split(r, [n_render_poses], 0) for r in self.rays] |
|
split0, split1 = zip(*split) |
|
self.render_rays = utils.Rays(*split0) |
|
self.rays = utils.Rays(*split1) |
|
|
|
def _recenter_poses(self, poses): |
|
"""Recenter poses according to the original NeRF code.""" |
|
poses_ = poses.copy() |
|
bottom = np.reshape([0, 0, 0, 1.], [1, 4]) |
|
c2w = self._poses_avg(poses) |
|
c2w = np.concatenate([c2w[:3, :4], bottom], -2) |
|
bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) |
|
poses = np.concatenate([poses[:, :3, :4], bottom], -2) |
|
poses = np.linalg.inv(c2w) @ poses |
|
poses_[:, :3, :4] = poses[:, :3, :4] |
|
poses = poses_ |
|
return poses |
|
|
|
def _poses_avg(self, poses): |
|
"""Average poses according to the original NeRF code.""" |
|
hwf = poses[0, :3, -1:] |
|
center = poses[:, :3, 3].mean(0) |
|
vec2 = self._normalize(poses[:, :3, 2].sum(0)) |
|
up = poses[:, :3, 1].sum(0) |
|
c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1) |
|
return c2w |
|
|
|
def _viewmatrix(self, z, up, pos): |
|
"""Construct lookat view matrix.""" |
|
vec2 = self._normalize(z) |
|
vec1_avg = up |
|
vec0 = self._normalize(np.cross(vec1_avg, vec2)) |
|
vec1 = self._normalize(np.cross(vec2, vec0)) |
|
m = np.stack([vec0, vec1, vec2, pos], 1) |
|
return m |
|
|
|
def _normalize(self, x): |
|
"""Normalization helper function.""" |
|
return x / np.linalg.norm(x) |
|
|
|
def _generate_spiral_poses(self, poses, bds): |
|
"""Generate a spiral path for rendering.""" |
|
c2w = self._poses_avg(poses) |
|
|
|
up = self._normalize(poses[:, :3, 1].sum(0)) |
|
|
|
close_depth, inf_depth = bds.min() * .9, bds.max() * 5. |
|
dt = .75 |
|
mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth)) |
|
focal = mean_dz |
|
|
|
tt = poses[:, :3, 3] |
|
rads = np.percentile(np.abs(tt), 90, 0) |
|
c2w_path = c2w |
|
n_views = 120 |
|
n_rots = 2 |
|
|
|
render_poses = [] |
|
rads = np.array(list(rads) + [1.]) |
|
hwf = c2w_path[:, 4:5] |
|
zrate = .5 |
|
for theta in np.linspace(0., 2. * np.pi * n_rots, n_views + 1)[:-1]: |
|
c = np.dot(c2w[:3, :4], (np.array( |
|
[np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)) |
|
z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) |
|
render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1)) |
|
self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4] |
|
|
|
def _generate_spherical_poses(self, poses, bds): |
|
"""Generate a 360 degree spherical path for rendering.""" |
|
|
|
p34_to_44 = lambda p: np.concatenate([ |
|
p, |
|
np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1]) |
|
], 1) |
|
rays_d = poses[:, :3, 2:3] |
|
rays_o = poses[:, :3, 3:4] |
|
|
|
def min_line_dist(rays_o, rays_d): |
|
a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) |
|
b_i = -a_i @ rays_o |
|
pt_mindist = np.squeeze(-np.linalg.inv( |
|
(np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0)) @ (b_i).mean(0)) |
|
return pt_mindist |
|
|
|
pt_mindist = min_line_dist(rays_o, rays_d) |
|
center = pt_mindist |
|
up = (poses[:, :3, 3] - center).mean(0) |
|
vec0 = self._normalize(up) |
|
vec1 = self._normalize(np.cross([.1, .2, .3], vec0)) |
|
vec2 = self._normalize(np.cross(vec0, vec1)) |
|
pos = center |
|
c2w = np.stack([vec1, vec2, vec0, pos], 1) |
|
poses_reset = ( |
|
np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])) |
|
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) |
|
sc = 1. / rad |
|
poses_reset[:, :3, 3] *= sc |
|
bds *= sc |
|
rad *= sc |
|
centroid = np.mean(poses_reset[:, :3, 3], 0) |
|
zh = centroid[2] |
|
radcircle = np.sqrt(rad ** 2 - zh ** 2) |
|
new_poses = [] |
|
|
|
for th in np.linspace(0., 2. * np.pi, 120): |
|
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) |
|
up = np.array([0, 0, -1.]) |
|
vec2 = self._normalize(camorigin) |
|
vec0 = self._normalize(np.cross(vec2, up)) |
|
vec1 = self._normalize(np.cross(vec2, vec0)) |
|
pos = camorigin |
|
p = np.stack([vec0, vec1, vec2, pos], 1) |
|
new_poses.append(p) |
|
|
|
new_poses = np.stack(new_poses, 0) |
|
new_poses = np.concatenate([ |
|
new_poses, |
|
np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape) |
|
], -1) |
|
poses_reset = np.concatenate([ |
|
poses_reset[:, :3, :4], |
|
np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape) |
|
], -1) |
|
if self.split == "test": |
|
self.render_poses = new_poses[:, :3, :4] |
|
return poses_reset |
|
|
|
|
|
dataset_dict = {"blender": Blender, |
|
"llff": LLFF} |
|
|