File size: 8,298 Bytes
8d5039c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from model import pointnet_util
from model.model_utils import *
# idx = pointnet_util.ball_query(self.radius, self.nsample, xyz, new_xyz)
# # _, idx = pointnet_util.knn_query(self.nsample, xyz, new_xyz)
# xyz_trans = xyz.permute(0, 2, 1)
# grouped_xyz = pointnet_util.grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
# grouped_xyz -= new_xyz.permute(0, 2, 1).unsqueeze(-1)
def get_model():
return PointNet2()
class PointNet2(nn.Module):
def __init__(self, in_channel=3):
super().__init__()
self.sa1 = PointNetSAModule(256, 0.1, 16, in_channel, [32, 32, 64])
self.sa2 = PointNetSAModule(128, 0.1, 16, 64, [64, 64, 128])
self.sa3 = PointNetSAModule(64, 0.2, 16, 128, [128, 128, 256])
self.sa4 = PointNetSAModule(16, 0.4, 16, 256, [256, 256, 512])
self.fp4 = PointNetFPModule(768, [256, 256])
self.fp3 = PointNetFPModule(384, [256, 256])
self.fp2 = PointNetFPModule(320, [256, 128])
self.fp1 = PointNetFPModule(128, [128, 128, 128])
self.shared_fc = Conv1dBN(128, 128)
self.drop = nn.Dropout(0.5)
self.offset_fc = nn.Conv1d(128, 3, 1)
self.cls_fc = nn.Conv1d(128, 1, 1)
def forward(self, batch_dict):
xyz = batch_dict['points']
fea = xyz
l0_fea = fea.permute(0, 2, 1)
l0_xyz = xyz
l1_xyz, l1_fea = self.sa1(l0_xyz, l0_fea)
l2_xyz, l2_fea = self.sa2(l1_xyz, l1_fea)
l3_xyz, l3_fea = self.sa3(l2_xyz, l2_fea)
l4_xyz, l4_fea = self.sa4(l3_xyz, l3_fea)
l3_fea = self.fp4(l3_xyz, l4_xyz, l3_fea, l4_fea)
l2_fea = self.fp3(l2_xyz, l3_xyz, l2_fea, l3_fea)
l1_fea = self.fp2(l1_xyz, l2_xyz, l1_fea, l2_fea)
l0_fea = self.fp1(l0_xyz, l1_xyz, None, l1_fea)
x = self.drop(self.shared_fc(l0_fea))
pred_offset = self.offset_fc(x).permute(0, 2, 1)
pred_cls = self.cls_fc(x).permute(0, 2, 1)
batch_dict['point_features'] = l0_fea.permute(0, 2, 1)
batch_dict['point_pred_score'] = torch.sigmoid(pred_cls).squeeze(-1)
batch_dict['point_pred_offset'] = pred_offset
return batch_dict
class PointNetSAModuleMSG(nn.Module):
def __init__(self, npoint, radii, nsamples, in_channel, mlps, use_xyz=True):
"""
PointNet Set Abstraction Module
:param npoint: int
:param radii: list of float, radius in ball_query
:param nsamples: list of int, number of samples in ball_query
:param in_channel: int
:param mlps: list of list of int
:param use_xyz: bool
"""
super().__init__()
assert len(radii) == len(nsamples) == len(mlps)
mlps = [[in_channel] + mlp for mlp in mlps]
self.npoint = npoint
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
r = radii[i]
nsample = nsamples[i]
mlp = mlps[i]
if use_xyz:
mlp[0] += 3
self.groupers.append(QueryAndGroup(r, nsample, use_xyz) if npoint is not None else GroupAll(use_xyz))
self.mlps.append(Conv2ds(mlp))
def forward(self, xyz, features, new_xyz=None):
"""
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
:param features: (B, C, N) tensor of the descriptors of the the features
:param new_xyz:
:return:
new_xyz: (B, npoint, 3) tensor of the new features' xyz
new_features: (B, C1, npoint) tensor of the new_features descriptors
"""
new_features_list = []
xyz = xyz.contiguous()
xyz_flipped = xyz.permute(0, 2, 1)
if new_xyz is None:
new_xyz = pointnet_util.gather_operation(xyz_flipped, pointnet_util.furthest_point_sample(
xyz, self.npoint, 1.0, 0.0)).permute(0, 2, 1) if self.npoint is not None else None
for i in range(len(self.groupers)):
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
new_features = F.max_pool2d(new_features, kernel_size=[1, new_features.size(3)]).squeeze(-1)
new_features_list.append(new_features)
return new_xyz, torch.cat(new_features_list, dim=1)
class PointNetSAModule(PointNetSAModuleMSG):
def __init__(self, npoint, radius, nsample, in_channel, mlp, use_xyz=True):
super().__init__(npoint, [radius], [nsample], in_channel, [mlp], use_xyz)
class PointNetFPModule(nn.Module):
def __init__(self, in_channel, mlp):
super().__init__()
self.mlp = Conv2ds([in_channel] + mlp)
def forward(self, pts1, pts2, fea1, fea2):
"""
:param pts1: (B, n, 3)
:param pts2: (B, m, 3) n > m
:param fea1: (B, C1, n)
:param fea2: (B, C2, m)
:return:
new_features: (B, mlp[-1], n)
"""
if pts2 is not None:
dist, idx = pointnet_util.three_nn(pts1, pts2)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feats = pointnet_util.three_interpolate(fea2, idx, weight)
else:
interpolated_feats = fea2.expand(*fea2.size()[0:2], pts1.size(1))
if fea1 is not None:
new_features = torch.cat([interpolated_feats, fea1], dim=1) # (B, C2 + C1, n)
else:
new_features = interpolated_feats
new_features = new_features.unsqueeze(-1)
new_features = self.mlp(new_features)
return new_features.squeeze(-1)
class QueryAndGroup(nn.Module):
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
"""
:param radius: float, radius of ball
:param nsample: int, maximum number of features to gather in the ball
:param use_xyz:
"""
super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centroids
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, 3 + C, npoint, nsample)
"""
idx = pointnet_util.ball_query(self.radius, self.nsample, xyz, new_xyz)
# _, idx = pointnet_util.knn_query(self.nsample, xyz, new_xyz)
xyz_trans = xyz.permute(0, 2, 1)
grouped_xyz = pointnet_util.grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
grouped_xyz -= new_xyz.permute(0, 2, 1).unsqueeze(-1)
if features is not None:
grouped_features = pointnet_util.grouping_operation(features, idx)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
new_features = grouped_xyz
return new_features
class GroupAll(nn.Module):
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: ignored
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, C + 3, 1, N)
"""
grouped_xyz = xyz.permute(0, 2, 1).unsqueeze(2)
if features is not None:
grouped_features = features.unsqueeze(2)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
else:
new_features = grouped_features
else:
new_features = grouped_xyz
return new_features
|