Spaces:
Sleeping
Sleeping
winfred2027
commited on
Commit
•
6de2454
1
Parent(s):
9fe654e
Upload 3 files
Browse files- openshape/__init__.py +47 -0
- openshape/pointnet_util.py +323 -0
- openshape/ppat_rgb.py +118 -0
openshape/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
from .ppat_rgb import Projected, PointPatchTransformer
|
5 |
+
|
6 |
+
|
7 |
+
def module(state_dict: dict, name):
|
8 |
+
return {'.'.join(k.split('.')[1:]): v for k, v in state_dict.items() if k.startswith(name + '.')}
|
9 |
+
|
10 |
+
|
11 |
+
def G14(s):
|
12 |
+
model = Projected(
|
13 |
+
PointPatchTransformer(512, 12, 8, 512*3, 256, 384, 0.2, 64, 6),
|
14 |
+
nn.Linear(512, 1280)
|
15 |
+
)
|
16 |
+
model.load_state_dict(module(s['state_dict'], 'module'))
|
17 |
+
return model
|
18 |
+
|
19 |
+
|
20 |
+
def L14(s):
|
21 |
+
model = Projected(
|
22 |
+
PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6),
|
23 |
+
nn.Linear(512, 768)
|
24 |
+
)
|
25 |
+
model.load_state_dict(module(s, 'pc_encoder'))
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
def B32(s):
|
30 |
+
model = PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6)
|
31 |
+
model.load_state_dict(module(s, 'pc_encoder'))
|
32 |
+
return model
|
33 |
+
|
34 |
+
|
35 |
+
model_list = {
|
36 |
+
"openshape-pointbert-vitb32-rgb": B32,
|
37 |
+
"openshape-pointbert-vitl14-rgb": L14,
|
38 |
+
"openshape-pointbert-vitg14-rgb": G14,
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def load_pc_encoder(name):
|
43 |
+
s = torch.load(hf_hub_download("OpenShape/" + name, "model.pt"), map_location='cpu')
|
44 |
+
model = model_list[name](s).eval()
|
45 |
+
if torch.cuda.is_available():
|
46 |
+
model.cuda()
|
47 |
+
return model
|
openshape/pointnet_util.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from time import time
|
5 |
+
import numpy as np
|
6 |
+
import dgl.geometry
|
7 |
+
|
8 |
+
def timeit(tag, t):
|
9 |
+
print("{}: {}s".format(tag, time() - t))
|
10 |
+
return time()
|
11 |
+
|
12 |
+
def pc_normalize(pc):
|
13 |
+
l = pc.shape[0]
|
14 |
+
centroid = np.mean(pc, axis=0)
|
15 |
+
pc = pc - centroid
|
16 |
+
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
|
17 |
+
pc = pc / m
|
18 |
+
return pc
|
19 |
+
|
20 |
+
def square_distance(src, dst):
|
21 |
+
"""
|
22 |
+
Calculate Euclid distance between each two points.
|
23 |
+
|
24 |
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
25 |
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
26 |
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
27 |
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
28 |
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
29 |
+
|
30 |
+
Input:
|
31 |
+
src: source points, [B, N, C]
|
32 |
+
dst: target points, [B, M, C]
|
33 |
+
Output:
|
34 |
+
dist: per-point square distance, [B, N, M]
|
35 |
+
"""
|
36 |
+
B, N, _ = src.shape
|
37 |
+
_, M, _ = dst.shape
|
38 |
+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
39 |
+
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
40 |
+
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
41 |
+
return dist
|
42 |
+
|
43 |
+
|
44 |
+
def index_points(points, idx):
|
45 |
+
"""
|
46 |
+
|
47 |
+
Input:
|
48 |
+
points: input points data, [B, N, C]
|
49 |
+
idx: sample index data, [B, S]
|
50 |
+
Return:
|
51 |
+
new_points:, indexed points data, [B, S, C]
|
52 |
+
"""
|
53 |
+
device = points.device
|
54 |
+
B = points.shape[0]
|
55 |
+
view_shape = list(idx.shape)
|
56 |
+
view_shape[1:] = [1] * (len(view_shape) - 1)
|
57 |
+
repeat_shape = list(idx.shape)
|
58 |
+
repeat_shape[0] = 1
|
59 |
+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
60 |
+
new_points = points[batch_indices, idx, :]
|
61 |
+
return new_points
|
62 |
+
|
63 |
+
|
64 |
+
def farthest_point_sample(xyz, npoint):
|
65 |
+
"""
|
66 |
+
Input:
|
67 |
+
xyz: pointcloud data, [B, N, 3]
|
68 |
+
npoint: number of samples
|
69 |
+
Return:
|
70 |
+
centroids: sampled pointcloud index, [B, npoint]
|
71 |
+
"""
|
72 |
+
return dgl.geometry.farthest_point_sampler(xyz, npoint)
|
73 |
+
device = xyz.device
|
74 |
+
B, N, C = xyz.shape
|
75 |
+
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
76 |
+
distance = torch.ones(B, N).to(device) * 1e10
|
77 |
+
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
78 |
+
batch_indices = torch.arange(B, dtype=torch.long).to(device)
|
79 |
+
for i in range(npoint):
|
80 |
+
centroids[:, i] = farthest
|
81 |
+
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
|
82 |
+
dist = torch.sum((xyz - centroid) ** 2, -1)
|
83 |
+
mask = dist < distance
|
84 |
+
distance[mask] = dist[mask]
|
85 |
+
farthest = torch.max(distance, -1)[1]
|
86 |
+
return centroids
|
87 |
+
|
88 |
+
|
89 |
+
def query_ball_point(radius, nsample, xyz, new_xyz):
|
90 |
+
"""
|
91 |
+
Input:
|
92 |
+
radius: local region radius
|
93 |
+
nsample: max sample number in local region
|
94 |
+
xyz: all points, [B, N, 3]
|
95 |
+
new_xyz: query points, [B, S, 3]
|
96 |
+
Return:
|
97 |
+
group_idx: grouped points index, [B, S, nsample]
|
98 |
+
"""
|
99 |
+
device = xyz.device
|
100 |
+
B, N, C = xyz.shape
|
101 |
+
_, S, _ = new_xyz.shape
|
102 |
+
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
103 |
+
sqrdists = square_distance(new_xyz, xyz)
|
104 |
+
group_idx[sqrdists > radius ** 2] = N
|
105 |
+
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
106 |
+
group_first = group_idx[..., :1].repeat([1, 1, nsample])
|
107 |
+
mask = group_idx == N
|
108 |
+
group_idx[mask] = group_first[mask]
|
109 |
+
return group_idx
|
110 |
+
|
111 |
+
|
112 |
+
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
|
113 |
+
"""
|
114 |
+
Input:
|
115 |
+
npoint:
|
116 |
+
radius:
|
117 |
+
nsample:
|
118 |
+
xyz: input points position data, [B, N, 3]
|
119 |
+
points: input points data, [B, N, D]
|
120 |
+
Return:
|
121 |
+
new_xyz: sampled points position data, [B, npoint, nsample, 3]
|
122 |
+
new_points: sampled points data, [B, npoint, nsample, 3+D]
|
123 |
+
"""
|
124 |
+
B, N, C = xyz.shape
|
125 |
+
S = npoint
|
126 |
+
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
|
127 |
+
# torch.cuda.empty_cache()
|
128 |
+
new_xyz = index_points(xyz, fps_idx)
|
129 |
+
# torch.cuda.empty_cache()
|
130 |
+
idx = query_ball_point(radius, nsample, xyz, new_xyz)
|
131 |
+
# torch.cuda.empty_cache()
|
132 |
+
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
|
133 |
+
# torch.cuda.empty_cache()
|
134 |
+
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
|
135 |
+
# torch.cuda.empty_cache()
|
136 |
+
|
137 |
+
if points is not None:
|
138 |
+
grouped_points = index_points(points, idx)
|
139 |
+
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
|
140 |
+
else:
|
141 |
+
new_points = grouped_xyz_norm
|
142 |
+
if returnfps:
|
143 |
+
return new_xyz, new_points, grouped_xyz, fps_idx
|
144 |
+
else:
|
145 |
+
return new_xyz, new_points
|
146 |
+
|
147 |
+
|
148 |
+
def sample_and_group_all(xyz, points):
|
149 |
+
"""
|
150 |
+
Input:
|
151 |
+
xyz: input points position data, [B, N, 3]
|
152 |
+
points: input points data, [B, N, D]
|
153 |
+
Return:
|
154 |
+
new_xyz: sampled points position data, [B, 1, 3]
|
155 |
+
new_points: sampled points data, [B, 1, N, 3+D]
|
156 |
+
"""
|
157 |
+
device = xyz.device
|
158 |
+
B, N, C = xyz.shape
|
159 |
+
new_xyz = torch.zeros(B, 1, C).to(device)
|
160 |
+
grouped_xyz = xyz.view(B, 1, N, C)
|
161 |
+
if points is not None:
|
162 |
+
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
|
163 |
+
else:
|
164 |
+
new_points = grouped_xyz
|
165 |
+
return new_xyz, new_points
|
166 |
+
|
167 |
+
|
168 |
+
class PointNetSetAbstraction(nn.Module):
|
169 |
+
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
|
170 |
+
super(PointNetSetAbstraction, self).__init__()
|
171 |
+
self.npoint = npoint
|
172 |
+
self.radius = radius
|
173 |
+
self.nsample = nsample
|
174 |
+
self.mlp_convs = nn.ModuleList()
|
175 |
+
self.mlp_bns = nn.ModuleList()
|
176 |
+
last_channel = in_channel
|
177 |
+
for out_channel in mlp:
|
178 |
+
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
179 |
+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
180 |
+
last_channel = out_channel
|
181 |
+
self.group_all = group_all
|
182 |
+
|
183 |
+
def forward(self, xyz, points):
|
184 |
+
"""
|
185 |
+
Input:
|
186 |
+
xyz: input points position data, [B, C, N]
|
187 |
+
points: input points data, [B, D, N]
|
188 |
+
Return:
|
189 |
+
new_xyz: sampled points position data, [B, C, S]
|
190 |
+
new_points_concat: sample points feature data, [B, D', S]
|
191 |
+
"""
|
192 |
+
xyz = xyz.permute(0, 2, 1)
|
193 |
+
if points is not None:
|
194 |
+
points = points.permute(0, 2, 1)
|
195 |
+
|
196 |
+
if self.group_all:
|
197 |
+
new_xyz, new_points = sample_and_group_all(xyz, points)
|
198 |
+
else:
|
199 |
+
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
|
200 |
+
# new_xyz: sampled points position data, [B, npoint, C]
|
201 |
+
# new_points: sampled points data, [B, npoint, nsample, C+D]
|
202 |
+
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
|
203 |
+
for i, conv in enumerate(self.mlp_convs):
|
204 |
+
bn = self.mlp_bns[i]
|
205 |
+
new_points = F.relu(bn(conv(new_points)))
|
206 |
+
|
207 |
+
new_points = torch.max(new_points, 2)[0]
|
208 |
+
new_xyz = new_xyz.permute(0, 2, 1)
|
209 |
+
return new_xyz, new_points
|
210 |
+
|
211 |
+
|
212 |
+
class PointNetSetAbstractionMsg(nn.Module):
|
213 |
+
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
|
214 |
+
super(PointNetSetAbstractionMsg, self).__init__()
|
215 |
+
self.npoint = npoint
|
216 |
+
self.radius_list = radius_list
|
217 |
+
self.nsample_list = nsample_list
|
218 |
+
self.conv_blocks = nn.ModuleList()
|
219 |
+
self.bn_blocks = nn.ModuleList()
|
220 |
+
for i in range(len(mlp_list)):
|
221 |
+
convs = nn.ModuleList()
|
222 |
+
bns = nn.ModuleList()
|
223 |
+
last_channel = in_channel + 3
|
224 |
+
for out_channel in mlp_list[i]:
|
225 |
+
convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
226 |
+
bns.append(nn.BatchNorm2d(out_channel))
|
227 |
+
last_channel = out_channel
|
228 |
+
self.conv_blocks.append(convs)
|
229 |
+
self.bn_blocks.append(bns)
|
230 |
+
|
231 |
+
def forward(self, xyz, points):
|
232 |
+
"""
|
233 |
+
Input:
|
234 |
+
xyz: input points position data, [B, C, N]
|
235 |
+
points: input points data, [B, D, N]
|
236 |
+
Return:
|
237 |
+
new_xyz: sampled points position data, [B, C, S]
|
238 |
+
new_points_concat: sample points feature data, [B, D', S]
|
239 |
+
"""
|
240 |
+
xyz = xyz.permute(0, 2, 1)
|
241 |
+
if points is not None:
|
242 |
+
points = points.permute(0, 2, 1)
|
243 |
+
|
244 |
+
B, N, C = xyz.shape
|
245 |
+
S = self.npoint
|
246 |
+
new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
|
247 |
+
new_points_list = []
|
248 |
+
for i, radius in enumerate(self.radius_list):
|
249 |
+
K = self.nsample_list[i]
|
250 |
+
group_idx = query_ball_point(radius, K, xyz, new_xyz)
|
251 |
+
grouped_xyz = index_points(xyz, group_idx)
|
252 |
+
grouped_xyz -= new_xyz.view(B, S, 1, C)
|
253 |
+
if points is not None:
|
254 |
+
grouped_points = index_points(points, group_idx)
|
255 |
+
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
|
256 |
+
else:
|
257 |
+
grouped_points = grouped_xyz
|
258 |
+
|
259 |
+
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
|
260 |
+
for j in range(len(self.conv_blocks[i])):
|
261 |
+
conv = self.conv_blocks[i][j]
|
262 |
+
bn = self.bn_blocks[i][j]
|
263 |
+
grouped_points = F.relu(bn(conv(grouped_points)))
|
264 |
+
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
|
265 |
+
new_points_list.append(new_points)
|
266 |
+
|
267 |
+
new_xyz = new_xyz.permute(0, 2, 1)
|
268 |
+
new_points_concat = torch.cat(new_points_list, dim=1)
|
269 |
+
return new_xyz, new_points_concat
|
270 |
+
|
271 |
+
|
272 |
+
class PointNetFeaturePropagation(nn.Module):
|
273 |
+
def __init__(self, in_channel, mlp):
|
274 |
+
super(PointNetFeaturePropagation, self).__init__()
|
275 |
+
self.mlp_convs = nn.ModuleList()
|
276 |
+
self.mlp_bns = nn.ModuleList()
|
277 |
+
last_channel = in_channel
|
278 |
+
for out_channel in mlp:
|
279 |
+
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
|
280 |
+
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
|
281 |
+
last_channel = out_channel
|
282 |
+
|
283 |
+
def forward(self, xyz1, xyz2, points1, points2):
|
284 |
+
"""
|
285 |
+
Input:
|
286 |
+
xyz1: input points position data, [B, C, N]
|
287 |
+
xyz2: sampled input points position data, [B, C, S]
|
288 |
+
points1: input points data, [B, D, N]
|
289 |
+
points2: input points data, [B, D, S]
|
290 |
+
Return:
|
291 |
+
new_points: upsampled points data, [B, D', N]
|
292 |
+
"""
|
293 |
+
xyz1 = xyz1.permute(0, 2, 1)
|
294 |
+
xyz2 = xyz2.permute(0, 2, 1)
|
295 |
+
|
296 |
+
points2 = points2.permute(0, 2, 1)
|
297 |
+
B, N, C = xyz1.shape
|
298 |
+
_, S, _ = xyz2.shape
|
299 |
+
|
300 |
+
if S == 1:
|
301 |
+
interpolated_points = points2.repeat(1, N, 1)
|
302 |
+
else:
|
303 |
+
dists = square_distance(xyz1, xyz2)
|
304 |
+
dists, idx = dists.sort(dim=-1)
|
305 |
+
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
|
306 |
+
|
307 |
+
dist_recip = 1.0 / (dists + 1e-8)
|
308 |
+
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
309 |
+
weight = dist_recip / norm
|
310 |
+
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
|
311 |
+
|
312 |
+
if points1 is not None:
|
313 |
+
points1 = points1.permute(0, 2, 1)
|
314 |
+
new_points = torch.cat([points1, interpolated_points], dim=-1)
|
315 |
+
else:
|
316 |
+
new_points = interpolated_points
|
317 |
+
|
318 |
+
new_points = new_points.permute(0, 2, 1)
|
319 |
+
for i, conv in enumerate(self.mlp_convs):
|
320 |
+
bn = self.mlp_bns[i]
|
321 |
+
new_points = F.relu(bn(conv(new_points)))
|
322 |
+
return new_points
|
323 |
+
|
openshape/ppat_rgb.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch_redstone as rst
|
4 |
+
from einops import rearrange
|
5 |
+
from .pointnet_util import PointNetSetAbstraction
|
6 |
+
|
7 |
+
|
8 |
+
class PreNorm(nn.Module):
|
9 |
+
def __init__(self, dim, fn):
|
10 |
+
super().__init__()
|
11 |
+
self.norm = nn.LayerNorm(dim)
|
12 |
+
self.fn = fn
|
13 |
+
def forward(self, x, *extra_args, **kwargs):
|
14 |
+
return self.fn(self.norm(x), *extra_args, **kwargs)
|
15 |
+
|
16 |
+
class FeedForward(nn.Module):
|
17 |
+
def __init__(self, dim, hidden_dim, dropout = 0.):
|
18 |
+
super().__init__()
|
19 |
+
self.net = nn.Sequential(
|
20 |
+
nn.Linear(dim, hidden_dim),
|
21 |
+
nn.GELU(),
|
22 |
+
nn.Dropout(dropout),
|
23 |
+
nn.Linear(hidden_dim, dim),
|
24 |
+
nn.Dropout(dropout)
|
25 |
+
)
|
26 |
+
def forward(self, x):
|
27 |
+
return self.net(x)
|
28 |
+
|
29 |
+
class Attention(nn.Module):
|
30 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rel_pe = False):
|
31 |
+
super().__init__()
|
32 |
+
inner_dim = dim_head * heads
|
33 |
+
project_out = not (heads == 1 and dim_head == dim)
|
34 |
+
|
35 |
+
self.heads = heads
|
36 |
+
self.scale = dim_head ** -0.5
|
37 |
+
|
38 |
+
self.attend = nn.Softmax(dim = -1)
|
39 |
+
self.dropout = nn.Dropout(dropout)
|
40 |
+
|
41 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
42 |
+
|
43 |
+
self.to_out = nn.Sequential(
|
44 |
+
nn.Linear(inner_dim, dim),
|
45 |
+
nn.Dropout(dropout)
|
46 |
+
) if project_out else nn.Identity()
|
47 |
+
|
48 |
+
self.rel_pe = rel_pe
|
49 |
+
if rel_pe:
|
50 |
+
self.pe = nn.Sequential(nn.Conv2d(3, 64, 1), nn.ReLU(), nn.Conv2d(64, 1, 1))
|
51 |
+
|
52 |
+
def forward(self, x, centroid_delta):
|
53 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
54 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
55 |
+
|
56 |
+
pe = self.pe(centroid_delta) if self.rel_pe else 0
|
57 |
+
dots = (torch.matmul(q, k.transpose(-1, -2)) + pe) * self.scale
|
58 |
+
|
59 |
+
attn = self.attend(dots)
|
60 |
+
attn = self.dropout(attn)
|
61 |
+
|
62 |
+
out = torch.matmul(attn, v)
|
63 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
64 |
+
return self.to_out(out)
|
65 |
+
|
66 |
+
|
67 |
+
class Transformer(nn.Module):
|
68 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., rel_pe = False):
|
69 |
+
super().__init__()
|
70 |
+
self.layers = nn.ModuleList([])
|
71 |
+
for _ in range(depth):
|
72 |
+
self.layers.append(nn.ModuleList([
|
73 |
+
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rel_pe = rel_pe)),
|
74 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
75 |
+
]))
|
76 |
+
def forward(self, x, centroid_delta):
|
77 |
+
for attn, ff in self.layers:
|
78 |
+
x = attn(x, centroid_delta) + x
|
79 |
+
x = ff(x) + x
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class PointPatchTransformer(nn.Module):
|
84 |
+
def __init__(self, dim, depth, heads, mlp_dim, sa_dim, patches, prad, nsamp, in_dim=3, dim_head=64, rel_pe=False, patch_dropout=0) -> None:
|
85 |
+
super().__init__()
|
86 |
+
self.patches = patches
|
87 |
+
self.patch_dropout = patch_dropout
|
88 |
+
self.sa = PointNetSetAbstraction(npoint=patches, radius=prad, nsample=nsamp, in_channel=in_dim + 3, mlp=[64, 64, sa_dim], group_all=False)
|
89 |
+
self.lift = nn.Sequential(nn.Conv1d(sa_dim + 3, dim, 1), rst.Lambda(lambda x: torch.permute(x, [0, 2, 1])), nn.LayerNorm([dim]))
|
90 |
+
self.cls_token = nn.Parameter(torch.randn(dim))
|
91 |
+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, 0.0, rel_pe)
|
92 |
+
|
93 |
+
def forward(self, features):
|
94 |
+
self.sa.npoint = self.patches
|
95 |
+
if self.training:
|
96 |
+
self.sa.npoint -= self.patch_dropout
|
97 |
+
# print("input", features.shape)
|
98 |
+
centroids, feature = self.sa(features[:, :3], features)
|
99 |
+
# print("f", feature.shape, 'c', centroids.shape)
|
100 |
+
x = self.lift(torch.cat([centroids, feature], dim=1))
|
101 |
+
|
102 |
+
x = rst.supercat([self.cls_token, x], dim=-2)
|
103 |
+
centroids = rst.supercat([centroids.new_zeros(1), centroids], dim=-1)
|
104 |
+
|
105 |
+
centroid_delta = centroids.unsqueeze(-1) - centroids.unsqueeze(-2)
|
106 |
+
x = self.transformer(x, centroid_delta)
|
107 |
+
|
108 |
+
return x[:, 0]
|
109 |
+
|
110 |
+
|
111 |
+
class Projected(nn.Module):
|
112 |
+
def __init__(self, ppat, proj) -> None:
|
113 |
+
super().__init__()
|
114 |
+
self.ppat = ppat
|
115 |
+
self.proj = proj
|
116 |
+
|
117 |
+
def forward(self, features: torch.Tensor):
|
118 |
+
return self.proj(self.ppat(features))
|