Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pickle | |
import random | |
import unittest | |
from typing import List, Tuple, Union | |
import torch | |
import torch.nn.functional as F | |
from pytorch3d.io import save_obj | |
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap | |
from pytorch3d.transforms.rotation_conversions import random_rotation | |
from .common_testing import get_random_cuda_device, get_tests_dir, TestCaseMixin | |
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3] | |
DATA_DIR = get_tests_dir() / "data" | |
DEBUG = False | |
DOT_EPS = 1e-3 | |
AREA_EPS = 1e-4 | |
UNIT_BOX = [ | |
[0, 0, 0], | |
[1, 0, 0], | |
[1, 1, 0], | |
[0, 1, 0], | |
[0, 0, 1], | |
[1, 0, 1], | |
[1, 1, 1], | |
[0, 1, 1], | |
] | |
class TestIoU3D(TestCaseMixin, unittest.TestCase): | |
def setUp(self) -> None: | |
super().setUp() | |
torch.manual_seed(1) | |
def create_box(xyz, whl): | |
x, y, z = xyz | |
w, h, le = whl | |
verts = torch.tensor( | |
[ | |
[x - w / 2.0, y - h / 2.0, z - le / 2.0], | |
[x + w / 2.0, y - h / 2.0, z - le / 2.0], | |
[x + w / 2.0, y + h / 2.0, z - le / 2.0], | |
[x - w / 2.0, y + h / 2.0, z - le / 2.0], | |
[x - w / 2.0, y - h / 2.0, z + le / 2.0], | |
[x + w / 2.0, y - h / 2.0, z + le / 2.0], | |
[x + w / 2.0, y + h / 2.0, z + le / 2.0], | |
[x - w / 2.0, y + h / 2.0, z + le / 2.0], | |
], | |
device=xyz.device, | |
dtype=torch.float32, | |
) | |
return verts | |
def _box3d_overlap_naive_batched(boxes1, boxes2): | |
""" | |
Wrapper around box3d_overlap_naive to support | |
batched input | |
""" | |
N = boxes1.shape[0] | |
M = boxes2.shape[0] | |
vols = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device) | |
ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device) | |
for n in range(N): | |
for m in range(M): | |
vol, iou = box3d_overlap_naive(boxes1[n], boxes2[m]) | |
vols[n, m] = vol | |
ious[n, m] = iou | |
return vols, ious | |
def _box3d_overlap_sampling_batched(boxes1, boxes2, num_samples: int): | |
""" | |
Wrapper around box3d_overlap_sampling to support | |
batched input | |
""" | |
N = boxes1.shape[0] | |
M = boxes2.shape[0] | |
ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device) | |
for n in range(N): | |
for m in range(M): | |
iou = box3d_overlap_sampling(boxes1[n], boxes2[m]) | |
ious[n, m] = iou | |
return ious | |
def _test_iou(self, overlap_fn, device): | |
box1 = torch.tensor( | |
UNIT_BOX, | |
dtype=torch.float32, | |
device=device, | |
) | |
# 1st test: same box, iou = 1.0 | |
vol, iou = overlap_fn(box1[None], box1[None]) | |
self.assertClose(vol, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype)) | |
self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype)) | |
# 2nd test | |
dd = random.random() | |
box2 = box1 + torch.tensor([[0.0, dd, 0.0]], device=device) | |
vol, iou = overlap_fn(box1[None], box2[None]) | |
self.assertClose( | |
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) | |
) | |
# symmetry | |
vol, iou = overlap_fn(box2[None], box1[None]) | |
self.assertClose( | |
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) | |
) | |
# 3rd test | |
dd = random.random() | |
box2 = box1 + torch.tensor([[dd, 0.0, 0.0]], device=device) | |
vol, _ = overlap_fn(box1[None], box2[None]) | |
self.assertClose( | |
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) | |
) | |
# symmetry | |
vol, _ = overlap_fn(box2[None], box1[None]) | |
self.assertClose( | |
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) | |
) | |
# 4th test | |
ddx, ddy, ddz = random.random(), random.random(), random.random() | |
box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device) | |
vol, _ = overlap_fn(box1[None], box2[None]) | |
self.assertClose( | |
vol, | |
torch.tensor( | |
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]], | |
device=vol.device, | |
dtype=vol.dtype, | |
), | |
) | |
# symmetry | |
vol, _ = overlap_fn(box2[None], box1[None]) | |
self.assertClose( | |
vol, | |
torch.tensor( | |
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]], | |
device=vol.device, | |
dtype=vol.dtype, | |
), | |
) | |
# Also check IoU is 1 when computing overlap with the same shifted box | |
vol, iou = overlap_fn(box2[None], box2[None]) | |
self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype)) | |
# 5th test | |
ddx, ddy, ddz = random.random(), random.random(), random.random() | |
box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device) | |
RR = random_rotation(dtype=torch.float32, device=device) | |
box1r = box1 @ RR.transpose(0, 1) | |
box2r = box2 @ RR.transpose(0, 1) | |
vol, _ = overlap_fn(box1r[None], box2r[None]) | |
self.assertClose( | |
vol, | |
torch.tensor( | |
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]], | |
device=vol.device, | |
dtype=vol.dtype, | |
), | |
) | |
# symmetry | |
vol, _ = overlap_fn(box2r[None], box1r[None]) | |
self.assertClose( | |
vol, | |
torch.tensor( | |
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]], | |
device=vol.device, | |
dtype=vol.dtype, | |
), | |
) | |
# 6th test | |
ddx, ddy, ddz = random.random(), random.random(), random.random() | |
box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device) | |
RR = random_rotation(dtype=torch.float32, device=device) | |
TT = torch.rand((1, 3), dtype=torch.float32, device=device) | |
box1r = box1 @ RR.transpose(0, 1) + TT | |
box2r = box2 @ RR.transpose(0, 1) + TT | |
vol, _ = overlap_fn(box1r[None], box2r[None]) | |
self.assertClose( | |
vol, | |
torch.tensor( | |
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]], | |
device=vol.device, | |
dtype=vol.dtype, | |
), | |
atol=1e-7, | |
) | |
# symmetry | |
vol, _ = overlap_fn(box2r[None], box1r[None]) | |
self.assertClose( | |
vol, | |
torch.tensor( | |
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]], | |
device=vol.device, | |
dtype=vol.dtype, | |
), | |
atol=1e-7, | |
) | |
# 7th test: hand coded example and test with meshlab output | |
# Meshlab procedure to compute volumes of shapes | |
# 1. Load a shape, then Filters | |
# -> Remeshing, Simplification, Reconstruction -> Convex Hull | |
# 2. Select the convex hull shape (This is important!) | |
# 3. Then Filters -> Quality Measure and Computation -> Compute Geometric Measures | |
# 3. Check for "Mesh Volume" in the stdout | |
box1r = torch.tensor( | |
[ | |
[3.1673, -2.2574, 0.4817], | |
[4.6470, 0.2223, 2.4197], | |
[5.2200, 1.1844, 0.7510], | |
[3.7403, -1.2953, -1.1869], | |
[-4.9316, 2.5724, 0.4856], | |
[-3.4519, 5.0521, 2.4235], | |
[-2.8789, 6.0142, 0.7549], | |
[-4.3586, 3.5345, -1.1831], | |
], | |
device=device, | |
) | |
box2r = torch.tensor( | |
[ | |
[0.5623, 4.0647, 3.4334], | |
[3.3584, 4.3191, 1.1791], | |
[3.0724, -5.9235, -0.3315], | |
[0.2763, -6.1779, 1.9229], | |
[-2.0773, 4.6121, 0.2213], | |
[0.7188, 4.8665, -2.0331], | |
[0.4328, -5.3761, -3.5436], | |
[-2.3633, -5.6305, -1.2893], | |
], | |
device=device, | |
) | |
# from Meshlab: | |
vol_inters = 33.558529 | |
vol_box1 = 65.899010 | |
vol_box2 = 156.386719 | |
iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters) | |
vol, iou = overlap_fn(box1r[None], box2r[None]) | |
self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1) | |
self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1) | |
# symmetry | |
vol, iou = overlap_fn(box2r[None], box1r[None]) | |
self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1) | |
self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1) | |
# 8th test: compare with sampling | |
# create box1 | |
ctrs = torch.rand((2, 3), device=device) | |
whl = torch.rand((2, 3), device=device) * 10.0 + 1.0 | |
# box8a & box8b | |
box8a = self.create_box(ctrs[0], whl[0]) | |
box8b = self.create_box(ctrs[1], whl[1]) | |
RR1 = random_rotation(dtype=torch.float32, device=device) | |
TT1 = torch.rand((1, 3), dtype=torch.float32, device=device) | |
RR2 = random_rotation(dtype=torch.float32, device=device) | |
TT2 = torch.rand((1, 3), dtype=torch.float32, device=device) | |
box1r = box8a @ RR1.transpose(0, 1) + TT1 | |
box2r = box8b @ RR2.transpose(0, 1) + TT2 | |
vol, iou = overlap_fn(box1r[None], box2r[None]) | |
iou_sampling = self._box3d_overlap_sampling_batched( | |
box1r[None], box2r[None], num_samples=10000 | |
) | |
self.assertClose(iou, iou_sampling, atol=1e-2) | |
# symmetry | |
vol, iou = overlap_fn(box2r[None], box1r[None]) | |
self.assertClose(iou, iou_sampling, atol=1e-2) | |
# 9th test: non overlapping boxes, iou = 0.0 | |
box2 = box1 + torch.tensor([[0.0, 100.0, 0.0]], device=device) | |
vol, iou = overlap_fn(box1[None], box2[None]) | |
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) | |
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) | |
# symmetry | |
vol, iou = overlap_fn(box2[None], box1[None]) | |
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) | |
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) | |
# 10th test: Non coplanar verts in a plane | |
box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device) | |
msg = "Plane vertices are not coplanar" | |
with self.assertRaisesRegex(ValueError, msg): | |
overlap_fn(box10[None], box10[None]) | |
# 11th test: Skewed bounding boxes but all verts are coplanar | |
box_skew_1 = torch.tensor( | |
[ | |
[0, 0, 0], | |
[1, 0, 0], | |
[1, 1, 0], | |
[0, 1, 0], | |
[-2, -2, 2], | |
[2, -2, 2], | |
[2, 2, 2], | |
[-2, 2, 2], | |
], | |
dtype=torch.float32, | |
device=device, | |
) | |
box_skew_2 = torch.tensor( | |
[ | |
[2.015995, 0.695233, 2.152806], | |
[2.832533, 0.663448, 1.576389], | |
[2.675445, -0.309592, 1.407520], | |
[1.858907, -0.277806, 1.983936], | |
[-0.413922, 3.161758, 2.044343], | |
[2.852230, 3.034615, -0.261321], | |
[2.223878, -0.857545, -0.936800], | |
[-1.042273, -0.730402, 1.368864], | |
], | |
dtype=torch.float32, | |
device=device, | |
) | |
vol1 = 14.000 | |
vol2 = 14.000005 | |
vol_inters = 5.431122 | |
iou = vol_inters / (vol1 + vol2 - vol_inters) | |
vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None]) | |
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1) | |
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1) | |
# symmetry | |
vols, ious = overlap_fn(box_skew_2[None], box_skew_1[None]) | |
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1) | |
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1) | |
# 12th test: Zero area bounding box (from GH issue #992) | |
box12a = torch.tensor( | |
[ | |
[-1.0000, -1.0000, -0.5000], | |
[1.0000, -1.0000, -0.5000], | |
[1.0000, 1.0000, -0.5000], | |
[-1.0000, 1.0000, -0.5000], | |
[-1.0000, -1.0000, 0.5000], | |
[1.0000, -1.0000, 0.5000], | |
[1.0000, 1.0000, 0.5000], | |
[-1.0000, 1.0000, 0.5000], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
box12b = torch.tensor( | |
[ | |
[0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
msg = "Planes have zero areas" | |
with self.assertRaisesRegex(ValueError, msg): | |
overlap_fn(box12a[None], box12b[None]) | |
# symmetry | |
with self.assertRaisesRegex(ValueError, msg): | |
overlap_fn(box12b[None], box12a[None]) | |
# 13th test: From GH issue #992 | |
# Zero area coplanar face after intersection | |
ctrs = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0]]) | |
whl = torch.tensor([[2.0, 2.0, 2.0], [2.0, 2, 2]]) | |
box13a = TestIoU3D.create_box(ctrs[0], whl[0]) | |
box13b = TestIoU3D.create_box(ctrs[1], whl[1]) | |
vol, iou = overlap_fn(box13a[None], box13b[None]) | |
self.assertClose(vol, torch.tensor([[2.0]], device=vol.device, dtype=vol.dtype)) | |
# 14th test: From GH issue #992 | |
# Random rotation, same boxes, iou should be 1.0 | |
corners = ( | |
torch.tensor( | |
[ | |
[-1.0, -1.0, -1.0], | |
[1.0, -1.0, -1.0], | |
[1.0, 1.0, -1.0], | |
[-1.0, 1.0, -1.0], | |
[-1.0, -1.0, 1.0], | |
[1.0, -1.0, 1.0], | |
[1.0, 1.0, 1.0], | |
[-1.0, 1.0, 1.0], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
* 0.5 | |
) | |
yaw = torch.tensor(0.185) | |
Rot = torch.tensor( | |
[ | |
[torch.cos(yaw), 0.0, torch.sin(yaw)], | |
[0.0, 1.0, 0.0], | |
[-torch.sin(yaw), 0.0, torch.cos(yaw)], | |
], | |
dtype=torch.float32, | |
device=device, | |
) | |
corners = (Rot.mm(corners.t())).t() | |
vol, iou = overlap_fn(corners[None], corners[None]) | |
self.assertClose( | |
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2 | |
) | |
# 15th test: From GH issue #1082 | |
box15a = torch.tensor( | |
[ | |
[-2.5629019, 4.13995749, -1.76344576], | |
[1.92329434, 4.28127117, -1.86155124], | |
[1.86994571, 5.97489644, -1.86155124], | |
[-2.61625053, 5.83358276, -1.76344576], | |
[-2.53123587, 4.14095496, -0.31397536], | |
[1.95496037, 4.28226864, -0.41208084], | |
[1.90161174, 5.97589391, -0.41208084], | |
[-2.5845845, 5.83458023, -0.31397536], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
box15b = torch.tensor( | |
[ | |
[-2.6256125, 4.13036357, -1.82893437], | |
[1.87201008, 4.25296695, -1.82893437], | |
[1.82562476, 5.95458116, -1.82893437], | |
[-2.67199782, 5.83197777, -1.82893437], | |
[-2.6256125, 4.13036357, -0.40095884], | |
[1.87201008, 4.25296695, -0.40095884], | |
[1.82562476, 5.95458116, -0.40095884], | |
[-2.67199782, 5.83197777, -0.40095884], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
vol, iou = overlap_fn(box15a[None], box15b[None]) | |
self.assertClose( | |
iou, torch.tensor([[0.91]], device=vol.device, dtype=vol.dtype), atol=1e-2 | |
) | |
# symmetry | |
vol, iou = overlap_fn(box15b[None], box15a[None]) | |
self.assertClose( | |
iou, torch.tensor([[0.91]], device=vol.device, dtype=vol.dtype), atol=1e-2 | |
) | |
# 16th test: From GH issue 1287 | |
box16a = torch.tensor( | |
[ | |
[-167.5847, -70.6167, -2.7927], | |
[-166.7333, -72.4264, -2.7927], | |
[-166.7333, -72.4264, -4.5927], | |
[-167.5847, -70.6167, -4.5927], | |
[-163.0605, -68.4880, -2.7927], | |
[-162.2090, -70.2977, -2.7927], | |
[-162.2090, -70.2977, -4.5927], | |
[-163.0605, -68.4880, -4.5927], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
box16b = torch.tensor( | |
[ | |
[-167.5847, -70.6167, -2.7927], | |
[-166.7333, -72.4264, -2.7927], | |
[-166.7333, -72.4264, -4.5927], | |
[-167.5847, -70.6167, -4.5927], | |
[-163.0605, -68.4880, -2.7927], | |
[-162.2090, -70.2977, -2.7927], | |
[-162.2090, -70.2977, -4.5927], | |
[-163.0605, -68.4880, -4.5927], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
vol, iou = overlap_fn(box16a[None], box16b[None]) | |
self.assertClose( | |
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2 | |
) | |
# symmetry | |
vol, iou = overlap_fn(box16b[None], box16a[None]) | |
self.assertClose( | |
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2 | |
) | |
# 17th test: From GH issue 1287 | |
box17a = torch.tensor( | |
[ | |
[-33.94158, -4.51639, 0.96941], | |
[-34.67156, -2.65437, 0.96941], | |
[-34.67156, -2.65437, -0.95367], | |
[-33.94158, -4.51639, -0.95367], | |
[-38.75954, -6.40521, 0.96941], | |
[-39.48952, -4.54319, 0.96941], | |
[-39.48952, -4.54319, -0.95367], | |
[-38.75954, -6.40521, -0.95367], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
box17b = torch.tensor( | |
[ | |
[-33.94159, -4.51638, 0.96939], | |
[-34.67158, -2.65437, 0.96939], | |
[-34.67158, -2.65437, -0.95368], | |
[-33.94159, -4.51638, -0.95368], | |
[-38.75954, -6.40523, 0.96939], | |
[-39.48953, -4.54321, 0.96939], | |
[-39.48953, -4.54321, -0.95368], | |
[-38.75954, -6.40523, -0.95368], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
vol, iou = overlap_fn(box17a[None], box17b[None]) | |
self.assertClose( | |
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2 | |
) | |
# symmetry | |
vol, iou = overlap_fn(box17b[None], box17a[None]) | |
self.assertClose( | |
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2 | |
) | |
# 18th test: From GH issue 1287 | |
box18a = torch.tensor( | |
[ | |
[-105.6248, -32.7026, -1.2279], | |
[-106.4690, -30.8895, -1.2279], | |
[-106.4690, -30.8895, -3.0279], | |
[-105.6248, -32.7026, -3.0279], | |
[-110.1575, -34.8132, -1.2279], | |
[-111.0017, -33.0001, -1.2279], | |
[-111.0017, -33.0001, -3.0279], | |
[-110.1575, -34.8132, -3.0279], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
box18b = torch.tensor( | |
[ | |
[-105.5094, -32.9504, -1.0641], | |
[-106.4272, -30.9793, -1.0641], | |
[-106.4272, -30.9793, -3.1916], | |
[-105.5094, -32.9504, -3.1916], | |
[-110.0421, -35.0609, -1.0641], | |
[-110.9599, -33.0899, -1.0641], | |
[-110.9599, -33.0899, -3.1916], | |
[-110.0421, -35.0609, -3.1916], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
# from Meshlab | |
vol_inters = 17.108501 | |
vol_box1 = 18.000067 | |
vol_box2 = 23.128527 | |
iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters) | |
vol, iou = overlap_fn(box18a[None], box18b[None]) | |
self.assertClose( | |
iou, | |
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype), | |
atol=1e-2, | |
) | |
self.assertClose( | |
vol, | |
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype), | |
atol=1e-2, | |
) | |
# symmetry | |
vol, iou = overlap_fn(box18b[None], box18a[None]) | |
self.assertClose( | |
iou, | |
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype), | |
atol=1e-2, | |
) | |
self.assertClose( | |
vol, | |
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype), | |
atol=1e-2, | |
) | |
# 19th example: From GH issue 1287 | |
box19a = torch.tensor( | |
[ | |
[-59.4785, -15.6003, 0.4398], | |
[-60.2263, -13.6928, 0.4398], | |
[-60.2263, -13.6928, -1.3909], | |
[-59.4785, -15.6003, -1.3909], | |
[-64.1743, -17.4412, 0.4398], | |
[-64.9221, -15.5337, 0.4398], | |
[-64.9221, -15.5337, -1.3909], | |
[-64.1743, -17.4412, -1.3909], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
box19b = torch.tensor( | |
[ | |
[-59.4874, -15.5775, -0.1512], | |
[-60.2174, -13.7155, -0.1512], | |
[-60.2174, -13.7155, -1.9820], | |
[-59.4874, -15.5775, -1.9820], | |
[-64.1832, -17.4185, -0.1512], | |
[-64.9132, -15.5564, -0.1512], | |
[-64.9132, -15.5564, -1.9820], | |
[-64.1832, -17.4185, -1.9820], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
# from Meshlab | |
vol_inters = 12.505723 | |
vol_box1 = 18.918238 | |
vol_box2 = 18.468531 | |
iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters) | |
vol, iou = overlap_fn(box19a[None], box19b[None]) | |
self.assertClose( | |
iou, | |
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype), | |
atol=1e-2, | |
) | |
self.assertClose( | |
vol, | |
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype), | |
atol=1e-2, | |
) | |
# symmetry | |
vol, iou = overlap_fn(box19b[None], box19a[None]) | |
self.assertClose( | |
iou, | |
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype), | |
atol=1e-2, | |
) | |
self.assertClose( | |
vol, | |
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype), | |
atol=1e-2, | |
) | |
def _test_real_boxes(self, overlap_fn, device): | |
data_filename = "./real_boxes.pkl" | |
with open(DATA_DIR / data_filename, "rb") as f: | |
example = pickle.load(f) | |
verts1 = torch.FloatTensor(example["verts1"]) | |
verts2 = torch.FloatTensor(example["verts2"]) | |
boxes = torch.stack((verts1, verts2)).to(device) | |
iou_expected = torch.eye(2).to(device) | |
vol, iou = overlap_fn(boxes, boxes) | |
self.assertClose(iou, iou_expected) | |
def test_iou_naive(self): | |
device = get_random_cuda_device() | |
self._test_iou(self._box3d_overlap_naive_batched, device) | |
self._test_compare_objectron(self._box3d_overlap_naive_batched, device) | |
self._test_real_boxes(self._box3d_overlap_naive_batched, device) | |
def test_iou_cpu(self): | |
device = torch.device("cpu") | |
self._test_iou(box3d_overlap, device) | |
self._test_compare_objectron(box3d_overlap, device) | |
self._test_real_boxes(box3d_overlap, device) | |
def test_iou_cuda(self): | |
device = torch.device("cuda:0") | |
self._test_iou(box3d_overlap, device) | |
self._test_compare_objectron(box3d_overlap, device) | |
self._test_real_boxes(box3d_overlap, device) | |
def _test_compare_objectron(self, overlap_fn, device): | |
# Load saved objectron data | |
data_filename = "./objectron_vols_ious.pt" | |
objectron_vals = torch.load(DATA_DIR / data_filename) | |
boxes1 = objectron_vals["boxes1"] | |
boxes2 = objectron_vals["boxes2"] | |
vols_objectron = objectron_vals["vols"] | |
ious_objectron = objectron_vals["ious"] | |
boxes1 = boxes1.to(device=device, dtype=torch.float32) | |
boxes2 = boxes2.to(device=device, dtype=torch.float32) | |
# Convert vertex orderings from Objectron to PyTorch3D convention | |
idx = torch.tensor( | |
OBJECTRON_TO_PYTORCH3D_FACE_IDX, dtype=torch.int64, device=device | |
) | |
boxes1 = boxes1.index_select(index=idx, dim=1) | |
boxes2 = boxes2.index_select(index=idx, dim=1) | |
# Run PyTorch3D version | |
vols, ious = overlap_fn(boxes1, boxes2) | |
# Check values match | |
self.assertClose(vols_objectron, vols.cpu()) | |
self.assertClose(ious_objectron, ious.cpu()) | |
def test_batched_errors(self): | |
N, M = 5, 10 | |
boxes1 = torch.randn((N, 8, 3)) | |
boxes2 = torch.randn((M, 10, 3)) | |
with self.assertRaisesRegex(ValueError, "(8, 3)"): | |
box3d_overlap(boxes1, boxes2) | |
def test_box_volume(self): | |
device = torch.device("cuda:0") | |
box1 = torch.tensor( | |
[ | |
[3.1673, -2.2574, 0.4817], | |
[4.6470, 0.2223, 2.4197], | |
[5.2200, 1.1844, 0.7510], | |
[3.7403, -1.2953, -1.1869], | |
[-4.9316, 2.5724, 0.4856], | |
[-3.4519, 5.0521, 2.4235], | |
[-2.8789, 6.0142, 0.7549], | |
[-4.3586, 3.5345, -1.1831], | |
], | |
dtype=torch.float32, | |
device=device, | |
) | |
box2 = torch.tensor( | |
[ | |
[0.5623, 4.0647, 3.4334], | |
[3.3584, 4.3191, 1.1791], | |
[3.0724, -5.9235, -0.3315], | |
[0.2763, -6.1779, 1.9229], | |
[-2.0773, 4.6121, 0.2213], | |
[0.7188, 4.8665, -2.0331], | |
[0.4328, -5.3761, -3.5436], | |
[-2.3633, -5.6305, -1.2893], | |
], | |
dtype=torch.float32, | |
device=device, | |
) | |
box3 = torch.tensor( | |
[ | |
[0, 0, 0], | |
[1, 0, 0], | |
[1, 1, 0], | |
[0, 1, 0], | |
[0, 0, 1], | |
[1, 0, 1], | |
[1, 1, 1], | |
[0, 1, 1], | |
], | |
dtype=torch.float32, | |
device=device, | |
) | |
RR = random_rotation(dtype=torch.float32, device=device) | |
TT = torch.rand((1, 3), dtype=torch.float32, device=device) | |
box4 = box3 @ RR.transpose(0, 1) + TT | |
self.assertClose(box_volume(box1).cpu(), torch.tensor(65.899010), atol=1e-3) | |
self.assertClose(box_volume(box2).cpu(), torch.tensor(156.386719), atol=1e-3) | |
self.assertClose(box_volume(box3).cpu(), torch.tensor(1.0), atol=1e-3) | |
self.assertClose(box_volume(box4).cpu(), torch.tensor(1.0), atol=1e-3) | |
def test_box_planar_dir(self): | |
device = torch.device("cuda:0") | |
box1 = torch.tensor( | |
UNIT_BOX, | |
dtype=torch.float32, | |
device=device, | |
) | |
n1 = torch.tensor( | |
[ | |
[0.0, 0.0, 1.0], | |
[0.0, -1.0, 0.0], | |
[0.0, 1.0, 0.0], | |
[1.0, 0.0, 0.0], | |
[-1.0, 0.0, 0.0], | |
[0.0, 0.0, -1.0], | |
], | |
device=device, | |
dtype=torch.float32, | |
) | |
RR = random_rotation(dtype=torch.float32, device=device) | |
TT = torch.rand((1, 3), dtype=torch.float32, device=device) | |
box2 = box1 @ RR.transpose(0, 1) + TT | |
n2 = n1 @ RR.transpose(0, 1) | |
self.assertClose(box_planar_dir(box1), n1) | |
self.assertClose(box_planar_dir(box2), n2) | |
def iou_naive(N: int, M: int, device="cpu"): | |
box = torch.tensor( | |
[UNIT_BOX], | |
dtype=torch.float32, | |
device=device, | |
) | |
boxes1 = box + torch.randn((N, 1, 3), device=device) | |
boxes2 = box + torch.randn((M, 1, 3), device=device) | |
def output(): | |
vol, iou = TestIoU3D._box3d_overlap_naive_batched(boxes1, boxes2) | |
return output | |
def iou(N: int, M: int, device="cpu"): | |
box = torch.tensor( | |
[UNIT_BOX], | |
dtype=torch.float32, | |
device=device, | |
) | |
boxes1 = box + torch.randn((N, 1, 3), device=device) | |
boxes2 = box + torch.randn((M, 1, 3), device=device) | |
def output(): | |
vol, iou = box3d_overlap(boxes1, boxes2) | |
return output | |
def iou_sampling(N: int, M: int, num_samples: int, device="cpu"): | |
box = torch.tensor( | |
[UNIT_BOX], | |
dtype=torch.float32, | |
device=device, | |
) | |
boxes1 = box + torch.randn((N, 1, 3), device=device) | |
boxes2 = box + torch.randn((M, 1, 3), device=device) | |
def output(): | |
_ = TestIoU3D._box3d_overlap_sampling_batched(boxes1, boxes2, num_samples) | |
return output | |
# -------------------------------------------------- # | |
# NAIVE IMPLEMENTATION # | |
# -------------------------------------------------- # | |
""" | |
The main functions below are: | |
* box3d_overlap_naive: which computes the exact IoU of box1 and box2 | |
* box3d_overlap_sampling: which computes an approximate IoU of box1 and box2 | |
by sampling points within the boxes | |
Note that both implementations currently do not support batching. | |
""" | |
# -------------------------------------------------- # | |
# Throughout this implementation, we assume that boxes | |
# are defined by their 8 corners in the following order | |
# | |
# (4) +---------+. (5) | |
# | ` . | ` . | |
# | (0) +---+-----+ (1) | |
# | | | | | |
# (7) +-----+---+. (6)| | |
# ` . | ` . | | |
# (3) ` +---------+ (2) | |
# | |
# -------------------------------------------------- # | |
# -------------------------------------------------- # | |
# HELPER FUNCTIONS FOR EXACT SOLUTION # | |
# -------------------------------------------------- # | |
def get_tri_verts(box: torch.Tensor) -> torch.Tensor: | |
""" | |
Return the vertex coordinates forming the triangles of the box. | |
The computation here resembles the Meshes data structure. | |
But since we only want this tiny functionality, we abstract it out. | |
Args: | |
box: tensor of shape (8, 3) | |
Returns: | |
tri_verts: tensor of shape (12, 3, 3) | |
""" | |
device = box.device | |
faces = torch.tensor(_box_triangles, device=device, dtype=torch.int64) # (12, 3) | |
tri_verts = box[faces] # (12, 3, 3) | |
return tri_verts | |
def get_plane_verts(box: torch.Tensor) -> torch.Tensor: | |
""" | |
Return the vertex coordinates forming the planes of the box. | |
The computation here resembles the Meshes data structure. | |
But since we only want this tiny functionality, we abstract it out. | |
Args: | |
box: tensor of shape (8, 3) | |
Returns: | |
plane_verts: tensor of shape (6, 4, 3) | |
""" | |
device = box.device | |
faces = torch.tensor(_box_planes, device=device, dtype=torch.int64) # (6, 4) | |
plane_verts = box[faces] # (6, 4, 3) | |
return plane_verts | |
def get_tri_center_normal(tris: torch.Tensor) -> torch.Tensor: | |
""" | |
Returns the center and normal of triangles | |
Args: | |
tris: tensor of shape (T, 3, 3) | |
Returns: | |
center: tensor of shape (T, 3) | |
normal: tensor of shape (T, 3) | |
""" | |
add_dim0 = False | |
if tris.ndim == 2: | |
tris = tris.unsqueeze(0) | |
add_dim0 = True | |
ctr = tris.mean(1) # (T, 3) | |
normals = torch.zeros_like(ctr) | |
v0, v1, v2 = tris.unbind(1) # 3 x (T, 3) | |
# unvectorized solution | |
T = tris.shape[0] | |
for t in range(T): | |
ns = torch.zeros((3, 3), device=tris.device) | |
ns[0] = torch.cross(v0[t] - ctr[t], v1[t] - ctr[t], dim=-1) | |
ns[1] = torch.cross(v0[t] - ctr[t], v2[t] - ctr[t], dim=-1) | |
ns[2] = torch.cross(v1[t] - ctr[t], v2[t] - ctr[t], dim=-1) | |
i = torch.norm(ns, dim=-1).argmax() | |
normals[t] = ns[i] | |
if add_dim0: | |
ctr = ctr[0] | |
normals = normals[0] | |
normals = F.normalize(normals, dim=-1) | |
return ctr, normals | |
def get_plane_center_normal(planes: torch.Tensor) -> torch.Tensor: | |
""" | |
Returns the center and normal of planes | |
Args: | |
planes: tensor of shape (P, 4, 3) | |
Returns: | |
center: tensor of shape (P, 3) | |
normal: tensor of shape (P, 3) | |
""" | |
add_dim0 = False | |
if planes.ndim == 2: | |
planes = planes.unsqueeze(0) | |
add_dim0 = True | |
ctr = planes.mean(1) # (P, 3) | |
normals = torch.zeros_like(ctr) | |
v0, v1, v2, v3 = planes.unbind(1) # 4 x (P, 3) | |
# unvectorized solution | |
P = planes.shape[0] | |
for t in range(P): | |
ns = torch.zeros((6, 3), device=planes.device) | |
ns[0] = torch.cross(v0[t] - ctr[t], v1[t] - ctr[t], dim=-1) | |
ns[1] = torch.cross(v0[t] - ctr[t], v2[t] - ctr[t], dim=-1) | |
ns[2] = torch.cross(v0[t] - ctr[t], v3[t] - ctr[t], dim=-1) | |
ns[3] = torch.cross(v1[t] - ctr[t], v2[t] - ctr[t], dim=-1) | |
ns[4] = torch.cross(v1[t] - ctr[t], v3[t] - ctr[t], dim=-1) | |
ns[5] = torch.cross(v2[t] - ctr[t], v3[t] - ctr[t], dim=-1) | |
i = torch.norm(ns, dim=-1).argmax() | |
normals[t] = ns[i] | |
if add_dim0: | |
ctr = ctr[0] | |
normals = normals[0] | |
normals = F.normalize(normals, dim=-1) | |
return ctr, normals | |
def box_planar_dir( | |
box: torch.Tensor, dot_eps: float = DOT_EPS, area_eps: float = AREA_EPS | |
) -> torch.Tensor: | |
""" | |
Finds the unit vector n which is perpendicular to each plane in the box | |
and points towards the inside of the box. | |
The planes are defined by `_box_planes`. | |
Since the shape is convex, we define the interior to be the direction | |
pointing to the center of the shape. | |
Args: | |
box: tensor of shape (8, 3) of the vertices of the 3D box | |
Returns: | |
n: tensor of shape (6,) of the unit vector orthogonal to the face pointing | |
towards the interior of the shape | |
""" | |
assert box.shape[0] == 8 and box.shape[1] == 3 | |
# center point of each box | |
box_ctr = box.mean(0).view(1, 3) | |
# box planes | |
plane_verts = get_plane_verts(box) # (6, 4, 3) | |
v0, v1, v2, v3 = plane_verts.unbind(1) | |
plane_ctr, n = get_plane_center_normal(plane_verts) | |
# Check all verts are coplanar | |
if ( | |
not ( | |
F.normalize(v3 - v0, dim=-1).unsqueeze(1).bmm(n.unsqueeze(2)).abs() | |
< dot_eps | |
) | |
.all() | |
.item() | |
): | |
msg = "Plane vertices are not coplanar" | |
raise ValueError(msg) | |
# Check all faces have non zero area | |
area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2 | |
area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2 | |
if (area1 < area_eps).any().item() or (area2 < area_eps).any().item(): | |
msg = "Planes have zero areas" | |
raise ValueError(msg) | |
# We can write: `box_ctr = plane_ctr + a * e0 + b * e1 + c * n`, (1). | |
# With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product, | |
# since that e0 is orthogonal to n. Same for e1. | |
""" | |
# Below is how one would solve for (a, b, c) | |
# Solving for (a, b) | |
numF = verts.shape[0] | |
A = torch.ones((numF, 2, 2), dtype=torch.float32, device=device) | |
B = torch.ones((numF, 2), dtype=torch.float32, device=device) | |
A[:, 0, 1] = (e0 * e1).sum(-1) | |
A[:, 1, 0] = (e0 * e1).sum(-1) | |
B[:, 0] = ((box_ctr - plane_ctr) * e0).sum(-1) | |
B[:, 1] = ((box_ctr - plane_ctr) * e1).sum(-1) | |
ab = torch.linalg.solve(A, B) # (numF, 2) | |
a, b = ab.unbind(1) | |
# solving for c | |
c = ((box_ctr - plane_ctr - a.view(numF, 1) * e0 - b.view(numF, 1) * e1) * n).sum(-1) | |
""" | |
# Since we know that <e0, n> = 0 and <e1, n> = 0 (e0 and e1 are orthogonal to n), | |
# the above solution is equivalent to | |
direc = F.normalize(box_ctr - plane_ctr, dim=-1) # (6, 3) | |
c = (direc * n).sum(-1) | |
# If c is negative, then we revert the direction of n such that n points "inside" | |
negc = c < 0.0 | |
n[negc] *= -1.0 | |
# c[negc] *= -1.0 | |
# Now (a, b, c) is the solution to (1) | |
return n | |
def tri_verts_area(tri_verts: torch.Tensor) -> torch.Tensor: | |
""" | |
Computes the area of the triangle faces in tri_verts | |
Args: | |
tri_verts: tensor of shape (T, 3, 3) | |
Returns: | |
areas: the area of the triangles (T, 1) | |
""" | |
add_dim = False | |
if tri_verts.ndim == 2: | |
tri_verts = tri_verts.unsqueeze(0) | |
add_dim = True | |
v0, v1, v2 = tri_verts.unbind(1) | |
areas = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2.0 | |
if add_dim: | |
areas = areas[0] | |
return areas | |
def box_volume(box: torch.Tensor) -> torch.Tensor: | |
""" | |
Computes the volume of each box in boxes. | |
The volume of each box is the sum of all the tetrahedrons | |
formed by the faces of the box. The face of the box is the base of | |
that tetrahedron and the center point of the box is the apex. | |
In other words, vol(box) = sum_i A_i * d_i / 3, | |
where A_i is the area of the i-th face and d_i is the | |
distance of the apex from the face. | |
We use the equivalent dot/cross product formulation. | |
Read https://en.wikipedia.org/wiki/Tetrahedron#Volume | |
Args: | |
box: tensor of shape (8, 3) containing the vertices | |
of the 3D box | |
Returns: | |
vols: the volume of the box | |
""" | |
assert box.shape[0] == 8 and box.shape[1] == 3 | |
# Compute the center point of each box | |
ctr = box.mean(0).view(1, 1, 3) | |
# Extract the coordinates of the faces for each box | |
tri_verts = get_tri_verts(box) | |
# Set the origin of the coordinate system to coincide | |
# with the apex of the tetrahedron to simplify the volume calculation | |
# See https://en.wikipedia.org/wiki/Tetrahedron#Volume | |
tri_verts = tri_verts - ctr | |
# Compute the volume of each box using the dot/cross product formula | |
vols = torch.sum( | |
tri_verts[:, 0] * torch.cross(tri_verts[:, 1], tri_verts[:, 2], dim=-1), | |
dim=-1, | |
) | |
vols = (vols.abs() / 6.0).sum() | |
return vols | |
def coplanar_tri_faces(tri1: torch.Tensor, tri2: torch.Tensor, eps: float = DOT_EPS): | |
""" | |
Determines whether two triangle faces in 3D are coplanar | |
Args: | |
tri1: tensor of shape (3, 3) of the vertices of the 1st triangle | |
tri2: tensor of shape (3, 3) of the vertices of the 2nd triangle | |
Returns: | |
is_coplanar: bool | |
""" | |
tri1_ctr, tri1_n = get_tri_center_normal(tri1) | |
tri2_ctr, tri2_n = get_tri_center_normal(tri2) | |
check1 = tri1_n.dot(tri2_n).abs() > 1 - eps # checks if parallel | |
dist12 = torch.norm(tri1.unsqueeze(1) - tri2.unsqueeze(0), dim=-1) | |
dist12_argmax = dist12.argmax() | |
i1 = dist12_argmax // 3 | |
i2 = dist12_argmax % 3 | |
assert dist12[i1, i2] == dist12.max() | |
check2 = ( | |
F.normalize(tri1[i1] - tri2[i2], dim=0).dot(tri1_n).abs() < eps | |
) or F.normalize(tri1[i1] - tri2[i2], dim=0).dot(tri2_n).abs() < eps | |
return check1 and check2 | |
def coplanar_tri_plane( | |
tri: torch.Tensor, plane: torch.Tensor, n: torch.Tensor, eps: float = DOT_EPS | |
): | |
""" | |
Determines whether two triangle faces in 3D are coplanar | |
Args: | |
tri: tensor of shape (3, 3) of the vertices of the triangle | |
plane: tensor of shape (4, 3) of the vertices of the plane | |
n: tensor of shape (3,) of the unit "inside" direction on the plane | |
Returns: | |
is_coplanar: bool | |
""" | |
tri_ctr, tri_n = get_tri_center_normal(tri) | |
check1 = tri_n.dot(n).abs() > 1 - eps # checks if parallel | |
dist12 = torch.norm(tri.unsqueeze(1) - plane.unsqueeze(0), dim=-1) | |
dist12_argmax = dist12.argmax() | |
i1 = dist12_argmax // 4 | |
i2 = dist12_argmax % 4 | |
assert dist12[i1, i2] == dist12.max() | |
check2 = F.normalize(tri[i1] - plane[i2], dim=0).dot(n).abs() < eps | |
return check1 and check2 | |
def is_inside( | |
plane: torch.Tensor, | |
n: torch.Tensor, | |
points: torch.Tensor, | |
return_proj: bool = True, | |
): | |
""" | |
Computes whether point is "inside" the plane. | |
The definition of "inside" means that the point | |
has a positive component in the direction of the plane normal defined by n. | |
For example, | |
plane | |
| | |
| . (A) | |
|--> n | |
| | |
.(B) | | |
Point (A) is "inside" the plane, while point (B) is "outside" the plane. | |
Args: | |
plane: tensor of shape (4,3) of vertices of a box plane | |
n: tensor of shape (3,) of the unit "inside" direction on the plane | |
points: tensor of shape (P, 3) of coordinates of a point | |
return_proj: bool whether to return the projected point on the plane | |
Returns: | |
is_inside: bool of shape (P,) of whether point is inside | |
p_proj: tensor of shape (P, 3) of the projected point on plane | |
""" | |
device = plane.device | |
v0, v1, v2, v3 = plane.unbind(0) | |
plane_ctr = plane.mean(0) | |
e0 = F.normalize(v0 - plane_ctr, dim=0) | |
e1 = F.normalize(v1 - plane_ctr, dim=0) | |
if not torch.allclose(e0.dot(n), torch.zeros((1,), device=device), atol=1e-2): | |
raise ValueError("Input n is not perpendicular to the plane") | |
if not torch.allclose(e1.dot(n), torch.zeros((1,), device=device), atol=1e-2): | |
raise ValueError("Input n is not perpendicular to the plane") | |
add_dim = False | |
if points.ndim == 1: | |
points = points.unsqueeze(0) | |
add_dim = True | |
assert points.shape[1] == 3 | |
# Every point p can be written as p = ctr + a e0 + b e1 + c n | |
# If return_proj is True, we need to solve for (a, b) | |
p_proj = None | |
if return_proj: | |
# solving for (a, b) | |
A = torch.tensor( | |
[[1.0, e0.dot(e1)], [e0.dot(e1), 1.0]], dtype=torch.float32, device=device | |
) | |
B = torch.zeros((2, points.shape[0]), dtype=torch.float32, device=device) | |
B[0, :] = torch.sum((points - plane_ctr.view(1, 3)) * e0.view(1, 3), dim=-1) | |
B[1, :] = torch.sum((points - plane_ctr.view(1, 3)) * e1.view(1, 3), dim=-1) | |
ab = A.inverse() @ B # (2, P) | |
p_proj = plane_ctr.view(1, 3) + ab.transpose(0, 1) @ torch.stack( | |
(e0, e1), dim=0 | |
) | |
# solving for c | |
# c = (point - ctr - a * e0 - b * e1).dot(n) | |
direc = torch.sum((points - plane_ctr.view(1, 3)) * n.view(1, 3), dim=-1) | |
ins = direc >= 0.0 | |
if add_dim: | |
assert p_proj.shape[0] == 1 | |
p_proj = p_proj[0] | |
return ins, p_proj | |
def plane_edge_point_of_intersection(plane, n, p0, p1, eps: float = DOT_EPS): | |
""" | |
Finds the point of intersection between a box plane and | |
a line segment connecting (p0, p1). | |
The plane is assumed to be infinite long. | |
Args: | |
plane: tensor of shape (4, 3) of the coordinates of the vertices defining the plane | |
n: tensor of shape (3,) of the unit direction perpendicular on the plane | |
(Note that we could compute n but since it's computed in the main | |
body of the function, we save time by feeding it in. For the purpose | |
of this function, it's not important that n points "inside" the shape.) | |
p0, p1: tensors of shape (3,), (3,) | |
Returns: | |
p: tensor of shape (3,) of the coordinates of the point of intersection | |
a: scalar such that p = p0 + a*(p1-p0) | |
""" | |
# The point of intersection can be parametrized | |
# p = p0 + a (p1 - p0) where a in [0, 1] | |
# We want to find a such that p is on plane | |
# <p - ctr, n> = 0 | |
# if segment (p0, p1) is parallel to plane (it can only be on it) | |
direc = F.normalize(p1 - p0, dim=0) | |
if direc.dot(n).abs() < eps: | |
return (p1 + p0) / 2.0, 0.5 | |
else: | |
ctr = plane.mean(0) | |
a = -(p0 - ctr).dot(n) / ((p1 - p0).dot(n)) | |
p = p0 + a * (p1 - p0) | |
return p, a | |
""" | |
The three following functions support clipping a triangle face by a plane. | |
They contain the following cases: (a) the triangle has one point "outside" the plane and | |
(b) the triangle has two points "outside" the plane. | |
This logic follows the logic of clipping triangles when they intersect the image plane while | |
rendering. | |
""" | |
def clip_tri_by_plane_oneout( | |
plane: torch.Tensor, | |
n: torch.Tensor, | |
vout: torch.Tensor, | |
vin1: torch.Tensor, | |
vin2: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Case (a). | |
Clips triangle by plane when vout is outside plane, and vin1, vin2, is inside | |
In this case, only one vertex of the triangle is outside the plane. | |
Clip the triangle into a quadrilateral, and then split into two triangles | |
Args: | |
plane: tensor of shape (4, 3) of the coordinates of the vertices forming the plane | |
n: tensor of shape (3,) of the unit "inside" direction of the plane | |
vout, vin1, vin2: tensors of shape (3,) of the points forming the triangle, where | |
vout is "outside" the plane and vin1, vin2 are "inside" | |
Returns: | |
verts: tensor of shape (4, 3) containing the new vertices formed after clipping the | |
original intersecting triangle (vout, vin1, vin2) | |
faces: tensor of shape (2, 3) defining the vertex indices forming the two new triangles | |
which are "inside" the plane formed after clipping | |
""" | |
device = plane.device | |
# point of intersection between plane and (vin1, vout) | |
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout) | |
assert a1 >= -0.0001 and a1 <= 1.0001, a1 | |
# point of intersection between plane and (vin2, vout) | |
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout) | |
assert a2 >= -0.0001 and a2 <= 1.0001, a2 | |
verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3 | |
faces = torch.tensor( | |
[[0, 1, 2], [0, 2, 3]], dtype=torch.int64, device=device | |
) # 2x3 | |
return verts, faces | |
def clip_tri_by_plane_twoout( | |
plane: torch.Tensor, | |
n: torch.Tensor, | |
vout1: torch.Tensor, | |
vout2: torch.Tensor, | |
vin: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Case (b). | |
Clips face by plane when vout1, vout2 are outside plane, and vin1 is inside | |
In this case, only one vertex of the triangle is inside the plane. | |
Args: | |
plane: tensor of shape (4, 3) of the coordinates of the vertices forming the plane | |
n: tensor of shape (3,) of the unit "inside" direction of the plane | |
vout1, vout2, vin: tensors of shape (3,) of the points forming the triangle, where | |
vin is "inside" the plane and vout1, vout2 are "outside" | |
Returns: | |
verts: tensor of shape (3, 3) containing the new vertices formed after clipping the | |
original intersectiong triangle (vout, vin1, vin2) | |
faces: tensor of shape (1, 3) defining the vertex indices forming | |
the single new triangle which is "inside" the plane formed after clipping | |
""" | |
device = plane.device | |
# point of intersection between plane and (vin, vout1) | |
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1) | |
assert a1 >= -0.0001 and a1 <= 1.0001, a1 | |
# point of intersection between plane and (vin, vout2) | |
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2) | |
assert a2 >= -0.0001 and a2 <= 1.0001, a2 | |
verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3 | |
faces = torch.tensor( | |
[ | |
[0, 1, 2], | |
], | |
dtype=torch.int64, | |
device=device, | |
) # 1x3 | |
return verts, faces | |
def clip_tri_by_plane(plane, n, tri_verts) -> Union[List, torch.Tensor]: | |
""" | |
Clip a trianglular face defined by tri_verts with a plane of inside "direction" n. | |
This function computes whether the triangle has one or two | |
or none points "outside" the plane. | |
Args: | |
plane: tensor of shape (4, 3) of the vertex coordinates of the plane | |
n: tensor of shape (3,) of the unit "inside" direction of the plane | |
tri_verts: tensor of shape (3, 3) of the vertex coordiantes of the the triangle faces | |
Returns: | |
tri_verts: tensor of shape (K, 3, 3) of the vertex coordinates of the triangles formed | |
after clipping. All K triangles are now "inside" the plane. | |
""" | |
if coplanar_tri_plane(tri_verts, plane, n): | |
return tri_verts.view(1, 3, 3) | |
v0, v1, v2 = tri_verts.unbind(0) | |
isin0, _ = is_inside(plane, n, v0) | |
isin1, _ = is_inside(plane, n, v1) | |
isin2, _ = is_inside(plane, n, v2) | |
if isin0 and isin1 and isin2: | |
# all in, no clipping, keep the old triangle face | |
return tri_verts.view(1, 3, 3) | |
elif (not isin0) and (not isin1) and (not isin2): | |
# all out, delete triangle | |
return [] | |
else: | |
if isin0: | |
if isin1: # (isin0, isin1, not isin2) | |
verts, faces = clip_tri_by_plane_oneout(plane, n, v2, v0, v1) | |
return verts[faces] | |
elif isin2: # (isin0, not isin1, isin2) | |
verts, faces = clip_tri_by_plane_oneout(plane, n, v1, v0, v2) | |
return verts[faces] | |
else: # (isin0, not isin1, not isin2) | |
verts, faces = clip_tri_by_plane_twoout(plane, n, v1, v2, v0) | |
return verts[faces] | |
else: | |
if isin1 and isin2: # (not isin0, isin1, isin2) | |
verts, faces = clip_tri_by_plane_oneout(plane, n, v0, v1, v2) | |
return verts[faces] | |
elif isin1: # (not isin0, isin1, not isin2) | |
verts, faces = clip_tri_by_plane_twoout(plane, n, v0, v2, v1) | |
return verts[faces] | |
elif isin2: # (not isin0, not isin1, isin2) | |
verts, faces = clip_tri_by_plane_twoout(plane, n, v0, v1, v2) | |
return verts[faces] | |
# Should not be reached | |
return [] | |
# -------------------------------------------------- # | |
# MAIN: BOX3D_OVERLAP # | |
# -------------------------------------------------- # | |
def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor): | |
""" | |
Computes the intersection of 3D boxes1 and boxes2. | |
Inputs boxes1, boxes2 are tensors of shape (8, 3) containing | |
the 8 corners of the boxes, as follows | |
(4) +---------+. (5) | |
| ` . | ` . | |
| (0) +---+-----+ (1) | |
| | | | | |
(7) +-----+---+. (6)| | |
` . | ` . | | |
(3) ` +---------+ (2) | |
Args: | |
box1: tensor of shape (8, 3) of the coordinates of the 1st box | |
box2: tensor of shape (8, 3) of the coordinates of the 2nd box | |
Returns: | |
vol: the volume of the intersecting convex shape | |
iou: the intersection over union which is simply | |
`iou = vol / (vol1 + vol2 - vol)` | |
""" | |
device = box1.device | |
# For boxes1 we compute the unit directions n1 corresponding to quad_faces | |
n1 = box_planar_dir(box1) # (6, 3) | |
# For boxes2 we compute the unit directions n2 corresponding to quad_faces | |
n2 = box_planar_dir(box2) | |
# We define triangle faces | |
vol1 = box_volume(box1) | |
vol2 = box_volume(box2) | |
tri_verts1 = get_tri_verts(box1) # (12, 3, 3) | |
plane_verts1 = get_plane_verts(box1) # (6, 4, 3) | |
tri_verts2 = get_tri_verts(box2) # (12, 3, 3) | |
plane_verts2 = get_plane_verts(box2) # (6, 4, 3) | |
num_planes = plane_verts1.shape[0] # (=6) based on our definition of planes | |
# Every triangle in box1 will be compared to each plane in box2. | |
# If the triangle is fully outside or fully inside, then it will remain as is | |
# If the triangle intersects with the (infinite) plane, it will be broken into | |
# subtriangles such that each subtriangle is either fully inside or outside the plane. | |
# Tris in Box1 -> Planes in Box2 | |
for pidx in range(num_planes): | |
plane = plane_verts2[pidx] | |
nplane = n2[pidx] | |
tri_verts_updated = torch.zeros((0, 3, 3), dtype=torch.float32, device=device) | |
for i in range(tri_verts1.shape[0]): | |
tri = clip_tri_by_plane(plane, nplane, tri_verts1[i]) | |
if len(tri) > 0: | |
tri_verts_updated = torch.cat((tri_verts_updated, tri), dim=0) | |
tri_verts1 = tri_verts_updated | |
# Tris in Box2 -> Planes in Box1 | |
for pidx in range(num_planes): | |
plane = plane_verts1[pidx] | |
nplane = n1[pidx] | |
tri_verts_updated = torch.zeros((0, 3, 3), dtype=torch.float32, device=device) | |
for i in range(tri_verts2.shape[0]): | |
tri = clip_tri_by_plane(plane, nplane, tri_verts2[i]) | |
if len(tri) > 0: | |
tri_verts_updated = torch.cat((tri_verts_updated, tri), dim=0) | |
tri_verts2 = tri_verts_updated | |
# remove triangles that are coplanar from the intersection as | |
# otherwise they would be doublecounting towards the volume | |
# this happens only if the original 3D boxes have common planes | |
# Since the resulting shape is convex and specifically composed of planar segments, | |
# each planar segment can belong either on box1 or box2 but not both. | |
# Without loss of generality, we assign shared planar segments to box1 | |
keep2 = torch.ones((tri_verts2.shape[0],), device=device, dtype=torch.bool) | |
for i1 in range(tri_verts1.shape[0]): | |
for i2 in range(tri_verts2.shape[0]): | |
if ( | |
coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2]) | |
and tri_verts_area(tri_verts1[i1]) > AREA_EPS | |
): | |
keep2[i2] = 0 | |
keep2 = keep2.nonzero()[:, 0] | |
tri_verts2 = tri_verts2[keep2] | |
# intersecting shape | |
num_faces = tri_verts1.shape[0] + tri_verts2.shape[0] | |
num_verts = num_faces * 3 # V=F*3 | |
overlap_faces = torch.arange(num_verts).view(num_faces, 3) # Fx3 | |
overlap_tri_verts = torch.cat((tri_verts1, tri_verts2), dim=0) # Fx3x3 | |
overlap_verts = overlap_tri_verts.view(num_verts, 3) # Vx3 | |
# the volume of the convex hull defined by (overlap_verts, overlap_faces) | |
# can be defined as the sum of all the tetrahedrons formed where for each tetrahedron | |
# the base is the triangle and the apex is the center point of the convex hull | |
# See the math here: https://en.wikipedia.org/wiki/Tetrahedron#Volume | |
# we compute the center by computing the center point of each face | |
# and then averaging the face centers | |
ctr = overlap_tri_verts.mean(1).mean(0) | |
tetras = overlap_tri_verts - ctr.view(1, 1, 3) | |
vol = torch.sum( | |
tetras[:, 0] * torch.cross(tetras[:, 1], tetras[:, 2], dim=-1), dim=-1 | |
) | |
vol = (vol.abs() / 6.0).sum() | |
iou = vol / (vol1 + vol2 - vol) | |
if DEBUG: | |
# save shapes | |
tri_faces = torch.tensor(_box_triangles, device=device, dtype=torch.int64) | |
save_obj("/tmp/output/shape1.obj", box1, tri_faces) | |
save_obj("/tmp/output/shape2.obj", box2, tri_faces) | |
if len(overlap_verts) > 0: | |
save_obj("/tmp/output/inters_shape.obj", overlap_verts, overlap_faces) | |
return vol, iou | |
# -------------------------------------------------- # | |
# HELPER FUNCTIONS FOR SAMPLING SOLUTION # | |
# -------------------------------------------------- # | |
def is_point_inside_box(box: torch.Tensor, points: torch.Tensor): | |
""" | |
Determines whether points are inside the boxes | |
Args: | |
box: tensor of shape (8, 3) of the corners of the boxes | |
points: tensor of shape (P, 3) of the points | |
Returns: | |
inside: bool tensor of shape (P,) | |
""" | |
device = box.device | |
P = points.shape[0] | |
n = box_planar_dir(box) # (6, 3) | |
box_planes = get_plane_verts(box) # (6, 4) | |
num_planes = box_planes.shape[0] # = 6 | |
# a point p is inside the box if it "inside" all planes of the box | |
# so we run the checks | |
ins = torch.zeros((P, num_planes), device=device, dtype=torch.bool) | |
for i in range(num_planes): | |
is_in, _ = is_inside(box_planes[i], n[i], points, return_proj=False) | |
ins[:, i] = is_in | |
ins = ins.all(dim=1) | |
return ins | |
def sample_points_within_box(box: torch.Tensor, num_samples: int = 10): | |
""" | |
Sample points within a box defined by its 8 coordinates | |
Args: | |
box: tensor of shape (8, 3) of the box coordinates | |
num_samples: int defining the number of samples | |
Returns: | |
points: (num_samples, 3) of points inside the box | |
""" | |
assert box.shape[0] == 8 and box.shape[1] == 3 | |
xyzmin = box.min(0).values.view(1, 3) | |
xyzmax = box.max(0).values.view(1, 3) | |
uvw = torch.rand((num_samples, 3), device=box.device) | |
points = uvw * (xyzmax - xyzmin) + xyzmin | |
# because the box is not axis aligned we need to check wether | |
# the points are within the box | |
num_points = 0 | |
samples = [] | |
while num_points < num_samples: | |
inside = is_point_inside_box(box, points) | |
samples.append(points[inside].view(-1, 3)) | |
num_points += inside.sum() | |
samples = torch.cat(samples, dim=0) | |
return samples[1:num_samples] | |
# -------------------------------------------------- # | |
# MAIN: BOX3D_OVERLAP_SAMPLING # | |
# -------------------------------------------------- # | |
def box3d_overlap_sampling( | |
box1: torch.Tensor, box2: torch.Tensor, num_samples: int = 10000 | |
): | |
""" | |
Computes the intersection of two boxes by sampling points | |
""" | |
vol1 = box_volume(box1) | |
vol2 = box_volume(box2) | |
points1 = sample_points_within_box(box1, num_samples=num_samples) | |
points2 = sample_points_within_box(box2, num_samples=num_samples) | |
isin21 = is_point_inside_box(box1, points2) | |
num21 = isin21.sum() | |
isin12 = is_point_inside_box(box2, points1) | |
num12 = isin12.sum() | |
assert num12 <= num_samples | |
assert num21 <= num_samples | |
inters = (vol1 * num12 + vol2 * num21) / 2.0 | |
union = vol1 * num_samples + vol2 * num_samples - inters | |
return inters / union | |