File size: 5,364 Bytes
915f69b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
################## sh function ##################
C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [
1.0925484305920792,
-1.0925484305920792,
0.31539156525252005,
-1.0925484305920792,
0.5462742152960396
]
C3 = [
-0.5900435899266435,
2.890611442640554,
-0.4570457994644658,
0.3731763325901154,
-0.4570457994644658,
1.445305721320277,
-0.5900435899266435
]
C4 = [
2.5033429417967046,
-1.7701307697799304,
0.9461746957575601,
-0.6690465435572892,
0.10578554691520431,
-0.6690465435572892,
0.47308734787878004,
-1.7701307697799304,
0.6258357354491761,
]
def eval_sh(deg, sh, dirs):
"""
Evaluate spherical harmonics at unit directions
using hardcoded SH polynomials.
Works with torch/np/jnp.
... Can be 0 or more batch dimensions.
:param deg: int SH max degree. Currently, 0-4 supported
:param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2)
:param dirs: torch.Tensor unit directions (..., 3)
:return: (..., C)
"""
assert deg <= 4 and deg >= 0
assert (deg + 1) ** 2 == sh.shape[-1]
C = sh.shape[-2]
result = C0 * sh[..., 0]
if deg > 0:
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
result = (result -
C1 * y * sh[..., 1] +
C1 * z * sh[..., 2] -
C1 * x * sh[..., 3])
if deg > 1:
xx, yy, zz = x * x, y * y, z * z
xy, yz, xz = x * y, y * z, x * z
result = (result +
C2[0] * xy * sh[..., 4] +
C2[1] * yz * sh[..., 5] +
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
C2[3] * xz * sh[..., 7] +
C2[4] * (xx - yy) * sh[..., 8])
if deg > 2:
result = (result +
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
C3[1] * xy * z * sh[..., 10] +
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
C3[5] * z * (xx - yy) * sh[..., 14] +
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
if deg > 3:
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
return result
def eval_sh_bases(deg, dirs):
"""
Evaluate spherical harmonics bases at unit directions,
without taking linear combination.
At each point, the final result may the be
obtained through simple multiplication.
:param deg: int SH max degree. Currently, 0-4 supported
:param dirs: torch.Tensor (..., 3) unit directions
:return: torch.Tensor (..., (deg+1) ** 2)
"""
assert deg <= 4 and deg >= 0
result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device)
result[..., 0] = C0
if deg > 0:
x, y, z = dirs.unbind(-1)
result[..., 1] = -C1 * y;
result[..., 2] = C1 * z;
result[..., 3] = -C1 * x;
if deg > 1:
xx, yy, zz = x * x, y * y, z * z
xy, yz, xz = x * y, y * z, x * z
result[..., 4] = C2[0] * xy;
result[..., 5] = C2[1] * yz;
result[..., 6] = C2[2] * (2.0 * zz - xx - yy);
result[..., 7] = C2[3] * xz;
result[..., 8] = C2[4] * (xx - yy);
if deg > 2:
result[..., 9] = C3[0] * y * (3 * xx - yy);
result[..., 10] = C3[1] * xy * z;
result[..., 11] = C3[2] * y * (4 * zz - xx - yy);
result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy);
result[..., 13] = C3[4] * x * (4 * zz - xx - yy);
result[..., 14] = C3[5] * z * (xx - yy);
result[..., 15] = C3[6] * x * (xx - 3 * yy);
if deg > 3:
result[..., 16] = C4[0] * xy * (xx - yy);
result[..., 17] = C4[1] * yz * (3 * xx - yy);
result[..., 18] = C4[2] * xy * (7 * zz - 1);
result[..., 19] = C4[3] * yz * (7 * zz - 3);
result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3);
result[..., 21] = C4[5] * xz * (7 * zz - 3);
result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1);
result[..., 23] = C4[7] * xz * (xx - 3 * yy);
result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy));
return result
|