Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
import torch | |
from pytorch3d.renderer.compositing import ( | |
alpha_composite, | |
norm_weighted_sum, | |
weighted_sum, | |
) | |
from .common_testing import get_random_cuda_device, TestCaseMixin | |
class TestAccumulatePoints(TestCaseMixin, unittest.TestCase): | |
# NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING) | |
def accumulate_alphacomposite_python(points_idx, alphas, features): | |
""" | |
Naive pure PyTorch implementation of alpha_composite. | |
Inputs / Outputs: Same as function | |
""" | |
B, K, H, W = points_idx.size() | |
C = features.size(0) | |
output = torch.zeros(B, C, H, W, dtype=alphas.dtype) | |
for b in range(0, B): | |
for c in range(0, C): | |
for i in range(0, W): | |
for j in range(0, H): | |
t_alpha = 1 | |
for k in range(0, K): | |
n_idx = points_idx[b, k, j, i] | |
if n_idx < 0: | |
continue | |
alpha = alphas[b, k, j, i] | |
output[b, c, j, i] += features[c, n_idx] * alpha * t_alpha | |
t_alpha = (1 - alpha) * t_alpha | |
return output | |
def accumulate_weightedsum_python(points_idx, alphas, features): | |
""" | |
Naive pure PyTorch implementation of weighted_sum rasterization. | |
Inputs / Outputs: Same as function | |
""" | |
B, K, H, W = points_idx.size() | |
C = features.size(0) | |
output = torch.zeros(B, C, H, W, dtype=alphas.dtype) | |
for b in range(0, B): | |
for c in range(0, C): | |
for i in range(0, W): | |
for j in range(0, H): | |
for k in range(0, K): | |
n_idx = points_idx[b, k, j, i] | |
if n_idx < 0: | |
continue | |
alpha = alphas[b, k, j, i] | |
output[b, c, j, i] += features[c, n_idx] * alpha | |
return output | |
def accumulate_weightedsumnorm_python(points_idx, alphas, features): | |
""" | |
Naive pure PyTorch implementation of norm_weighted_sum. | |
Inputs / Outputs: Same as function | |
""" | |
B, K, H, W = points_idx.size() | |
C = features.size(0) | |
output = torch.zeros(B, C, H, W, dtype=alphas.dtype) | |
for b in range(0, B): | |
for c in range(0, C): | |
for i in range(0, W): | |
for j in range(0, H): | |
t_alpha = 0 | |
for k in range(0, K): | |
n_idx = points_idx[b, k, j, i] | |
if n_idx < 0: | |
continue | |
t_alpha += alphas[b, k, j, i] | |
t_alpha = max(t_alpha, 1e-4) | |
for k in range(0, K): | |
n_idx = points_idx[b, k, j, i] | |
if n_idx < 0: | |
continue | |
alpha = alphas[b, k, j, i] | |
output[b, c, j, i] += features[c, n_idx] * alpha / t_alpha | |
return output | |
def test_python(self): | |
device = torch.device("cpu") | |
self._simple_alphacomposite(self.accumulate_alphacomposite_python, device) | |
self._simple_wsum(self.accumulate_weightedsum_python, device) | |
self._simple_wsumnorm(self.accumulate_weightedsumnorm_python, device) | |
def test_cpu(self): | |
device = torch.device("cpu") | |
self._simple_alphacomposite(alpha_composite, device) | |
self._simple_wsum(weighted_sum, device) | |
self._simple_wsumnorm(norm_weighted_sum, device) | |
def test_cuda(self): | |
device = get_random_cuda_device() | |
self._simple_alphacomposite(alpha_composite, device) | |
self._simple_wsum(weighted_sum, device) | |
self._simple_wsumnorm(norm_weighted_sum, device) | |
def test_python_vs_cpu_vs_cuda(self): | |
self._python_vs_cpu_vs_cuda( | |
self.accumulate_alphacomposite_python, alpha_composite | |
) | |
self._python_vs_cpu_vs_cuda( | |
self.accumulate_weightedsumnorm_python, norm_weighted_sum | |
) | |
self._python_vs_cpu_vs_cuda(self.accumulate_weightedsum_python, weighted_sum) | |
def _python_vs_cpu_vs_cuda(self, accumulate_func_python, accumulate_func): | |
torch.manual_seed(231) | |
device = torch.device("cpu") | |
W = 8 | |
C = 3 | |
P = 32 | |
for d in ["cpu", get_random_cuda_device()]: | |
# TODO(gkioxari) add torch.float64 to types after double precision | |
# support is added to atomicAdd | |
for t in [torch.float32]: | |
device = torch.device(d) | |
# Create values | |
alphas = torch.rand(2, 4, W, W, dtype=t).to(device) | |
alphas.requires_grad = True | |
alphas_cpu = alphas.detach().cpu() | |
alphas_cpu.requires_grad = True | |
features = torch.randn(C, P, dtype=t).to(device) | |
features.requires_grad = True | |
features_cpu = features.detach().cpu() | |
features_cpu.requires_grad = True | |
inds = torch.randint(P + 1, size=(2, 4, W, W)).to(device) - 1 | |
inds_cpu = inds.detach().cpu() | |
args_cuda = (inds, alphas, features) | |
args_cpu = (inds_cpu, alphas_cpu, features_cpu) | |
self._compare_impls( | |
accumulate_func_python, | |
accumulate_func, | |
args_cpu, | |
args_cuda, | |
(alphas_cpu, features_cpu), | |
(alphas, features), | |
compare_grads=True, | |
) | |
def _compare_impls( | |
self, fn1, fn2, args1, args2, grads1, grads2, compare_grads=False | |
): | |
res1 = fn1(*args1) | |
res2 = fn2(*args2) | |
self.assertClose(res1.cpu(), res2.cpu(), atol=1e-6) | |
if not compare_grads: | |
return | |
# Compare gradients | |
torch.manual_seed(231) | |
grad_res = torch.randn_like(res1) | |
loss1 = (res1 * grad_res).sum() | |
loss1.backward() | |
grads1 = [gradsi.grad.data.clone().cpu() for gradsi in grads1] | |
grad_res = grad_res.to(res2) | |
loss2 = (res2 * grad_res).sum() | |
loss2.backward() | |
grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2] | |
for i in range(0, len(grads1)): | |
self.assertClose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6) | |
def _simple_wsum(self, accum_func, device): | |
# Initialise variables | |
features = torch.Tensor([[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]).to(device) | |
alphas = torch.Tensor( | |
[ | |
[ | |
[ | |
[0.5, 0.5, 0.5, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 0.5, 0.5, 0.5], | |
], | |
[ | |
[0.5, 0.5, 0.5, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 0.5, 0.5, 0.5], | |
], | |
] | |
] | |
).to(device) | |
points_idx = ( | |
torch.Tensor( | |
[ | |
[ | |
# fmt: off | |
[ | |
[0, 0, 0, 0], # noqa: E241, E201 | |
[0, -1, -1, -1], # noqa: E241, E201 | |
[0, 1, 1, 0], # noqa: E241, E201 | |
[0, 0, 0, 0], # noqa: E241, E201 | |
], | |
[ | |
[2, 2, 2, 2], # noqa: E241, E201 | |
[2, 3, 3, 2], # noqa: E241, E201 | |
[2, 3, 3, 2], # noqa: E241, E201 | |
[2, 2, -1, 2], # noqa: E241, E201 | |
], | |
# fmt: on | |
] | |
] | |
) | |
.long() | |
.to(device) | |
) | |
result = accum_func(points_idx, alphas, features) | |
self.assertTrue(result.shape == (1, 2, 4, 4)) | |
true_result = torch.Tensor( | |
[ | |
[ | |
[ | |
[0.35, 0.35, 0.35, 0.35], | |
[0.35, 0.90, 0.90, 0.30], | |
[0.35, 1.30, 1.30, 0.35], | |
[0.35, 0.35, 0.05, 0.35], | |
], | |
[ | |
[0.35, 0.35, 0.35, 0.35], | |
[0.35, 0.90, 0.90, 0.30], | |
[0.35, 1.30, 1.30, 0.35], | |
[0.35, 0.35, 0.05, 0.35], | |
], | |
] | |
] | |
).to(device) | |
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3) | |
def _simple_wsumnorm(self, accum_func, device): | |
# Initialise variables | |
features = torch.Tensor([[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]).to(device) | |
alphas = torch.Tensor( | |
[ | |
[ | |
[ | |
[0.5, 0.5, 0.5, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 0.5, 0.5, 0.5], | |
], | |
[ | |
[0.5, 0.5, 0.5, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 0.5, 0.5, 0.5], | |
], | |
] | |
] | |
).to(device) | |
# fmt: off | |
points_idx = ( | |
torch.Tensor( | |
[ | |
[ | |
[ | |
[0, 0, 0, 0], # noqa: E241, E201 | |
[0, -1, -1, -1], # noqa: E241, E201 | |
[0, 1, 1, 0], # noqa: E241, E201 | |
[0, 0, 0, 0], # noqa: E241, E201 | |
], | |
[ | |
[2, 2, 2, 2], # noqa: E241, E201 | |
[2, 3, 3, 2], # noqa: E241, E201 | |
[2, 3, 3, 2], # noqa: E241, E201 | |
[2, 2, -1, 2], # noqa: E241, E201 | |
], | |
] | |
] | |
) | |
.long() | |
.to(device) | |
) | |
# fmt: on | |
result = accum_func(points_idx, alphas, features) | |
self.assertTrue(result.shape == (1, 2, 4, 4)) | |
true_result = torch.Tensor( | |
[ | |
[ | |
[ | |
[0.35, 0.35, 0.35, 0.35], | |
[0.35, 0.90, 0.90, 0.60], | |
[0.35, 0.65, 0.65, 0.35], | |
[0.35, 0.35, 0.10, 0.35], | |
], | |
[ | |
[0.35, 0.35, 0.35, 0.35], | |
[0.35, 0.90, 0.90, 0.60], | |
[0.35, 0.65, 0.65, 0.35], | |
[0.35, 0.35, 0.10, 0.35], | |
], | |
] | |
] | |
).to(device) | |
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3) | |
def _simple_alphacomposite(self, accum_func, device): | |
# Initialise variables | |
features = torch.Tensor([[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]).to(device) | |
alphas = torch.Tensor( | |
[ | |
[ | |
[ | |
[0.5, 0.5, 0.5, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 0.5, 0.5, 0.5], | |
], | |
[ | |
[0.5, 0.5, 0.5, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 1.0, 1.0, 0.5], | |
[0.5, 0.5, 0.5, 0.5], | |
], | |
] | |
] | |
).to(device) | |
# fmt: off | |
points_idx = ( | |
torch.Tensor( | |
[ | |
[ | |
[ | |
[0, 0, 0, 0], # noqa: E241, E201 | |
[0, -1, -1, -1], # noqa: E241, E201 | |
[0, 1, 1, 0], # noqa: E241, E201 | |
[0, 0, 0, 0], # noqa: E241, E201 | |
], | |
[ | |
[2, 2, 2, 2], # noqa: E241, E201 | |
[2, 3, 3, 2], # noqa: E241, E201 | |
[2, 3, 3, 2], # noqa: E241, E201 | |
[2, 2, -1, 2], # noqa: E241, E201 | |
], | |
] | |
] | |
) | |
.long() | |
.to(device) | |
) | |
# fmt: on | |
result = accum_func(points_idx, alphas, features) | |
self.assertTrue(result.shape == (1, 2, 4, 4)) | |
true_result = torch.Tensor( | |
[ | |
[ | |
[ | |
[0.20, 0.20, 0.20, 0.20], | |
[0.20, 0.90, 0.90, 0.30], | |
[0.20, 0.40, 0.40, 0.20], | |
[0.20, 0.20, 0.05, 0.20], | |
], | |
[ | |
[0.20, 0.20, 0.20, 0.20], | |
[0.20, 0.90, 0.90, 0.30], | |
[0.20, 0.40, 0.40, 0.20], | |
[0.20, 0.20, 0.05, 0.20], | |
], | |
] | |
] | |
).to(device) | |
self.assertTrue((result == true_result).all().item()) | |