Linly-Talker / pytorch3d /tests /test_compositing.py
linxianzhong0128's picture
Upload folder using huggingface_hub
7088d16 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.renderer.compositing import (
alpha_composite,
norm_weighted_sum,
weighted_sum,
)
from .common_testing import get_random_cuda_device, TestCaseMixin
class TestAccumulatePoints(TestCaseMixin, unittest.TestCase):
# NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING)
@staticmethod
def accumulate_alphacomposite_python(points_idx, alphas, features):
"""
Naive pure PyTorch implementation of alpha_composite.
Inputs / Outputs: Same as function
"""
B, K, H, W = points_idx.size()
C = features.size(0)
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
for b in range(0, B):
for c in range(0, C):
for i in range(0, W):
for j in range(0, H):
t_alpha = 1
for k in range(0, K):
n_idx = points_idx[b, k, j, i]
if n_idx < 0:
continue
alpha = alphas[b, k, j, i]
output[b, c, j, i] += features[c, n_idx] * alpha * t_alpha
t_alpha = (1 - alpha) * t_alpha
return output
@staticmethod
def accumulate_weightedsum_python(points_idx, alphas, features):
"""
Naive pure PyTorch implementation of weighted_sum rasterization.
Inputs / Outputs: Same as function
"""
B, K, H, W = points_idx.size()
C = features.size(0)
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
for b in range(0, B):
for c in range(0, C):
for i in range(0, W):
for j in range(0, H):
for k in range(0, K):
n_idx = points_idx[b, k, j, i]
if n_idx < 0:
continue
alpha = alphas[b, k, j, i]
output[b, c, j, i] += features[c, n_idx] * alpha
return output
@staticmethod
def accumulate_weightedsumnorm_python(points_idx, alphas, features):
"""
Naive pure PyTorch implementation of norm_weighted_sum.
Inputs / Outputs: Same as function
"""
B, K, H, W = points_idx.size()
C = features.size(0)
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
for b in range(0, B):
for c in range(0, C):
for i in range(0, W):
for j in range(0, H):
t_alpha = 0
for k in range(0, K):
n_idx = points_idx[b, k, j, i]
if n_idx < 0:
continue
t_alpha += alphas[b, k, j, i]
t_alpha = max(t_alpha, 1e-4)
for k in range(0, K):
n_idx = points_idx[b, k, j, i]
if n_idx < 0:
continue
alpha = alphas[b, k, j, i]
output[b, c, j, i] += features[c, n_idx] * alpha / t_alpha
return output
def test_python(self):
device = torch.device("cpu")
self._simple_alphacomposite(self.accumulate_alphacomposite_python, device)
self._simple_wsum(self.accumulate_weightedsum_python, device)
self._simple_wsumnorm(self.accumulate_weightedsumnorm_python, device)
def test_cpu(self):
device = torch.device("cpu")
self._simple_alphacomposite(alpha_composite, device)
self._simple_wsum(weighted_sum, device)
self._simple_wsumnorm(norm_weighted_sum, device)
def test_cuda(self):
device = get_random_cuda_device()
self._simple_alphacomposite(alpha_composite, device)
self._simple_wsum(weighted_sum, device)
self._simple_wsumnorm(norm_weighted_sum, device)
def test_python_vs_cpu_vs_cuda(self):
self._python_vs_cpu_vs_cuda(
self.accumulate_alphacomposite_python, alpha_composite
)
self._python_vs_cpu_vs_cuda(
self.accumulate_weightedsumnorm_python, norm_weighted_sum
)
self._python_vs_cpu_vs_cuda(self.accumulate_weightedsum_python, weighted_sum)
def _python_vs_cpu_vs_cuda(self, accumulate_func_python, accumulate_func):
torch.manual_seed(231)
device = torch.device("cpu")
W = 8
C = 3
P = 32
for d in ["cpu", get_random_cuda_device()]:
# TODO(gkioxari) add torch.float64 to types after double precision
# support is added to atomicAdd
for t in [torch.float32]:
device = torch.device(d)
# Create values
alphas = torch.rand(2, 4, W, W, dtype=t).to(device)
alphas.requires_grad = True
alphas_cpu = alphas.detach().cpu()
alphas_cpu.requires_grad = True
features = torch.randn(C, P, dtype=t).to(device)
features.requires_grad = True
features_cpu = features.detach().cpu()
features_cpu.requires_grad = True
inds = torch.randint(P + 1, size=(2, 4, W, W)).to(device) - 1
inds_cpu = inds.detach().cpu()
args_cuda = (inds, alphas, features)
args_cpu = (inds_cpu, alphas_cpu, features_cpu)
self._compare_impls(
accumulate_func_python,
accumulate_func,
args_cpu,
args_cuda,
(alphas_cpu, features_cpu),
(alphas, features),
compare_grads=True,
)
def _compare_impls(
self, fn1, fn2, args1, args2, grads1, grads2, compare_grads=False
):
res1 = fn1(*args1)
res2 = fn2(*args2)
self.assertClose(res1.cpu(), res2.cpu(), atol=1e-6)
if not compare_grads:
return
# Compare gradients
torch.manual_seed(231)
grad_res = torch.randn_like(res1)
loss1 = (res1 * grad_res).sum()
loss1.backward()
grads1 = [gradsi.grad.data.clone().cpu() for gradsi in grads1]
grad_res = grad_res.to(res2)
loss2 = (res2 * grad_res).sum()
loss2.backward()
grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2]
for i in range(0, len(grads1)):
self.assertClose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)
def _simple_wsum(self, accum_func, device):
# Initialise variables
features = torch.Tensor([[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]).to(device)
alphas = torch.Tensor(
[
[
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
]
]
).to(device)
points_idx = (
torch.Tensor(
[
[
# fmt: off
[
[0, 0, 0, 0], # noqa: E241, E201
[0, -1, -1, -1], # noqa: E241, E201
[0, 1, 1, 0], # noqa: E241, E201
[0, 0, 0, 0], # noqa: E241, E201
],
[
[2, 2, 2, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 2, -1, 2], # noqa: E241, E201
],
# fmt: on
]
]
)
.long()
.to(device)
)
result = accum_func(points_idx, alphas, features)
self.assertTrue(result.shape == (1, 2, 4, 4))
true_result = torch.Tensor(
[
[
[
[0.35, 0.35, 0.35, 0.35],
[0.35, 0.90, 0.90, 0.30],
[0.35, 1.30, 1.30, 0.35],
[0.35, 0.35, 0.05, 0.35],
],
[
[0.35, 0.35, 0.35, 0.35],
[0.35, 0.90, 0.90, 0.30],
[0.35, 1.30, 1.30, 0.35],
[0.35, 0.35, 0.05, 0.35],
],
]
]
).to(device)
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
def _simple_wsumnorm(self, accum_func, device):
# Initialise variables
features = torch.Tensor([[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]).to(device)
alphas = torch.Tensor(
[
[
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
]
]
).to(device)
# fmt: off
points_idx = (
torch.Tensor(
[
[
[
[0, 0, 0, 0], # noqa: E241, E201
[0, -1, -1, -1], # noqa: E241, E201
[0, 1, 1, 0], # noqa: E241, E201
[0, 0, 0, 0], # noqa: E241, E201
],
[
[2, 2, 2, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 2, -1, 2], # noqa: E241, E201
],
]
]
)
.long()
.to(device)
)
# fmt: on
result = accum_func(points_idx, alphas, features)
self.assertTrue(result.shape == (1, 2, 4, 4))
true_result = torch.Tensor(
[
[
[
[0.35, 0.35, 0.35, 0.35],
[0.35, 0.90, 0.90, 0.60],
[0.35, 0.65, 0.65, 0.35],
[0.35, 0.35, 0.10, 0.35],
],
[
[0.35, 0.35, 0.35, 0.35],
[0.35, 0.90, 0.90, 0.60],
[0.35, 0.65, 0.65, 0.35],
[0.35, 0.35, 0.10, 0.35],
],
]
]
).to(device)
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
def _simple_alphacomposite(self, accum_func, device):
# Initialise variables
features = torch.Tensor([[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]).to(device)
alphas = torch.Tensor(
[
[
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
]
]
).to(device)
# fmt: off
points_idx = (
torch.Tensor(
[
[
[
[0, 0, 0, 0], # noqa: E241, E201
[0, -1, -1, -1], # noqa: E241, E201
[0, 1, 1, 0], # noqa: E241, E201
[0, 0, 0, 0], # noqa: E241, E201
],
[
[2, 2, 2, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 2, -1, 2], # noqa: E241, E201
],
]
]
)
.long()
.to(device)
)
# fmt: on
result = accum_func(points_idx, alphas, features)
self.assertTrue(result.shape == (1, 2, 4, 4))
true_result = torch.Tensor(
[
[
[
[0.20, 0.20, 0.20, 0.20],
[0.20, 0.90, 0.90, 0.30],
[0.20, 0.40, 0.40, 0.20],
[0.20, 0.20, 0.05, 0.20],
],
[
[0.20, 0.20, 0.20, 0.20],
[0.20, 0.90, 0.90, 0.30],
[0.20, 0.40, 0.40, 0.20],
[0.20, 0.20, 0.05, 0.20],
],
]
]
).to(device)
self.assertTrue((result == true_result).all().item())