import MinkowskiEngine as ME import torch.nn as nn from MinkowskiEngine.modules.resnet_block import BasicBlock class ResNetBase(nn.Module): BLOCK = None LAYERS = () INIT_DIM = 64 PLANES = (64, 128, 256, 512) def __init__(self, in_channels, out_channels, D=3): nn.Module.__init__(self) self.D = D assert self.BLOCK is not None self.network_initialization(in_channels, out_channels, D) self.weight_initialization() def network_initialization(self, in_channels, out_channels, D): self.inplanes = self.INIT_DIM self.conv1 = nn.Sequential( ME.MinkowskiConvolution( in_channels, self.inplanes, kernel_size=3, stride=2, dimension=D ), ME.MinkowskiInstanceNorm(self.inplanes), ME.MinkowskiReLU(inplace=True), ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=D), ) self.layer1 = self._make_layer( self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2 ) self.layer2 = self._make_layer( self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2 ) self.layer3 = self._make_layer( self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2 ) self.layer4 = self._make_layer( self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2 ) self.conv5 = nn.Sequential( ME.MinkowskiDropout(), ME.MinkowskiConvolution( self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D ), ME.MinkowskiInstanceNorm(self.inplanes), ME.MinkowskiGELU(), ) self.glob_pool = ME.MinkowskiGlobalMaxPooling() self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) def weight_initialization(self): for m in self.modules(): if isinstance(m, ME.MinkowskiConvolution): ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") if isinstance(m, ME.MinkowskiBatchNorm): nn.init.constant_(m.bn.weight, 1) nn.init.constant_(m.bn.bias, 0) def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( ME.MinkowskiConvolution( self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, dimension=self.D, ), ME.MinkowskiBatchNorm(planes * block.expansion), ) layers = [] layers.append( block( self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, dimension=self.D, ) ) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append( block( self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D ) ) return nn.Sequential(*layers) def forward(self, x: ME.SparseTensor): x = self.conv1(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.conv5(x) x = self.glob_pool(x) return self.final(x) class MinkResNet(ResNetBase): BLOCK = BasicBlock DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) PLANES = (32, 64, 128, 256, 256, 128, 96, 96) INIT_DIM = 32 OUT_TENSOR_STRIDE = 1 # To use the model, must call initialize_coords before forward pass. # Once data is processed, call clear to reset the model before calling # initialize_coords def __init__(self, D=3): self.in_channels = 6 self.out_channels = 1280 self.embedding_channel = 1024 ResNetBase.__init__(self, self.in_channels, self.out_channels, D) def get_conv_block(self, in_channel, out_channel, kernel_size, stride): return nn.Sequential( ME.MinkowskiConvolution( in_channel, out_channel, kernel_size=kernel_size, stride=stride, dimension=self.D, ), ME.MinkowskiBatchNorm(out_channel), ME.MinkowskiLeakyReLU(), ) def get_mlp_block(self, in_channel, out_channel): return nn.Sequential( ME.MinkowskiLinear(in_channel, out_channel, bias=False), ME.MinkowskiBatchNorm(out_channel), ME.MinkowskiLeakyReLU(), ) def network_initialization(self, in_channels, out_channels, D): # Output of the first conv concated to conv6 self.inplanes = self.INIT_DIM self.conv0p1s1 = ME.MinkowskiConvolution( in_channels, self.inplanes, kernel_size=5, dimension=D) self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) self.conv1p1s2 = ME.MinkowskiConvolution( self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0]) self.conv2p2s2 = ME.MinkowskiConvolution( self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1]) self.conv3p4s2 = ME.MinkowskiConvolution( self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2]) self.conv4p8s2 = ME.MinkowskiConvolution( self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3]) self.conv5 = nn.Sequential( self.get_conv_block( self.PLANES[0] + self.PLANES[1] + self.PLANES[2] + self.PLANES[3], self.embedding_channel // 2, kernel_size=3, stride=2, ), self.get_conv_block( self.embedding_channel // 2, self.embedding_channel, kernel_size=3, stride=2, ), ) self.relu = ME.MinkowskiReLU(inplace=True) self.global_max_pool = ME.MinkowskiGlobalMaxPooling() self.global_avg_pool = ME.MinkowskiGlobalAvgPooling() self.final = nn.Sequential( self.get_mlp_block(self.embedding_channel * 2, 1024), ME.MinkowskiDropout(), self.get_mlp_block(1024, 1024), ME.MinkowskiLinear(1024, out_channels, bias=True), ) def forward(self, xyz, features, device="cuda", quantization_size=0.05): xyz[:, 1:] = xyz[:, 1:] / quantization_size #print(xyz.dtype, xyz, quantization_size) x = ME.TensorField( coordinates=xyz, features=features, device=device, ) out = self.conv0p1s1(x.sparse()) out = self.bn0(out) out_p1 = self.relu(out) out = self.conv1p1s2(out_p1) out = self.bn1(out) out = self.relu(out) out_b1p2 = self.block1(out) out = self.conv2p2s2(out_b1p2) out = self.bn2(out) out = self.relu(out) out_b2p4 = self.block2(out) out = self.conv3p4s2(out_b2p4) out = self.bn3(out) out = self.relu(out) out_b3p8 = self.block3(out) # tensor_stride=16 out = self.conv4p8s2(out_b3p8) out = self.bn4(out) out = self.relu(out) out = self.block4(out) x1 = out_b1p2.slice(x) x2 = out_b2p4.slice(x) x3 = out_b3p8.slice(x) x4 = out.slice(x) x = ME.cat(x1, x2, x3, x4) y = self.conv5(x.sparse()) x1 = self.global_max_pool(y) x2 = self.global_avg_pool(y) return self.final(ME.cat(x1, x2)).F class MinkResNet34(MinkResNet): LAYERS = (3, 4, 6, 3)