Flash3d / flash3d /util /vis3d.py
Ryukijano's picture
commit the whole flash3d
ffbcf9e verified
from pathlib import Path
from jaxtyping import Float
import numpy as np
from scipy.spatial.transform import Rotation as R
from plyfile import PlyData, PlyElement
import torch
from torch import Tensor
from einops import rearrange, einsum
def construct_list_of_attributes(num_rest: int) -> list[str]:
attributes = ["x", "y", "z", "nx", "ny", "nz"]
for i in range(3):
attributes.append(f"f_dc_{i}")
for i in range(num_rest):
attributes.append(f"f_rest_{i}")
attributes.append("opacity")
for i in range(3):
attributes.append(f"scale_{i}")
for i in range(4):
attributes.append(f"rot_{i}")
return attributes
def export_ply(
means: Float[Tensor, "gaussian 3"],
scales: Float[Tensor, "gaussian 3"],
rotations: Float[Tensor, "gaussian 4"],
harmonics: Float[Tensor, "gaussian 3 d_sh"],
opacities: Float[Tensor, "gaussian"],
path: Path,
):
path = Path(path)
# Shift the scene so that the median Gaussian is at the origin.
means = means - means.median(dim=0).values
# Rescale the scene so that most Gaussians are within range [-1, 1].
scale_factor = means.abs().quantile(0.95, dim=0).max()
means = means / scale_factor
scales = scales / scale_factor
scales = scales * 4.0
scales = torch.clamp(scales, 0, 0.0075)
# Define a rotation that makes +Z be the world up vector.
# rotation = [
# [0, 0, 1],
# [-1, 0, 0],
# [0, -1, 0],
# ]
rotation = [
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
]
rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device)
# The Polycam viewer seems to start at a 45 degree angle. Since we want to be
# looking directly at the object, we compose a 45 degree rotation onto the above
# rotation.
# adjustment = torch.tensor(
# R.from_rotvec([0, 0, -45], True).as_matrix(),
# dtype=torch.float32,
# device=means.device,
# )
# rotation = adjustment @ rotation
# We also want to see the scene in camera space (as the default view). We therefore
# compose the w2c rotation onto the above rotation.
# rotation = rotation @ extrinsics[:3, :3].inverse()
# Apply the rotation to the means (Gaussian positions).
means = einsum(rotation, means, "i j, ... j -> ... i")
# Apply the rotation to the Gaussian rotations.
rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
rotations = rotation.detach().cpu().numpy() @ rotations
rotations = R.from_matrix(rotations).as_quat()
x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
rotations = np.stack((w, x, y, z), axis=-1)
# Since our axes are swizzled for the spherical harmonics, we only export the DC
# band.
harmonics_view_invariant = harmonics
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)]
elements = np.empty(means.shape[0], dtype=dtype_full)
attributes = (
means.detach().cpu().numpy(),
torch.zeros_like(means).detach().cpu().numpy(),
harmonics_view_invariant.detach().cpu().contiguous().numpy(),
opacities.detach().cpu().numpy(),
scales.log().detach().cpu().numpy(),
rotations,
)
attributes = np.concatenate(attributes, axis=1)
elements[:] = list(map(tuple, attributes))
path.parent.mkdir(exist_ok=True, parents=True)
PlyData([PlyElement.describe(elements, "vertex")]).write(path)
def save_ply(outputs, path, num_gauss=3):
pad = 32
def crop_r(t):
h, w = 256, 384
H = h + pad * 2
W = w + pad * 2
t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W)
t = t[..., pad:H-pad, pad:W-pad]
t = rearrange(t, "b c h w -> b c (h w)")
return t
def crop(t):
h, w = 256, 384
H = h + pad * 2
W = w + pad * 2
t = t[..., pad:H-pad, pad:W-pad]
return t
# import pdb
# pdb.set_trace()
means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3]
scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
opacities = rearrange(crop(outputs[('gauss_opacity', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
harmonics = rearrange(crop(outputs[('gauss_features_dc', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
export_ply(
means,
scales,
rotations,
harmonics,
opacities,
path
)