|
from pathlib import Path |
|
from jaxtyping import Float |
|
import numpy as np |
|
from scipy.spatial.transform import Rotation as R |
|
from plyfile import PlyData, PlyElement |
|
import torch |
|
from torch import Tensor |
|
from einops import rearrange, einsum |
|
|
|
|
|
def construct_list_of_attributes(num_rest: int) -> list[str]: |
|
attributes = ["x", "y", "z", "nx", "ny", "nz"] |
|
for i in range(3): |
|
attributes.append(f"f_dc_{i}") |
|
for i in range(num_rest): |
|
attributes.append(f"f_rest_{i}") |
|
attributes.append("opacity") |
|
for i in range(3): |
|
attributes.append(f"scale_{i}") |
|
for i in range(4): |
|
attributes.append(f"rot_{i}") |
|
return attributes |
|
|
|
|
|
def export_ply( |
|
means: Float[Tensor, "gaussian 3"], |
|
scales: Float[Tensor, "gaussian 3"], |
|
rotations: Float[Tensor, "gaussian 4"], |
|
harmonics: Float[Tensor, "gaussian 3 d_sh"], |
|
opacities: Float[Tensor, "gaussian"], |
|
path: Path, |
|
): |
|
path = Path(path) |
|
|
|
means = means - means.median(dim=0).values |
|
|
|
|
|
scale_factor = means.abs().quantile(0.95, dim=0).max() |
|
means = means / scale_factor |
|
scales = scales / scale_factor |
|
scales = scales * 4.0 |
|
scales = torch.clamp(scales, 0, 0.0075) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rotation = [ |
|
[1, 0, 0], |
|
[0, 1, 0], |
|
[0, 0, 1], |
|
] |
|
rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
means = einsum(rotation, means, "i j, ... j -> ... i") |
|
|
|
|
|
rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() |
|
rotations = rotation.detach().cpu().numpy() @ rotations |
|
rotations = R.from_matrix(rotations).as_quat() |
|
x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") |
|
rotations = np.stack((w, x, y, z), axis=-1) |
|
|
|
|
|
|
|
harmonics_view_invariant = harmonics |
|
|
|
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] |
|
elements = np.empty(means.shape[0], dtype=dtype_full) |
|
attributes = ( |
|
means.detach().cpu().numpy(), |
|
torch.zeros_like(means).detach().cpu().numpy(), |
|
harmonics_view_invariant.detach().cpu().contiguous().numpy(), |
|
opacities.detach().cpu().numpy(), |
|
scales.log().detach().cpu().numpy(), |
|
rotations, |
|
) |
|
attributes = np.concatenate(attributes, axis=1) |
|
elements[:] = list(map(tuple, attributes)) |
|
path.parent.mkdir(exist_ok=True, parents=True) |
|
PlyData([PlyElement.describe(elements, "vertex")]).write(path) |
|
|
|
|
|
def save_ply(outputs, path, num_gauss=3): |
|
pad = 32 |
|
|
|
def crop_r(t): |
|
h, w = 256, 384 |
|
H = h + pad * 2 |
|
W = w + pad * 2 |
|
t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W) |
|
t = t[..., pad:H-pad, pad:W-pad] |
|
t = rearrange(t, "b c h w -> b c (h w)") |
|
return t |
|
|
|
def crop(t): |
|
h, w = 256, 384 |
|
H = h + pad * 2 |
|
W = w + pad * 2 |
|
t = t[..., pad:H-pad, pad:W-pad] |
|
return t |
|
|
|
|
|
|
|
means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3] |
|
scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] |
|
rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] |
|
opacities = rearrange(crop(outputs[('gauss_opacity', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] |
|
harmonics = rearrange(crop(outputs[('gauss_features_dc', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] |
|
|
|
export_ply( |
|
means, |
|
scales, |
|
rotations, |
|
harmonics, |
|
opacities, |
|
path |
|
) |