# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest import torch from pytorch3d.renderer.blending import ( BlendParams, hard_rgb_blend, sigmoid_alpha_blend, softmax_rgb_blend, ) from pytorch3d.renderer.cameras import FoVPerspectiveCameras from pytorch3d.renderer.mesh.rasterizer import Fragments from pytorch3d.renderer.splatter_blend import SplatterBlender from .common_testing import TestCaseMixin def sigmoid_blend_naive_loop(colors, fragments, blend_params): """ Naive for loop based implementation of distance based alpha calculation. Only for test purposes. """ pix_to_face = fragments.pix_to_face dists = fragments.dists sigma = blend_params.sigma N, H, W, K = pix_to_face.shape device = pix_to_face.device pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device) for n in range(N): for h in range(H): for w in range(W): alpha = 1.0 # Loop over k faces and calculate 2D distance based probability # map. for k in range(K): if pix_to_face[n, h, w, k] >= 0: prob = torch.sigmoid(-dists[n, h, w, k] / sigma) alpha *= 1.0 - prob # cumulative product pixel_colors[n, h, w, :3] = colors[n, h, w, 0, :] pixel_colors[n, h, w, 3] = 1.0 - alpha return pixel_colors def sigmoid_alpha_blend_vectorized(colors, fragments, blend_params) -> torch.Tensor: N, H, W, K = fragments.pix_to_face.shape pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device) mask = fragments.pix_to_face >= 0 prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask pixel_colors[..., :3] = colors[..., 0, :] pixel_colors[..., 3] = 1.0 - torch.prod((1.0 - prob), dim=-1) return pixel_colors def sigmoid_blend_naive_loop_backward(grad_images, images, fragments, blend_params): pix_to_face = fragments.pix_to_face dists = fragments.dists sigma = blend_params.sigma N, H, W, K = pix_to_face.shape device = pix_to_face.device grad_distances = torch.zeros((N, H, W, K), dtype=dists.dtype, device=device) for n in range(N): for h in range(H): for w in range(W): alpha = 1.0 - images[n, h, w, 3] grad_alpha = grad_images[n, h, w, 3] # Loop over k faces and calculate 2D distance based probability # map. for k in range(K): if pix_to_face[n, h, w, k] >= 0: prob = torch.sigmoid(-dists[n, h, w, k] / sigma) grad_distances[n, h, w, k] = ( grad_alpha * (-1.0 / sigma) * prob * alpha ) return grad_distances def softmax_blend_naive(colors, fragments, blend_params): """ Naive for loop based implementation of softmax blending. Only for test purposes. """ pix_to_face = fragments.pix_to_face dists = fragments.dists zbuf = fragments.zbuf sigma = blend_params.sigma gamma = blend_params.gamma N, H, W, K = pix_to_face.shape device = pix_to_face.device pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device) # Near and far clipping planes zfar = 100.0 znear = 1.0 eps = 1e-10 bk_color = blend_params.background_color if not torch.is_tensor(bk_color): bk_color = torch.tensor(bk_color, dtype=colors.dtype, device=device) for n in range(N): for h in range(H): for w in range(W): alpha = 1.0 weights_k = torch.zeros(K, device=device) zmax = torch.tensor(0.0, device=device) # Loop over K to find max z. for k in range(K): if pix_to_face[n, h, w, k] >= 0: zinv = (zfar - zbuf[n, h, w, k]) / (zfar - znear) if zinv > zmax: zmax = zinv # Loop over K faces to calculate 2D distance based probability # map and zbuf based weights for colors. for k in range(K): if pix_to_face[n, h, w, k] >= 0: zinv = (zfar - zbuf[n, h, w, k]) / (zfar - znear) prob = torch.sigmoid(-dists[n, h, w, k] / sigma) alpha *= 1.0 - prob # cumulative product weights_k[k] = prob * torch.exp((zinv - zmax) / gamma) # Clamp to ensure delta is never 0 delta = torch.exp((eps - zmax) / blend_params.gamma).clamp(min=eps) delta = delta.to(device) denom = weights_k.sum() + delta cols = (weights_k[..., None] * colors[n, h, w, :, :]).sum(dim=0) pixel_colors[n, h, w, :3] = cols + delta * bk_color pixel_colors[n, h, w, :3] /= denom pixel_colors[n, h, w, 3] = 1.0 - alpha return pixel_colors class TestBlending(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: torch.manual_seed(42) def _compare_impls( self, fn1, fn2, args1, args2, grad_var1=None, grad_var2=None, compare_grads=True ): out1 = fn1(*args1) out2 = fn2(*args2) self.assertClose(out1.cpu()[..., 3], out2.cpu()[..., 3], atol=1e-7) # Check gradients if not compare_grads: return grad_out = torch.randn_like(out1) (out1 * grad_out).sum().backward() self.assertTrue(hasattr(grad_var1, "grad")) (out2 * grad_out).sum().backward() self.assertTrue(hasattr(grad_var2, "grad")) self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5) def test_hard_rgb_blend(self): N, H, W, K = 5, 10, 10, 20 pix_to_face = torch.randint(low=-1, high=100, size=(N, H, W, K)) bary_coords = torch.ones((N, H, W, K, 3)) fragments = Fragments( pix_to_face=pix_to_face, bary_coords=bary_coords, zbuf=pix_to_face, # dummy dists=pix_to_face, # dummy ) colors = torch.randn((N, H, W, K, 3)) blend_params = BlendParams(1e-4, 1e-4, (0.5, 0.5, 1)) images = hard_rgb_blend(colors, fragments, blend_params) # Examine if the foreground colors are correct. is_foreground = pix_to_face[..., 0] >= 0 self.assertClose(images[is_foreground][:, :3], colors[is_foreground][..., 0, :]) # Examine if the background colors are correct. for i in range(3): # i.e. RGB channel_color = blend_params.background_color[i] self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all()) # Examine the alpha channel self.assertClose(images[..., 3], (pix_to_face[..., 0] >= 0).float()) def test_sigmoid_alpha_blend_manual_gradients(self): # Create dummy outputs of rasterization torch.manual_seed(231) F = 32 # number of faces in the mesh # The python loop version is really slow so only using small input sizes. N, S, K = 2, 3, 2 device = torch.device("cuda") pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1 colors = torch.randn((N, S, S, K, 3), device=device) empty = torch.tensor([], device=device) # # randomly flip the sign of the distance # # (-) means inside triangle, (+) means outside triangle. random_sign_flip = torch.rand((N, S, S, K)) random_sign_flip[random_sign_flip > 0.5] *= -1.0 dists = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) fragments = Fragments( pix_to_face=pix_to_face, bary_coords=empty, # dummy zbuf=empty, # dummy dists=dists, ) blend_params = BlendParams(sigma=1e-3) pix_cols = sigmoid_blend_naive_loop(colors, fragments, blend_params) grad_out = torch.randn_like(pix_cols) # Backward pass pix_cols.backward(grad_out) grad_dists = sigmoid_blend_naive_loop_backward( grad_out, pix_cols, fragments, blend_params ) self.assertTrue(torch.allclose(dists.grad, grad_dists, atol=1e-7)) def test_sigmoid_alpha_blend_python(self): """ Test outputs of python tensorised function and python loop """ # Create dummy outputs of rasterization torch.manual_seed(231) F = 32 # number of faces in the mesh # The python loop version is really slow so only using small input sizes. N, S, K = 1, 4, 1 device = torch.device("cuda") pix_to_face = torch.randint(low=-1, high=F, size=(N, S, S, K), device=device) colors = torch.randn((N, S, S, K, 3), device=device) empty = torch.tensor([], device=device) dists1 = torch.randn(size=(N, S, S, K), device=device) dists2 = dists1.clone() dists1.requires_grad = True dists2.requires_grad = True fragments1 = Fragments( pix_to_face=pix_to_face, bary_coords=empty, # dummy zbuf=empty, # dummy dists=dists1, ) fragments2 = Fragments( pix_to_face=pix_to_face, bary_coords=empty, # dummy zbuf=empty, # dummy dists=dists2, ) blend_params = BlendParams(sigma=1e-2) args1 = (colors, fragments1, blend_params) args2 = (colors, fragments2, blend_params) self._compare_impls( sigmoid_alpha_blend, sigmoid_alpha_blend_vectorized, args1, args2, dists1, dists2, compare_grads=True, ) def test_softmax_rgb_blend(self): # Create dummy outputs of rasterization simulating a cube in the center # of the image with surrounding padded values. N, S, K = 1, 8, 2 device = torch.device("cuda") pix_to_face = torch.full( (N, S, S, K), fill_value=-1, dtype=torch.int64, device=device ) h = int(S / 2) pix_to_face_full = torch.randint( size=(N, h, h, K), low=0, high=100, device=device ) s = int(S / 4) e = int(0.75 * S) pix_to_face[:, s:e, s:e, :] = pix_to_face_full empty = torch.tensor([], device=device) random_sign_flip = torch.rand((N, S, S, K), device=device) random_sign_flip[random_sign_flip > 0.5] *= -1.0 zbuf1 = torch.randn(size=(N, S, S, K), device=device) # randomly flip the sign of the distance # (-) means inside triangle, (+) means outside triangle. dists1 = torch.randn(size=(N, S, S, K), device=device) * random_sign_flip dists2 = dists1.clone() zbuf2 = zbuf1.clone() dists1.requires_grad = True dists2.requires_grad = True colors = torch.randn((N, S, S, K, 3), device=device) fragments1 = Fragments( pix_to_face=pix_to_face, bary_coords=empty, # dummy zbuf=zbuf1, dists=dists1, ) fragments2 = Fragments( pix_to_face=pix_to_face, bary_coords=empty, # dummy zbuf=zbuf2, dists=dists2, ) blend_params = BlendParams(sigma=1e-3) args1 = (colors, fragments1, blend_params) args2 = (colors, fragments2, blend_params) self._compare_impls( softmax_rgb_blend, softmax_blend_naive, args1, args2, dists1, dists2, compare_grads=True, ) @staticmethod def bm_sigmoid_alpha_blending( num_meshes: int = 16, image_size: int = 128, faces_per_pixel: int = 100, device="cuda", backend: str = "pytorch", ): device = torch.device(device) torch.manual_seed(231) # Create dummy outputs of rasterization N, S, K = num_meshes, image_size, faces_per_pixel F = 32 # num faces in the mesh pix_to_face = torch.randint( low=-1, high=F + 1, size=(N, S, S, K), device=device ) colors = torch.randn((N, S, S, K, 3), device=device) empty = torch.tensor([], device=device) dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) fragments = Fragments( pix_to_face=pix_to_face, bary_coords=empty, # dummy zbuf=empty, # dummy dists=dists1, ) blend_params = BlendParams(sigma=1e-3) blend_fn = ( sigmoid_alpha_blend_vectorized if backend == "pytorch" else sigmoid_alpha_blend ) torch.cuda.synchronize() def fn(): # test forward and backward pass images = blend_fn(colors, fragments, blend_params) images.sum().backward() torch.cuda.synchronize() return fn @staticmethod def bm_softmax_blending( num_meshes: int = 16, image_size: int = 128, faces_per_pixel: int = 100, device: str = "cpu", backend: str = "pytorch", ): if torch.cuda.is_available() and "cuda:" in device: # If a device other than the default is used, set the device explicity. torch.cuda.set_device(device) device = torch.device(device) torch.manual_seed(231) # Create dummy outputs of rasterization N, S, K = num_meshes, image_size, faces_per_pixel F = 32 # num faces in the mesh pix_to_face = torch.randint( low=-1, high=F + 1, size=(N, S, S, K), device=device ) colors = torch.randn((N, S, S, K, 3), device=device) empty = torch.tensor([], device=device) dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) zbuf = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) fragments = Fragments( pix_to_face=pix_to_face, bary_coords=empty, zbuf=zbuf, dists=dists1 # dummy ) blend_params = BlendParams(sigma=1e-3) torch.cuda.synchronize() def fn(): # test forward and backward pass images = softmax_rgb_blend(colors, fragments, blend_params) images.sum().backward() torch.cuda.synchronize() return fn @staticmethod def bm_splatter_blending( num_meshes: int = 16, image_size: int = 128, faces_per_pixel: int = 2, use_jit: bool = False, device: str = "cpu", backend: str = "pytorch", ): if torch.cuda.is_available() and "cuda:" in device: # If a device other than the default is used, set the device explicity. torch.cuda.set_device(device) device = torch.device(device) torch.manual_seed(231) # Create dummy outputs of rasterization N, S, K = num_meshes, image_size, faces_per_pixel F = 32 # num faces in the mesh pixel_coords_camera = torch.randn( (N, S, S, K, 3), device=device, requires_grad=True ) cameras = FoVPerspectiveCameras(device=device) colors = torch.randn((N, S, S, K, 3), device=device) background_mask = torch.randint( low=-1, high=F + 1, size=(N, S, S, K), device=device ) background_mask = torch.full((N, S, S, K), False, dtype=bool, device=device) blend_params = BlendParams(sigma=0.5) torch.cuda.synchronize() splatter_blender = SplatterBlender((N, S, S, K), colors.device) def fn(): # test forward and backward pass images = splatter_blender( colors, pixel_coords_camera, cameras, background_mask, blend_params, ) images.sum().backward() torch.cuda.synchronize() return fn def test_blend_params(self): """Test color parameter of BlendParams(). Assert passed value overrides default value. """ bp_default = BlendParams() bp_new = BlendParams(background_color=(0.5, 0.5, 0.5)) self.assertEqual(bp_new.background_color, (0.5, 0.5, 0.5)) self.assertEqual(bp_default.background_color, (1.0, 1.0, 1.0))