Spaces:
Sleeping
Sleeping
File size: 7,338 Bytes
2fd6166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import functional as F
__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
class FrustumPointNetLoss(nn.Module):
def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
super().__init__()
self.box_loss_weight = box_loss_weight
self.corners_loss_weight = corners_loss_weight
self.heading_residual_loss_weight = heading_residual_loss_weight
self.size_residual_loss_weight = size_residual_loss_weight
self.num_heading_angle_bins = num_heading_angle_bins
self.num_size_templates = num_size_templates
self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
self.register_buffer(
'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
)
def forward(self, inputs, targets):
mask_logits = inputs['mask_logits'] # (B, 2, N)
center_reg = inputs['center_reg'] # (B, 3)
center = inputs['center'] # (B, 3)
heading_scores = inputs['heading_scores'] # (B, NH)
heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
heading_residuals = inputs['heading_residuals'] # (B, NH)
size_scores = inputs['size_scores'] # (B, NS)
size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
size_residuals = inputs['size_residuals'] # (B, NS, 3)
mask_logits_target = targets['mask_logits'] # (B, N)
center_target = targets['center'] # (B, 3)
heading_bin_id_target = targets['heading_bin_id'] # (B, )
heading_residual_target = targets['heading_residual'] # (B, )
size_template_id_target = targets['size_template_id'] # (B, )
size_residual_target = targets['size_residual'] # (B, 3)
batch_size = center.size(0)
batch_id = torch.arange(batch_size, device=center.device)
# Basic Classification and Regression losses
mask_loss = F.cross_entropy(mask_logits, mask_logits_target)
heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target)
size_loss = F.cross_entropy(size_scores, size_template_id_target)
center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0)
center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0)
# Refinement losses for size/heading
heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, )
heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins)
heading_residual_normalized_loss = PF.huber_loss(
heading_residuals_normalized - heading_residual_normalized_target, delta=1.0
)
size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3)
size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target]
size_residual_normalized_loss = PF.huber_loss(
torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0
)
# Bounding box losses
heading = (heading_residuals[batch_id, heading_bin_id_target]
+ self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
size = (size_residuals[batch_id, size_template_id_target]
+ self.size_templates[size_template_id_target]) # (B, 3)
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target,
sizes=size_target, with_flip=True) # (B, 3, 8)
corners_loss = PF.huber_loss(torch.min(
torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
), delta=1.0)
# Summing up
loss = mask_loss + self.box_loss_weight * (
center_loss + center_reg_loss + heading_loss + size_loss
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
+ self.size_residual_loss_weight * size_residual_normalized_loss
+ self.corners_loss_weight * corners_loss
)
return loss
def get_box_corners_3d(centers, headings, sizes, with_flip=False):
"""
:param centers: coords of box centers, FloatTensor[N, 3]
:param headings: heading angles, FloatTensor[N, ]
:param sizes: box sizes, FloatTensor[N, 3]
:param with_flip: bool, whether to return flipped box (headings + np.pi)
:return:
coords of box corners, FloatTensor[N, 3, 8]
NOTE: corner points are in counter clockwise order, e.g.,
2--1
3--0 5
7--4
"""
l = sizes[:, 0] # (N,)
w = sizes[:, 1] # (N,)
h = sizes[:, 2] # (N,)
x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8)
y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8)
z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8)
c = torch.cos(headings) # (N,)
s = torch.sin(headings) # (N,)
o = torch.ones_like(headings) # (N,)
z = torch.zeros_like(headings) # (N,)
centers = centers.unsqueeze(-1) # (B, 3, 1)
corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3)
if with_flip:
R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3)
return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers
else:
return torch.matmul(R, corners) + centers
# centers = centers.unsqueeze(1) # (B, 1, 3)
# corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3)
# RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
# if with_flip:
# RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3)
# return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3)
# else:
# return torch.matmul(corners, RT) + centers # (N, 8, 3)
# corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
# R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
# corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8)
# corners = corners.transpose(1, 2) # (N, 8, 3)
|