S23DR-P2R / pointnet2.py
colin1842's picture
add model
8d5039c
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from model import pointnet_util
from model.model_utils import *
# idx = pointnet_util.ball_query(self.radius, self.nsample, xyz, new_xyz)
# # _, idx = pointnet_util.knn_query(self.nsample, xyz, new_xyz)
# xyz_trans = xyz.permute(0, 2, 1)
# grouped_xyz = pointnet_util.grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
# grouped_xyz -= new_xyz.permute(0, 2, 1).unsqueeze(-1)
def get_model():
return PointNet2()
class PointNet2(nn.Module):
def __init__(self, in_channel=3):
super().__init__()
self.sa1 = PointNetSAModule(256, 0.1, 16, in_channel, [32, 32, 64])
self.sa2 = PointNetSAModule(128, 0.1, 16, 64, [64, 64, 128])
self.sa3 = PointNetSAModule(64, 0.2, 16, 128, [128, 128, 256])
self.sa4 = PointNetSAModule(16, 0.4, 16, 256, [256, 256, 512])
self.fp4 = PointNetFPModule(768, [256, 256])
self.fp3 = PointNetFPModule(384, [256, 256])
self.fp2 = PointNetFPModule(320, [256, 128])
self.fp1 = PointNetFPModule(128, [128, 128, 128])
self.shared_fc = Conv1dBN(128, 128)
self.drop = nn.Dropout(0.5)
self.offset_fc = nn.Conv1d(128, 3, 1)
self.cls_fc = nn.Conv1d(128, 1, 1)
def forward(self, batch_dict):
xyz = batch_dict['points']
fea = xyz
l0_fea = fea.permute(0, 2, 1)
l0_xyz = xyz
l1_xyz, l1_fea = self.sa1(l0_xyz, l0_fea)
l2_xyz, l2_fea = self.sa2(l1_xyz, l1_fea)
l3_xyz, l3_fea = self.sa3(l2_xyz, l2_fea)
l4_xyz, l4_fea = self.sa4(l3_xyz, l3_fea)
l3_fea = self.fp4(l3_xyz, l4_xyz, l3_fea, l4_fea)
l2_fea = self.fp3(l2_xyz, l3_xyz, l2_fea, l3_fea)
l1_fea = self.fp2(l1_xyz, l2_xyz, l1_fea, l2_fea)
l0_fea = self.fp1(l0_xyz, l1_xyz, None, l1_fea)
x = self.drop(self.shared_fc(l0_fea))
pred_offset = self.offset_fc(x).permute(0, 2, 1)
pred_cls = self.cls_fc(x).permute(0, 2, 1)
batch_dict['point_features'] = l0_fea.permute(0, 2, 1)
batch_dict['point_pred_score'] = torch.sigmoid(pred_cls).squeeze(-1)
batch_dict['point_pred_offset'] = pred_offset
return batch_dict
class PointNetSAModuleMSG(nn.Module):
def __init__(self, npoint, radii, nsamples, in_channel, mlps, use_xyz=True):
"""
PointNet Set Abstraction Module
:param npoint: int
:param radii: list of float, radius in ball_query
:param nsamples: list of int, number of samples in ball_query
:param in_channel: int
:param mlps: list of list of int
:param use_xyz: bool
"""
super().__init__()
assert len(radii) == len(nsamples) == len(mlps)
mlps = [[in_channel] + mlp for mlp in mlps]
self.npoint = npoint
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
r = radii[i]
nsample = nsamples[i]
mlp = mlps[i]
if use_xyz:
mlp[0] += 3
self.groupers.append(QueryAndGroup(r, nsample, use_xyz) if npoint is not None else GroupAll(use_xyz))
self.mlps.append(Conv2ds(mlp))
def forward(self, xyz, features, new_xyz=None):
"""
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
:param features: (B, C, N) tensor of the descriptors of the the features
:param new_xyz:
:return:
new_xyz: (B, npoint, 3) tensor of the new features' xyz
new_features: (B, C1, npoint) tensor of the new_features descriptors
"""
new_features_list = []
xyz = xyz.contiguous()
xyz_flipped = xyz.permute(0, 2, 1)
if new_xyz is None:
new_xyz = pointnet_util.gather_operation(xyz_flipped, pointnet_util.furthest_point_sample(
xyz, self.npoint, 1.0, 0.0)).permute(0, 2, 1) if self.npoint is not None else None
for i in range(len(self.groupers)):
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
new_features = F.max_pool2d(new_features, kernel_size=[1, new_features.size(3)]).squeeze(-1)
new_features_list.append(new_features)
return new_xyz, torch.cat(new_features_list, dim=1)
class PointNetSAModule(PointNetSAModuleMSG):
def __init__(self, npoint, radius, nsample, in_channel, mlp, use_xyz=True):
super().__init__(npoint, [radius], [nsample], in_channel, [mlp], use_xyz)
class PointNetFPModule(nn.Module):
def __init__(self, in_channel, mlp):
super().__init__()
self.mlp = Conv2ds([in_channel] + mlp)
def forward(self, pts1, pts2, fea1, fea2):
"""
:param pts1: (B, n, 3)
:param pts2: (B, m, 3) n > m
:param fea1: (B, C1, n)
:param fea2: (B, C2, m)
:return:
new_features: (B, mlp[-1], n)
"""
if pts2 is not None:
dist, idx = pointnet_util.three_nn(pts1, pts2)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feats = pointnet_util.three_interpolate(fea2, idx, weight)
else:
interpolated_feats = fea2.expand(*fea2.size()[0:2], pts1.size(1))
if fea1 is not None:
new_features = torch.cat([interpolated_feats, fea1], dim=1) # (B, C2 + C1, n)
else:
new_features = interpolated_feats
new_features = new_features.unsqueeze(-1)
new_features = self.mlp(new_features)
return new_features.squeeze(-1)
class QueryAndGroup(nn.Module):
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
"""
:param radius: float, radius of ball
:param nsample: int, maximum number of features to gather in the ball
:param use_xyz:
"""
super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centroids
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, 3 + C, npoint, nsample)
"""
idx = pointnet_util.ball_query(self.radius, self.nsample, xyz, new_xyz)
# _, idx = pointnet_util.knn_query(self.nsample, xyz, new_xyz)
xyz_trans = xyz.permute(0, 2, 1)
grouped_xyz = pointnet_util.grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
grouped_xyz -= new_xyz.permute(0, 2, 1).unsqueeze(-1)
if features is not None:
grouped_features = pointnet_util.grouping_operation(features, idx)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
new_features = grouped_xyz
return new_features
class GroupAll(nn.Module):
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: ignored
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, C + 3, 1, N)
"""
grouped_xyz = xyz.permute(0, 2, 1).unsqueeze(2)
if features is not None:
grouped_features = features.unsqueeze(2)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
else:
new_features = grouped_features
else:
new_features = grouped_xyz
return new_features