Spaces:
Sleeping
Sleeping
File size: 7,175 Bytes
2fd6166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import numpy as np
import torch
import torch.nn as nn
from model.pvcnn.modules import Attention
from model.pvcnn.pvcnn_utils import create_mlp_components, create_pointnet2_sa_components, create_pointnet2_fp_modules
from model.pvcnn.pvcnn_utils import get_timestep_embedding
class PVCNN2Base(nn.Module):
def __init__(
self,
num_classes: int,
embed_dim: int,
use_att: bool = True,
dropout: float = 0.1,
extra_feature_channels: int = 3,
width_multiplier: int = 1,
voxel_resolution_multiplier: int = 1
):
super().__init__()
assert extra_feature_channels >= 0
self.embed_dim = embed_dim
self.dropout = dropout
self.width_multiplier = width_multiplier
self.in_channels = extra_feature_channels + 3
# Create PointNet-2 model
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
sa_blocks_config=self.sa_blocks,
extra_feature_channels=extra_feature_channels,
with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier
)
self.sa_layers = nn.ModuleList(sa_layers)
# Additional global attention module, default true
self.global_att = None if not use_att else Attention(channels_sa_features, 8, D=1)
# Only use extra features in the last fp module
sa_in_channels[0] = extra_feature_channels
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
fp_blocks=self.fp_blocks,
in_channels=channels_sa_features,
sa_in_channels=sa_in_channels,
with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier
)
self.fp_layers = nn.ModuleList(fp_layers)
# Create MLP layers
self.channels_fp_features = channels_fp_features
layers, _ = create_mlp_components(
in_channels=channels_fp_features,
out_channels=[128, dropout, num_classes], # was 0.5
classifier=True,
dim=2,
width_multiplier=width_multiplier
)
self.classifier = nn.Sequential(*layers) # applied to point features directly
# Time embedding function
self.embedf = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.LeakyReLU(0.1, inplace=True),
nn.Linear(embed_dim, embed_dim),
)
def forward(self, inputs: torch.Tensor, t: torch.Tensor, ret_feats=False):
"""
The inputs have size (B, 3 + S, N), where S is the number of additional
feature channels and N is the number of points. The timesteps t can be either
continuous or discrete. This model has a sort of U-Net-like structure I think,
which is why it first goes down and then up in terms of resolution (?)
torch.Size([16, 394, 16384])
Downscaling step 0 feature shape: torch.Size([16, 64, 1024])
Downscaling step 1 feature shape: torch.Size([16, 128, 256])
Downscaling step 2 feature shape: torch.Size([16, 256, 64])
Downscaling step 3 feature shape: torch.Size([16, 512, 16])
Upscaling step 0 feature shape: torch.Size([16, 256, 64])
Upscaling step 1 feature shape: torch.Size([16, 256, 256])
Upscaling step 2 feature shape: torch.Size([16, 128, 1024])
Upscaling step 3 feature shape: torch.Size([16, 64, 16384])
"""
# Embed timesteps, sinusoidal encoding
t_emb = get_timestep_embedding(self.embed_dim, t, inputs.device).float()
t_emb = self.embedf(t_emb)[:, :, None].expand(-1, -1, inputs.shape[-1])
# Separate input coordinates and features
coords = inputs[:, :3, :].contiguous() # (B, 3, N) range (-3.5, 3.5)
features = inputs # (B, 3 + S, N)
# Downscaling layers
coords_list = []
in_features_list = []
for i, sa_blocks in enumerate(self.sa_layers):
in_features_list.append(features)
coords_list.append(coords)
if i == 0:
features, coords, t_emb = sa_blocks((features, coords, t_emb))
else:
features, coords, t_emb = sa_blocks((torch.cat([features, t_emb], dim=1), coords, t_emb))
# Replace the input features
in_features_list[0] = inputs[:, 3:, :].contiguous()
# Apply global attention layer
if self.global_att is not None:
features = self.global_att(features)
# Upscaling layers
feats_list = [] # save intermediate features from the decoder layers
for fp_idx, fp_blocks in enumerate(self.fp_layers):
features, coords, t_emb = fp_blocks(
( # this is a tuple because of nn.Sequential
coords_list[-1 - fp_idx], # reverse coords list from above
coords, # original point coordinates
torch.cat([features, t_emb], dim=1), # keep concatenating upsampled features with timesteps
in_features_list[-1 - fp_idx], # reverse features list from above
t_emb # original timestep embedding
) # this is where point voxel convolution is carried out, the point feature network preserves the order.
)
feats_list.append((features, coords)) # t_emb is always the same
# exit(0)
# Output MLP layers
output = self.classifier(features)
if ret_feats:
return output, feats_list # return intermediate features
return output
class PVCNN2(PVCNN2Base):
# exact same configuration from PVD: https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/train_completion.py#L375
# conv_configs, sa_configs
# conv_configs: (out_ch, num_blocks, voxel_reso), sa_configs: (num_centers, radius, num_neighbors, out_channels)
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))), # the first is out_channels, num_blocks, voxel_resolution
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, embed_dim, use_att=True, dropout=0.1, extra_feature_channels=3,
width_multiplier=1, voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
|