add model
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +21 -0
- _script.py +145 -0
- dataset/__pycache__/data_utils.cpython-38.pyc +0 -0
- dataset/__pycache__/roofn3d_dataset.cpython-38.pyc +0 -0
- dataset/data_utils.py +48 -0
- dataset/roofn3d_dataset.py +235 -0
- format_dataset.py +53 -0
- hoho_train_checkpoint_epoch_90.pth +3 -0
- model/__pycache__/cluster_refine.cpython-38.pyc +0 -0
- model/__pycache__/edge_pred_net.cpython-38.pyc +0 -0
- model/__pycache__/model_utils.cpython-38.pyc +0 -0
- model/__pycache__/pointnet2.cpython-38.pyc +0 -0
- model/__pycache__/pointnet_stack_utils.cpython-38.pyc +0 -0
- model/__pycache__/pointnet_util.cpython-38.pyc +0 -0
- model/__pycache__/roofnet.cpython-38.pyc +0 -0
- model/cluster_refine.py +305 -0
- model/edge_pred_net.py +173 -0
- model/model_utils.py +156 -0
- model/pointnet2.py +305 -0
- model/pointnet_stack_utils.py +265 -0
- model/pointnet_util.py +518 -0
- model/roofnet.py +35 -0
- model_cfg.yaml +26 -0
- output/hoho_test/checkpoint_epoch_90_all.pth +3 -0
- output/hoho_test/test/log.txt +0 -0
- output/hoho_test/test/submission.parquet +3 -0
- output/hoho_train/ckpt/checkpoint_epoch_41.pth +3 -0
- output/hoho_train/ckpt/checkpoint_epoch_42.pth +3 -0
- output/hoho_train/ckpt/checkpoint_epoch_43.pth +3 -0
- output/hoho_train/ckpt/checkpoint_epoch_44.pth +3 -0
- output/hoho_train/ckpt/checkpoint_epoch_45.pth +3 -0
- output/hoho_train/log.txt +9 -0
- pc_util/setup.py +23 -0
- pc_util/src/ball_query.cpp +84 -0
- pc_util/src/ball_query_gpu.cu +270 -0
- pc_util/src/ball_query_gpu.h +38 -0
- pc_util/src/cluster.cpp +50 -0
- pc_util/src/cluster_gpu.cu +192 -0
- pc_util/src/cluster_gpu.h +34 -0
- pc_util/src/cuda_utils.h +15 -0
- pc_util/src/group_points.cpp +98 -0
- pc_util/src/group_points_gpu.cu +199 -0
- pc_util/src/group_points_gpu.h +36 -0
- pc_util/src/interpolate.cpp +148 -0
- pc_util/src/interpolate_gpu.cu +343 -0
- pc_util/src/interpolate_gpu.h +61 -0
- pc_util/src/pointnet2_api.cpp +41 -0
- pc_util/src/sampling.cpp +46 -0
- pc_util/src/sampling_gpu.cu +259 -0
- pc_util/src/sampling_gpu.h +29 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Weihang Li
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
_script.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### This is example of the script that will be run in the test environment.
|
2 |
+
### Some parts of the code are compulsory and you should NOT CHANGE THEM.
|
3 |
+
### They are between '''---compulsory---''' comments.
|
4 |
+
### You can change the rest of the code to define and test your solution.
|
5 |
+
### However, you should not change the signature of the provided function.
|
6 |
+
### The script would save "submission.parquet" file in the current directory.
|
7 |
+
### The actual logic of the solution is implemented in the `handcrafted_solution.py` file.
|
8 |
+
### The `handcrafted_solution.py` file is a placeholder for your solution.
|
9 |
+
### You should implement the logic of your solution in that file.
|
10 |
+
### You can use any additional files and subdirectories to organize your code.
|
11 |
+
|
12 |
+
'''---compulsory---'''
|
13 |
+
# import subprocess
|
14 |
+
# from pathlib import Path
|
15 |
+
# def install_package_from_local_file(package_name, folder='packages'):
|
16 |
+
# """
|
17 |
+
# Installs a package from a local .whl file or a directory containing .whl files using pip.
|
18 |
+
|
19 |
+
# Parameters:
|
20 |
+
# path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
|
21 |
+
# """
|
22 |
+
# try:
|
23 |
+
# pth = str(Path(folder) / package_name)
|
24 |
+
# subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
|
25 |
+
# "--no-index", # Do not use package index
|
26 |
+
# "--find-links", pth, # Look for packages in the specified directory or at the file
|
27 |
+
# package_name]) # Specify the package to install
|
28 |
+
# print(f"Package installed successfully from {pth}")
|
29 |
+
# except subprocess.CalledProcessError as e:
|
30 |
+
# print(f"Failed to install package from {pth}. Error: {e}")
|
31 |
+
|
32 |
+
# install_package_from_local_file('hoho')
|
33 |
+
|
34 |
+
import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
|
35 |
+
# import subprocess
|
36 |
+
# import importlib
|
37 |
+
# from pathlib import Path
|
38 |
+
# import subprocess
|
39 |
+
|
40 |
+
|
41 |
+
# ### The function below is useful for installing additional python wheels.
|
42 |
+
# def install_package_from_local_file(package_name, folder='packages'):
|
43 |
+
# """
|
44 |
+
# Installs a package from a local .whl file or a directory containing .whl files using pip.
|
45 |
+
|
46 |
+
# Parameters:
|
47 |
+
# path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
|
48 |
+
# """
|
49 |
+
# try:
|
50 |
+
# pth = str(Path(folder) / package_name)
|
51 |
+
# subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
|
52 |
+
# "--no-index", # Do not use package index
|
53 |
+
# "--find-links", pth, # Look for packages in the specified directory or at the file
|
54 |
+
# package_name]) # Specify the package to install
|
55 |
+
# print(f"Package installed successfully from {pth}")
|
56 |
+
# except subprocess.CalledProcessError as e:
|
57 |
+
# print(f"Failed to install package from {pth}. Error: {e}")
|
58 |
+
|
59 |
+
|
60 |
+
# pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
|
61 |
+
# install_package_from_local_file('webdataset')
|
62 |
+
# install_package_from_local_file('tqdm')
|
63 |
+
|
64 |
+
### Here you can import any library or module you want.
|
65 |
+
### The code below is used to read and parse the input dataset.
|
66 |
+
### Please, do not modify it.
|
67 |
+
|
68 |
+
import webdataset as wds
|
69 |
+
from tqdm import tqdm
|
70 |
+
from typing import Dict
|
71 |
+
import pandas as pd
|
72 |
+
from transformers import AutoTokenizer
|
73 |
+
import os
|
74 |
+
import time
|
75 |
+
import io
|
76 |
+
from PIL import Image as PImage
|
77 |
+
import numpy as np
|
78 |
+
|
79 |
+
from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
|
80 |
+
from hoho import proc, Sample
|
81 |
+
|
82 |
+
def convert_entry_to_human_readable(entry):
|
83 |
+
out = {}
|
84 |
+
already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
|
85 |
+
for k, v in entry.items():
|
86 |
+
if k in already_good:
|
87 |
+
out[k] = v
|
88 |
+
continue
|
89 |
+
if k == 'points3d':
|
90 |
+
out[k] = read_points3D_binary(fid=io.BytesIO(v))
|
91 |
+
if k == 'cameras':
|
92 |
+
out[k] = read_cameras_binary(fid=io.BytesIO(v))
|
93 |
+
if k == 'images':
|
94 |
+
out[k] = read_images_binary(fid=io.BytesIO(v))
|
95 |
+
if k in ['ade20k', 'gestalt']:
|
96 |
+
out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
|
97 |
+
if k == 'depthcm':
|
98 |
+
out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
|
99 |
+
return out
|
100 |
+
|
101 |
+
'''---end of compulsory---'''
|
102 |
+
|
103 |
+
### The part below is used to define and test your solution.
|
104 |
+
|
105 |
+
from pathlib import Path
|
106 |
+
def save_submission(submission, path):
|
107 |
+
"""
|
108 |
+
Saves the submission to a specified path.
|
109 |
+
|
110 |
+
Parameters:
|
111 |
+
submission (List[Dict[]]): The submission to save.
|
112 |
+
path (str): The path to save the submission to.
|
113 |
+
"""
|
114 |
+
sub = pd.DataFrame(submission, columns=["__key__", "wf_vertices", "wf_edges"])
|
115 |
+
sub.to_parquet(path)
|
116 |
+
print(f"Submission saved to {path}")
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
from handcrafted_solution import predict
|
120 |
+
print ("------------ Loading dataset------------ ")
|
121 |
+
params = hoho.get_params()
|
122 |
+
dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
|
123 |
+
|
124 |
+
print('------------ Now you can do your solution ---------------')
|
125 |
+
solution = []
|
126 |
+
from concurrent.futures import ProcessPoolExecutor
|
127 |
+
with ProcessPoolExecutor(max_workers=8) as pool:
|
128 |
+
results = []
|
129 |
+
for i, sample in enumerate(tqdm(dataset)):
|
130 |
+
results.append(pool.submit(predict, sample, visualize=False))
|
131 |
+
|
132 |
+
for i, result in enumerate(tqdm(results)):
|
133 |
+
key, pred_vertices, pred_edges = result.result()
|
134 |
+
solution.append({
|
135 |
+
'__key__': key,
|
136 |
+
'wf_vertices': pred_vertices.tolist(),
|
137 |
+
'wf_edges': pred_edges
|
138 |
+
})
|
139 |
+
if i % 100 == 0:
|
140 |
+
# incrementally save the results in case we run out of time
|
141 |
+
print(f"Processed {i} samples")
|
142 |
+
# save_submission(solution, Path(params['output_path']) / "submission.parquet")
|
143 |
+
print('------------ Saving results ---------------')
|
144 |
+
save_submission(solution, Path(params['output_path']) / "submission.parquet")
|
145 |
+
print("------------ Done ------------ ")
|
dataset/__pycache__/data_utils.cpython-38.pyc
ADDED
Binary file (1.5 kB). View file
|
|
dataset/__pycache__/roofn3d_dataset.cpython-38.pyc
ADDED
Binary file (6.38 kB). View file
|
|
dataset/data_utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
#from .roofn3d_dataset import RoofN3dDataset
|
3 |
+
from dataset.roofn3d_dataset import RoofN3dDataset, HohoDataset
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
|
7 |
+
__all__ = {
|
8 |
+
'RoofN3dDataset': RoofN3dDataset
|
9 |
+
}
|
10 |
+
|
11 |
+
class GaussianTransform:
|
12 |
+
def __init__(self, sigma = (0.005, 0.015), clip = 0.05, p = 0.8):
|
13 |
+
self.sigma = sigma
|
14 |
+
self.clip = clip
|
15 |
+
self.p = p
|
16 |
+
|
17 |
+
def __call__(self, points):
|
18 |
+
if np.random.rand(1) < self.p:
|
19 |
+
lastsigma = np.random.rand(1) * (self.sigma[1] - self.sigma[0]) + self.sigma[0]
|
20 |
+
row, Col = points.shape
|
21 |
+
jittered_point = np.clip(lastsigma * np.random.randn(row, Col), -1 * self.clip, self.clip)
|
22 |
+
jittered_point += points
|
23 |
+
return jittered_point
|
24 |
+
else:
|
25 |
+
return points
|
26 |
+
|
27 |
+
def build_dataloader(key, xyz, batch_size, data_cfg, workers=1, logger=None):
|
28 |
+
trasform = GaussianTransform(sigma= (0.005, 0.010), clip = 10, p = 0.0)
|
29 |
+
|
30 |
+
dataset = HohoDataset(key, xyz, trasform, data_cfg, logger)
|
31 |
+
dataloader = DataLoader(
|
32 |
+
dataset, batch_size=batch_size, pin_memory=True, num_workers=workers, collate_fn=dataset.collate_batch,
|
33 |
+
shuffle=False)
|
34 |
+
return dataloader
|
35 |
+
|
36 |
+
# def build_dataloader(path, batch_size, data_cfg, workers=16, logger=None, training=True):
|
37 |
+
# path += '/train.txt' if training else '/test.txt'
|
38 |
+
|
39 |
+
# if training:
|
40 |
+
# trasform = GaussianTransform(sigma=(0.005, 0.010), clip = 10, p = 0.8)
|
41 |
+
# else:
|
42 |
+
# trasform = GaussianTransform(sigma= (0.005, 0.010), clip = 10, p = 0.0)
|
43 |
+
|
44 |
+
# dataset = RoofN3dDataset(path, trasform, data_cfg, logger)
|
45 |
+
# dataloader = DataLoader(
|
46 |
+
# dataset, batch_size=batch_size, pin_memory=True, num_workers=workers, collate_fn=dataset.collate_batch,
|
47 |
+
# shuffle=training)
|
48 |
+
# return dataloader
|
dataset/roofn3d_dataset.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from collections import defaultdict
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
|
7 |
+
def read_pts(pts_file):
|
8 |
+
with open(pts_file, 'r') as f:
|
9 |
+
lines = f.readlines()
|
10 |
+
pts = np.array([f.strip().split(' ') for f in lines], dtype=np.float64)
|
11 |
+
return pts
|
12 |
+
|
13 |
+
|
14 |
+
# def load_obj(obj_file):
|
15 |
+
# vs, edges = [], set()
|
16 |
+
# with open(obj_file, 'r') as f:
|
17 |
+
# lines = f.readlines()
|
18 |
+
# for f in lines:
|
19 |
+
# vals = f.strip().split(' ')
|
20 |
+
# if vals[0] == 'v':
|
21 |
+
# vs.append(vals[1:])
|
22 |
+
# else:
|
23 |
+
# obj_data = np.array(vals[1:], dtype=int).reshape(-1, 1) - 1
|
24 |
+
# idx = np.arange(len(obj_data)) - 1
|
25 |
+
# cur_edge = np.concatenate([obj_data, obj_data[idx]], -1)
|
26 |
+
# [edges.add(tuple(sorted(e))) for e in cur_edge]
|
27 |
+
# vs = np.array(vs, dtype=np.float64)
|
28 |
+
# edges = np.array(list(edges))
|
29 |
+
# return vs, edges
|
30 |
+
|
31 |
+
def load_obj(obj_file):
|
32 |
+
vs, edges = [], set()
|
33 |
+
with open(obj_file, 'r') as f:
|
34 |
+
lines = f.readlines()
|
35 |
+
|
36 |
+
for line in lines:
|
37 |
+
vals = line.strip().split(' ')
|
38 |
+
if vals[0] == 'v':
|
39 |
+
vs.append([float(coord) for coord in vals[1:]])
|
40 |
+
elif vals[0] == 'l':
|
41 |
+
vertex_indices = [int(idx) - 1 for idx in vals[1:]] # Convert to zero-based index
|
42 |
+
for i in range(len(vertex_indices) - 1):
|
43 |
+
edge = tuple(sorted((vertex_indices[i], vertex_indices[i + 1])))
|
44 |
+
edges.add(edge)
|
45 |
+
|
46 |
+
vs = np.array(vs, dtype=np.float64)
|
47 |
+
edges = np.array(list(edges), dtype=np.int32)
|
48 |
+
|
49 |
+
return vs, edges
|
50 |
+
|
51 |
+
def writePoints(points, clsRoad):
|
52 |
+
with open(clsRoad, 'w+') as file1:
|
53 |
+
for i in range(len(points)):
|
54 |
+
point = points[i]
|
55 |
+
file1.write(str(point[0]))
|
56 |
+
file1.write(' ')
|
57 |
+
file1.write(str(point[1]))
|
58 |
+
file1.write(' ')
|
59 |
+
file1.write(str(point[2]))
|
60 |
+
file1.write(' ')
|
61 |
+
file1.write('\n')
|
62 |
+
|
63 |
+
|
64 |
+
class RoofN3dDataset(Dataset):
|
65 |
+
def __init__(self, data_path, transform, data_cfg, logger=None):
|
66 |
+
with open(data_path, 'r') as f:
|
67 |
+
self.file_list = f.readlines()
|
68 |
+
self.file_list = [f.strip() for f in self.file_list]
|
69 |
+
flist = []
|
70 |
+
for l in self.file_list:
|
71 |
+
flist.append(l)
|
72 |
+
self.file_list = flist
|
73 |
+
|
74 |
+
self.npoint = data_cfg.NPOINT
|
75 |
+
|
76 |
+
self.transform = transform
|
77 |
+
|
78 |
+
if logger is not None:
|
79 |
+
logger.info('Total samples: %d' % len(self))
|
80 |
+
|
81 |
+
def __len__(self):
|
82 |
+
return len(self.file_list)
|
83 |
+
|
84 |
+
def __getitem__(self, item):
|
85 |
+
file_path = self.file_list[item]
|
86 |
+
frame_id = file_path.split('/')[-1]
|
87 |
+
points = read_pts(file_path + '/points.xyz')
|
88 |
+
points = self.transform(points)
|
89 |
+
|
90 |
+
if len(points) > self.npoint:
|
91 |
+
idx = np.random.randint(0, len(points), self.npoint)
|
92 |
+
else:
|
93 |
+
idx = np.random.randint(0, len(points), self.npoint - len(points))
|
94 |
+
idx = np.append(np.arange(0, len(points)), idx)
|
95 |
+
np.random.shuffle(idx)
|
96 |
+
|
97 |
+
|
98 |
+
points = points[idx]
|
99 |
+
|
100 |
+
|
101 |
+
vectors, edges = load_obj(self.file_list[item] + '/polygon.obj')
|
102 |
+
min_pt, max_pt = np.min(points, axis=0), np.max(points, axis=0)
|
103 |
+
|
104 |
+
|
105 |
+
maxXYZ = np.max(max_pt)
|
106 |
+
minXYZ = np.min(min_pt)
|
107 |
+
min_pt[:] = minXYZ
|
108 |
+
max_pt[:] = maxXYZ
|
109 |
+
|
110 |
+
points = (points - min_pt) / (max_pt - min_pt)
|
111 |
+
vectors = (vectors - min_pt) / (max_pt - min_pt)
|
112 |
+
points = points.astype(np.float32)
|
113 |
+
vectors = vectors.astype(np.float32)
|
114 |
+
min_pt = min_pt.astype(np.float32)
|
115 |
+
max_pt = max_pt.astype(np.float32)
|
116 |
+
pt = np.concatenate(( np.expand_dims(min_pt, 0), np.expand_dims(max_pt, 0)), axis = 0)
|
117 |
+
data_dict = {'points': points, 'vectors': vectors, 'edges': edges, 'frame_id': frame_id, 'minMaxPt': pt}
|
118 |
+
return data_dict
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def collate_batch(batch_list, _unused=False):
|
122 |
+
data_dict = defaultdict(list)
|
123 |
+
for cur_sample in batch_list:
|
124 |
+
for key, val in cur_sample.items():
|
125 |
+
data_dict[key].append(val)
|
126 |
+
batch_size = len(batch_list)
|
127 |
+
ret = {}
|
128 |
+
for key, val in data_dict.items():
|
129 |
+
try:
|
130 |
+
if key == 'points':
|
131 |
+
ret[key] = np.concatenate(val, axis=0).reshape([batch_size, -1, val[0].shape[-1]])
|
132 |
+
elif key in ['vectors', 'edges']:
|
133 |
+
max_vec = max([len(x) for x in val])
|
134 |
+
batch_vecs = np.ones((batch_size, max_vec, val[0].shape[-1]), dtype=np.float32) * -1e1
|
135 |
+
for k in range(batch_size):
|
136 |
+
batch_vecs[k, :val[k].__len__(), :] = val[k]
|
137 |
+
ret[key] = batch_vecs
|
138 |
+
elif key in ['frame_id']:
|
139 |
+
ret[key] = val
|
140 |
+
elif key in ['minMaxPt']:
|
141 |
+
ret[key] = val
|
142 |
+
else:
|
143 |
+
ret[key] = np.stack(val, axis=0)
|
144 |
+
except:
|
145 |
+
print('Error in collate_batch: key=%s' % key)
|
146 |
+
raise TypeError
|
147 |
+
|
148 |
+
ret['batch_size'] = batch_size
|
149 |
+
return ret
|
150 |
+
|
151 |
+
|
152 |
+
class HohoDataset(Dataset):
|
153 |
+
def __init__(self, key, xyz, transform, data_cfg, logger=None):
|
154 |
+
|
155 |
+
self.npoint = data_cfg.NPOINT
|
156 |
+
self.frame_id = key
|
157 |
+
self.xyz = xyz
|
158 |
+
self.transform = transform
|
159 |
+
|
160 |
+
if logger is not None:
|
161 |
+
logger.info('Total samples: %d' % len(self))
|
162 |
+
|
163 |
+
def __len__(self):
|
164 |
+
return 1
|
165 |
+
|
166 |
+
def __getitem__(self, item):
|
167 |
+
frame_id = self.frame_id
|
168 |
+
# points = read_pts(file_path + '/points.xyz')
|
169 |
+
# points = self.transform(points)
|
170 |
+
points = self.xyz
|
171 |
+
|
172 |
+
if len(points) > self.npoint:
|
173 |
+
idx = np.random.randint(0, len(points), self.npoint)
|
174 |
+
else:
|
175 |
+
idx = np.random.randint(0, len(points), self.npoint - len(points))
|
176 |
+
idx = np.append(np.arange(0, len(points)), idx)
|
177 |
+
np.random.shuffle(idx)
|
178 |
+
|
179 |
+
|
180 |
+
points = points[idx]
|
181 |
+
|
182 |
+
|
183 |
+
# vectors, edges = load_obj(self.file_list[item] + '/polygon.obj')
|
184 |
+
min_pt, max_pt = np.min(points, axis=0), np.max(points, axis=0)
|
185 |
+
|
186 |
+
|
187 |
+
maxXYZ = np.max(max_pt)
|
188 |
+
minXYZ = np.min(min_pt)
|
189 |
+
min_pt[:] = minXYZ
|
190 |
+
max_pt[:] = maxXYZ
|
191 |
+
|
192 |
+
points = (points - min_pt) / (max_pt - min_pt)
|
193 |
+
# vectors = (vectors - min_pt) / (max_pt - min_pt)
|
194 |
+
points = points.astype(np.float32)
|
195 |
+
# vectors = vectors.astype(np.float32)
|
196 |
+
min_pt = min_pt.astype(np.float32)
|
197 |
+
max_pt = max_pt.astype(np.float32)
|
198 |
+
pt = np.concatenate(( np.expand_dims(min_pt, 0), np.expand_dims(max_pt, 0)), axis = 0)
|
199 |
+
data_dict = {'points': points, 'vectors': None, 'edges': None, 'frame_id': frame_id, 'minMaxPt': pt}
|
200 |
+
return data_dict
|
201 |
+
|
202 |
+
@staticmethod
|
203 |
+
def collate_batch(batch_list, _unused=False):
|
204 |
+
data_dict = defaultdict(list)
|
205 |
+
for cur_sample in batch_list:
|
206 |
+
for key, val in cur_sample.items():
|
207 |
+
data_dict[key].append(val)
|
208 |
+
batch_size = len(batch_list)
|
209 |
+
ret = {}
|
210 |
+
for key, val in data_dict.items():
|
211 |
+
try:
|
212 |
+
if key == 'points':
|
213 |
+
ret[key] = np.concatenate(val, axis=0).reshape([batch_size, -1, val[0].shape[-1]])
|
214 |
+
elif key in ['vectors', 'edges']:
|
215 |
+
continue
|
216 |
+
max_vec = max([len(x) for x in val])
|
217 |
+
batch_vecs = np.ones((batch_size, max_vec, val[0].shape[-1]), dtype=np.float32) * -1e1
|
218 |
+
for k in range(batch_size):
|
219 |
+
batch_vecs[k, :val[k].__len__(), :] = val[k]
|
220 |
+
# ret[key] = batch_vecs
|
221 |
+
ret[key] = None
|
222 |
+
elif key in ['frame_id']:
|
223 |
+
ret[key] = val
|
224 |
+
elif key in ['minMaxPt']:
|
225 |
+
ret[key] = val
|
226 |
+
else:
|
227 |
+
ret[key] = np.stack(val, axis=0)
|
228 |
+
except:
|
229 |
+
print('Error in collate_batch: key=%s' % key)
|
230 |
+
raise TypeError
|
231 |
+
|
232 |
+
ret['batch_size'] = batch_size
|
233 |
+
return ret
|
234 |
+
|
235 |
+
|
format_dataset.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import logging
|
4 |
+
|
5 |
+
# Configure logging
|
6 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
7 |
+
|
8 |
+
def create_directory(path):
|
9 |
+
if not os.path.exists(path):
|
10 |
+
os.makedirs(path)
|
11 |
+
|
12 |
+
def transfer_files(source_clean_xyz_dir, source_gt_dir, destination_dir, txt_file_path):
|
13 |
+
create_directory(destination_dir)
|
14 |
+
|
15 |
+
subdirectory_paths = []
|
16 |
+
|
17 |
+
for filename in os.listdir(source_clean_xyz_dir):
|
18 |
+
if filename.endswith('.xyz'):
|
19 |
+
base_name = os.path.splitext(filename)[0]
|
20 |
+
|
21 |
+
new_subdir = os.path.join(destination_dir, base_name)
|
22 |
+
create_directory(new_subdir)
|
23 |
+
|
24 |
+
subdirectory_paths.append(new_subdir)
|
25 |
+
|
26 |
+
source_xyz = os.path.join(source_clean_xyz_dir, filename)
|
27 |
+
destination_xyz = os.path.join(new_subdir, 'points.xyz')
|
28 |
+
|
29 |
+
shutil.copy(source_xyz, destination_xyz)
|
30 |
+
logging.info(f'Copied {source_xyz} to {destination_xyz}')
|
31 |
+
|
32 |
+
source_obj = os.path.join(source_gt_dir, f'{base_name}.obj')
|
33 |
+
destination_obj = os.path.join(new_subdir, 'polygon.obj')
|
34 |
+
|
35 |
+
if os.path.exists(source_obj):
|
36 |
+
shutil.copy(source_obj, destination_obj)
|
37 |
+
logging.info(f'Copied {source_obj} to {destination_obj}')
|
38 |
+
else:
|
39 |
+
logging.warning(f'File not found: {source_obj}')
|
40 |
+
|
41 |
+
with open(txt_file_path, 'w') as txt_file:
|
42 |
+
for path in subdirectory_paths:
|
43 |
+
txt_file.write(path + '\n')
|
44 |
+
logging.info(f'Written subdirectory paths to {txt_file_path}')
|
45 |
+
|
46 |
+
# Define paths
|
47 |
+
source_clean_xyz_dir = 'Data/hoho_data_train/clean_xyz'
|
48 |
+
source_gt_dir = 'Data/hoho_data_train/gt'
|
49 |
+
destination_dir = 'Data/hoho_data_train'
|
50 |
+
txt_file_path = 'Data/hoho_data_train/train.txt'
|
51 |
+
|
52 |
+
# Run the transfer process
|
53 |
+
transfer_files(source_clean_xyz_dir, source_gt_dir, destination_dir, txt_file_path)
|
hoho_train_checkpoint_epoch_90.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:41b6815747720660fdd68b6272fa3bae40b4f09095b6ddf14540f7f26ee284fb
|
3 |
+
size 17019805
|
model/__pycache__/cluster_refine.cpython-38.pyc
ADDED
Binary file (9.14 kB). View file
|
|
model/__pycache__/edge_pred_net.cpython-38.pyc
ADDED
Binary file (5.32 kB). View file
|
|
model/__pycache__/model_utils.cpython-38.pyc
ADDED
Binary file (5.88 kB). View file
|
|
model/__pycache__/pointnet2.cpython-38.pyc
ADDED
Binary file (10.6 kB). View file
|
|
model/__pycache__/pointnet_stack_utils.cpython-38.pyc
ADDED
Binary file (8.86 kB). View file
|
|
model/__pycache__/pointnet_util.cpython-38.pyc
ADDED
Binary file (10.7 kB). View file
|
|
model/__pycache__/roofnet.cpython-38.pyc
ADDED
Binary file (1.32 kB). View file
|
|
model/cluster_refine.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from .pointnet_stack_utils import *
|
6 |
+
from .model_utils import *
|
7 |
+
from scipy.optimize import linear_sum_assignment
|
8 |
+
from utils import loss_utils
|
9 |
+
import pc_util
|
10 |
+
|
11 |
+
|
12 |
+
class ClusterRefineNet(nn.Module):
|
13 |
+
def __init__(self, model_cfg, input_channel):
|
14 |
+
super().__init__()
|
15 |
+
self.model_cfg = model_cfg
|
16 |
+
self.matcher = HungarianMatcher(self.model_cfg.MatchRadius)
|
17 |
+
sa_cfg = model_cfg.RefineSA
|
18 |
+
mlps = sa_cfg.MLPs
|
19 |
+
mlps = [[input_channel] + mlp for mlp in mlps]
|
20 |
+
self.fea_refine_module = StackSAModuleMSG(
|
21 |
+
radii=sa_cfg.Radii,
|
22 |
+
nsamples=sa_cfg.Nsamples,
|
23 |
+
mlps=mlps,
|
24 |
+
use_xyz=True,
|
25 |
+
pool_method='max_pool'
|
26 |
+
)
|
27 |
+
self.num_output_feature = sum([mlp[-1]for mlp in mlps])
|
28 |
+
self.shared_fc = LinearBN(256, 128)
|
29 |
+
self.drop = nn.Dropout(0.5)
|
30 |
+
self.offset_fc = nn.Linear(128, 3)
|
31 |
+
# self.cls_fc = nn.Linear(128, 1)
|
32 |
+
if self.training:
|
33 |
+
self.train_dict = {}
|
34 |
+
# self.add_module(
|
35 |
+
# 'cls_loss_func',
|
36 |
+
# loss_utils.SigmoidBCELoss()
|
37 |
+
# )
|
38 |
+
self.add_module(
|
39 |
+
'reg_loss_func',
|
40 |
+
loss_utils.WeightedSmoothL1Loss()
|
41 |
+
)
|
42 |
+
self.loss_weight = self.model_cfg.LossWeight
|
43 |
+
|
44 |
+
|
45 |
+
self.init_weights()
|
46 |
+
|
47 |
+
def init_weights(self):
|
48 |
+
for m in self.modules():
|
49 |
+
if isinstance(m, nn.Conv2d):
|
50 |
+
nn.init.kaiming_normal_(m.weight)
|
51 |
+
if m.bias is not None:
|
52 |
+
nn.init.constant_(m.bias, 0)
|
53 |
+
if isinstance(m, nn.BatchNorm2d):
|
54 |
+
nn.init.constant_(m.weight, 1.0)
|
55 |
+
nn.init.constant_(m.bias, 0)
|
56 |
+
|
57 |
+
|
58 |
+
# tips: change from batch to stack
|
59 |
+
def forward(self, batch_dict):
|
60 |
+
offset_pts = batch_dict['points'].clone()
|
61 |
+
offset = batch_dict['point_pred_offset']
|
62 |
+
pts_score = batch_dict['point_pred_score']
|
63 |
+
score_thresh = self.model_cfg.ScoreThresh
|
64 |
+
offset_pts[pts_score > score_thresh] += offset[pts_score > score_thresh]
|
65 |
+
pts_cluster = offset_pts.new_ones(offset_pts.shape) * -10
|
66 |
+
pts_cluster[pts_score > score_thresh] = offset_pts[pts_score > score_thresh]
|
67 |
+
cluster_idx = dbscan_cluster(self.model_cfg.Cluster.eps, self.model_cfg.Cluster.min_pts, pts_cluster)
|
68 |
+
key_pts, num_cluster = get_cluster_pts(pts_cluster, cluster_idx)
|
69 |
+
if self.training:
|
70 |
+
new_pts, targets, labels, matches, new_xyz_batch_cnt = self.matcher(key_pts, batch_dict['vectors'])
|
71 |
+
offset_targets = (targets - new_pts) / self.model_cfg.MatchRadius if new_pts is not None else None
|
72 |
+
batch_dict['matches'] = matches
|
73 |
+
self.train_dict.update({
|
74 |
+
'keypoint_cls_label': labels,
|
75 |
+
'keypoint_offset_label': offset_targets
|
76 |
+
})
|
77 |
+
else:
|
78 |
+
pts_list, new_xyz_batch_cnt = [], []
|
79 |
+
for i, pts in enumerate(key_pts):
|
80 |
+
pts = pts[torch.sum(pts, -1) > -2e1]
|
81 |
+
if len(pts) == 0:
|
82 |
+
new_xyz_batch_cnt.append(0)
|
83 |
+
continue
|
84 |
+
new_xyz_batch_cnt.append(len(pts))
|
85 |
+
pts_list.append(pts)
|
86 |
+
if sum(new_xyz_batch_cnt) == 0:
|
87 |
+
new_pts, new_xyz_batch_cnt = None, None
|
88 |
+
else:
|
89 |
+
new_pts = torch.cat(pts_list, 0)
|
90 |
+
new_xyz_batch_cnt = new_pts.new_tensor(new_xyz_batch_cnt, dtype=torch.int32)
|
91 |
+
if new_pts is None:
|
92 |
+
exit()
|
93 |
+
batch_idx = torch.zeros(new_pts.shape[0], device=new_pts.device)
|
94 |
+
idx = 0
|
95 |
+
for i, cnt in enumerate(new_xyz_batch_cnt):
|
96 |
+
if cnt == 0:
|
97 |
+
continue
|
98 |
+
batch_idx[idx:idx + cnt] = i
|
99 |
+
idx += cnt
|
100 |
+
|
101 |
+
pos_mask = new_xyz_batch_cnt > 0
|
102 |
+
offset_pts = offset_pts[pos_mask]
|
103 |
+
xyz = offset_pts.view(-1, 3)
|
104 |
+
xyz_batch_cnt = offset_pts.new_ones(offset_pts.shape[0], dtype=torch.int32) * offset_pts.shape[1]
|
105 |
+
new_xyz_batch_cnt = new_xyz_batch_cnt[pos_mask]
|
106 |
+
point_fea = batch_dict['point_features']
|
107 |
+
point_fea = point_fea * pts_score.detach().unsqueeze(-1)
|
108 |
+
point_fea = point_fea[pos_mask]
|
109 |
+
point_fea = point_fea.contiguous().view(-1, point_fea.shape[-1])
|
110 |
+
_, refine_fea = self.fea_refine_module(xyz, xyz_batch_cnt, new_pts, new_xyz_batch_cnt, point_fea)
|
111 |
+
|
112 |
+
x = self.drop(self.shared_fc(refine_fea))
|
113 |
+
pred_offset = self.offset_fc(x)
|
114 |
+
# pred_cls = self.cls_fc(x)
|
115 |
+
if self.training:
|
116 |
+
self.train_dict.update({
|
117 |
+
# 'keypoint_cls_pred': pred_cls,
|
118 |
+
'keypoint_offset_pred': pred_offset
|
119 |
+
})
|
120 |
+
batch_dict['keypoint'] = torch.cat([batch_idx.view(-1, 1), new_pts], -1)
|
121 |
+
batch_dict['keypoint_features'] = refine_fea
|
122 |
+
# batch_dict['keypoint_pred_score'] = torch.sigmoid(pred_cls).squeeze(-1)
|
123 |
+
batch_dict['refined_keypoint'] = pred_offset * self.model_cfg.MatchRadius + new_pts
|
124 |
+
return batch_dict
|
125 |
+
|
126 |
+
def loss(self, loss_dict, disp_dict):
|
127 |
+
# pred_cls, pred_offset = self.train_dict['keypoint_cls_pred'], self.train_dict['keypoint_offset_pred']
|
128 |
+
pred_offset = self.train_dict['keypoint_offset_pred']
|
129 |
+
label_cls, label_offset = self.train_dict['keypoint_cls_label'], self.train_dict['keypoint_offset_label']
|
130 |
+
# cls_loss = self.get_cls_loss(pred_cls, label_cls, self.loss_weight['cls_weight'])
|
131 |
+
reg_loss = self.get_reg_loss(pred_offset, label_offset, label_cls, self.loss_weight['reg_weight'])
|
132 |
+
loss = reg_loss
|
133 |
+
# loss = cls_loss + reg_loss
|
134 |
+
loss_dict.update({
|
135 |
+
# 'refine_cls_loss': cls_loss.item(),
|
136 |
+
'refine_offset_loss': reg_loss.item(),
|
137 |
+
'refine_loss': loss.item()
|
138 |
+
})
|
139 |
+
|
140 |
+
# pred_cls = pred_cls.squeeze(-1)
|
141 |
+
# label_cls = label_cls.squeeze(-1)
|
142 |
+
# pred_logit = torch.sigmoid(pred_cls)
|
143 |
+
# pred = torch.where(pred_logit >= 0.5, pred_logit.new_ones(pred_logit.shape),
|
144 |
+
# pred_logit.new_zeros(pred_logit.shape))
|
145 |
+
# acc = torch.sum((pred == label_cls) & (label_cls == 1)).item() / torch.sum(label_cls == 1).item()
|
146 |
+
# disp_dict.update({'pos_acc': acc})
|
147 |
+
return loss, loss_dict, disp_dict
|
148 |
+
|
149 |
+
def get_cls_loss(self, pred, label, weight):
|
150 |
+
batch_size = int(pred.shape[0])
|
151 |
+
positives = label > 0
|
152 |
+
negatives = label == 0
|
153 |
+
cls_weights = (negatives * 1.0 + positives * 1.0).float()
|
154 |
+
pos_normalizer = positives.sum(1, keepdim=True).float()
|
155 |
+
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
|
156 |
+
cls_loss_src = self.cls_loss_func(pred.squeeze(-1), label, weights=cls_weights) # [N, M]
|
157 |
+
cls_loss = cls_loss_src.sum() / batch_size
|
158 |
+
|
159 |
+
cls_loss = cls_loss * weight
|
160 |
+
return cls_loss
|
161 |
+
|
162 |
+
def get_reg_loss(self, pred, label, cls_label, weight):
|
163 |
+
positives = cls_label > 0
|
164 |
+
reg_weights = positives.float()
|
165 |
+
pos_normalizer = positives.sum().float()
|
166 |
+
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
|
167 |
+
reg_loss_src = self.reg_loss_func(pred.unsqueeze(dim=0), label.unsqueeze(dim=0), weights=reg_weights.unsqueeze(dim=0))
|
168 |
+
reg_loss = reg_loss_src.sum()
|
169 |
+
reg_loss = reg_loss * weight
|
170 |
+
return reg_loss
|
171 |
+
|
172 |
+
|
173 |
+
class StackSAModuleMSG(nn.Module):
|
174 |
+
|
175 |
+
def __init__(self, radii, nsamples, mlps, use_xyz, pool_method='max_pool'):
|
176 |
+
"""
|
177 |
+
Args:
|
178 |
+
radii: list of float, list of radii to group with
|
179 |
+
nsamples: list of int, number of samples in each ball query
|
180 |
+
mlps: list of list of int, spec of the pointnet before the global pooling for each scale
|
181 |
+
use_xyz:
|
182 |
+
pool_method: max_pool / avg_pool
|
183 |
+
"""
|
184 |
+
super().__init__()
|
185 |
+
|
186 |
+
assert len(radii) == len(nsamples) == len(mlps)
|
187 |
+
|
188 |
+
self.groupers = nn.ModuleList()
|
189 |
+
self.mlps = nn.ModuleList()
|
190 |
+
for i in range(len(radii)):
|
191 |
+
radius = radii[i]
|
192 |
+
nsample = nsamples[i]
|
193 |
+
self.groupers.append(QueryAndGroup(radius, nsample, use_xyz=use_xyz))
|
194 |
+
mlp_spec = mlps[i]
|
195 |
+
if use_xyz:
|
196 |
+
mlp_spec[0] += 3
|
197 |
+
|
198 |
+
shared_mlps = []
|
199 |
+
for k in range(len(mlp_spec) - 1):
|
200 |
+
shared_mlps.extend([
|
201 |
+
nn.Conv2d(mlp_spec[k], mlp_spec[k + 1], kernel_size=1, bias=False),
|
202 |
+
nn.BatchNorm2d(mlp_spec[k + 1]),
|
203 |
+
nn.ReLU()
|
204 |
+
])
|
205 |
+
self.mlps.append(nn.Sequential(*shared_mlps))
|
206 |
+
self.pool_method = pool_method
|
207 |
+
|
208 |
+
self.init_weights()
|
209 |
+
|
210 |
+
def init_weights(self):
|
211 |
+
for m in self.modules():
|
212 |
+
if isinstance(m, nn.Conv2d):
|
213 |
+
nn.init.kaiming_normal_(m.weight)
|
214 |
+
if m.bias is not None:
|
215 |
+
nn.init.constant_(m.bias, 0)
|
216 |
+
if isinstance(m, nn.BatchNorm2d):
|
217 |
+
nn.init.constant_(m.weight, 1.0)
|
218 |
+
nn.init.constant_(m.bias, 0)
|
219 |
+
|
220 |
+
def forward(self, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features=None, empty_voxel_set_zeros=True):
|
221 |
+
"""
|
222 |
+
:param xyz: (N1 + N2 ..., 3) tensor of the xyz coordinates of the features
|
223 |
+
:param xyz_batch_cnt: (batch_size), [N1, N2, ...]
|
224 |
+
:param new_xyz: (M1 + M2 ..., 3)
|
225 |
+
:param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
|
226 |
+
:param features: (N1 + N2 ..., C) tensor of the descriptors of the the features
|
227 |
+
:return:
|
228 |
+
new_xyz: (M1 + M2 ..., 3) tensor of the new features' xyz
|
229 |
+
new_features: (M1 + M2 ..., \sum_k(mlps[k][-1])) tensor of the new_features descriptors
|
230 |
+
"""
|
231 |
+
new_features_list = []
|
232 |
+
for k in range(len(self.groupers)):
|
233 |
+
new_features, ball_idxs = self.groupers[k](
|
234 |
+
xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features
|
235 |
+
) # (M1 + M2, C, nsample)
|
236 |
+
new_features = new_features.permute(1, 0, 2).unsqueeze(dim=0) # (1, C, M1 + M2 ..., nsample)
|
237 |
+
new_features = self.mlps[k](new_features) # (1, C, M1 + M2 ..., nsample)
|
238 |
+
|
239 |
+
if self.pool_method == 'max_pool':
|
240 |
+
new_features = F.max_pool2d(
|
241 |
+
new_features, kernel_size=[1, new_features.size(3)]
|
242 |
+
).squeeze(dim=-1) # (1, C, M1 + M2 ...)
|
243 |
+
elif self.pool_method == 'avg_pool':
|
244 |
+
new_features = F.avg_pool2d(
|
245 |
+
new_features, kernel_size=[1, new_features.size(3)]
|
246 |
+
).squeeze(dim=-1) # (1, C, M1 + M2 ...)
|
247 |
+
else:
|
248 |
+
raise NotImplementedError
|
249 |
+
new_features = new_features.squeeze(dim=0).permute(1, 0) # (M1 + M2 ..., C)
|
250 |
+
new_features_list.append(new_features)
|
251 |
+
|
252 |
+
new_features = torch.cat(new_features_list, dim=1) # (M1 + M2 ..., C)
|
253 |
+
|
254 |
+
return new_xyz, new_features
|
255 |
+
|
256 |
+
|
257 |
+
class HungarianMatcher(nn.Module):
|
258 |
+
def __init__(self, match_r):
|
259 |
+
super().__init__()
|
260 |
+
self.dist_thresh = match_r
|
261 |
+
|
262 |
+
# tips: matcher with dist threshold
|
263 |
+
@torch.no_grad()
|
264 |
+
def forward(self, output, targets):
|
265 |
+
pts_list, target_list, label_list, match_list, new_xyz_batch_cnt = [], [], [], [], []
|
266 |
+
for i in range(output.shape[0]):
|
267 |
+
tmp_output, tmp_targets = output[i], targets[i]
|
268 |
+
tmp_output = tmp_output[torch.sum(tmp_output, -1) > -2e1]
|
269 |
+
if len(tmp_output) == 0:
|
270 |
+
new_xyz_batch_cnt.append(0)
|
271 |
+
continue
|
272 |
+
tmp_targets = tmp_targets[torch.sum(tmp_targets, -1) > -2e1]
|
273 |
+
vec_a = torch.sum(tmp_output.unsqueeze(1).repeat(1, tmp_targets.shape[0], 1) ** 2, -1)
|
274 |
+
vec_b = torch.sum(tmp_targets.unsqueeze(0).repeat(tmp_output.shape[0], 1, 1) ** 2, -1)
|
275 |
+
dist_matrix = vec_a + vec_b - 2 * torch.mm(tmp_output, tmp_targets.permute(1, 0))
|
276 |
+
dist_matrix = F.relu(dist_matrix)
|
277 |
+
dist_matrix = torch.sqrt(dist_matrix)
|
278 |
+
|
279 |
+
out_ind, tar_ind = linear_sum_assignment(dist_matrix.cpu().numpy())
|
280 |
+
out_ind, tar_ind = dist_matrix.new_tensor(out_ind, dtype=torch.int64), dist_matrix.new_tensor(tar_ind, dtype=torch.int64)
|
281 |
+
dist_val = dist_matrix[out_ind, tar_ind]
|
282 |
+
out_ind = out_ind[dist_val < self.dist_thresh]
|
283 |
+
tar_ind = tar_ind[dist_val < self.dist_thresh]
|
284 |
+
|
285 |
+
pts_list.append(tmp_output)
|
286 |
+
tmp_label = tmp_targets.new_zeros(tmp_output.shape[0])
|
287 |
+
tmp_label[out_ind] = 1.
|
288 |
+
tmp_pts_target = tmp_targets.new_zeros(tmp_output.shape)
|
289 |
+
tmp_pts_target[out_ind] = tmp_targets[tar_ind]
|
290 |
+
tmp_match = tmp_targets.new_ones(tmp_output.shape[0], dtype=torch.int64) * -1
|
291 |
+
tmp_match[out_ind] = tar_ind
|
292 |
+
label_list.append(tmp_label)
|
293 |
+
target_list.append(tmp_pts_target)
|
294 |
+
match_list.append(tmp_match)
|
295 |
+
new_xyz_batch_cnt.append(tmp_output.shape[0])
|
296 |
+
if sum(new_xyz_batch_cnt) == 0:
|
297 |
+
return None, None, None, None, None
|
298 |
+
return torch.cat(pts_list, 0), torch.cat(target_list, 0), torch.cat(label_list, 0), torch.cat(match_list, 0), tmp_output.new_tensor(new_xyz_batch_cnt, dtype=torch.int32)
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
|
model/edge_pred_net.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from .pointnet_stack_utils import *
|
6 |
+
from .model_utils import *
|
7 |
+
from scipy.optimize import linear_sum_assignment
|
8 |
+
from utils import loss_utils
|
9 |
+
import pc_util
|
10 |
+
import itertools
|
11 |
+
|
12 |
+
|
13 |
+
class EdgeAttentionNet(nn.Module):
|
14 |
+
def __init__(self, model_cfg, input_channel):
|
15 |
+
super().__init__()
|
16 |
+
self.model_cfg = model_cfg
|
17 |
+
self.freeze = False
|
18 |
+
|
19 |
+
self.att_layer = PairedPointAttention(input_channel)
|
20 |
+
num_feature = self.att_layer.num_output_feature
|
21 |
+
self.shared_fc = LinearBN(num_feature, num_feature)
|
22 |
+
self.drop = nn.Dropout(0.5)
|
23 |
+
self.cls_fc = nn.Linear(num_feature, 1)
|
24 |
+
if self.training:
|
25 |
+
self.train_dict = {}
|
26 |
+
self.add_module(
|
27 |
+
'cls_loss_func',
|
28 |
+
loss_utils.SigmoidBCELoss()
|
29 |
+
)
|
30 |
+
self.loss_weight = self.model_cfg.LossWeight
|
31 |
+
|
32 |
+
self.init_weights()
|
33 |
+
|
34 |
+
def init_weights(self):
|
35 |
+
for m in self.modules():
|
36 |
+
if isinstance(m, nn.Conv2d):
|
37 |
+
nn.init.kaiming_normal_(m.weight)
|
38 |
+
if m.bias is not None:
|
39 |
+
nn.init.constant_(m.bias, 0)
|
40 |
+
if isinstance(m, nn.BatchNorm2d):
|
41 |
+
nn.init.constant_(m.weight, 1.0)
|
42 |
+
nn.init.constant_(m.bias, 0)
|
43 |
+
|
44 |
+
def forward(self, batch_dict):
|
45 |
+
batch_idx = batch_dict['keypoint'][:, 0]
|
46 |
+
point_fea = batch_dict['keypoint_features']
|
47 |
+
|
48 |
+
if self.training:
|
49 |
+
matches = batch_dict['matches']
|
50 |
+
edge_label = batch_dict['edges']
|
51 |
+
bin_label_list = []
|
52 |
+
for i, edge in enumerate(edge_label):
|
53 |
+
mask = batch_idx == i
|
54 |
+
tmp_idx = batch_idx[mask]
|
55 |
+
if tmp_idx.shape[0] <= 1:
|
56 |
+
continue
|
57 |
+
match = matches[mask]
|
58 |
+
match_edge = list(itertools.combinations(match.cpu().numpy(), 2))
|
59 |
+
match_edge = [tuple(sorted(e)) for e in match_edge]
|
60 |
+
edge = [tuple(e) for e in edge.cpu().numpy()]
|
61 |
+
label = edge_label.new_tensor([e in edge for e in match_edge])
|
62 |
+
bin_label_list.append(label)
|
63 |
+
self.train_dict['label'] = torch.cat(bin_label_list)
|
64 |
+
|
65 |
+
idx = 0
|
66 |
+
pair_idx_list = []
|
67 |
+
pair_idx_list1, pair_idx_list2 = [], []
|
68 |
+
for i in range(batch_dict['batch_size']):
|
69 |
+
mask = batch_idx == i
|
70 |
+
tmp_idx = batch_idx[mask]
|
71 |
+
if tmp_idx.shape[0] <= 1:
|
72 |
+
continue
|
73 |
+
fea = point_fea[mask]
|
74 |
+
pair_idx = itertools.combinations(range(fea.shape[0]), 2)
|
75 |
+
pair_idx = point_fea.new_tensor(list(pair_idx))
|
76 |
+
pair_idx_list.append(pair_idx)
|
77 |
+
pair_idx_list1.append(pair_idx[:, 0] + idx)
|
78 |
+
pair_idx_list2.append(pair_idx[:, 1] + idx)
|
79 |
+
idx += tmp_idx.shape[0]
|
80 |
+
print('pair_idx_list:', pair_idx_list)
|
81 |
+
if pair_idx_list1 and pair_idx_list2:
|
82 |
+
pair_idx1 = torch.cat(pair_idx_list1).long()
|
83 |
+
pair_idx2 = torch.cat(pair_idx_list2).long()
|
84 |
+
pair_fea1 = point_fea[pair_idx1]
|
85 |
+
pair_fea2 = point_fea[pair_idx2]
|
86 |
+
edge_fea = self.att_layer(pair_fea1, pair_fea2)
|
87 |
+
edge_pred = self.cls_fc(self.drop(self.shared_fc(edge_fea)))
|
88 |
+
batch_dict['pair_points'] = torch.cat(pair_idx_list, 0)
|
89 |
+
batch_dict['edge_score'] = torch.sigmoid(edge_pred).view(-1)
|
90 |
+
if self.training:
|
91 |
+
self.train_dict['edge_pred'] = edge_pred
|
92 |
+
else:
|
93 |
+
print("Warning: pair_idx_list1 or pair_idx_list2 is empty!")
|
94 |
+
batch_dict['pair_points'] = torch.tensor([])
|
95 |
+
batch_dict['edge_score'] = torch.tensor([])
|
96 |
+
if self.training:
|
97 |
+
self.train_dict['edge_pred'] = edge_pred
|
98 |
+
return batch_dict
|
99 |
+
|
100 |
+
def loss(self, loss_dict, disp_dict):
|
101 |
+
pred_cls = self.train_dict['edge_pred']
|
102 |
+
label_cls = self.train_dict['label']
|
103 |
+
cls_loss = self.get_cls_loss(pred_cls, label_cls, self.loss_weight['cls_weight'])
|
104 |
+
loss = cls_loss
|
105 |
+
loss_dict.update({
|
106 |
+
'edge_cls_loss': cls_loss.item(),
|
107 |
+
'edge_loss': loss.item()
|
108 |
+
})
|
109 |
+
|
110 |
+
pred_cls = pred_cls.squeeze(-1)
|
111 |
+
label_cls = label_cls.squeeze(-1)
|
112 |
+
pred_logit = torch.sigmoid(pred_cls)
|
113 |
+
pred = torch.where(pred_logit >= 0.5, pred_logit.new_ones(pred_logit.shape),
|
114 |
+
pred_logit.new_zeros(pred_logit.shape))
|
115 |
+
acc = torch.sum((pred == label_cls) & (label_cls == 1)).item() / torch.sum(label_cls == 1).item()
|
116 |
+
#acc = torch.sum((pred == label_cls)).item() / len(label_cls.view(-1))
|
117 |
+
disp_dict.update({'edge_acc': acc})
|
118 |
+
return loss, loss_dict, disp_dict
|
119 |
+
|
120 |
+
def get_cls_loss(self, pred, label, weight):
|
121 |
+
positives = label > 0
|
122 |
+
negatives = label == 0
|
123 |
+
cls_weights = (negatives * 1.0 + positives * 1.0).float()
|
124 |
+
pos_normalizer = positives.sum().float()
|
125 |
+
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
|
126 |
+
cls_loss_src = self.cls_loss_func(pred.squeeze(-1), label, weights=cls_weights) # [N, M]
|
127 |
+
cls_loss = cls_loss_src.sum()
|
128 |
+
|
129 |
+
cls_loss = cls_loss * weight
|
130 |
+
return cls_loss
|
131 |
+
|
132 |
+
|
133 |
+
class PairedPointAttention(nn.Module):
|
134 |
+
def __init__(self, input_channel):
|
135 |
+
super().__init__()
|
136 |
+
self.edge_att1 = nn.Sequential(
|
137 |
+
nn.Linear(input_channel, input_channel),
|
138 |
+
nn.BatchNorm1d(input_channel),
|
139 |
+
nn.ReLU(),
|
140 |
+
nn.Linear(input_channel, input_channel),
|
141 |
+
nn.Sigmoid(),
|
142 |
+
)
|
143 |
+
self.edge_att2 = nn.Sequential(
|
144 |
+
nn.Linear(input_channel, input_channel),
|
145 |
+
nn.BatchNorm1d(input_channel),
|
146 |
+
nn.ReLU(),
|
147 |
+
nn.Linear(input_channel, input_channel),
|
148 |
+
nn.Sigmoid(),
|
149 |
+
)
|
150 |
+
self.fea_fusion_layer = nn.MaxPool1d(2)
|
151 |
+
|
152 |
+
self.num_output_feature = input_channel
|
153 |
+
|
154 |
+
def forward(self, point_fea1, point_fea2):
|
155 |
+
fusion_fea = point_fea1 + point_fea2
|
156 |
+
att1 = self.edge_att1(fusion_fea)
|
157 |
+
att2 = self.edge_att2(fusion_fea)
|
158 |
+
att_fea1 = point_fea1 * att1
|
159 |
+
att_fea2 = point_fea2 * att2
|
160 |
+
fea = torch.cat([att_fea1.unsqueeze(1), att_fea2.unsqueeze(1)], 1)
|
161 |
+
fea = self.fea_fusion_layer(fea.permute(0, 2, 1)).squeeze(-1)
|
162 |
+
return fea
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
|
model/model_utils.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import pc_util
|
6 |
+
from torch.autograd import Function, Variable
|
7 |
+
|
8 |
+
|
9 |
+
class Conv2ds(nn.Sequential):
|
10 |
+
def __init__(self, cns):
|
11 |
+
super().__init__()
|
12 |
+
for i in range(len(cns) - 1):
|
13 |
+
in_cn, out_cn = cns[i], cns[i + 1]
|
14 |
+
self.add_module('conv%d' % (i + 1), Conv2dBN(in_cn, out_cn))
|
15 |
+
|
16 |
+
|
17 |
+
class Conv2dBN(nn.Module):
|
18 |
+
def __init__(self, in_channel, out_channel):
|
19 |
+
super().__init__()
|
20 |
+
self.bn = nn.BatchNorm2d(out_channel)
|
21 |
+
self.conv = nn.Conv2d(in_channel, out_channel, 1)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
return self.bn(F.relu(self.conv(x), inplace=True))
|
25 |
+
|
26 |
+
|
27 |
+
class Conv1ds(nn.Sequential):
|
28 |
+
def __init__(self, cns):
|
29 |
+
super().__init__()
|
30 |
+
for i in range(len(cns) - 1):
|
31 |
+
in_cn, out_cn = cns[i], cns[i + 1]
|
32 |
+
self.add_module('conv%d' % (i + 1), Conv1dBN(in_cn, out_cn))
|
33 |
+
|
34 |
+
|
35 |
+
class Conv1dBN(nn.Module):
|
36 |
+
def __init__(self, in_channel, out_channel):
|
37 |
+
super().__init__()
|
38 |
+
self.bn = nn.BatchNorm1d(out_channel)
|
39 |
+
self.conv = nn.Conv1d(in_channel, out_channel, 1)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
return self.bn(F.relu(self.conv(x), inplace=True))
|
43 |
+
|
44 |
+
|
45 |
+
class Linears(nn.Sequential):
|
46 |
+
def __init__(self, cns):
|
47 |
+
super().__init__()
|
48 |
+
for i in range(len(cns) - 1):
|
49 |
+
in_cn, out_cn = cns[i], cns[i + 1]
|
50 |
+
self.add_module('linear%d' % (i + 1), LinearBN(in_cn, out_cn))
|
51 |
+
|
52 |
+
|
53 |
+
class LinearBN(nn.Module):
|
54 |
+
def __init__(self, in_channel, out_channel):
|
55 |
+
super().__init__()
|
56 |
+
self.bn = nn.BatchNorm1d(out_channel)
|
57 |
+
self.conv = nn.Linear(in_channel, out_channel)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
return self.bn(F.relu(self.conv(x), inplace=True))
|
61 |
+
|
62 |
+
|
63 |
+
def load_params_with_optimizer(net, filename, to_cpu=False, optimizer=None, logger=None):
|
64 |
+
if not os.path.isfile(filename):
|
65 |
+
raise FileNotFoundError
|
66 |
+
|
67 |
+
logger.info('==> Loading parameters from checkpoint')
|
68 |
+
checkpoint = torch.load(filename)
|
69 |
+
epoch = checkpoint.get('epoch', -1)
|
70 |
+
it = checkpoint.get('it', 0.0)
|
71 |
+
|
72 |
+
net.load_state_dict(checkpoint['model_state'])
|
73 |
+
|
74 |
+
if optimizer is not None:
|
75 |
+
logger.info('==> Loading optimizer parameters from checkpoint')
|
76 |
+
optimizer.load_state_dict(checkpoint['optimizer_state'])
|
77 |
+
|
78 |
+
logger.info('==> Done')
|
79 |
+
|
80 |
+
return it, epoch
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
def load_params(net, filename, logger=None):
|
85 |
+
if not os.path.isfile(filename):
|
86 |
+
raise FileNotFoundError
|
87 |
+
if logger is not None:
|
88 |
+
logger.info('==> Loading parameters from checkpoint')
|
89 |
+
checkpoint = torch.load(filename)
|
90 |
+
|
91 |
+
net.load_state_dict(checkpoint['model_state'])
|
92 |
+
if logger is not None:
|
93 |
+
logger.info('==> Done')
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
class DBSCANCluster(Function):
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def forward(ctx, eps: float, min_pts: int, point: torch.Tensor) -> torch.Tensor:
|
102 |
+
"""
|
103 |
+
:param ctx:
|
104 |
+
:param eps: float, dbscan eps
|
105 |
+
:param min_pts: int, dbscan core point threshold
|
106 |
+
:param point: (B, N, 3) xyz coordinates of the points
|
107 |
+
:return:
|
108 |
+
idx: (B, N) cluster idx
|
109 |
+
"""
|
110 |
+
point = point.contiguous()
|
111 |
+
|
112 |
+
B, N, _ = point.size()
|
113 |
+
idx = torch.cuda.IntTensor(B, N).zero_() - 1
|
114 |
+
|
115 |
+
pc_util.dbscan_wrapper(B, N, eps, min_pts, point, idx)
|
116 |
+
ctx.mark_non_differentiable(idx)
|
117 |
+
return idx
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def backward(ctx, grad_out):
|
121 |
+
return ()
|
122 |
+
|
123 |
+
|
124 |
+
dbscan_cluster = DBSCANCluster.apply
|
125 |
+
|
126 |
+
|
127 |
+
class GetClusterPts(Function):
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def forward(ctx, point: torch.Tensor, cluster_idx: torch.Tensor) -> torch.Tensor:
|
131 |
+
"""
|
132 |
+
:param ctx:
|
133 |
+
:param point: (B, N, 3) xyz coordinates of the points
|
134 |
+
:param cluster_idx: (B, N) cluster idx
|
135 |
+
:return:
|
136 |
+
key_pts: (B, M, 3) cluster center pts, M is max_num_cluster_class
|
137 |
+
num_cluster: (B, M) cluster num, num of pts in each cluster class
|
138 |
+
"""
|
139 |
+
cluster_idx = cluster_idx.contiguous()
|
140 |
+
|
141 |
+
B, N = cluster_idx.size()
|
142 |
+
M = torch.max(cluster_idx) +1
|
143 |
+
key_pts = torch.cuda.FloatTensor(B, M, 3).zero_()
|
144 |
+
num_cluster = torch.cuda.IntTensor(B, M).zero_()
|
145 |
+
pc_util.cluster_pts_wrapper(B, N, M, point, cluster_idx, key_pts, num_cluster)
|
146 |
+
key_pts[key_pts * 1e4 == 0] = -1e1
|
147 |
+
ctx.mark_non_differentiable(key_pts)
|
148 |
+
ctx.mark_non_differentiable(num_cluster)
|
149 |
+
return key_pts, num_cluster
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def backward(ctx, grad_out):
|
153 |
+
return ()
|
154 |
+
|
155 |
+
|
156 |
+
get_cluster_pts = GetClusterPts.apply
|
model/pointnet2.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from .pointnet_util import *
|
6 |
+
from .model_utils import *
|
7 |
+
from utils import loss_utils
|
8 |
+
|
9 |
+
|
10 |
+
class PointNet2(nn.Module):
|
11 |
+
def __init__(self, model_cfg, in_channel=3):
|
12 |
+
super().__init__()
|
13 |
+
self.model_cfg = model_cfg
|
14 |
+
self.sa1 = PointNetSAModule(256, 0.1, 16, in_channel, [32, 32, 64])
|
15 |
+
self.sa2 = PointNetSAModule(128, 0.2, 16, 64, [64, 64, 128])
|
16 |
+
self.sa3 = PointNetSAModule(64, 0.4, 16, 128, [128, 128, 256])
|
17 |
+
self.sa4 = PointNetSAModule(16, 0.8, 16, 256, [256, 256, 512])
|
18 |
+
self.fp4 = PointNetFPModule(768, [256, 256])
|
19 |
+
self.fp3 = PointNetFPModule(384, [256, 256])
|
20 |
+
self.fp2 = PointNetFPModule(320, [256, 128])
|
21 |
+
self.fp1 = PointNetFPModule(128, [128, 128, 128])
|
22 |
+
self.shared_fc = Conv1dBN(128, 128)
|
23 |
+
self.drop = nn.Dropout(0.5)
|
24 |
+
self.offset_fc = nn.Conv1d(128, 3, 1)
|
25 |
+
self.cls_fc = nn.Conv1d(128, 1, 1)
|
26 |
+
self.init_weights()
|
27 |
+
self.num_output_feature = 128
|
28 |
+
if self.training:
|
29 |
+
self.train_dict = {}
|
30 |
+
self.add_module(
|
31 |
+
'cls_loss_func',
|
32 |
+
loss_utils.SigmoidBCELoss()
|
33 |
+
)
|
34 |
+
self.add_module(
|
35 |
+
'reg_loss_func',
|
36 |
+
loss_utils.WeightedSmoothL1Loss()
|
37 |
+
)
|
38 |
+
self.loss_weight = self.model_cfg.LossWeight
|
39 |
+
|
40 |
+
def init_weights(self):
|
41 |
+
for m in self.modules():
|
42 |
+
if isinstance(m, nn.Conv2d):
|
43 |
+
nn.init.kaiming_normal_(m.weight)
|
44 |
+
if m.bias is not None:
|
45 |
+
nn.init.constant_(m.bias, 0)
|
46 |
+
if isinstance(m, nn.BatchNorm2d):
|
47 |
+
nn.init.constant_(m.weight, 1.0)
|
48 |
+
nn.init.constant_(m.bias, 0)
|
49 |
+
|
50 |
+
def forward(self, batch_dict):
|
51 |
+
xyz = batch_dict['points']
|
52 |
+
# vectors = batch_dict['vectors']
|
53 |
+
vectors = None
|
54 |
+
if self.training:
|
55 |
+
offset, cls = self.assign_targets(xyz, vectors, self.model_cfg.PosRadius)
|
56 |
+
self.train_dict.update({
|
57 |
+
'offset_label': offset,
|
58 |
+
'cls_label': cls
|
59 |
+
})
|
60 |
+
|
61 |
+
fea = xyz
|
62 |
+
l0_fea = fea.permute(0, 2, 1)
|
63 |
+
l0_xyz = xyz
|
64 |
+
|
65 |
+
l1_xyz, l1_fea = self.sa1(l0_xyz, l0_fea)
|
66 |
+
l2_xyz, l2_fea = self.sa2(l1_xyz, l1_fea)
|
67 |
+
l3_xyz, l3_fea = self.sa3(l2_xyz, l2_fea)
|
68 |
+
l4_xyz, l4_fea = self.sa4(l3_xyz, l3_fea)
|
69 |
+
|
70 |
+
l3_fea = self.fp4(l3_xyz, l4_xyz, l3_fea, l4_fea)
|
71 |
+
l2_fea = self.fp3(l2_xyz, l3_xyz, l2_fea, l3_fea)
|
72 |
+
l1_fea = self.fp2(l1_xyz, l2_xyz, l1_fea, l2_fea)
|
73 |
+
l0_fea = self.fp1(l0_xyz, l1_xyz, None, l1_fea)
|
74 |
+
|
75 |
+
x = self.drop(self.shared_fc(l0_fea))
|
76 |
+
pred_offset = self.offset_fc(x).permute(0, 2, 1)
|
77 |
+
pred_cls = self.cls_fc(x).permute(0, 2, 1)
|
78 |
+
if self.training:
|
79 |
+
self.train_dict.update({
|
80 |
+
'cls_pred': pred_cls,
|
81 |
+
'offset_pred': pred_offset
|
82 |
+
})
|
83 |
+
batch_dict['point_features'] = l0_fea.permute(0, 2, 1)
|
84 |
+
batch_dict['point_pred_score'] = torch.sigmoid(pred_cls).squeeze(-1)
|
85 |
+
batch_dict['point_pred_offset'] = pred_offset * self.model_cfg.PosRadius
|
86 |
+
return batch_dict
|
87 |
+
|
88 |
+
def loss(self, loss_dict, disp_dict):
|
89 |
+
pred_cls, pred_offset = self.train_dict['cls_pred'], self.train_dict['offset_pred']
|
90 |
+
label_cls, label_offset = self.train_dict['cls_label'], self.train_dict['offset_label']
|
91 |
+
cls_loss = self.get_cls_loss(pred_cls, label_cls, self.loss_weight['cls_weight'])
|
92 |
+
reg_loss = self.get_reg_loss(pred_offset, label_offset, label_cls, self.loss_weight['reg_weight'])
|
93 |
+
loss = cls_loss + reg_loss
|
94 |
+
loss_dict.update({
|
95 |
+
'pts_cls_loss': cls_loss.item(),
|
96 |
+
'pts_offset_loss': reg_loss.item(),
|
97 |
+
'pts_loss': loss.item()
|
98 |
+
})
|
99 |
+
|
100 |
+
pred_cls = pred_cls.squeeze(-1)
|
101 |
+
label_cls = label_cls.squeeze(-1)
|
102 |
+
pred_logit = torch.sigmoid(pred_cls)
|
103 |
+
pred = torch.where(pred_logit >= 0.5, pred_logit.new_ones(pred_logit.shape), pred_logit.new_zeros(pred_logit.shape))
|
104 |
+
acc = torch.sum((pred == label_cls) & (label_cls == 1)).item() / torch.sum(label_cls == 1).item()
|
105 |
+
#acc = torch.sum(pred == label_cls).item() / len(label_cls.view(-1))
|
106 |
+
disp_dict.update({'pts_acc': acc})
|
107 |
+
return loss, loss_dict, disp_dict
|
108 |
+
|
109 |
+
def get_cls_loss(self, pred, label, weight):
|
110 |
+
batch_size = int(pred.shape[0])
|
111 |
+
positives = label > 0
|
112 |
+
negatives = label == 0
|
113 |
+
cls_weights = (negatives * 1.0 + positives * 1.0).float()
|
114 |
+
pos_normalizer = positives.sum(1, keepdim=True).float()
|
115 |
+
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
|
116 |
+
cls_loss_src = self.cls_loss_func(pred.squeeze(-1), label, weights=cls_weights) # [N, M]
|
117 |
+
cls_loss = cls_loss_src.sum() / batch_size
|
118 |
+
|
119 |
+
cls_loss = cls_loss * weight
|
120 |
+
return cls_loss
|
121 |
+
|
122 |
+
def get_reg_loss(self, pred, label, cls_label, weight):
|
123 |
+
batch_size = int(pred.shape[0])
|
124 |
+
positives = cls_label > 0
|
125 |
+
reg_weights = positives.float()
|
126 |
+
pos_normalizer = positives.sum(1, keepdim=True).float()
|
127 |
+
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
|
128 |
+
reg_loss_src = self.reg_loss_func(pred, label, weights=reg_weights) # [N, M]
|
129 |
+
reg_loss = reg_loss_src.sum() / batch_size
|
130 |
+
reg_loss = reg_loss * weight
|
131 |
+
return reg_loss
|
132 |
+
|
133 |
+
def assign_targets(self, points, gvs, radius):
|
134 |
+
idx = ball_center_query(radius, points, gvs).type(torch.int64)
|
135 |
+
batch_size = gvs.size()[0]
|
136 |
+
idx_add = torch.arange(batch_size).to(idx.device).unsqueeze(-1).repeat(1, idx.shape[-1]) * gvs.shape[1]
|
137 |
+
gvs = gvs.view(-1, 3)
|
138 |
+
idx_add += idx
|
139 |
+
target_points = gvs[idx_add.view(-1)].view(batch_size, -1, 3)
|
140 |
+
dis = target_points - points
|
141 |
+
dis[idx < 0] = 0
|
142 |
+
dis /= radius
|
143 |
+
label = torch.where(idx >= 0, torch.ones(idx.shape).to(idx.device),
|
144 |
+
torch.zeros(idx.shape).to(idx.device))
|
145 |
+
return dis, label
|
146 |
+
|
147 |
+
|
148 |
+
class PointNetSAModuleMSG(nn.Module):
|
149 |
+
def __init__(self, npoint, radii, nsamples, in_channel, mlps, use_xyz=True):
|
150 |
+
"""
|
151 |
+
PointNet Set Abstraction Module
|
152 |
+
:param npoint: int
|
153 |
+
:param radii: list of float, radius in ball_query
|
154 |
+
:param nsamples: list of int, number of samples in ball_query
|
155 |
+
:param in_channel: int
|
156 |
+
:param mlps: list of list of int
|
157 |
+
:param use_xyz: bool
|
158 |
+
"""
|
159 |
+
super().__init__()
|
160 |
+
assert len(radii) == len(nsamples) == len(mlps)
|
161 |
+
mlps = [[in_channel] + mlp for mlp in mlps]
|
162 |
+
self.npoint = npoint
|
163 |
+
self.groupers = nn.ModuleList()
|
164 |
+
self.mlps = nn.ModuleList()
|
165 |
+
|
166 |
+
for i in range(len(radii)):
|
167 |
+
r = radii[i]
|
168 |
+
nsample = nsamples[i]
|
169 |
+
mlp = mlps[i]
|
170 |
+
if use_xyz:
|
171 |
+
mlp[0] += 3
|
172 |
+
self.groupers.append(QueryAndGroup(r, nsample, use_xyz) if npoint is not None else GroupAll(use_xyz))
|
173 |
+
self.mlps.append(Conv2ds(mlp))
|
174 |
+
|
175 |
+
def forward(self, xyz, features, new_xyz=None):
|
176 |
+
"""
|
177 |
+
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
|
178 |
+
:param features: (B, C, N) tensor of the descriptors of the the features
|
179 |
+
:param new_xyz:
|
180 |
+
:return:
|
181 |
+
new_xyz: (B, npoint, 3) tensor of the new features' xyz
|
182 |
+
new_features: (B, C1, npoint) tensor of the new_features descriptors
|
183 |
+
"""
|
184 |
+
new_features_list = []
|
185 |
+
xyz = xyz.contiguous()
|
186 |
+
xyz_flipped = xyz.permute(0, 2, 1)
|
187 |
+
if new_xyz is None:
|
188 |
+
new_xyz = gather_operation(xyz_flipped, furthest_point_sample(
|
189 |
+
xyz, self.npoint, 1.0, 0.0)).permute(0, 2, 1) if self.npoint is not None else None
|
190 |
+
|
191 |
+
for i in range(len(self.groupers)):
|
192 |
+
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
|
193 |
+
|
194 |
+
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
|
195 |
+
new_features = F.max_pool2d(new_features, kernel_size=[1, new_features.size(3)]).squeeze(-1)
|
196 |
+
new_features_list.append(new_features)
|
197 |
+
|
198 |
+
return new_xyz, torch.cat(new_features_list, dim=1)
|
199 |
+
|
200 |
+
|
201 |
+
class PointNetSAModule(PointNetSAModuleMSG):
|
202 |
+
def __init__(self, npoint, radius, nsample, in_channel, mlp, use_xyz=True):
|
203 |
+
super().__init__(npoint, [radius], [nsample], in_channel, [mlp], use_xyz)
|
204 |
+
|
205 |
+
|
206 |
+
class PointNetFPModule(nn.Module):
|
207 |
+
def __init__(self, in_channel, mlp):
|
208 |
+
super().__init__()
|
209 |
+
self.mlp = Conv2ds([in_channel] + mlp)
|
210 |
+
|
211 |
+
def forward(self, pts1, pts2, fea1, fea2):
|
212 |
+
"""
|
213 |
+
:param pts1: (B, n, 3)
|
214 |
+
:param pts2: (B, m, 3) n > m
|
215 |
+
:param fea1: (B, C1, n)
|
216 |
+
:param fea2: (B, C2, m)
|
217 |
+
:return:
|
218 |
+
new_features: (B, mlp[-1], n)
|
219 |
+
"""
|
220 |
+
if pts2 is not None:
|
221 |
+
dist, idx = three_nn(pts1, pts2)
|
222 |
+
dist_recip = 1.0 / (dist + 1e-8)
|
223 |
+
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
224 |
+
weight = dist_recip / norm
|
225 |
+
|
226 |
+
interpolated_feats = three_interpolate(fea2, idx, weight)
|
227 |
+
else:
|
228 |
+
interpolated_feats = fea2.expand(*fea2.size()[0:2], pts1.size(1))
|
229 |
+
|
230 |
+
if fea1 is not None:
|
231 |
+
new_features = torch.cat([interpolated_feats, fea1], dim=1) # (B, C2 + C1, n)
|
232 |
+
else:
|
233 |
+
new_features = interpolated_feats
|
234 |
+
|
235 |
+
new_features = new_features.unsqueeze(-1)
|
236 |
+
new_features = self.mlp(new_features)
|
237 |
+
|
238 |
+
return new_features.squeeze(-1)
|
239 |
+
|
240 |
+
|
241 |
+
class QueryAndGroup(nn.Module):
|
242 |
+
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
|
243 |
+
"""
|
244 |
+
:param radius: float, radius of ball
|
245 |
+
:param nsample: int, maximum number of features to gather in the ball
|
246 |
+
:param use_xyz:
|
247 |
+
"""
|
248 |
+
super().__init__()
|
249 |
+
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
250 |
+
|
251 |
+
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
|
252 |
+
"""
|
253 |
+
:param xyz: (B, N, 3) xyz coordinates of the features
|
254 |
+
:param new_xyz: (B, npoint, 3) centroids
|
255 |
+
:param features: (B, C, N) descriptors of the features
|
256 |
+
:return:
|
257 |
+
new_features: (B, 3 + C, npoint, nsample)
|
258 |
+
"""
|
259 |
+
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
260 |
+
# _, idx = pointnet_util.knn_query(self.nsample, xyz, new_xyz)
|
261 |
+
xyz_trans = xyz.permute(0, 2, 1)
|
262 |
+
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
263 |
+
grouped_xyz -= new_xyz.permute(0, 2, 1).unsqueeze(-1)
|
264 |
+
|
265 |
+
if features is not None:
|
266 |
+
grouped_features = grouping_operation(features, idx)
|
267 |
+
if self.use_xyz:
|
268 |
+
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
|
269 |
+
else:
|
270 |
+
new_features = grouped_features
|
271 |
+
else:
|
272 |
+
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
|
273 |
+
new_features = grouped_xyz
|
274 |
+
|
275 |
+
return new_features
|
276 |
+
|
277 |
+
|
278 |
+
class GroupAll(nn.Module):
|
279 |
+
def __init__(self, use_xyz: bool = True):
|
280 |
+
super().__init__()
|
281 |
+
self.use_xyz = use_xyz
|
282 |
+
|
283 |
+
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
|
284 |
+
"""
|
285 |
+
:param xyz: (B, N, 3) xyz coordinates of the features
|
286 |
+
:param new_xyz: ignored
|
287 |
+
:param features: (B, C, N) descriptors of the features
|
288 |
+
:return:
|
289 |
+
new_features: (B, C + 3, 1, N)
|
290 |
+
"""
|
291 |
+
grouped_xyz = xyz.permute(0, 2, 1).unsqueeze(2)
|
292 |
+
if features is not None:
|
293 |
+
grouped_features = features.unsqueeze(2)
|
294 |
+
if self.use_xyz:
|
295 |
+
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
|
296 |
+
else:
|
297 |
+
new_features = grouped_features
|
298 |
+
else:
|
299 |
+
new_features = grouped_xyz
|
300 |
+
|
301 |
+
return new_features
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
|
model/pointnet_stack_utils.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.autograd import Function, Variable
|
4 |
+
import pc_util
|
5 |
+
|
6 |
+
|
7 |
+
class BallQuery(Function):
|
8 |
+
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor,
|
11 |
+
new_xyz: torch.Tensor, new_xyz_batch_cnt):
|
12 |
+
"""
|
13 |
+
Args:
|
14 |
+
ctx:
|
15 |
+
radius: float, radius of the balls
|
16 |
+
nsample: int, maximum number of features in the balls
|
17 |
+
xyz: (N1 + N2 ..., 3) xyz coordinates of the features
|
18 |
+
xyz_batch_cnt: (batch_size), [N1, N2, ...]
|
19 |
+
new_xyz: (M1 + M2 ..., 3) centers of the ball query
|
20 |
+
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
idx: (M1 + M2, nsample) tensor with the indicies of the features that form the query balls
|
24 |
+
"""
|
25 |
+
assert new_xyz.is_contiguous()
|
26 |
+
assert new_xyz_batch_cnt.is_contiguous()
|
27 |
+
assert xyz.is_contiguous()
|
28 |
+
assert xyz_batch_cnt.is_contiguous()
|
29 |
+
|
30 |
+
B = xyz_batch_cnt.shape[0]
|
31 |
+
M = new_xyz.shape[0]
|
32 |
+
idx = torch.cuda.IntTensor(M, nsample).zero_()
|
33 |
+
|
34 |
+
pc_util.ball_query_wrapper_stack(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx)
|
35 |
+
empty_ball_mask = (idx[:, 0] == -1)
|
36 |
+
idx[empty_ball_mask] = 0
|
37 |
+
return idx, empty_ball_mask
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def backward(ctx, a=None):
|
41 |
+
return None, None, None, None
|
42 |
+
|
43 |
+
|
44 |
+
ball_query = BallQuery.apply
|
45 |
+
|
46 |
+
|
47 |
+
class GroupingOperation(Function):
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def forward(ctx, features: torch.Tensor, features_batch_cnt: torch.Tensor,
|
51 |
+
idx: torch.Tensor, idx_batch_cnt: torch.Tensor):
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
ctx:
|
55 |
+
features: (N1 + N2 ..., C) tensor of features to group
|
56 |
+
features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with
|
57 |
+
idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with
|
58 |
+
idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
output: (M1 + M2, C, nsample) tensor
|
62 |
+
"""
|
63 |
+
assert features.is_contiguous()
|
64 |
+
assert features_batch_cnt.is_contiguous()
|
65 |
+
assert idx.is_contiguous()
|
66 |
+
assert idx_batch_cnt.is_contiguous()
|
67 |
+
|
68 |
+
assert features.shape[0] == features_batch_cnt.sum(), \
|
69 |
+
'features: %s, features_batch_cnt: %s' % (str(features.shape), str(features_batch_cnt))
|
70 |
+
assert idx.shape[0] == idx_batch_cnt.sum(), \
|
71 |
+
'idx: %s, idx_batch_cnt: %s' % (str(idx.shape), str(idx_batch_cnt))
|
72 |
+
|
73 |
+
M, nsample = idx.size()
|
74 |
+
N, C = features.size()
|
75 |
+
B = idx_batch_cnt.shape[0]
|
76 |
+
output = torch.cuda.FloatTensor(M, C, nsample)
|
77 |
+
|
78 |
+
pc_util.group_points_wrapper_stack(B, M, C, nsample, features, features_batch_cnt, idx, idx_batch_cnt, output)
|
79 |
+
|
80 |
+
ctx.for_backwards = (B, N, idx, features_batch_cnt, idx_batch_cnt)
|
81 |
+
return output
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def backward(ctx, grad_out: torch.Tensor):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
ctx:
|
88 |
+
grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the output from forward
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
grad_features: (N1 + N2 ..., C) gradient of the features
|
92 |
+
"""
|
93 |
+
B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards
|
94 |
+
|
95 |
+
M, C, nsample = grad_out.size()
|
96 |
+
grad_features = Variable(torch.cuda.FloatTensor(N, C).zero_())
|
97 |
+
|
98 |
+
grad_out_data = grad_out.data.contiguous()
|
99 |
+
pc_util.group_points_grad_wrapper_stack(B, M, C, N, nsample, grad_out_data, idx,
|
100 |
+
idx_batch_cnt, features_batch_cnt, grad_features.data)
|
101 |
+
return grad_features, None, None, None
|
102 |
+
|
103 |
+
|
104 |
+
grouping_operation = GroupingOperation.apply
|
105 |
+
|
106 |
+
|
107 |
+
class QueryAndGroup(nn.Module):
|
108 |
+
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
|
109 |
+
"""
|
110 |
+
Args:
|
111 |
+
radius: float, radius of ball
|
112 |
+
nsample: int, maximum number of features to gather in the ball
|
113 |
+
use_xyz:
|
114 |
+
"""
|
115 |
+
super().__init__()
|
116 |
+
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
117 |
+
|
118 |
+
def forward(self, xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor,
|
119 |
+
new_xyz: torch.Tensor, new_xyz_batch_cnt: torch.Tensor,
|
120 |
+
features: torch.Tensor = None):
|
121 |
+
"""
|
122 |
+
Args:
|
123 |
+
xyz: (N1 + N2 ..., 3) xyz coordinates of the features
|
124 |
+
xyz_batch_cnt: (batch_size), [N1, N2, ...]
|
125 |
+
new_xyz: (M1 + M2 ..., 3) centers of the ball query
|
126 |
+
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
|
127 |
+
features: (N1 + N2 ..., C) tensor of features to group
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
new_features: (M1 + M2, C, nsample) tensor
|
131 |
+
"""
|
132 |
+
assert xyz.shape[0] == xyz_batch_cnt.sum(), 'xyz: %s, xyz_batch_cnt: %s' % (str(xyz.shape), str(new_xyz_batch_cnt))
|
133 |
+
assert new_xyz.shape[0] == new_xyz_batch_cnt.sum(), \
|
134 |
+
'new_xyz: %s, new_xyz_batch_cnt: %s' % (str(new_xyz.shape), str(new_xyz_batch_cnt))
|
135 |
+
|
136 |
+
# idx: (M1 + M2 ..., nsample), empty_ball_mask: (M1 + M2 ...)
|
137 |
+
idx, empty_ball_mask = ball_query(self.radius, self.nsample, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt)
|
138 |
+
grouped_xyz = grouping_operation(xyz, xyz_batch_cnt, idx, new_xyz_batch_cnt) # (M1 + M2, 3, nsample)
|
139 |
+
grouped_xyz -= new_xyz.unsqueeze(-1)
|
140 |
+
|
141 |
+
grouped_xyz[empty_ball_mask] = 0
|
142 |
+
|
143 |
+
if features is not None:
|
144 |
+
grouped_features = grouping_operation(features, xyz_batch_cnt, idx, new_xyz_batch_cnt) # (M1 + M2, C, nsample)
|
145 |
+
grouped_features[empty_ball_mask] = 0
|
146 |
+
if self.use_xyz:
|
147 |
+
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (M1 + M2 ..., C + 3, nsample)
|
148 |
+
else:
|
149 |
+
new_features = grouped_features
|
150 |
+
else:
|
151 |
+
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
|
152 |
+
new_features = grouped_xyz
|
153 |
+
|
154 |
+
return new_features, idx
|
155 |
+
|
156 |
+
|
157 |
+
class FurthestPointSampling(Function):
|
158 |
+
@staticmethod
|
159 |
+
def forward(ctx, xyz: torch.Tensor, npoint: int):
|
160 |
+
"""
|
161 |
+
Args:
|
162 |
+
ctx:
|
163 |
+
xyz: (B, N, 3) where N > npoint
|
164 |
+
npoint: int, number of features in the sampled set
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
output: (B, npoint) tensor containing the set
|
168 |
+
"""
|
169 |
+
assert xyz.is_contiguous()
|
170 |
+
|
171 |
+
B, N, _ = xyz.size()
|
172 |
+
output = torch.cuda.IntTensor(B, npoint)
|
173 |
+
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
174 |
+
|
175 |
+
pc_util.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
|
176 |
+
return output
|
177 |
+
|
178 |
+
@staticmethod
|
179 |
+
def backward(xyz, a=None):
|
180 |
+
return None, None
|
181 |
+
|
182 |
+
|
183 |
+
furthest_point_sample = FurthestPointSampling.apply
|
184 |
+
|
185 |
+
|
186 |
+
class ThreeNN(Function):
|
187 |
+
@staticmethod
|
188 |
+
def forward(ctx, unknown, unknown_batch_cnt, known, known_batch_cnt):
|
189 |
+
"""
|
190 |
+
Args:
|
191 |
+
ctx:
|
192 |
+
unknown: (N1 + N2..., 3)
|
193 |
+
unknown_batch_cnt: (batch_size), [N1, N2, ...]
|
194 |
+
known: (M1 + M2..., 3)
|
195 |
+
known_batch_cnt: (batch_size), [M1, M2, ...]
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors
|
199 |
+
idx: (N1 + N2 ..., 3) index of the three nearest neighbors, range [0, M1+M2+...]
|
200 |
+
"""
|
201 |
+
assert unknown.shape.__len__() == 2 and unknown.shape[1] == 3
|
202 |
+
assert known.shape.__len__() == 2 and known.shape[1] == 3
|
203 |
+
assert unknown_batch_cnt.__len__() == known_batch_cnt.__len__()
|
204 |
+
|
205 |
+
dist2 = unknown.new_zeros(unknown.shape)
|
206 |
+
idx = unknown_batch_cnt.new_zeros(unknown.shape).int()
|
207 |
+
|
208 |
+
pc_util.three_nn_wrapper_stack(
|
209 |
+
unknown.contiguous(), unknown_batch_cnt.contiguous(),
|
210 |
+
known.contiguous(), known_batch_cnt.contiguous(), dist2, idx
|
211 |
+
)
|
212 |
+
return torch.sqrt(dist2), idx
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def backward(ctx, a=None, b=None):
|
216 |
+
return None, None
|
217 |
+
|
218 |
+
|
219 |
+
three_nn = ThreeNN.apply
|
220 |
+
|
221 |
+
|
222 |
+
class ThreeInterpolate(Function):
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor):
|
226 |
+
"""
|
227 |
+
Args:
|
228 |
+
ctx:
|
229 |
+
features: (M1 + M2 ..., C)
|
230 |
+
idx: [N1 + N2 ..., 3]
|
231 |
+
weight: [N1 + N2 ..., 3]
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
out_tensor: (N1 + N2 ..., C)
|
235 |
+
"""
|
236 |
+
assert idx.shape[0] == weight.shape[0] and idx.shape[1] == weight.shape[1] == 3
|
237 |
+
|
238 |
+
ctx.three_interpolate_for_backward = (idx, weight, features.shape[0])
|
239 |
+
output = features.new_zeros((idx.shape[0], features.shape[1]))
|
240 |
+
pc_util.three_interpolate_wrapper_stack(features.contiguous(), idx.contiguous(), weight.contiguous(), output)
|
241 |
+
return output
|
242 |
+
|
243 |
+
@staticmethod
|
244 |
+
def backward(ctx, grad_out: torch.Tensor):
|
245 |
+
"""
|
246 |
+
Args:
|
247 |
+
ctx:
|
248 |
+
grad_out: (N1 + N2 ..., C)
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
grad_features: (M1 + M2 ..., C)
|
252 |
+
"""
|
253 |
+
idx, weight, M = ctx.three_interpolate_for_backward
|
254 |
+
grad_features = grad_out.new_zeros((M, grad_out.shape[1]))
|
255 |
+
pc_util.three_interpolate_grad_wrapper_stack(
|
256 |
+
grad_out.contiguous(), idx.contiguous(), weight.contiguous(), grad_features
|
257 |
+
)
|
258 |
+
return grad_features, None, None
|
259 |
+
|
260 |
+
|
261 |
+
three_interpolate = ThreeInterpolate.apply
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
pass
|
model/pointnet_util.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Variable
|
3 |
+
from torch.autograd import Function
|
4 |
+
import torch.nn as nn
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
import pc_util
|
8 |
+
|
9 |
+
|
10 |
+
# class FurthestPointSampling(Function):
|
11 |
+
# @staticmethod
|
12 |
+
# def forward(ctx, xyz: torch.Tensor, npoint: int, wd: float = 1.0, wf: float = 0.0) -> torch.Tensor:
|
13 |
+
# """
|
14 |
+
# Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
15 |
+
# minimum distance
|
16 |
+
# :param ctx:
|
17 |
+
# :param xyz: (B, N, C) where N > npoint
|
18 |
+
# :param npoint: int, number of features in the sampled set
|
19 |
+
# :param wd: float, weight of xyz distance
|
20 |
+
# :param wf: float, weight of fea distance
|
21 |
+
# :return:
|
22 |
+
# output: (B, npoint) tensor containing the set
|
23 |
+
# """
|
24 |
+
# xyz = xyz.contiguous()
|
25 |
+
|
26 |
+
# B, N, C = xyz.size()
|
27 |
+
# output = torch.cuda.IntTensor(B, npoint)
|
28 |
+
# temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
29 |
+
|
30 |
+
# pc_util.furthest_point_sampling_wrapper(B, C, N, npoint, wd, wf, xyz, temp, output)
|
31 |
+
# # pc_util.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
|
32 |
+
# ctx.mark_non_differentiable(output)
|
33 |
+
# return output
|
34 |
+
|
35 |
+
# @staticmethod
|
36 |
+
# def backward(ctx, grad_out):
|
37 |
+
# return ()
|
38 |
+
|
39 |
+
class FurthestPointSampling(Function):
|
40 |
+
@staticmethod
|
41 |
+
def forward(ctx, xyz: torch.Tensor, npoint: int, wd: float = 1.0, wf: float = 0.0) -> torch.Tensor:
|
42 |
+
"""
|
43 |
+
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
44 |
+
minimum distance.
|
45 |
+
:param ctx:
|
46 |
+
:param xyz: (B, N, C) where N > npoint
|
47 |
+
:param npoint: int, number of features in the sampled set
|
48 |
+
:param wd: float, weight of xyz distance
|
49 |
+
:param wf: float, weight of fea distance
|
50 |
+
:return:
|
51 |
+
output: (B, npoint) tensor containing the set
|
52 |
+
"""
|
53 |
+
if not torch.cuda.is_available():
|
54 |
+
raise RuntimeError("CUDA is not available, and no CPU fallback is implemented.")
|
55 |
+
|
56 |
+
xyz = xyz.contiguous()
|
57 |
+
|
58 |
+
B, N, C = xyz.size()
|
59 |
+
device = torch.device('cuda')
|
60 |
+
output = torch.zeros(B, npoint, dtype=torch.int32, device=device)
|
61 |
+
temp = torch.full((B, N), 1e10, dtype=torch.float32, device=device)
|
62 |
+
|
63 |
+
pc_util.furthest_point_sampling_wrapper(B, C, N, npoint, wd, wf, xyz, temp, output)
|
64 |
+
ctx.mark_non_differentiable(output)
|
65 |
+
return output
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def backward(ctx, grad_out):
|
69 |
+
return ()
|
70 |
+
|
71 |
+
furthest_point_sample = FurthestPointSampling.apply
|
72 |
+
|
73 |
+
|
74 |
+
# class GatherOperation(Function):
|
75 |
+
|
76 |
+
# @staticmethod
|
77 |
+
# def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
78 |
+
# """
|
79 |
+
# :param ctx:
|
80 |
+
# :param features: (B, C, N)
|
81 |
+
# :param idx: (B, npoint) index tensor of the features to gather
|
82 |
+
# :return:
|
83 |
+
# output: (B, C, npoint)
|
84 |
+
# """
|
85 |
+
# features = features.contiguous()
|
86 |
+
# idx = idx.contiguous()
|
87 |
+
|
88 |
+
# B, npoint = idx.size()
|
89 |
+
# _, C, N = features.size()
|
90 |
+
# output = torch.cuda.FloatTensor(B, C, npoint)
|
91 |
+
|
92 |
+
# pc_util.gather_points_wrapper(B, C, N, npoint, features, idx, output)
|
93 |
+
|
94 |
+
# ctx.save_for_backwards = (idx, features)
|
95 |
+
# return output
|
96 |
+
|
97 |
+
# @staticmethod
|
98 |
+
# def backward(ctx, grad_out):
|
99 |
+
# idx, features = ctx.saved_tensors
|
100 |
+
# B, npoint = idx.size()
|
101 |
+
# _, C, N = features.size()
|
102 |
+
|
103 |
+
# grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
104 |
+
# grad_out_data = grad_out.data.contiguous()
|
105 |
+
# pc_util.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
|
106 |
+
# return grad_features, None
|
107 |
+
|
108 |
+
class GatherOperation(Function):
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
112 |
+
"""
|
113 |
+
:param ctx:
|
114 |
+
:param features: (B, C, N)
|
115 |
+
:param idx: (B, npoint) index tensor of the features to gather
|
116 |
+
:return:
|
117 |
+
output: (B, C, npoint)
|
118 |
+
"""
|
119 |
+
if not torch.cuda.is_available():
|
120 |
+
raise RuntimeError("CUDA is not available, and no CPU fallback is implemented.")
|
121 |
+
|
122 |
+
features = features.contiguous()
|
123 |
+
idx = idx.contiguous()
|
124 |
+
|
125 |
+
B, npoint = idx.size()
|
126 |
+
_, C, N = features.size()
|
127 |
+
device = torch.device('cuda')
|
128 |
+
output = torch.zeros(B, C, npoint, dtype=torch.float32, device=device)
|
129 |
+
|
130 |
+
pc_util.gather_points_wrapper(B, C, N, npoint, features, idx, output)
|
131 |
+
|
132 |
+
ctx.save_for_backward(idx, features)
|
133 |
+
return output
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def backward(ctx, grad_out):
|
137 |
+
idx, features = ctx.saved_tensors
|
138 |
+
B, npoint = idx.size()
|
139 |
+
_, C, N = features.size()
|
140 |
+
|
141 |
+
device = torch.device('cuda')
|
142 |
+
grad_features = torch.zeros(B, C, N, dtype=torch.float32, device=device)
|
143 |
+
grad_out_data = grad_out.contiguous()
|
144 |
+
pc_util.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features)
|
145 |
+
return grad_features, None
|
146 |
+
|
147 |
+
gather_operation = GatherOperation.apply
|
148 |
+
|
149 |
+
|
150 |
+
class ThreeNN(Function):
|
151 |
+
|
152 |
+
@staticmethod
|
153 |
+
def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
154 |
+
"""
|
155 |
+
Find the three nearest neighbors of unknown in known
|
156 |
+
:param ctx:
|
157 |
+
:param unknown: (B, N, 3)
|
158 |
+
:param known: (B, M, 3)
|
159 |
+
:return:
|
160 |
+
dist: (B, N, 3) l2 distance to the three nearest neighbors
|
161 |
+
idx: (B, N, 3) index of 3 nearest neighbors
|
162 |
+
"""
|
163 |
+
unknown = unknown.contiguous()
|
164 |
+
known = known.contiguous()
|
165 |
+
|
166 |
+
B, N, _ = unknown.size()
|
167 |
+
m = known.size(1)
|
168 |
+
dist2 = torch.cuda.FloatTensor(B, N, 3)
|
169 |
+
idx = torch.cuda.IntTensor(B, N, 3)
|
170 |
+
|
171 |
+
pc_util.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
|
172 |
+
return torch.sqrt(dist2), idx
|
173 |
+
|
174 |
+
@staticmethod
|
175 |
+
def backward(ctx, a=None, b=None):
|
176 |
+
return ()
|
177 |
+
|
178 |
+
|
179 |
+
three_nn = ThreeNN.apply
|
180 |
+
|
181 |
+
|
182 |
+
class ThreeInterpolate(Function):
|
183 |
+
|
184 |
+
@staticmethod
|
185 |
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
186 |
+
"""
|
187 |
+
Performs weight linear interpolation on 3 features
|
188 |
+
:param ctx:
|
189 |
+
:param features: (B, C, M) Features descriptors to be interpolated from
|
190 |
+
:param idx: (B, n, 3) three nearest neighbors of the target features in features
|
191 |
+
:param weight: (B, n, 3) weights
|
192 |
+
:return:
|
193 |
+
output: (B, C, N) tensor of the interpolated features
|
194 |
+
"""
|
195 |
+
features = features.contiguous()
|
196 |
+
idx = idx.contiguous()
|
197 |
+
weight = weight.contiguous()
|
198 |
+
|
199 |
+
B, c, m = features.size()
|
200 |
+
n = idx.size(1)
|
201 |
+
ctx.save_for_backward(idx, weight, features)
|
202 |
+
output = torch.cuda.FloatTensor(B, c, n)
|
203 |
+
|
204 |
+
pc_util.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
|
205 |
+
return output
|
206 |
+
|
207 |
+
@staticmethod
|
208 |
+
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
209 |
+
"""
|
210 |
+
:param ctx:
|
211 |
+
:param grad_out: (B, C, N) tensor with gradients of outputs
|
212 |
+
:return:
|
213 |
+
grad_features: (B, C, M) tensor with gradients of features
|
214 |
+
None:
|
215 |
+
None:
|
216 |
+
"""
|
217 |
+
idx, weight, features = ctx.saved_tensors
|
218 |
+
m = features.size(2)
|
219 |
+
B, c, n = grad_out.size()
|
220 |
+
|
221 |
+
grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
|
222 |
+
grad_out_data = grad_out.data.contiguous()
|
223 |
+
|
224 |
+
pc_util.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
|
225 |
+
return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
|
226 |
+
|
227 |
+
|
228 |
+
three_interpolate = ThreeInterpolate.apply
|
229 |
+
|
230 |
+
|
231 |
+
# class GroupingOperation(Function):
|
232 |
+
|
233 |
+
# @staticmethod
|
234 |
+
# def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
235 |
+
# """
|
236 |
+
# :param ctx:
|
237 |
+
# :param features: (B, C, N) tensor of features to group
|
238 |
+
# :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
|
239 |
+
# :return:
|
240 |
+
# output: (B, C, npoint, nsample) tensor
|
241 |
+
# """
|
242 |
+
# features = features.contiguous()
|
243 |
+
# idx = idx.contiguous()
|
244 |
+
|
245 |
+
# B, nfeatures, nsample = idx.size()
|
246 |
+
# _, C, N = features.size()
|
247 |
+
# output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
248 |
+
|
249 |
+
# pc_util.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
|
250 |
+
|
251 |
+
# ctx.save_for_backward(idx, features)
|
252 |
+
# return output
|
253 |
+
|
254 |
+
# @staticmethod
|
255 |
+
# def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
256 |
+
# """
|
257 |
+
# :param ctx:
|
258 |
+
# :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
|
259 |
+
# :return:
|
260 |
+
# grad_features: (B, C, N) gradient of the features
|
261 |
+
# """
|
262 |
+
# idx, features = ctx.saved_tensors
|
263 |
+
# N = features.size(2)
|
264 |
+
|
265 |
+
# B, C, npoint, nsample = grad_out.size()
|
266 |
+
# grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
267 |
+
|
268 |
+
# grad_out_data = grad_out.data.contiguous()
|
269 |
+
# pc_util.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
|
270 |
+
# return grad_features, torch.zeros_like(idx)
|
271 |
+
|
272 |
+
class GroupingOperation(Function):
|
273 |
+
|
274 |
+
@staticmethod
|
275 |
+
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
276 |
+
"""
|
277 |
+
:param ctx:
|
278 |
+
:param features: (B, C, N) tensor of features to group
|
279 |
+
:param idx: (B, npoint, nsample) tensor containing the indices of features to group with
|
280 |
+
:return:
|
281 |
+
output: (B, C, npoint, nsample) tensor
|
282 |
+
"""
|
283 |
+
if not torch.cuda.is_available():
|
284 |
+
raise RuntimeError("CUDA is not available, and no CPU fallback is implemented.")
|
285 |
+
|
286 |
+
features = features.contiguous()
|
287 |
+
idx = idx.contiguous()
|
288 |
+
|
289 |
+
B, npoint, nsample = idx.size()
|
290 |
+
_, C, N = features.size()
|
291 |
+
device = torch.device('cuda')
|
292 |
+
output = torch.zeros(B, C, npoint, nsample, dtype=torch.float32, device=device)
|
293 |
+
|
294 |
+
pc_util.group_points_wrapper(B, C, N, npoint, nsample, features, idx, output)
|
295 |
+
|
296 |
+
ctx.save_for_backward(idx, features)
|
297 |
+
return output
|
298 |
+
|
299 |
+
@staticmethod
|
300 |
+
def backward(ctx, grad_out: torch.Tensor) -> torch.Tensor:
|
301 |
+
"""
|
302 |
+
:param ctx:
|
303 |
+
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
|
304 |
+
:return:
|
305 |
+
grad_features: (B, C, N) gradient of the features
|
306 |
+
"""
|
307 |
+
if not torch.cuda.is_available():
|
308 |
+
raise RuntimeError("CUDA is not available, and no CPU fallback is implemented.")
|
309 |
+
|
310 |
+
idx, features = ctx.saved_tensors
|
311 |
+
B, C, N = features.size()
|
312 |
+
|
313 |
+
_, _, npoint, nsample = grad_out.size()
|
314 |
+
device = torch.device('cuda')
|
315 |
+
grad_features = torch.zeros(B, C, N, dtype=torch.float32, device=device)
|
316 |
+
|
317 |
+
grad_out_data = grad_out.contiguous()
|
318 |
+
pc_util.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features)
|
319 |
+
return grad_features, torch.zeros_like(idx)
|
320 |
+
|
321 |
+
grouping_operation = GroupingOperation.apply
|
322 |
+
|
323 |
+
|
324 |
+
# class BallQuery(Function):
|
325 |
+
|
326 |
+
# @staticmethod
|
327 |
+
# def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
|
328 |
+
# """
|
329 |
+
# :param ctx:
|
330 |
+
# :param radius: float, radius of the balls
|
331 |
+
# :param nsample: int, maximum number of features in the balls
|
332 |
+
# :param xyz: (B, N, 3) xyz coordinates of the features
|
333 |
+
# :param new_xyz: (B, npoint, 3) centers of the ball query
|
334 |
+
# :return:
|
335 |
+
# idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
336 |
+
# """
|
337 |
+
# new_xyz = new_xyz.contiguous()
|
338 |
+
# xyz = xyz.contiguous()
|
339 |
+
|
340 |
+
# B, N, _ = xyz.size()
|
341 |
+
# npoint = new_xyz.size(1)
|
342 |
+
# idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
|
343 |
+
|
344 |
+
# pc_util.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
|
345 |
+
# ctx.mark_non_differentiable(idx)
|
346 |
+
# return idx
|
347 |
+
|
348 |
+
# @staticmethod
|
349 |
+
# def backward(ctx, grad_out):
|
350 |
+
# return ()
|
351 |
+
|
352 |
+
class BallQuery(Function):
|
353 |
+
|
354 |
+
@staticmethod
|
355 |
+
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
|
356 |
+
"""
|
357 |
+
:param ctx:
|
358 |
+
:param radius: float, radius of the balls
|
359 |
+
:param nsample: int, maximum number of features in the balls
|
360 |
+
:param xyz: (B, N, 3) xyz coordinates of the features
|
361 |
+
:param new_xyz: (B, npoint, 3) centers of the ball query
|
362 |
+
:return:
|
363 |
+
idx: (B, npoint, nsample) tensor with the indices of the features that form the query balls
|
364 |
+
"""
|
365 |
+
if not torch.cuda.is_available():
|
366 |
+
raise RuntimeError("CUDA is not available, and no CPU fallback is implemented.")
|
367 |
+
|
368 |
+
new_xyz = new_xyz.contiguous()
|
369 |
+
xyz = xyz.contiguous()
|
370 |
+
|
371 |
+
B, N, _ = xyz.size()
|
372 |
+
npoint = new_xyz.size(1)
|
373 |
+
device = torch.device('cuda')
|
374 |
+
idx = torch.zeros(B, npoint, nsample, dtype=torch.int32, device=device)
|
375 |
+
|
376 |
+
pc_util.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
|
377 |
+
ctx.mark_non_differentiable(idx)
|
378 |
+
return idx
|
379 |
+
|
380 |
+
@staticmethod
|
381 |
+
def backward(ctx, grad_out):
|
382 |
+
return ()
|
383 |
+
|
384 |
+
ball_query = BallQuery.apply
|
385 |
+
|
386 |
+
|
387 |
+
# class BallCenterQuery(Function):
|
388 |
+
|
389 |
+
# @staticmethod
|
390 |
+
# def forward(ctx, radius: float, point: torch.Tensor, key_point: torch.Tensor) -> torch.Tensor:
|
391 |
+
# """
|
392 |
+
# :param ctx:
|
393 |
+
# :param radius: float, radius of the balls
|
394 |
+
# :param point: (B, N, 3) xyz coordinates of the features
|
395 |
+
# :param key_point: (B, npoint, 3) centers of the ball query
|
396 |
+
# :return:
|
397 |
+
# idx: (B, N) tensor with the indicies of the features that form the query balls
|
398 |
+
# """
|
399 |
+
# point = point.contiguous()
|
400 |
+
# key_point = key_point.contiguous()
|
401 |
+
|
402 |
+
# B, N, _ = point.size()
|
403 |
+
# npoint = key_point.size(1)
|
404 |
+
# idx = torch.cuda.IntTensor(B, N).zero_() - 1
|
405 |
+
|
406 |
+
# pc_util.ball_center_query_wrapper(B, N, npoint, radius, point, key_point, idx)
|
407 |
+
# ctx.mark_non_differentiable(idx)
|
408 |
+
# return idx
|
409 |
+
|
410 |
+
# @staticmethod
|
411 |
+
# def backward(ctx, grad_out):
|
412 |
+
# return ()
|
413 |
+
|
414 |
+
class BallCenterQuery(Function):
|
415 |
+
|
416 |
+
@staticmethod
|
417 |
+
def forward(ctx, radius: float, point: torch.Tensor, key_point: torch.Tensor) -> torch.Tensor:
|
418 |
+
"""
|
419 |
+
:param ctx:
|
420 |
+
:param radius: float, radius of the balls
|
421 |
+
:param point: (B, N, 3) xyz coordinates of the features
|
422 |
+
:param key_point: (B, npoint, 3) centers of the ball query
|
423 |
+
:return:
|
424 |
+
idx: (B, N) tensor with the indices of the features that form the query balls
|
425 |
+
"""
|
426 |
+
if not torch.cuda.is_available():
|
427 |
+
raise RuntimeError("CUDA is not available, and no CPU fallback is implemented.")
|
428 |
+
|
429 |
+
point = point.contiguous()
|
430 |
+
key_point = key_point.contiguous()
|
431 |
+
|
432 |
+
B, N, _ = point.size()
|
433 |
+
npoint = key_point.size(1)
|
434 |
+
device = torch.device('cuda')
|
435 |
+
idx = torch.full((B, N), -1, dtype=torch.int32, device=device)
|
436 |
+
|
437 |
+
pc_util.ball_center_query_wrapper(B, N, npoint, radius, point, key_point, idx)
|
438 |
+
ctx.mark_non_differentiable(idx)
|
439 |
+
return idx
|
440 |
+
|
441 |
+
@staticmethod
|
442 |
+
def backward(ctx, grad_out):
|
443 |
+
return ()
|
444 |
+
|
445 |
+
ball_center_query = BallCenterQuery.apply
|
446 |
+
|
447 |
+
|
448 |
+
import numpy as np
|
449 |
+
|
450 |
+
|
451 |
+
# class KNNQuery(Function):
|
452 |
+
|
453 |
+
# @staticmethod
|
454 |
+
# def forward(ctx, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
455 |
+
# """
|
456 |
+
# Find the three nearest neighbors of unknown in known
|
457 |
+
# :param ctx:
|
458 |
+
# :param nsample: int, number of features in knn
|
459 |
+
# :param xyz: (B, N, 3)
|
460 |
+
# :param new_xyz: (B, npoint, 3)
|
461 |
+
# :return:
|
462 |
+
# dist: (B, npoint, nsample) l2 distance to knn
|
463 |
+
# idx: (B, npoint, nsample) index of knn
|
464 |
+
# """
|
465 |
+
# new_xyz = new_xyz.contiguous()
|
466 |
+
# xyz = xyz.contiguous()
|
467 |
+
|
468 |
+
# B, N, _ = xyz.size()
|
469 |
+
# npoint = new_xyz.size(1)
|
470 |
+
# dist2 = torch.cuda.FloatTensor(np.ones([B, npoint, nsample]) * 1e4)
|
471 |
+
# idx = torch.cuda.IntTensor(B, npoint, nsample)
|
472 |
+
|
473 |
+
# pc_util.knn_query_wrapper(B, N, npoint, nsample, new_xyz, xyz, dist2, idx)
|
474 |
+
# ctx.mark_non_differentiable(dist2, idx)
|
475 |
+
# return torch.sqrt(dist2), idx
|
476 |
+
|
477 |
+
# @staticmethod
|
478 |
+
# def backward(ctx, grad_out):
|
479 |
+
# return ()
|
480 |
+
|
481 |
+
class KNNQuery(Function):
|
482 |
+
|
483 |
+
@staticmethod
|
484 |
+
def forward(ctx, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
485 |
+
"""
|
486 |
+
Find the k nearest neighbors of unknown in known
|
487 |
+
:param ctx:
|
488 |
+
:param nsample: int, number of features in knn
|
489 |
+
:param xyz: (B, N, 3)
|
490 |
+
:param new_xyz: (B, npoint, 3)
|
491 |
+
:return:
|
492 |
+
dist: (B, npoint, nsample) l2 distance to knn
|
493 |
+
idx: (B, npoint, nsample) index of knn
|
494 |
+
"""
|
495 |
+
if not torch.cuda.is_available():
|
496 |
+
raise RuntimeError("CUDA is not available, and no CPU fallback is implemented.")
|
497 |
+
|
498 |
+
new_xyz = new_xyz.contiguous()
|
499 |
+
xyz = xyz.contiguous()
|
500 |
+
|
501 |
+
B, N, _ = xyz.size()
|
502 |
+
npoint = new_xyz.size(1)
|
503 |
+
device = torch.device('cuda')
|
504 |
+
dist2 = torch.full((B, npoint, nsample), 1e4, dtype=torch.float32, device=device)
|
505 |
+
idx = torch.zeros((B, npoint, nsample), dtype=torch.int32, device=device)
|
506 |
+
|
507 |
+
pc_util.knn_query_wrapper(B, N, npoint, nsample, new_xyz, xyz, dist2, idx)
|
508 |
+
ctx.mark_non_differentiable(dist2, idx)
|
509 |
+
return torch.sqrt(dist2), idx
|
510 |
+
|
511 |
+
@staticmethod
|
512 |
+
def backward(ctx, grad_out):
|
513 |
+
return ()
|
514 |
+
|
515 |
+
knn_query = KNNQuery.apply
|
516 |
+
|
517 |
+
|
518 |
+
|
model/roofnet.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .pointnet2 import PointNet2
|
2 |
+
from .cluster_refine import ClusterRefineNet
|
3 |
+
from .edge_pred_net import EdgeAttentionNet
|
4 |
+
import torch.nn as nn
|
5 |
+
from sklearn.cluster import DBSCAN
|
6 |
+
|
7 |
+
|
8 |
+
class RoofNet(nn.Module):
|
9 |
+
def __init__(self, model_cfg, input_channel=3):
|
10 |
+
super().__init__()
|
11 |
+
self.use_edge = False
|
12 |
+
self.model_cfg = model_cfg
|
13 |
+
self.keypoint_det_net = PointNet2(model_cfg.PointNet2, input_channel)
|
14 |
+
self.cluster_refine_net = ClusterRefineNet(model_cfg.ClusterRefineNet, input_channel=self.keypoint_det_net.num_output_feature)
|
15 |
+
self.edge_att_net = EdgeAttentionNet(model_cfg.EdgeAttentionNet, input_channel=self.cluster_refine_net.num_output_feature)
|
16 |
+
|
17 |
+
def forward(self, batch_dict):
|
18 |
+
batch_dict = self.keypoint_det_net(batch_dict)
|
19 |
+
if self.use_edge:
|
20 |
+
batch_dict = self.cluster_refine_net(batch_dict)
|
21 |
+
batch_dict = self.edge_att_net(batch_dict)
|
22 |
+
if self.training:
|
23 |
+
loss = 0
|
24 |
+
loss_dict = {}
|
25 |
+
disp_dict = {}
|
26 |
+
tmp_loss, loss_dict, disp_dict = self.keypoint_det_net.loss(loss_dict, disp_dict)
|
27 |
+
loss += tmp_loss
|
28 |
+
if self.use_edge:
|
29 |
+
tmp_loss, loss_dict, disp_dict = self.cluster_refine_net.loss(loss_dict, disp_dict)
|
30 |
+
loss += tmp_loss
|
31 |
+
tmp_loss, loss_dict, disp_dict = self.edge_att_net.loss(loss_dict, disp_dict)
|
32 |
+
loss += tmp_loss
|
33 |
+
return loss, loss_dict, disp_dict
|
34 |
+
else:
|
35 |
+
return batch_dict
|
model_cfg.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA:
|
2 |
+
NPOINT: 1024
|
3 |
+
MODEL:
|
4 |
+
PointNet2:
|
5 |
+
PosRadius: 0.15
|
6 |
+
LossWeight: {
|
7 |
+
'cls_weight': 1.0,
|
8 |
+
'reg_weight': 1.0
|
9 |
+
}
|
10 |
+
ClusterRefineNet:
|
11 |
+
ScoreThresh: 0.5
|
12 |
+
MatchRadius: 0.2
|
13 |
+
Cluster:
|
14 |
+
eps: 0.05
|
15 |
+
min_pts: 5
|
16 |
+
RefineSA:
|
17 |
+
Radii: [0.1, 0.2]
|
18 |
+
Nsamples: [16, 16]
|
19 |
+
MLPs: [[128, 128], [128, 128]]
|
20 |
+
LossWeight: {
|
21 |
+
'reg_weight': 1.0
|
22 |
+
}
|
23 |
+
EdgeAttentionNet:
|
24 |
+
LossWeight: {
|
25 |
+
'cls_weight': 1.0,
|
26 |
+
}
|
output/hoho_test/checkpoint_epoch_90_all.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:beadacd967ce200405fafbb9a3434191606dbda4ee4ab70ddf7f861064861cf7
|
3 |
+
size 16980109
|
output/hoho_test/test/log.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
output/hoho_test/test/submission.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:199824086b9925e081aa30308e43ab1e5c7c269907197516d67582ab75801008
|
3 |
+
size 54126
|
output/hoho_train/ckpt/checkpoint_epoch_41.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f4a7fbbbad3fb17933ecab37a19e7555d6134445d092779c0934a7128e9ab8b
|
3 |
+
size 17019805
|
output/hoho_train/ckpt/checkpoint_epoch_42.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:220030fbbe46caaddc307c32f26e8a9096ef128b0b9b68701a05acaa3b5e6520
|
3 |
+
size 17019805
|
output/hoho_train/ckpt/checkpoint_epoch_43.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:61afd8b6e889507930257152cabe814d6f938c76f199fac8cf3e2ad893364abe
|
3 |
+
size 17019805
|
output/hoho_train/ckpt/checkpoint_epoch_44.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f674dfe549c1ce1988af4f91397a14cca5004178933d7902ae71316fa494bf12
|
3 |
+
size 17019805
|
output/hoho_train/ckpt/checkpoint_epoch_45.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8846e86eff7717407c72704e9b4757c429ca8103cf800c2bfb10c27c58abab39
|
3 |
+
size 17019805
|
output/hoho_train/log.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-05-28 18:15:32,724 INFO **********************Start logging**********************
|
2 |
+
2024-05-28 18:15:32,725 INFO Total samples: 4328
|
3 |
+
2024-05-28 18:15:33,700 INFO **********************Start training**********************
|
4 |
+
2024-05-29 16:17:23,568 INFO **********************Start logging**********************
|
5 |
+
2024-05-29 16:17:23,578 INFO Total samples: 4328
|
6 |
+
2024-05-29 16:17:24,693 INFO ==> Loading parameters from checkpoint
|
7 |
+
2024-05-29 16:17:24,732 INFO ==> Loading optimizer parameters from checkpoint
|
8 |
+
2024-05-29 16:17:24,740 INFO ==> Done
|
9 |
+
2024-05-29 16:17:24,740 INFO **********************Start training**********************
|
pc_util/setup.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
3 |
+
|
4 |
+
setup(
|
5 |
+
name='pc_util',
|
6 |
+
version='1.0',
|
7 |
+
ext_modules=[
|
8 |
+
CUDAExtension('pc_util', [
|
9 |
+
'src/pointnet2_api.cpp',
|
10 |
+
'src/ball_query.cpp',
|
11 |
+
'src/ball_query_gpu.cu',
|
12 |
+
'src/group_points.cpp',
|
13 |
+
'src/group_points_gpu.cu',
|
14 |
+
'src/interpolate.cpp',
|
15 |
+
'src/interpolate_gpu.cu',
|
16 |
+
'src/sampling.cpp',
|
17 |
+
'src/sampling_gpu.cu',
|
18 |
+
'src/cluster.cpp',
|
19 |
+
'src/cluster_gpu.cu',
|
20 |
+
], extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']})
|
21 |
+
],
|
22 |
+
cmdclass={'build_ext': BuildExtension}
|
23 |
+
)
|
pc_util/src/ball_query.cpp
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <vector>
|
3 |
+
// #include <THC/THC.h>
|
4 |
+
#include <cuda.h>
|
5 |
+
#include <cuda_runtime_api.h>
|
6 |
+
#include "ball_query_gpu.h"
|
7 |
+
|
8 |
+
// extern THCState *state;
|
9 |
+
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <ATen/cuda/CUDAEvent.h>
|
12 |
+
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
13 |
+
|
14 |
+
#define CHECK_CUDA(x) do { \
|
15 |
+
if (!x.type().is_cuda()) { \
|
16 |
+
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
|
17 |
+
exit(-1); \
|
18 |
+
} \
|
19 |
+
} while (0)
|
20 |
+
#define CHECK_CONTIGUOUS(x) do { \
|
21 |
+
if (!x.is_contiguous()) { \
|
22 |
+
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
|
23 |
+
exit(-1); \
|
24 |
+
} \
|
25 |
+
} while (0)
|
26 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
|
27 |
+
|
28 |
+
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
|
29 |
+
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
|
30 |
+
CHECK_INPUT(new_xyz_tensor);
|
31 |
+
CHECK_INPUT(xyz_tensor);
|
32 |
+
const float *new_xyz = new_xyz_tensor.data<float>();
|
33 |
+
const float *xyz = xyz_tensor.data<float>();
|
34 |
+
int *idx = idx_tensor.data<int>();
|
35 |
+
|
36 |
+
ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx);
|
37 |
+
return 1;
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
int ball_center_query_wrapper_fast(int b, int n, int m, float radius,
|
42 |
+
at::Tensor point_tensor, at::Tensor key_point_tensor, at::Tensor idx_tensor) {
|
43 |
+
CHECK_INPUT(point_tensor);
|
44 |
+
CHECK_INPUT(key_point_tensor);
|
45 |
+
const float *point = point_tensor.data<float>();
|
46 |
+
const float *key_point = key_point_tensor.data<float>();
|
47 |
+
int *idx = idx_tensor.data<int>();
|
48 |
+
|
49 |
+
ball_center_query_kernel_launcher_fast(b, n, m, radius, point, key_point, idx);
|
50 |
+
return 1;
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
int knn_query_wrapper_fast(int b, int n, int m, int nsample,
|
55 |
+
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
|
56 |
+
CHECK_INPUT(new_xyz_tensor);
|
57 |
+
CHECK_INPUT(xyz_tensor);
|
58 |
+
const float *new_xyz = new_xyz_tensor.data<float>();
|
59 |
+
const float *xyz = xyz_tensor.data<float>();
|
60 |
+
float *dist2 = dist2_tensor.data<float>();
|
61 |
+
int *idx = idx_tensor.data<int>();
|
62 |
+
|
63 |
+
knn_query_kernel_launcher_fast(b, n, m, nsample, new_xyz, xyz, dist2, idx);
|
64 |
+
return 1;
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
int ball_query_wrapper_stack(int B, int M, float radius, int nsample,
|
69 |
+
at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
|
70 |
+
at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor) {
|
71 |
+
CHECK_INPUT(new_xyz_tensor);
|
72 |
+
CHECK_INPUT(xyz_tensor);
|
73 |
+
CHECK_INPUT(new_xyz_batch_cnt_tensor);
|
74 |
+
CHECK_INPUT(xyz_batch_cnt_tensor);
|
75 |
+
|
76 |
+
const float *new_xyz = new_xyz_tensor.data<float>();
|
77 |
+
const float *xyz = xyz_tensor.data<float>();
|
78 |
+
const int *new_xyz_batch_cnt = new_xyz_batch_cnt_tensor.data<int>();
|
79 |
+
const int *xyz_batch_cnt = xyz_batch_cnt_tensor.data<int>();
|
80 |
+
int *idx = idx_tensor.data<int>();
|
81 |
+
|
82 |
+
ball_query_kernel_launcher_stack(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
|
83 |
+
return 1;
|
84 |
+
}
|
pc_util/src/ball_query_gpu.cu
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <math.h>
|
2 |
+
#include <stdio.h>
|
3 |
+
#include <stdlib.h>
|
4 |
+
|
5 |
+
|
6 |
+
#include "ball_query_gpu.h"
|
7 |
+
#include "cuda_utils.h"
|
8 |
+
|
9 |
+
|
10 |
+
__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample,
|
11 |
+
const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
|
12 |
+
// new_xyz: (B, M, 3)
|
13 |
+
// xyz: (B, N, 3)
|
14 |
+
// output:
|
15 |
+
// idx: (B, M, nsample)
|
16 |
+
int bs_idx = blockIdx.y;
|
17 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
18 |
+
if (bs_idx >= b || pt_idx >= m) return;
|
19 |
+
|
20 |
+
new_xyz += bs_idx * m * 3 + pt_idx * 3;
|
21 |
+
xyz += bs_idx * n * 3;
|
22 |
+
idx += bs_idx * m * nsample + pt_idx * nsample;
|
23 |
+
|
24 |
+
float radius2 = radius * radius;
|
25 |
+
float new_x = new_xyz[0];
|
26 |
+
float new_y = new_xyz[1];
|
27 |
+
float new_z = new_xyz[2];
|
28 |
+
|
29 |
+
int cnt = 0;
|
30 |
+
for (int k = 0; k < n; ++k) {
|
31 |
+
float x = xyz[k * 3 + 0];
|
32 |
+
float y = xyz[k * 3 + 1];
|
33 |
+
float z = xyz[k * 3 + 2];
|
34 |
+
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
|
35 |
+
if (d2 < radius2){
|
36 |
+
if (cnt == 0){
|
37 |
+
for (int l = 0; l < nsample; ++l) {
|
38 |
+
idx[l] = k;
|
39 |
+
}
|
40 |
+
}
|
41 |
+
idx[cnt] = k;
|
42 |
+
++cnt;
|
43 |
+
if (cnt >= nsample) break;
|
44 |
+
}
|
45 |
+
}
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
|
50 |
+
const float *new_xyz, const float *xyz, int *idx) {
|
51 |
+
// new_xyz: (B, M, 3)
|
52 |
+
// xyz: (B, N, 3)
|
53 |
+
// output:
|
54 |
+
// idx: (B, M, nsample)
|
55 |
+
|
56 |
+
cudaError_t err;
|
57 |
+
|
58 |
+
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
|
59 |
+
dim3 threads(THREADS_PER_BLOCK);
|
60 |
+
|
61 |
+
ball_query_kernel_fast<<<blocks, threads>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
|
62 |
+
// cudaDeviceSynchronize(); // for using printf in kernel function
|
63 |
+
err = cudaGetLastError();
|
64 |
+
if (cudaSuccess != err) {
|
65 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
66 |
+
exit(-1);
|
67 |
+
}
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
__global__ void ball_center_query_kernel_fast(int b, int n, int m, float radius, \
|
72 |
+
const float *__restrict__ point, const float *__restrict__ key_point, int *__restrict__ idx) {
|
73 |
+
// key_point: (B, M, 3)
|
74 |
+
// point: (B, N, 3)
|
75 |
+
// output:
|
76 |
+
// idx: (B, N)
|
77 |
+
int bs_idx = blockIdx.y;
|
78 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
79 |
+
if (bs_idx >= b || pt_idx >= n) return;
|
80 |
+
|
81 |
+
point += bs_idx * n * 3 + pt_idx * 3;
|
82 |
+
key_point += bs_idx * m * 3;
|
83 |
+
idx += bs_idx * n + pt_idx;
|
84 |
+
|
85 |
+
float radius2 = radius * radius;
|
86 |
+
float point_x = point[0];
|
87 |
+
float point_y = point[1];
|
88 |
+
float point_z = point[2];
|
89 |
+
|
90 |
+
float bestd = 1e8;
|
91 |
+
for (int k = 0; k < m; ++k) {
|
92 |
+
float x = key_point[k * 3 + 0];
|
93 |
+
float y = key_point[k * 3 + 1];
|
94 |
+
float z = key_point[k * 3 + 2];
|
95 |
+
if (((x + 1) * (x + 1) + (y + 1) * (y + 1) + (z + 1) * (z + 1)) < 1e-4) break;
|
96 |
+
float d2 = (point_x - x) * (point_x - x) + (point_y - y) * (point_y - y) + (point_z - z) * (point_z - z);
|
97 |
+
if (d2 < radius2 && d2 < bestd){
|
98 |
+
idx[0] = k;
|
99 |
+
bestd = d2;
|
100 |
+
}
|
101 |
+
}
|
102 |
+
}
|
103 |
+
|
104 |
+
|
105 |
+
void ball_center_query_kernel_launcher_fast(int b, int n, int m, float radius, \
|
106 |
+
const float *point, const float *key_point, int *idx) {
|
107 |
+
// point: (B, n, 3)
|
108 |
+
// key_point: (B, m, 3)
|
109 |
+
// output:
|
110 |
+
// idx: (B, n)
|
111 |
+
|
112 |
+
cudaError_t err;
|
113 |
+
|
114 |
+
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
|
115 |
+
dim3 threads(THREADS_PER_BLOCK);
|
116 |
+
|
117 |
+
ball_center_query_kernel_fast<<<blocks, threads>>>(b, n, m, radius, point, key_point, idx);
|
118 |
+
// cudaDeviceSynchronize(); // for using printf in kernel function
|
119 |
+
err = cudaGetLastError();
|
120 |
+
if (cudaSuccess != err) {
|
121 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
122 |
+
exit(-1);
|
123 |
+
}
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
__global__ void knn_query_kernel_fast(int b, int n, int m, int nsample, const float *__restrict__ new_xyz,
|
131 |
+
const float *__restrict__ xyz, float *__restrict__ dist2, int *__restrict__ idx) {
|
132 |
+
|
133 |
+
// new_xyz: (B, M, 3)
|
134 |
+
// xyz: (B, N, 3)
|
135 |
+
// output:
|
136 |
+
// dist2: (B, M, nsample)
|
137 |
+
// idx: (B, M, nsample)
|
138 |
+
|
139 |
+
int bs_idx = blockIdx.y;
|
140 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
141 |
+
if (bs_idx >= b || pt_idx >= m) return;
|
142 |
+
|
143 |
+
new_xyz += bs_idx * m * 3 + pt_idx * 3;
|
144 |
+
xyz += bs_idx * n * 3;
|
145 |
+
dist2 += bs_idx * m * nsample + pt_idx * nsample;
|
146 |
+
idx += bs_idx * m * nsample + pt_idx * nsample;
|
147 |
+
|
148 |
+
float nx = new_xyz[0];
|
149 |
+
float ny = new_xyz[1];
|
150 |
+
float nz = new_xyz[2];
|
151 |
+
|
152 |
+
for (int i = 0; i < n; ++i) {
|
153 |
+
float x = xyz[i * 3 + 0];
|
154 |
+
float y = xyz[i * 3 + 1];
|
155 |
+
float z = xyz[i * 3 + 2];
|
156 |
+
float d2 = (nx - x) * (nx - x) + (ny - y) * (ny - y) + (nz - z) * (nz - z);
|
157 |
+
if (d2 < dist2[nsample - 1]) {
|
158 |
+
dist2[nsample - 1] = d2;
|
159 |
+
idx[nsample - 1] = i;
|
160 |
+
for (int j = nsample - 2; j >= 0; j--) {
|
161 |
+
if (d2 < dist2[j]){
|
162 |
+
dist2[j + 1] = dist2[j];
|
163 |
+
dist2[j] = d2;
|
164 |
+
idx[j + 1] = idx[j];
|
165 |
+
idx[j] = i;
|
166 |
+
}
|
167 |
+
}
|
168 |
+
}
|
169 |
+
}
|
170 |
+
}
|
171 |
+
|
172 |
+
|
173 |
+
void knn_query_kernel_launcher_fast(int b, int n, int m, int nsample, \
|
174 |
+
const float *new_xyz, const float *xyz, float *dist2, int *idx) {
|
175 |
+
cudaError_t err;
|
176 |
+
|
177 |
+
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
|
178 |
+
dim3 threads(THREADS_PER_BLOCK);
|
179 |
+
|
180 |
+
knn_query_kernel_fast<<<blocks, threads>>>(b, n, m, nsample, new_xyz, xyz, dist2, idx);
|
181 |
+
// cudaDeviceSynchronize(); // for using printf in kernel function
|
182 |
+
err = cudaGetLastError();
|
183 |
+
if (cudaSuccess != err) {
|
184 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
185 |
+
exit(-1);
|
186 |
+
}
|
187 |
+
}
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
__global__ void ball_query_kernel_stack(int B, int M, float radius, int nsample, \
|
197 |
+
const float *new_xyz, const int *new_xyz_batch_cnt, const float *xyz, const int *xyz_batch_cnt, int *idx) {
|
198 |
+
// :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features
|
199 |
+
// :param xyz_batch_cnt: (batch_size), [N1, N2, ...]
|
200 |
+
// :param new_xyz: (M1 + M2 ..., 3) centers of the ball query
|
201 |
+
// :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
|
202 |
+
// output:
|
203 |
+
// idx: (M, nsample)
|
204 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
205 |
+
if (pt_idx >= M) return;
|
206 |
+
|
207 |
+
int bs_idx = 0, pt_cnt = new_xyz_batch_cnt[0];
|
208 |
+
for (int k = 1; k < B; k++){
|
209 |
+
if (pt_idx < pt_cnt) break;
|
210 |
+
pt_cnt += new_xyz_batch_cnt[k];
|
211 |
+
bs_idx = k;
|
212 |
+
}
|
213 |
+
|
214 |
+
int xyz_batch_start_idx = 0;
|
215 |
+
for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k];
|
216 |
+
// for (int k = 0; k < bs_idx; k++) new_xyz_batch_start_idx += new_xyz_batch_cnt[k];
|
217 |
+
|
218 |
+
new_xyz += pt_idx * 3;
|
219 |
+
xyz += xyz_batch_start_idx * 3;
|
220 |
+
idx += pt_idx * nsample;
|
221 |
+
|
222 |
+
float radius2 = radius * radius;
|
223 |
+
float new_x = new_xyz[0];
|
224 |
+
float new_y = new_xyz[1];
|
225 |
+
float new_z = new_xyz[2];
|
226 |
+
int n = xyz_batch_cnt[bs_idx];
|
227 |
+
|
228 |
+
int cnt = 0;
|
229 |
+
for (int k = 0; k < n; ++k) {
|
230 |
+
float x = xyz[k * 3 + 0];
|
231 |
+
float y = xyz[k * 3 + 1];
|
232 |
+
float z = xyz[k * 3 + 2];
|
233 |
+
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
|
234 |
+
if (d2 < radius2){
|
235 |
+
if (cnt == 0){
|
236 |
+
for (int l = 0; l < nsample; ++l) {
|
237 |
+
idx[l] = k;
|
238 |
+
}
|
239 |
+
}
|
240 |
+
idx[cnt] = k;
|
241 |
+
++cnt;
|
242 |
+
if (cnt >= nsample) break;
|
243 |
+
}
|
244 |
+
}
|
245 |
+
if (cnt == 0) idx[0] = -1;
|
246 |
+
}
|
247 |
+
|
248 |
+
|
249 |
+
void ball_query_kernel_launcher_stack(int B, int M, float radius, int nsample,
|
250 |
+
const float *new_xyz, const int *new_xyz_batch_cnt, const float *xyz, const int *xyz_batch_cnt, int *idx){
|
251 |
+
// :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features
|
252 |
+
// :param xyz_batch_cnt: (batch_size), [N1, N2, ...]
|
253 |
+
// :param new_xyz: (M1 + M2 ..., 3) centers of the ball query
|
254 |
+
// :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
|
255 |
+
// output:
|
256 |
+
// idx: (M, nsample)
|
257 |
+
|
258 |
+
cudaError_t err;
|
259 |
+
|
260 |
+
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
|
261 |
+
dim3 threads(THREADS_PER_BLOCK);
|
262 |
+
|
263 |
+
ball_query_kernel_stack<<<blocks, threads>>>(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
|
264 |
+
// cudaDeviceSynchronize(); // for using printf in kernel function
|
265 |
+
err = cudaGetLastError();
|
266 |
+
if (cudaSuccess != err) {
|
267 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
268 |
+
exit(-1);
|
269 |
+
}
|
270 |
+
}
|
pc_util/src/ball_query_gpu.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _BALL_QUERY_GPU_H
|
2 |
+
#define _BALL_QUERY_GPU_H
|
3 |
+
|
4 |
+
#include <torch/serialize/tensor.h>
|
5 |
+
#include <vector>
|
6 |
+
#include <cuda.h>
|
7 |
+
#include <cuda_runtime_api.h>
|
8 |
+
|
9 |
+
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
|
10 |
+
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
|
11 |
+
|
12 |
+
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample,
|
13 |
+
const float *new_xyz, const float *xyz, int *idx);
|
14 |
+
|
15 |
+
int ball_center_query_wrapper_fast(int b, int n, int m, float radius,
|
16 |
+
at::Tensor point_tensor, at::Tensor key_point_tensor, at::Tensor idx_tensor);
|
17 |
+
|
18 |
+
void ball_center_query_kernel_launcher_fast(int b, int n, int m, float radius,
|
19 |
+
const float *point, const float *key_point, int *idx);
|
20 |
+
|
21 |
+
int knn_query_wrapper_fast(int b, int n, int m, int nsample,
|
22 |
+
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
|
23 |
+
|
24 |
+
void knn_query_kernel_launcher_fast(int b, int n, int m, int nsample,
|
25 |
+
const float *new_xyz, const float *xyz, float *dist2, int *idx);
|
26 |
+
|
27 |
+
|
28 |
+
int ball_query_wrapper_stack(int B, int M, float radius, int nsample,
|
29 |
+
at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
|
30 |
+
at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor);
|
31 |
+
|
32 |
+
|
33 |
+
void ball_query_kernel_launcher_stack(int B, int M, float radius, int nsample,
|
34 |
+
const float *new_xyz, const int *new_xyz_batch_cnt, const float *xyz, const int *xyz_batch_cnt, int *idx);
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
#endif
|
pc_util/src/cluster.cpp
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <vector>
|
3 |
+
// #include <THC/THC.h>
|
4 |
+
#include <cuda.h>
|
5 |
+
#include <cuda_runtime_api.h>
|
6 |
+
#include "cluster_gpu.h"
|
7 |
+
|
8 |
+
// extern THCState *state;
|
9 |
+
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <ATen/cuda/CUDAEvent.h>
|
12 |
+
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
13 |
+
|
14 |
+
#define CHECK_CUDA(x) do { \
|
15 |
+
if (!x.type().is_cuda()) { \
|
16 |
+
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
|
17 |
+
exit(-1); \
|
18 |
+
} \
|
19 |
+
} while (0)
|
20 |
+
#define CHECK_CONTIGUOUS(x) do { \
|
21 |
+
if (!x.is_contiguous()) { \
|
22 |
+
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
|
23 |
+
exit(-1); \
|
24 |
+
} \
|
25 |
+
} while (0)
|
26 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
|
27 |
+
|
28 |
+
int dbscan_wrapper_fast(int b, int n, float eps, int min_pts, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
|
29 |
+
CHECK_INPUT(xyz_tensor);
|
30 |
+
const float *xyz = xyz_tensor.data<float>();
|
31 |
+
int *idx = idx_tensor.data<int>();
|
32 |
+
|
33 |
+
dbscan_kernel_launcher_fast(b, n, eps, min_pts, xyz, idx);
|
34 |
+
return 1;
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
int cluster_pts_wrapper_fast(int b, int n, int m, at::Tensor xyz_tensor, at::Tensor idx_tensor,
|
39 |
+
at::Tensor new_xyz_tensor, at::Tensor num_tensor) {
|
40 |
+
CHECK_INPUT(xyz_tensor);
|
41 |
+
CHECK_INPUT(idx_tensor);
|
42 |
+
const float *xyz = xyz_tensor.data<float>();
|
43 |
+
const int *idx = idx_tensor.data<int>();
|
44 |
+
float *new_xyz = new_xyz_tensor.data<float>();
|
45 |
+
int *num = num_tensor.data<int>();
|
46 |
+
|
47 |
+
cluster_pts_kernel_launcher_fast(b, n, m, xyz, idx, new_xyz, num);
|
48 |
+
return 1;
|
49 |
+
}
|
50 |
+
|
pc_util/src/cluster_gpu.cu
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <math.h>
|
2 |
+
#include <stdio.h>
|
3 |
+
#include <stdlib.h>
|
4 |
+
|
5 |
+
|
6 |
+
#include "cluster_gpu.h"
|
7 |
+
#include "cuda_utils.h"
|
8 |
+
|
9 |
+
|
10 |
+
__device__ float get_dis(float x1, float y1, float z1, float x2, float y2, float z2) {
|
11 |
+
float dis = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2) + (z1 - z2) * (z1 - z2);
|
12 |
+
return sqrt(dis);
|
13 |
+
}
|
14 |
+
/*
|
15 |
+
__device__ void dfs (int i, int c, int n, int min_pts, const int* pts_cnt, const int* pts_adj, int* idx, int label) {
|
16 |
+
idx[i] = c;
|
17 |
+
if(pts_cnt[i] < min_pts) return;
|
18 |
+
|
19 |
+
for(int j=0;j<n;j++) {
|
20 |
+
|
21 |
+
int adj = pts_adj[i * n + j];
|
22 |
+
printf("%d %d %d\n", i * n, i * n + j, adj);
|
23 |
+
if (adj == -1) break;
|
24 |
+
if (idx[adj] == -1)
|
25 |
+
dfs(adj, c, n, min_pts, pts_cnt, pts_adj, idx, label);
|
26 |
+
}
|
27 |
+
}
|
28 |
+
*/
|
29 |
+
|
30 |
+
__global__ void dbscan_kernel_fast(int b, int n, float eps, int min_pts, const float *__restrict__ xyz, int *__restrict__ idx,
|
31 |
+
int *__restrict__ pts_cnt, int *__restrict__ pts_adj, int *__restrict__ pts_stack) {
|
32 |
+
// xyz: (B, N, 3)
|
33 |
+
// output:
|
34 |
+
// idx: (B, N)
|
35 |
+
int bs_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
36 |
+
if (bs_idx >= b) return;
|
37 |
+
|
38 |
+
xyz += bs_idx * n * 3;
|
39 |
+
idx += bs_idx * n;
|
40 |
+
pts_cnt += bs_idx * n;
|
41 |
+
pts_stack += bs_idx * n;
|
42 |
+
pts_adj += bs_idx * n * n;
|
43 |
+
|
44 |
+
for(int i=0;i<n;i++) {
|
45 |
+
pts_cnt[i] = 0;
|
46 |
+
for(int j=0;j<n;j++) {
|
47 |
+
pts_adj[i * n + j] = -1;
|
48 |
+
if(i==j) continue;
|
49 |
+
float x1 = xyz[i * 3 + 0];
|
50 |
+
float y1 = xyz[i * 3 + 1];
|
51 |
+
float z1 = xyz[i * 3 + 2];
|
52 |
+
float x2 = xyz[j * 3 + 0];
|
53 |
+
float y2 = xyz[j * 3 + 1];
|
54 |
+
float z2 = xyz[j * 3 + 2];
|
55 |
+
|
56 |
+
if(get_dis(x2, y2, z2, -10.0, -10.0, -10.0) < 1e-3) continue;
|
57 |
+
if(get_dis(x1, y1, z1, x2, y2, z2) <= eps) {
|
58 |
+
pts_adj[i * n + pts_cnt[i]] = j;
|
59 |
+
pts_cnt[i] += 1;
|
60 |
+
}
|
61 |
+
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
int cluster_idx = 0;
|
66 |
+
|
67 |
+
for(int i=0;i<n;i++) {
|
68 |
+
if(idx[i] != -1) continue;
|
69 |
+
|
70 |
+
if(pts_cnt[i] >= min_pts) {
|
71 |
+
for(int j=0;j<n;j++)
|
72 |
+
pts_stack[j] = -1;
|
73 |
+
pts_stack[0] = i;
|
74 |
+
int stack_idx = 0;
|
75 |
+
int stack_len = 1;
|
76 |
+
while (stack_idx < n && pts_stack[stack_idx] != -1)
|
77 |
+
{
|
78 |
+
int pts_idx = pts_stack[stack_idx];
|
79 |
+
idx[pts_idx] = cluster_idx;
|
80 |
+
if(pts_cnt[pts_idx] < min_pts){
|
81 |
+
stack_idx += 1;
|
82 |
+
continue;
|
83 |
+
}
|
84 |
+
for(int j=0;j<n;j++) {
|
85 |
+
int adj = pts_adj[pts_idx * n + j];
|
86 |
+
if (adj == -1) break;
|
87 |
+
if (idx[adj] == -1)
|
88 |
+
{
|
89 |
+
idx[adj] = -2;
|
90 |
+
pts_stack[stack_len++] = adj;
|
91 |
+
}
|
92 |
+
}
|
93 |
+
stack_idx += 1;
|
94 |
+
}
|
95 |
+
cluster_idx += 1;
|
96 |
+
}
|
97 |
+
}
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
void dbscan_kernel_launcher_fast(int b, int n, float eps, int min_pts, const float *xyz, int *idx) {
|
102 |
+
// xyz: (B, N, 3)
|
103 |
+
// output:
|
104 |
+
// idx: (B, N)
|
105 |
+
|
106 |
+
cudaError_t err;
|
107 |
+
|
108 |
+
dim3 blocks(DIVUP(b, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
|
109 |
+
dim3 threads(THREADS_PER_BLOCK);
|
110 |
+
|
111 |
+
int* pts_cnt;
|
112 |
+
int* pts_stack;
|
113 |
+
int* pts_adj;
|
114 |
+
|
115 |
+
err = cudaMalloc((void**)&pts_cnt, b * n * sizeof(int));
|
116 |
+
if (cudaSuccess != err) {
|
117 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
118 |
+
exit(-1);
|
119 |
+
}
|
120 |
+
|
121 |
+
err = cudaMalloc((void**)&pts_stack, b * n * sizeof(int));
|
122 |
+
if (cudaSuccess != err) {
|
123 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
124 |
+
exit(-1);
|
125 |
+
}
|
126 |
+
|
127 |
+
err = cudaMalloc((void**)&pts_adj, b * n * n * sizeof(int));
|
128 |
+
if (cudaSuccess != err) {
|
129 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
130 |
+
exit(-1);
|
131 |
+
}
|
132 |
+
|
133 |
+
dbscan_kernel_fast<<<blocks, threads>>>(b, n, eps, min_pts, xyz, idx, pts_cnt, pts_adj, pts_stack);
|
134 |
+
// cudaDeviceSynchronize(); // for using printf in kernel function
|
135 |
+
cudaFree(pts_cnt);
|
136 |
+
cudaFree(pts_stack);
|
137 |
+
cudaFree(pts_adj);
|
138 |
+
err = cudaGetLastError();
|
139 |
+
if (cudaSuccess != err) {
|
140 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
141 |
+
exit(-1);
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
__global__ void cluster_pts_kernel_fast(int b, int n, int m, const float *__restrict__ xyz, const int *__restrict__ idx,
|
148 |
+
float *__restrict__ new_xyz, int *__restrict__ num) {
|
149 |
+
int bs_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
150 |
+
if (bs_idx >= b ) return;
|
151 |
+
|
152 |
+
xyz += bs_idx * n * 3;
|
153 |
+
idx += bs_idx * n;
|
154 |
+
new_xyz += bs_idx * m * 3;
|
155 |
+
num += bs_idx * m;
|
156 |
+
|
157 |
+
for(int i=0;i<n;i++) {
|
158 |
+
if (idx[i] == -1) continue;
|
159 |
+
int c_idx = idx[i];
|
160 |
+
new_xyz[c_idx * 3 + 0] += xyz[i * 3 + 0];
|
161 |
+
new_xyz[c_idx * 3 + 1] += xyz[i * 3 + 1];
|
162 |
+
new_xyz[c_idx * 3 + 2] += xyz[i * 3 + 2];
|
163 |
+
num[c_idx] += 1;
|
164 |
+
}
|
165 |
+
for(int i=0;i<m;i++) {
|
166 |
+
if (num[i] == 0) break;
|
167 |
+
new_xyz[i * 3 + 0] /= num[i];
|
168 |
+
new_xyz[i * 3 + 1] /= num[i];
|
169 |
+
new_xyz[i * 3 + 2] /= num[i];
|
170 |
+
}
|
171 |
+
|
172 |
+
}
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
void cluster_pts_kernel_launcher_fast(int b, int n, int m, const float *xyz, const int *idx, float *new_xyz, int *num) {
|
178 |
+
cudaError_t err;
|
179 |
+
|
180 |
+
dim3 blocks(DIVUP(b, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
|
181 |
+
dim3 threads(THREADS_PER_BLOCK);
|
182 |
+
|
183 |
+
cluster_pts_kernel_fast<<<blocks, threads>>>(b, n, m, xyz, idx, new_xyz, num);
|
184 |
+
// cudaDeviceSynchronize(); // for using printf in kernel function
|
185 |
+
err = cudaGetLastError();
|
186 |
+
if (cudaSuccess != err) {
|
187 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
188 |
+
exit(-1);
|
189 |
+
}
|
190 |
+
}
|
191 |
+
|
192 |
+
|
pc_util/src/cluster_gpu.h
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _CLUSTER_GPU_H
|
2 |
+
#define _CLUSTER_GPU_H
|
3 |
+
|
4 |
+
#include <torch/serialize/tensor.h>
|
5 |
+
#include <vector>
|
6 |
+
#include <cuda.h>
|
7 |
+
#include <cuda_runtime_api.h>
|
8 |
+
|
9 |
+
int dbscan_wrapper_fast(int b, int n, float eps, int min_pts, at::Tensor xyz_tensor, at::Tensor idx_tensor);
|
10 |
+
|
11 |
+
void dbscan_kernel_launcher_fast(int b, int n, float eps, int min_pts, const float *xyz, int *idx);
|
12 |
+
|
13 |
+
int cluster_pts_wrapper_fast(int b, int n, int m, at::Tensor xyz_tensor, at::Tensor idx_tensor,
|
14 |
+
at::Tensor new_xyz_tensor, at::Tensor num_tensor);
|
15 |
+
|
16 |
+
void cluster_pts_kernel_launcher_fast(int b, int n, int m, const float *xyz, const int *idx, float *new_xyz, int *num);
|
17 |
+
|
18 |
+
|
19 |
+
int dbscan_wrapper_stack(int b, int n, float eps, int min_pts, at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor,
|
20 |
+
at::Tensor idx_tensor);
|
21 |
+
|
22 |
+
|
23 |
+
void dbscan_kernel_launcher_stack(int b, int n, float eps, int min_pts,
|
24 |
+
const float *xyz, const int *xyz_batch_cnt, int *idx);
|
25 |
+
|
26 |
+
int cluster_pts_wrapper_stack(int B, at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor,
|
27 |
+
at::Tensor new_xyz_tensor, at::Tensor cluster_cnt_tensor);
|
28 |
+
|
29 |
+
|
30 |
+
void cluster_pts_kernel_launcher_stack(int B, const float *xyz, const int *xyz_batch_cnt, int *idx,
|
31 |
+
const float *new_xyz, const int *cluster_cnt);
|
32 |
+
|
33 |
+
#endif
|
34 |
+
|
pc_util/src/cuda_utils.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _CUDA_UTILS_H
|
2 |
+
#define _CUDA_UTILS_H
|
3 |
+
|
4 |
+
#include <cmath>
|
5 |
+
|
6 |
+
#define TOTAL_THREADS 1024
|
7 |
+
#define THREADS_PER_BLOCK 256
|
8 |
+
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
|
9 |
+
|
10 |
+
inline int opt_n_threads(int work_size) {
|
11 |
+
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
|
12 |
+
|
13 |
+
return max(min(1 << pow_2, TOTAL_THREADS), 1);
|
14 |
+
}
|
15 |
+
#endif
|
pc_util/src/group_points.cpp
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <cuda.h>
|
3 |
+
#include <cuda_runtime_api.h>
|
4 |
+
#include <vector>
|
5 |
+
// #include <THC/THC.h>
|
6 |
+
#include "group_points_gpu.h"
|
7 |
+
|
8 |
+
// extern THCState *state;
|
9 |
+
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <ATen/cuda/CUDAEvent.h>
|
12 |
+
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
13 |
+
|
14 |
+
#define CHECK_CUDA(x) do { \
|
15 |
+
if (!x.type().is_cuda()) { \
|
16 |
+
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
|
17 |
+
exit(-1); \
|
18 |
+
} \
|
19 |
+
} while (0)
|
20 |
+
#define CHECK_CONTIGUOUS(x) do { \
|
21 |
+
if (!x.is_contiguous()) { \
|
22 |
+
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
|
23 |
+
exit(-1); \
|
24 |
+
} \
|
25 |
+
} while (0)
|
26 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
|
31 |
+
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
|
32 |
+
|
33 |
+
float *grad_points = grad_points_tensor.data<float>();
|
34 |
+
const int *idx = idx_tensor.data<int>();
|
35 |
+
const float *grad_out = grad_out_tensor.data<float>();
|
36 |
+
|
37 |
+
group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points);
|
38 |
+
return 1;
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
|
43 |
+
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) {
|
44 |
+
|
45 |
+
const float *points = points_tensor.data<float>();
|
46 |
+
const int *idx = idx_tensor.data<int>();
|
47 |
+
float *out = out_tensor.data<float>();
|
48 |
+
|
49 |
+
group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out);
|
50 |
+
return 1;
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
int group_points_grad_wrapper_stack(int B, int M, int C, int N, int nsample,
|
60 |
+
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor,
|
61 |
+
at::Tensor features_batch_cnt_tensor, at::Tensor grad_features_tensor) {
|
62 |
+
|
63 |
+
CHECK_INPUT(grad_out_tensor);
|
64 |
+
CHECK_INPUT(idx_tensor);
|
65 |
+
CHECK_INPUT(idx_batch_cnt_tensor);
|
66 |
+
CHECK_INPUT(features_batch_cnt_tensor);
|
67 |
+
CHECK_INPUT(grad_features_tensor);
|
68 |
+
|
69 |
+
const float *grad_out = grad_out_tensor.data<float>();
|
70 |
+
const int *idx = idx_tensor.data<int>();
|
71 |
+
const int *idx_batch_cnt = idx_batch_cnt_tensor.data<int>();
|
72 |
+
const int *features_batch_cnt = features_batch_cnt_tensor.data<int>();
|
73 |
+
float *grad_features = grad_features_tensor.data<float>();
|
74 |
+
|
75 |
+
group_points_grad_kernel_launcher_stack(B, M, C, N, nsample, grad_out, idx, idx_batch_cnt, features_batch_cnt, grad_features);
|
76 |
+
return 1;
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
int group_points_wrapper_stack(int B, int M, int C, int nsample,
|
81 |
+
at::Tensor features_tensor, at::Tensor features_batch_cnt_tensor,
|
82 |
+
at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor, at::Tensor out_tensor) {
|
83 |
+
|
84 |
+
CHECK_INPUT(features_tensor);
|
85 |
+
CHECK_INPUT(features_batch_cnt_tensor);
|
86 |
+
CHECK_INPUT(idx_tensor);
|
87 |
+
CHECK_INPUT(idx_batch_cnt_tensor);
|
88 |
+
CHECK_INPUT(out_tensor);
|
89 |
+
|
90 |
+
const float *features = features_tensor.data<float>();
|
91 |
+
const int *idx = idx_tensor.data<int>();
|
92 |
+
const int *features_batch_cnt = features_batch_cnt_tensor.data<int>();
|
93 |
+
const int *idx_batch_cnt = idx_batch_cnt_tensor.data<int>();
|
94 |
+
float *out = out_tensor.data<float>();
|
95 |
+
|
96 |
+
group_points_kernel_launcher_stack(B, M, C, nsample, features, features_batch_cnt, idx, idx_batch_cnt, out);
|
97 |
+
return 1;
|
98 |
+
}
|
pc_util/src/group_points_gpu.cu
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <stdlib.h>
|
3 |
+
|
4 |
+
#include "cuda_utils.h"
|
5 |
+
#include "group_points_gpu.h"
|
6 |
+
|
7 |
+
|
8 |
+
__global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample,
|
9 |
+
const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) {
|
10 |
+
// grad_out: (B, C, npoints, nsample)
|
11 |
+
// idx: (B, npoints, nsample)
|
12 |
+
// output:
|
13 |
+
// grad_points: (B, C, N)
|
14 |
+
int bs_idx = blockIdx.z;
|
15 |
+
int c_idx = blockIdx.y;
|
16 |
+
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
17 |
+
int pt_idx = index / nsample;
|
18 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
|
19 |
+
|
20 |
+
int sample_idx = index % nsample;
|
21 |
+
grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
22 |
+
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
23 |
+
|
24 |
+
atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]);
|
25 |
+
}
|
26 |
+
|
27 |
+
void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
|
28 |
+
const float *grad_out, const int *idx, float *grad_points) {
|
29 |
+
// grad_out: (B, C, npoints, nsample)
|
30 |
+
// idx: (B, npoints, nsample)
|
31 |
+
// output:
|
32 |
+
// grad_points: (B, C, N)
|
33 |
+
cudaError_t err;
|
34 |
+
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
35 |
+
dim3 threads(THREADS_PER_BLOCK);
|
36 |
+
|
37 |
+
group_points_grad_kernel_fast<<<blocks, threads>>>(b, c, n, npoints, nsample, grad_out, idx, grad_points);
|
38 |
+
|
39 |
+
err = cudaGetLastError();
|
40 |
+
if (cudaSuccess != err) {
|
41 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
42 |
+
exit(-1);
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
__global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample,
|
48 |
+
const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
|
49 |
+
// points: (B, C, N)
|
50 |
+
// idx: (B, npoints, nsample)
|
51 |
+
// output:
|
52 |
+
// out: (B, C, npoints, nsample)
|
53 |
+
int bs_idx = blockIdx.z;
|
54 |
+
int c_idx = blockIdx.y;
|
55 |
+
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
56 |
+
int pt_idx = index / nsample;
|
57 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
|
58 |
+
|
59 |
+
int sample_idx = index % nsample;
|
60 |
+
|
61 |
+
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
62 |
+
int in_idx = bs_idx * c * n + c_idx * n + idx[0];
|
63 |
+
int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
|
64 |
+
|
65 |
+
out[out_idx] = points[in_idx];
|
66 |
+
}
|
67 |
+
|
68 |
+
|
69 |
+
void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
|
70 |
+
const float *points, const int *idx, float *out) {
|
71 |
+
// points: (B, C, N)
|
72 |
+
// idx: (B, npoints, nsample)
|
73 |
+
// output:
|
74 |
+
// out: (B, C, npoints, nsample)
|
75 |
+
cudaError_t err;
|
76 |
+
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
77 |
+
dim3 threads(THREADS_PER_BLOCK);
|
78 |
+
|
79 |
+
group_points_kernel_fast<<<blocks, threads>>>(b, c, n, npoints, nsample, points, idx, out);
|
80 |
+
// cudaDeviceSynchronize(); // for using printf in kernel function
|
81 |
+
err = cudaGetLastError();
|
82 |
+
if (cudaSuccess != err) {
|
83 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
84 |
+
exit(-1);
|
85 |
+
}
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
__global__ void group_points_grad_kernel_stack(int B, int M, int C, int N, int nsample,
|
90 |
+
const float *grad_out, const int *idx, const int *idx_batch_cnt, const int *features_batch_cnt, float *grad_features) {
|
91 |
+
// :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the output from forward
|
92 |
+
// :param idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with
|
93 |
+
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with
|
94 |
+
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with
|
95 |
+
// :return:
|
96 |
+
// grad_features: (N1 + N2 ..., C) gradient of the features
|
97 |
+
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
98 |
+
int sample_idx = index % nsample;
|
99 |
+
int C_idx = (index / nsample) % C;
|
100 |
+
int pt_idx = (index / nsample / C);
|
101 |
+
|
102 |
+
if (pt_idx >= M || C_idx >= C || sample_idx >= nsample) return;
|
103 |
+
|
104 |
+
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
|
105 |
+
for (int k = 1; k < B; k++){
|
106 |
+
if (pt_idx < pt_cnt) break;
|
107 |
+
pt_cnt += idx_batch_cnt[k];
|
108 |
+
bs_idx = k;
|
109 |
+
}
|
110 |
+
|
111 |
+
int features_batch_start_idx = 0;
|
112 |
+
for (int k = 0; k < bs_idx; k++) features_batch_start_idx += features_batch_cnt[k];
|
113 |
+
|
114 |
+
grad_out += pt_idx * C * nsample + C_idx * nsample + sample_idx;
|
115 |
+
idx += pt_idx * nsample + sample_idx;
|
116 |
+
grad_features += (features_batch_start_idx + idx[0]) * C + C_idx;
|
117 |
+
|
118 |
+
atomicAdd(grad_features, grad_out[0]);
|
119 |
+
}
|
120 |
+
|
121 |
+
void group_points_grad_kernel_launcher_stack(int B, int M, int C, int N, int nsample,
|
122 |
+
const float *grad_out, const int *idx, const int *idx_batch_cnt, const int *features_batch_cnt, float *grad_features) {
|
123 |
+
// :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the output from forward
|
124 |
+
// :param idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with
|
125 |
+
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with
|
126 |
+
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with
|
127 |
+
// :return:
|
128 |
+
// grad_features: (N1 + N2 ..., C) gradient of the features
|
129 |
+
|
130 |
+
cudaError_t err;
|
131 |
+
// dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
132 |
+
dim3 blocks(DIVUP(M * C * nsample, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
|
133 |
+
dim3 threads(THREADS_PER_BLOCK);
|
134 |
+
|
135 |
+
group_points_grad_kernel_stack<<<blocks, threads>>>(B, M, C, N, nsample, grad_out, idx, idx_batch_cnt, features_batch_cnt, grad_features);
|
136 |
+
|
137 |
+
err = cudaGetLastError();
|
138 |
+
if (cudaSuccess != err) {
|
139 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
140 |
+
exit(-1);
|
141 |
+
}
|
142 |
+
}
|
143 |
+
|
144 |
+
|
145 |
+
__global__ void group_points_kernel_stack(int B, int M, int C, int nsample,
|
146 |
+
const float *features, const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt, float *out) {
|
147 |
+
// :param features: (N1 + N2 ..., C) tensor of features to group
|
148 |
+
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with
|
149 |
+
// :param idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with
|
150 |
+
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with
|
151 |
+
// :return:
|
152 |
+
// output: (M1 + M2, C, nsample) tensor
|
153 |
+
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
154 |
+
int sample_idx = index % nsample;
|
155 |
+
int C_idx = (index / nsample) % C;
|
156 |
+
int pt_idx = (index / nsample / C);
|
157 |
+
|
158 |
+
if (pt_idx >= M || C_idx >= C || sample_idx >= nsample) return;
|
159 |
+
|
160 |
+
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
|
161 |
+
for (int k = 1; k < B; k++){
|
162 |
+
if (pt_idx < pt_cnt) break;
|
163 |
+
pt_cnt += idx_batch_cnt[k];
|
164 |
+
bs_idx = k;
|
165 |
+
}
|
166 |
+
|
167 |
+
int features_batch_start_idx = 0;
|
168 |
+
for (int k = 0; k < bs_idx; k++) features_batch_start_idx += features_batch_cnt[k];
|
169 |
+
features += features_batch_start_idx * C;
|
170 |
+
|
171 |
+
idx += pt_idx * nsample + sample_idx;
|
172 |
+
int in_idx = idx[0] * C + C_idx;
|
173 |
+
int out_idx = pt_idx * C * nsample + C_idx * nsample + sample_idx;
|
174 |
+
|
175 |
+
out[out_idx] = features[in_idx];
|
176 |
+
}
|
177 |
+
|
178 |
+
|
179 |
+
void group_points_kernel_launcher_stack(int B, int M, int C, int nsample,
|
180 |
+
const float *features, const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt, float *out) {
|
181 |
+
// :param features: (N1 + N2 ..., C) tensor of features to group
|
182 |
+
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with
|
183 |
+
// :param idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with
|
184 |
+
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with
|
185 |
+
// :return:
|
186 |
+
// output: (M1 + M2, C, nsample) tensor
|
187 |
+
|
188 |
+
cudaError_t err;
|
189 |
+
dim3 blocks(DIVUP(M * C * nsample, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
|
190 |
+
dim3 threads(THREADS_PER_BLOCK);
|
191 |
+
|
192 |
+
group_points_kernel_stack<<<blocks, threads>>>(B, M, C, nsample, features, features_batch_cnt, idx, idx_batch_cnt, out);
|
193 |
+
// cudaDeviceSynchronize(); // for using printf in kernel function
|
194 |
+
err = cudaGetLastError();
|
195 |
+
if (cudaSuccess != err) {
|
196 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
197 |
+
exit(-1);
|
198 |
+
}
|
199 |
+
}
|
pc_util/src/group_points_gpu.h
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _GROUP_POINTS_GPU_H
|
2 |
+
#define _GROUP_POINTS_GPU_H
|
3 |
+
|
4 |
+
#include <torch/serialize/tensor.h>
|
5 |
+
#include <cuda.h>
|
6 |
+
#include <cuda_runtime_api.h>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
|
10 |
+
int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
|
11 |
+
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
|
12 |
+
|
13 |
+
void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
|
14 |
+
const float *points, const int *idx, float *out);
|
15 |
+
|
16 |
+
int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
|
17 |
+
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
|
18 |
+
|
19 |
+
void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
|
20 |
+
const float *grad_out, const int *idx, float *grad_points);
|
21 |
+
|
22 |
+
int group_points_wrapper_stack(int B, int M, int C, int nsample,
|
23 |
+
at::Tensor features_tensor, at::Tensor features_batch_cnt_tensor,
|
24 |
+
at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor, at::Tensor out_tensor);
|
25 |
+
|
26 |
+
void group_points_kernel_launcher_stack(int B, int M, int C, int nsample,
|
27 |
+
const float *features, const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt, float *out);
|
28 |
+
|
29 |
+
int group_points_grad_wrapper_stack(int B, int M, int C, int N, int nsample,
|
30 |
+
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor,
|
31 |
+
at::Tensor features_batch_cnt_tensor, at::Tensor grad_features_tensor);
|
32 |
+
|
33 |
+
void group_points_grad_kernel_launcher_stack(int B, int M, int C, int N, int nsample,
|
34 |
+
const float *grad_out, const int *idx, const int *idx_batch_cnt, const int *features_batch_cnt, float *grad_features);
|
35 |
+
|
36 |
+
#endif
|
pc_util/src/interpolate.cpp
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <vector>
|
3 |
+
// #include <THC/THC.h>
|
4 |
+
#include <math.h>
|
5 |
+
#include <stdio.h>
|
6 |
+
#include <stdlib.h>
|
7 |
+
#include <cuda.h>
|
8 |
+
#include <cuda_runtime_api.h>
|
9 |
+
#include "interpolate_gpu.h"
|
10 |
+
|
11 |
+
// extern THCState *state;
|
12 |
+
|
13 |
+
#include <ATen/cuda/CUDAContext.h>
|
14 |
+
#include <ATen/cuda/CUDAEvent.h>
|
15 |
+
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
16 |
+
|
17 |
+
#define CHECK_CUDA(x) do { \
|
18 |
+
if (!x.type().is_cuda()) { \
|
19 |
+
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
|
20 |
+
exit(-1); \
|
21 |
+
} \
|
22 |
+
} while (0)
|
23 |
+
#define CHECK_CONTIGUOUS(x) do { \
|
24 |
+
if (!x.is_contiguous()) { \
|
25 |
+
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
|
26 |
+
exit(-1); \
|
27 |
+
} \
|
28 |
+
} while (0)
|
29 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
|
30 |
+
|
31 |
+
|
32 |
+
void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
|
33 |
+
at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
|
34 |
+
const float *unknown = unknown_tensor.data<float>();
|
35 |
+
const float *known = known_tensor.data<float>();
|
36 |
+
float *dist2 = dist2_tensor.data<float>();
|
37 |
+
int *idx = idx_tensor.data<int>();
|
38 |
+
|
39 |
+
three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx);
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
void three_interpolate_wrapper_fast(int b, int c, int m, int n,
|
44 |
+
at::Tensor points_tensor,
|
45 |
+
at::Tensor idx_tensor,
|
46 |
+
at::Tensor weight_tensor,
|
47 |
+
at::Tensor out_tensor) {
|
48 |
+
|
49 |
+
const float *points = points_tensor.data<float>();
|
50 |
+
const float *weight = weight_tensor.data<float>();
|
51 |
+
float *out = out_tensor.data<float>();
|
52 |
+
const int *idx = idx_tensor.data<int>();
|
53 |
+
|
54 |
+
|
55 |
+
three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out);
|
56 |
+
}
|
57 |
+
|
58 |
+
void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
|
59 |
+
at::Tensor grad_out_tensor,
|
60 |
+
at::Tensor idx_tensor,
|
61 |
+
at::Tensor weight_tensor,
|
62 |
+
at::Tensor grad_points_tensor) {
|
63 |
+
|
64 |
+
const float *grad_out = grad_out_tensor.data<float>();
|
65 |
+
const float *weight = weight_tensor.data<float>();
|
66 |
+
float *grad_points = grad_points_tensor.data<float>();
|
67 |
+
const int *idx = idx_tensor.data<int>();
|
68 |
+
|
69 |
+
three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points);
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
void three_nn_wrapper_stack(at::Tensor unknown_tensor,
|
74 |
+
at::Tensor unknown_batch_cnt_tensor, at::Tensor known_tensor,
|
75 |
+
at::Tensor known_batch_cnt_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor){
|
76 |
+
// unknown: (N1 + N2 ..., 3)
|
77 |
+
// unknown_batch_cnt: (batch_size), [N1, N2, ...]
|
78 |
+
// known: (M1 + M2 ..., 3)
|
79 |
+
// known_batch_cnt: (batch_size), [M1, M2, ...]
|
80 |
+
// Return:
|
81 |
+
// dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors
|
82 |
+
// idx: (N1 + N2 ..., 3) index of the three nearest neighbors
|
83 |
+
CHECK_INPUT(unknown_tensor);
|
84 |
+
CHECK_INPUT(unknown_batch_cnt_tensor);
|
85 |
+
CHECK_INPUT(known_tensor);
|
86 |
+
CHECK_INPUT(known_batch_cnt_tensor);
|
87 |
+
CHECK_INPUT(dist2_tensor);
|
88 |
+
CHECK_INPUT(idx_tensor);
|
89 |
+
|
90 |
+
int batch_size = unknown_batch_cnt_tensor.size(0);
|
91 |
+
int N = unknown_tensor.size(0);
|
92 |
+
int M = known_tensor.size(0);
|
93 |
+
const float *unknown = unknown_tensor.data<float>();
|
94 |
+
const int *unknown_batch_cnt = unknown_batch_cnt_tensor.data<int>();
|
95 |
+
const float *known = known_tensor.data<float>();
|
96 |
+
const int *known_batch_cnt = known_batch_cnt_tensor.data<int>();
|
97 |
+
float *dist2 = dist2_tensor.data<float>();
|
98 |
+
int *idx = idx_tensor.data<int>();
|
99 |
+
|
100 |
+
three_nn_kernel_launcher_stack(batch_size, N, M, unknown, unknown_batch_cnt, known, known_batch_cnt, dist2, idx);
|
101 |
+
}
|
102 |
+
|
103 |
+
|
104 |
+
void three_interpolate_wrapper_stack(at::Tensor features_tensor,
|
105 |
+
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor) {
|
106 |
+
// features_tensor: (M1 + M2 ..., C)
|
107 |
+
// idx_tensor: [N1 + N2 ..., 3]
|
108 |
+
// weight_tensor: [N1 + N2 ..., 3]
|
109 |
+
// Return:
|
110 |
+
// out_tensor: (N1 + N2 ..., C)
|
111 |
+
CHECK_INPUT(features_tensor);
|
112 |
+
CHECK_INPUT(idx_tensor);
|
113 |
+
CHECK_INPUT(weight_tensor);
|
114 |
+
CHECK_INPUT(out_tensor);
|
115 |
+
|
116 |
+
int N = out_tensor.size(0);
|
117 |
+
int channels = features_tensor.size(1);
|
118 |
+
const float *features = features_tensor.data<float>();
|
119 |
+
const float *weight = weight_tensor.data<float>();
|
120 |
+
const int *idx = idx_tensor.data<int>();
|
121 |
+
float *out = out_tensor.data<float>();
|
122 |
+
|
123 |
+
three_interpolate_kernel_launcher_stack(N, channels, features, idx, weight, out);
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
void three_interpolate_grad_wrapper_stack(at::Tensor grad_out_tensor, at::Tensor idx_tensor,
|
128 |
+
at::Tensor weight_tensor, at::Tensor grad_features_tensor) {
|
129 |
+
// grad_out_tensor: (N1 + N2 ..., C)
|
130 |
+
// idx_tensor: [N1 + N2 ..., 3]
|
131 |
+
// weight_tensor: [N1 + N2 ..., 3]
|
132 |
+
// Return:
|
133 |
+
// grad_features_tensor: (M1 + M2 ..., C)
|
134 |
+
CHECK_INPUT(grad_out_tensor);
|
135 |
+
CHECK_INPUT(idx_tensor);
|
136 |
+
CHECK_INPUT(weight_tensor);
|
137 |
+
CHECK_INPUT(grad_features_tensor);
|
138 |
+
|
139 |
+
int N = grad_out_tensor.size(0);
|
140 |
+
int channels = grad_out_tensor.size(1);
|
141 |
+
const float *grad_out = grad_out_tensor.data<float>();
|
142 |
+
const float *weight = weight_tensor.data<float>();
|
143 |
+
const int *idx = idx_tensor.data<int>();
|
144 |
+
float *grad_features = grad_features_tensor.data<float>();
|
145 |
+
|
146 |
+
// printf("N=%d, channels=%d\n", N, channels);
|
147 |
+
three_interpolate_grad_kernel_launcher_stack(N, channels, grad_out, idx, weight, grad_features);
|
148 |
+
}
|
pc_util/src/interpolate_gpu.cu
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <math.h>
|
2 |
+
#include <stdio.h>
|
3 |
+
#include <stdlib.h>
|
4 |
+
|
5 |
+
#include "cuda_utils.h"
|
6 |
+
#include "interpolate_gpu.h"
|
7 |
+
|
8 |
+
|
9 |
+
__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown,
|
10 |
+
const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
|
11 |
+
// unknown: (B, N, 3)
|
12 |
+
// known: (B, M, 3)
|
13 |
+
// output:
|
14 |
+
// dist2: (B, N, 3)
|
15 |
+
// idx: (B, N, 3)
|
16 |
+
|
17 |
+
int bs_idx = blockIdx.y;
|
18 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
19 |
+
if (bs_idx >= b || pt_idx >= n) return;
|
20 |
+
|
21 |
+
unknown += bs_idx * n * 3 + pt_idx * 3;
|
22 |
+
known += bs_idx * m * 3;
|
23 |
+
dist2 += bs_idx * n * 3 + pt_idx * 3;
|
24 |
+
idx += bs_idx * n * 3 + pt_idx * 3;
|
25 |
+
|
26 |
+
float ux = unknown[0];
|
27 |
+
float uy = unknown[1];
|
28 |
+
float uz = unknown[2];
|
29 |
+
|
30 |
+
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
|
31 |
+
int besti1 = 0, besti2 = 0, besti3 = 0;
|
32 |
+
for (int k = 0; k < m; ++k) {
|
33 |
+
float x = known[k * 3 + 0];
|
34 |
+
float y = known[k * 3 + 1];
|
35 |
+
float z = known[k * 3 + 2];
|
36 |
+
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
|
37 |
+
if (d < best1) {
|
38 |
+
best3 = best2; besti3 = besti2;
|
39 |
+
best2 = best1; besti2 = besti1;
|
40 |
+
best1 = d; besti1 = k;
|
41 |
+
}
|
42 |
+
else if (d < best2) {
|
43 |
+
best3 = best2; besti3 = besti2;
|
44 |
+
best2 = d; besti2 = k;
|
45 |
+
}
|
46 |
+
else if (d < best3) {
|
47 |
+
best3 = d; besti3 = k;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
|
51 |
+
idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
|
56 |
+
const float *known, float *dist2, int *idx) {
|
57 |
+
// unknown: (B, N, 3)
|
58 |
+
// known: (B, M, 3)
|
59 |
+
// output:
|
60 |
+
// dist2: (B, N, 3)
|
61 |
+
// idx: (B, N, 3)
|
62 |
+
|
63 |
+
cudaError_t err;
|
64 |
+
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
|
65 |
+
dim3 threads(THREADS_PER_BLOCK);
|
66 |
+
|
67 |
+
three_nn_kernel_fast<<<blocks, threads>>>(b, n, m, unknown, known, dist2, idx);
|
68 |
+
|
69 |
+
err = cudaGetLastError();
|
70 |
+
if (cudaSuccess != err) {
|
71 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
72 |
+
exit(-1);
|
73 |
+
}
|
74 |
+
}
|
75 |
+
|
76 |
+
|
77 |
+
__global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points,
|
78 |
+
const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {
|
79 |
+
// points: (B, C, M)
|
80 |
+
// idx: (B, N, 3)
|
81 |
+
// weight: (B, N, 3)
|
82 |
+
// output:
|
83 |
+
// out: (B, C, N)
|
84 |
+
|
85 |
+
int bs_idx = blockIdx.z;
|
86 |
+
int c_idx = blockIdx.y;
|
87 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
88 |
+
|
89 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
|
90 |
+
|
91 |
+
weight += bs_idx * n * 3 + pt_idx * 3;
|
92 |
+
points += bs_idx * c * m + c_idx * m;
|
93 |
+
idx += bs_idx * n * 3 + pt_idx * 3;
|
94 |
+
out += bs_idx * c * n + c_idx * n;
|
95 |
+
|
96 |
+
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
|
97 |
+
}
|
98 |
+
|
99 |
+
void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
|
100 |
+
const float *points, const int *idx, const float *weight, float *out) {
|
101 |
+
// points: (B, C, M)
|
102 |
+
// idx: (B, N, 3)
|
103 |
+
// weight: (B, N, 3)
|
104 |
+
// output:
|
105 |
+
// out: (B, C, N)
|
106 |
+
|
107 |
+
cudaError_t err;
|
108 |
+
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
109 |
+
dim3 threads(THREADS_PER_BLOCK);
|
110 |
+
three_interpolate_kernel_fast<<<blocks, threads>>>(b, c, m, n, points, idx, weight, out);
|
111 |
+
|
112 |
+
err = cudaGetLastError();
|
113 |
+
if (cudaSuccess != err) {
|
114 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
115 |
+
exit(-1);
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
|
120 |
+
__global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
|
121 |
+
const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) {
|
122 |
+
// grad_out: (B, C, N)
|
123 |
+
// weight: (B, N, 3)
|
124 |
+
// output:
|
125 |
+
// grad_points: (B, C, M)
|
126 |
+
|
127 |
+
int bs_idx = blockIdx.z;
|
128 |
+
int c_idx = blockIdx.y;
|
129 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
130 |
+
|
131 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
|
132 |
+
|
133 |
+
grad_out += bs_idx * c * n + c_idx * n + pt_idx;
|
134 |
+
weight += bs_idx * n * 3 + pt_idx * 3;
|
135 |
+
grad_points += bs_idx * c * m + c_idx * m;
|
136 |
+
idx += bs_idx * n * 3 + pt_idx * 3;
|
137 |
+
|
138 |
+
|
139 |
+
atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
|
140 |
+
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
|
141 |
+
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
|
142 |
+
}
|
143 |
+
|
144 |
+
void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
|
145 |
+
const int *idx, const float *weight, float *grad_points) {
|
146 |
+
// grad_out: (B, C, N)
|
147 |
+
// weight: (B, N, 3)
|
148 |
+
// output:
|
149 |
+
// grad_points: (B, C, M)
|
150 |
+
|
151 |
+
cudaError_t err;
|
152 |
+
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
153 |
+
dim3 threads(THREADS_PER_BLOCK);
|
154 |
+
three_interpolate_grad_kernel_fast<<<blocks, threads>>>(b, c, n, m, grad_out, idx, weight, grad_points);
|
155 |
+
|
156 |
+
err = cudaGetLastError();
|
157 |
+
if (cudaSuccess != err) {
|
158 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
159 |
+
exit(-1);
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
|
164 |
+
__global__ void three_nn_kernel_stack(int batch_size, int N, int M, const float *unknown,
|
165 |
+
const int *unknown_batch_cnt, const float *known, const int *known_batch_cnt,
|
166 |
+
float *dist2, int *idx) {
|
167 |
+
// unknown: (N1 + N2 ..., 3)
|
168 |
+
// unknown_batch_cnt: (batch_size), [N1, N2, ...]
|
169 |
+
// known: (M1 + M2 ..., 3)
|
170 |
+
// known_batch_cnt: (batch_size), [M1, M2, ...]
|
171 |
+
// Return:
|
172 |
+
// dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors
|
173 |
+
// idx: (N1 + N2 ..., 3) index of the three nearest neighbors
|
174 |
+
|
175 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
176 |
+
if (pt_idx >= N) return;
|
177 |
+
|
178 |
+
int bs_idx = 0, pt_cnt = unknown_batch_cnt[0];
|
179 |
+
for (int k = 1; k < batch_size; k++){
|
180 |
+
if (pt_idx < pt_cnt) break;
|
181 |
+
pt_cnt += unknown_batch_cnt[k];
|
182 |
+
bs_idx = k;
|
183 |
+
}
|
184 |
+
|
185 |
+
int cur_num_known_points = known_batch_cnt[bs_idx];
|
186 |
+
|
187 |
+
int known_batch_start_idx = 0;
|
188 |
+
for (int k = 0; k < bs_idx; k++) known_batch_start_idx += known_batch_cnt[k];
|
189 |
+
|
190 |
+
known += known_batch_start_idx * 3;
|
191 |
+
unknown += pt_idx * 3;
|
192 |
+
dist2 += pt_idx * 3;
|
193 |
+
idx += pt_idx * 3;
|
194 |
+
|
195 |
+
float ux = unknown[0];
|
196 |
+
float uy = unknown[1];
|
197 |
+
float uz = unknown[2];
|
198 |
+
|
199 |
+
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
|
200 |
+
int besti1 = 0, besti2 = 0, besti3 = 0;
|
201 |
+
for (int k = 0; k < cur_num_known_points; ++k) {
|
202 |
+
float x = known[k * 3 + 0];
|
203 |
+
float y = known[k * 3 + 1];
|
204 |
+
float z = known[k * 3 + 2];
|
205 |
+
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
|
206 |
+
if (d < best1) {
|
207 |
+
best3 = best2; besti3 = besti2;
|
208 |
+
best2 = best1; besti2 = besti1;
|
209 |
+
best1 = d; besti1 = k;
|
210 |
+
}
|
211 |
+
else if (d < best2) {
|
212 |
+
best3 = best2; besti3 = besti2;
|
213 |
+
best2 = d; besti2 = k;
|
214 |
+
}
|
215 |
+
else if (d < best3) {
|
216 |
+
best3 = d; besti3 = k;
|
217 |
+
}
|
218 |
+
}
|
219 |
+
dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
|
220 |
+
idx[0] = besti1 + known_batch_start_idx;
|
221 |
+
idx[1] = besti2 + known_batch_start_idx;
|
222 |
+
idx[2] = besti3 + known_batch_start_idx;
|
223 |
+
}
|
224 |
+
|
225 |
+
|
226 |
+
void three_nn_kernel_launcher_stack(int batch_size, int N, int M, const float *unknown,
|
227 |
+
const int *unknown_batch_cnt, const float *known, const int *known_batch_cnt,
|
228 |
+
float *dist2, int *idx) {
|
229 |
+
// unknown: (N1 + N2 ..., 3)
|
230 |
+
// unknown_batch_cnt: (batch_size), [N1, N2, ...]
|
231 |
+
// known: (M1 + M2 ..., 3)
|
232 |
+
// known_batch_cnt: (batch_size), [M1, M2, ...]
|
233 |
+
// Return:
|
234 |
+
// dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors
|
235 |
+
// idx: (N1 + N2 ..., 3) index of the three nearest neighbors
|
236 |
+
|
237 |
+
cudaError_t err;
|
238 |
+
dim3 blocks(DIVUP(N, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
|
239 |
+
dim3 threads(THREADS_PER_BLOCK);
|
240 |
+
|
241 |
+
three_nn_kernel_stack<<<blocks, threads>>>(
|
242 |
+
batch_size, N, M, unknown, unknown_batch_cnt,
|
243 |
+
known, known_batch_cnt, dist2, idx
|
244 |
+
);
|
245 |
+
|
246 |
+
err = cudaGetLastError();
|
247 |
+
if (cudaSuccess != err) {
|
248 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
249 |
+
exit(-1);
|
250 |
+
}
|
251 |
+
}
|
252 |
+
|
253 |
+
|
254 |
+
|
255 |
+
__global__ void three_interpolate_kernel_stack(int N, int channels, const float *features,
|
256 |
+
const int *idx, const float *weight, float *out) {
|
257 |
+
// features: (M1 + M2 ..., C)
|
258 |
+
// idx: [N1 + N2 ..., 3]
|
259 |
+
// weight: [N1 + N2 ..., 3]
|
260 |
+
// Return:
|
261 |
+
// out: (N1 + N2 ..., C)
|
262 |
+
|
263 |
+
int c_idx = blockIdx.y;
|
264 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
265 |
+
if (pt_idx >= N || c_idx >= channels) return;
|
266 |
+
|
267 |
+
weight += pt_idx * 3;
|
268 |
+
idx += pt_idx * 3;
|
269 |
+
out += pt_idx * channels + c_idx;
|
270 |
+
|
271 |
+
out[0] = weight[0] * features[idx[0] * channels + c_idx] +
|
272 |
+
weight[1] * features[idx[1] * channels + c_idx] +
|
273 |
+
weight[2] * features[idx[2] * channels + c_idx];
|
274 |
+
}
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
void three_interpolate_kernel_launcher_stack(int N, int channels,
|
279 |
+
const float *features, const int *idx, const float *weight, float *out) {
|
280 |
+
// features: (M1 + M2 ..., C)
|
281 |
+
// idx: [N1 + N2 ..., 3]
|
282 |
+
// weight: [N1 + N2 ..., 3]
|
283 |
+
// Return:
|
284 |
+
// out: (N1 + N2 ..., C)
|
285 |
+
|
286 |
+
cudaError_t err;
|
287 |
+
dim3 blocks(DIVUP(N, THREADS_PER_BLOCK), channels);
|
288 |
+
dim3 threads(THREADS_PER_BLOCK);
|
289 |
+
three_interpolate_kernel_stack<<<blocks, threads>>>(N, channels, features, idx, weight, out);
|
290 |
+
|
291 |
+
err = cudaGetLastError();
|
292 |
+
if (cudaSuccess != err) {
|
293 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
294 |
+
exit(-1);
|
295 |
+
}
|
296 |
+
}
|
297 |
+
|
298 |
+
|
299 |
+
__global__ void three_interpolate_grad_kernel_stack(int N, int channels, const float *grad_out,
|
300 |
+
const int *idx, const float *weight, float *grad_features) {
|
301 |
+
// grad_out_tensor: (N1 + N2 ..., C)
|
302 |
+
// idx_tensor: [N1 + N2 ..., 3]
|
303 |
+
// weight_tensor: [N1 + N2 ..., 3]
|
304 |
+
// Return:
|
305 |
+
// grad_features_tensor: (M1 + M2 ..., C)
|
306 |
+
|
307 |
+
int c_idx = blockIdx.y;
|
308 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
309 |
+
if (pt_idx >= N || c_idx >= channels) return;
|
310 |
+
|
311 |
+
grad_out += pt_idx * channels + c_idx;
|
312 |
+
weight += pt_idx * 3;
|
313 |
+
idx += pt_idx * 3;
|
314 |
+
|
315 |
+
// printf("pt_idx=%d, c_idx=%d, idx=(%d, %d, %d), grad_out=%f\n", pt_idx, c_idx, idx[0], idx[1], idx[2], grad_out[0]);
|
316 |
+
|
317 |
+
atomicAdd(grad_features + idx[0] * channels + c_idx, grad_out[0] * weight[0]);
|
318 |
+
atomicAdd(grad_features + idx[1] * channels + c_idx, grad_out[0] * weight[1]);
|
319 |
+
atomicAdd(grad_features + idx[2] * channels + c_idx, grad_out[0] * weight[2]);
|
320 |
+
}
|
321 |
+
|
322 |
+
|
323 |
+
void three_interpolate_grad_kernel_launcher_stack(int N, int channels, const float *grad_out,
|
324 |
+
const int *idx, const float *weight, float *grad_features) {
|
325 |
+
// grad_out_tensor: (N1 + N2 ..., C)
|
326 |
+
// idx_tensor: [N1 + N2 ..., 3]
|
327 |
+
// weight_tensor: [N1 + N2 ..., 3]
|
328 |
+
// Return:
|
329 |
+
// grad_features_tensor: (M1 + M2 ..., C)
|
330 |
+
|
331 |
+
cudaError_t err;
|
332 |
+
dim3 blocks(DIVUP(N, THREADS_PER_BLOCK), channels); // blockIdx.x(col), blockIdx.y(row)
|
333 |
+
dim3 threads(THREADS_PER_BLOCK);
|
334 |
+
three_interpolate_grad_kernel_stack<<<blocks, threads>>>(
|
335 |
+
N, channels, grad_out, idx, weight, grad_features
|
336 |
+
);
|
337 |
+
|
338 |
+
err = cudaGetLastError();
|
339 |
+
if (cudaSuccess != err) {
|
340 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
341 |
+
exit(-1);
|
342 |
+
}
|
343 |
+
}
|
pc_util/src/interpolate_gpu.h
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _INTERPOLATE_GPU_H
|
2 |
+
#define _INTERPOLATE_GPU_H
|
3 |
+
|
4 |
+
#include <torch/serialize/tensor.h>
|
5 |
+
#include<vector>
|
6 |
+
#include <cuda.h>
|
7 |
+
#include <cuda_runtime_api.h>
|
8 |
+
|
9 |
+
|
10 |
+
void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
|
11 |
+
at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
|
12 |
+
|
13 |
+
void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
|
14 |
+
const float *known, float *dist2, int *idx);
|
15 |
+
|
16 |
+
|
17 |
+
void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor,
|
18 |
+
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
|
19 |
+
|
20 |
+
void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
|
21 |
+
const float *points, const int *idx, const float *weight, float *out);
|
22 |
+
|
23 |
+
|
24 |
+
void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor,
|
25 |
+
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);
|
26 |
+
|
27 |
+
void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
|
28 |
+
const int *idx, const float *weight, float *grad_points);
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
void three_nn_wrapper_stack(at::Tensor unknown_tensor,
|
33 |
+
at::Tensor unknown_batch_cnt_tensor, at::Tensor known_tensor,
|
34 |
+
at::Tensor known_batch_cnt_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
|
35 |
+
|
36 |
+
|
37 |
+
void three_interpolate_wrapper_stack(at::Tensor features_tensor,
|
38 |
+
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
void three_interpolate_grad_wrapper_stack(at::Tensor grad_out_tensor, at::Tensor idx_tensor,
|
43 |
+
at::Tensor weight_tensor, at::Tensor grad_features_tensor);
|
44 |
+
|
45 |
+
|
46 |
+
void three_nn_kernel_launcher_stack(int batch_size, int N, int M, const float *unknown,
|
47 |
+
const int *unknown_batch_cnt, const float *known, const int *known_batch_cnt,
|
48 |
+
float *dist2, int *idx);
|
49 |
+
|
50 |
+
|
51 |
+
void three_interpolate_kernel_launcher_stack(int N, int channels,
|
52 |
+
const float *features, const int *idx, const float *weight, float *out);
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
void three_interpolate_grad_kernel_launcher_stack(int N, int channels, const float *grad_out,
|
57 |
+
const int *idx, const float *weight, float *grad_features);
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
#endif
|
pc_util/src/pointnet2_api.cpp
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
#include "ball_query_gpu.h"
|
5 |
+
#include "group_points_gpu.h"
|
6 |
+
#include "sampling_gpu.h"
|
7 |
+
#include "interpolate_gpu.h"
|
8 |
+
#include "cluster_gpu.h"
|
9 |
+
|
10 |
+
|
11 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
12 |
+
m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");
|
13 |
+
m.def("ball_center_query_wrapper", &ball_center_query_wrapper_fast, "ball_center_query_wrapper_fast");
|
14 |
+
m.def("knn_query_wrapper", &knn_query_wrapper_fast, "knn_query_wrapper_fast");
|
15 |
+
|
16 |
+
m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
|
17 |
+
m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");
|
18 |
+
|
19 |
+
m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
|
20 |
+
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
|
21 |
+
|
22 |
+
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
|
23 |
+
|
24 |
+
m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
|
25 |
+
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
|
26 |
+
m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
|
27 |
+
|
28 |
+
m.def("dbscan_wrapper", &dbscan_wrapper_fast, "dbscan_wrapper_fast");
|
29 |
+
m.def("cluster_pts_wrapper", &cluster_pts_wrapper_fast, "cluster_pts_wrapper_fast");
|
30 |
+
|
31 |
+
|
32 |
+
m.def("ball_query_wrapper_stack", &ball_query_wrapper_stack, "ball_query_wrapper_stack");
|
33 |
+
|
34 |
+
m.def("group_points_wrapper_stack", &group_points_wrapper_stack, "group_points_wrapper_stack");
|
35 |
+
m.def("group_points_grad_wrapper_stack", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack");
|
36 |
+
|
37 |
+
m.def("three_nn_wrapper_stack", &three_nn_wrapper_stack, "three_nn_wrapper_stack");
|
38 |
+
m.def("three_interpolate_wrapper_stack", &three_interpolate_wrapper_stack, "three_interpolate_wrapper_stack");
|
39 |
+
m.def("three_interpolate_grad_wrapper_stack", &three_interpolate_grad_wrapper_stack, "three_interpolate_grad_wrapper_stack");
|
40 |
+
|
41 |
+
}
|
pc_util/src/sampling.cpp
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/serialize/tensor.h>
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
+
#include <vector>
|
4 |
+
// #include <THC/THC.h>
|
5 |
+
|
6 |
+
#include "sampling_gpu.h"
|
7 |
+
|
8 |
+
// extern THCState *state;
|
9 |
+
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <ATen/cuda/CUDAEvent.h>
|
12 |
+
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
13 |
+
|
14 |
+
int gather_points_wrapper_fast(int b, int c, int n, int npoints,
|
15 |
+
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
|
16 |
+
const float *points = points_tensor.data<float>();
|
17 |
+
const int *idx = idx_tensor.data<int>();
|
18 |
+
float *out = out_tensor.data<float>();
|
19 |
+
|
20 |
+
gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out);
|
21 |
+
return 1;
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
|
26 |
+
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
|
27 |
+
|
28 |
+
const float *grad_out = grad_out_tensor.data<float>();
|
29 |
+
const int *idx = idx_tensor.data<int>();
|
30 |
+
float *grad_points = grad_points_tensor.data<float>();
|
31 |
+
|
32 |
+
gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points);
|
33 |
+
return 1;
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
int furthest_point_sampling_wrapper(int b, int c, int n, int m, float w1, float w2,
|
38 |
+
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
|
39 |
+
|
40 |
+
const float *points = points_tensor.data<float>();
|
41 |
+
float *temp = temp_tensor.data<float>();
|
42 |
+
int *idx = idx_tensor.data<int>();
|
43 |
+
|
44 |
+
furthest_point_sampling_kernel_launcher(b, c, n, m, w1, w2, points, temp, idx);
|
45 |
+
return 1;
|
46 |
+
}
|
pc_util/src/sampling_gpu.cu
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <stdlib.h>
|
3 |
+
|
4 |
+
#include "cuda_utils.h"
|
5 |
+
#include "sampling_gpu.h"
|
6 |
+
|
7 |
+
|
8 |
+
__global__ void gather_points_kernel_fast(int b, int c, int n, int m,
|
9 |
+
const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
|
10 |
+
// points: (B, C, N)
|
11 |
+
// idx: (B, M)
|
12 |
+
// output:
|
13 |
+
// out: (B, C, M)
|
14 |
+
|
15 |
+
int bs_idx = blockIdx.z;
|
16 |
+
int c_idx = blockIdx.y;
|
17 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
18 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
|
19 |
+
|
20 |
+
out += bs_idx * c * m + c_idx * m + pt_idx;
|
21 |
+
idx += bs_idx * m + pt_idx;
|
22 |
+
points += bs_idx * c * n + c_idx * n;
|
23 |
+
out[0] = points[idx[0]];
|
24 |
+
}
|
25 |
+
|
26 |
+
void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
|
27 |
+
const float *points, const int *idx, float *out) {
|
28 |
+
// points: (B, C, N)
|
29 |
+
// idx: (B, npoints)
|
30 |
+
// output:
|
31 |
+
// out: (B, C, npoints)
|
32 |
+
|
33 |
+
cudaError_t err;
|
34 |
+
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
35 |
+
dim3 threads(THREADS_PER_BLOCK);
|
36 |
+
|
37 |
+
gather_points_kernel_fast<<<blocks, threads>>>(b, c, n, npoints, points, idx, out);
|
38 |
+
|
39 |
+
err = cudaGetLastError();
|
40 |
+
if (cudaSuccess != err) {
|
41 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
42 |
+
exit(-1);
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
__global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
|
47 |
+
const int *__restrict__ idx, float *__restrict__ grad_points) {
|
48 |
+
// grad_out: (B, C, M)
|
49 |
+
// idx: (B, M)
|
50 |
+
// output:
|
51 |
+
// grad_points: (B, C, N)
|
52 |
+
|
53 |
+
int bs_idx = blockIdx.z;
|
54 |
+
int c_idx = blockIdx.y;
|
55 |
+
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
56 |
+
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
|
57 |
+
|
58 |
+
grad_out += bs_idx * c * m + c_idx * m + pt_idx;
|
59 |
+
idx += bs_idx * m + pt_idx;
|
60 |
+
grad_points += bs_idx * c * n + c_idx * n;
|
61 |
+
|
62 |
+
atomicAdd(grad_points + idx[0], grad_out[0]);
|
63 |
+
}
|
64 |
+
|
65 |
+
void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
|
66 |
+
const float *grad_out, const int *idx, float *grad_points) {
|
67 |
+
// grad_out: (B, C, npoints)
|
68 |
+
// idx: (B, npoints)
|
69 |
+
// output:
|
70 |
+
// grad_points: (B, C, N)
|
71 |
+
|
72 |
+
cudaError_t err;
|
73 |
+
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
|
74 |
+
dim3 threads(THREADS_PER_BLOCK);
|
75 |
+
|
76 |
+
gather_points_grad_kernel_fast<<<blocks, threads>>>(b, c, n, npoints, grad_out, idx, grad_points);
|
77 |
+
|
78 |
+
err = cudaGetLastError();
|
79 |
+
if (cudaSuccess != err) {
|
80 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
81 |
+
exit(-1);
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
+
|
86 |
+
__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
|
87 |
+
const float v1 = dists[idx1], v2 = dists[idx2];
|
88 |
+
const int i1 = dists_i[idx1], i2 = dists_i[idx2];
|
89 |
+
dists[idx1] = max(v1, v2);
|
90 |
+
dists_i[idx1] = v2 > v1 ? i2 : i1;
|
91 |
+
}
|
92 |
+
|
93 |
+
template <unsigned int block_size>
|
94 |
+
__global__ void furthest_point_sampling_kernel(int b, int c, int n, int m, float w1, float w2,
|
95 |
+
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
|
96 |
+
// dataset: (B, N, 3)
|
97 |
+
// tmp: (B, N)
|
98 |
+
// output:
|
99 |
+
// idx: (B, M)
|
100 |
+
|
101 |
+
if (m <= 0) return;
|
102 |
+
__shared__ float dists[block_size];
|
103 |
+
__shared__ int dists_i[block_size];
|
104 |
+
|
105 |
+
int batch_index = blockIdx.x;
|
106 |
+
dataset += batch_index * n * c;
|
107 |
+
temp += batch_index * n;
|
108 |
+
idxs += batch_index * m;
|
109 |
+
|
110 |
+
int tid = threadIdx.x;
|
111 |
+
const int stride = block_size;
|
112 |
+
|
113 |
+
int old = 0;
|
114 |
+
if (threadIdx.x == 0)
|
115 |
+
idxs[0] = old;
|
116 |
+
|
117 |
+
__syncthreads();
|
118 |
+
for (int j = 1; j < m; j++) {
|
119 |
+
int besti = 0;
|
120 |
+
float best = -1;
|
121 |
+
float x1 = dataset[old * c + 0];
|
122 |
+
float y1 = dataset[old * c + 1];
|
123 |
+
float z1 = dataset[old * c + 2];
|
124 |
+
|
125 |
+
for (int k = tid; k < n; k += stride) {
|
126 |
+
float x2, y2, z2;
|
127 |
+
x2 = dataset[k * c + 0];
|
128 |
+
y2 = dataset[k * c + 1];
|
129 |
+
z2 = dataset[k * c + 2];
|
130 |
+
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
|
131 |
+
// if (mag <= 1e-3)
|
132 |
+
// continue;
|
133 |
+
|
134 |
+
float xyz_d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
|
135 |
+
float fea_d = 0;
|
136 |
+
for (int l = 3; l < c; l++) {
|
137 |
+
fea_d += (dataset[old * c + l] - dataset[k * c + l]) * (dataset[old * c + l] - dataset[k * c + l]);
|
138 |
+
}
|
139 |
+
float d = w1 * xyz_d + w2 * fea_d;
|
140 |
+
float d2 = min(d, temp[k]);
|
141 |
+
temp[k] = d2;
|
142 |
+
besti = d2 > best ? k : besti;
|
143 |
+
best = d2 > best ? d2 : best;
|
144 |
+
}
|
145 |
+
dists[tid] = best;
|
146 |
+
dists_i[tid] = besti;
|
147 |
+
__syncthreads();
|
148 |
+
|
149 |
+
if (block_size >= 1024) {
|
150 |
+
if (tid < 512) {
|
151 |
+
__update(dists, dists_i, tid, tid + 512);
|
152 |
+
}
|
153 |
+
__syncthreads();
|
154 |
+
}
|
155 |
+
|
156 |
+
if (block_size >= 512) {
|
157 |
+
if (tid < 256) {
|
158 |
+
__update(dists, dists_i, tid, tid + 256);
|
159 |
+
}
|
160 |
+
__syncthreads();
|
161 |
+
}
|
162 |
+
if (block_size >= 256) {
|
163 |
+
if (tid < 128) {
|
164 |
+
__update(dists, dists_i, tid, tid + 128);
|
165 |
+
}
|
166 |
+
__syncthreads();
|
167 |
+
}
|
168 |
+
if (block_size >= 128) {
|
169 |
+
if (tid < 64) {
|
170 |
+
__update(dists, dists_i, tid, tid + 64);
|
171 |
+
}
|
172 |
+
__syncthreads();
|
173 |
+
}
|
174 |
+
if (block_size >= 64) {
|
175 |
+
if (tid < 32) {
|
176 |
+
__update(dists, dists_i, tid, tid + 32);
|
177 |
+
}
|
178 |
+
__syncthreads();
|
179 |
+
}
|
180 |
+
if (block_size >= 32) {
|
181 |
+
if (tid < 16) {
|
182 |
+
__update(dists, dists_i, tid, tid + 16);
|
183 |
+
}
|
184 |
+
__syncthreads();
|
185 |
+
}
|
186 |
+
if (block_size >= 16) {
|
187 |
+
if (tid < 8) {
|
188 |
+
__update(dists, dists_i, tid, tid + 8);
|
189 |
+
}
|
190 |
+
__syncthreads();
|
191 |
+
}
|
192 |
+
if (block_size >= 8) {
|
193 |
+
if (tid < 4) {
|
194 |
+
__update(dists, dists_i, tid, tid + 4);
|
195 |
+
}
|
196 |
+
__syncthreads();
|
197 |
+
}
|
198 |
+
if (block_size >= 4) {
|
199 |
+
if (tid < 2) {
|
200 |
+
__update(dists, dists_i, tid, tid + 2);
|
201 |
+
}
|
202 |
+
__syncthreads();
|
203 |
+
}
|
204 |
+
if (block_size >= 2) {
|
205 |
+
if (tid < 1) {
|
206 |
+
__update(dists, dists_i, tid, tid + 1);
|
207 |
+
}
|
208 |
+
__syncthreads();
|
209 |
+
}
|
210 |
+
|
211 |
+
old = dists_i[0];
|
212 |
+
if (tid == 0)
|
213 |
+
idxs[j] = old;
|
214 |
+
}
|
215 |
+
}
|
216 |
+
|
217 |
+
void furthest_point_sampling_kernel_launcher(int b, int c, int n, int m, float w1, float w2,
|
218 |
+
const float *dataset, float *temp, int *idxs) {
|
219 |
+
// dataset: (B, N, 3)
|
220 |
+
// tmp: (B, N)
|
221 |
+
// output:
|
222 |
+
// idx: (B, M)
|
223 |
+
|
224 |
+
cudaError_t err;
|
225 |
+
unsigned int n_threads = opt_n_threads(n);
|
226 |
+
|
227 |
+
switch (n_threads) {
|
228 |
+
case 1024:
|
229 |
+
furthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
230 |
+
case 512:
|
231 |
+
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
232 |
+
case 256:
|
233 |
+
furthest_point_sampling_kernel<256><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
234 |
+
case 128:
|
235 |
+
furthest_point_sampling_kernel<128><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
236 |
+
case 64:
|
237 |
+
furthest_point_sampling_kernel<64><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
238 |
+
case 32:
|
239 |
+
furthest_point_sampling_kernel<32><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
240 |
+
case 16:
|
241 |
+
furthest_point_sampling_kernel<16><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
242 |
+
case 8:
|
243 |
+
furthest_point_sampling_kernel<8><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
244 |
+
case 4:
|
245 |
+
furthest_point_sampling_kernel<4><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
246 |
+
case 2:
|
247 |
+
furthest_point_sampling_kernel<2><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
248 |
+
case 1:
|
249 |
+
furthest_point_sampling_kernel<1><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs); break;
|
250 |
+
default:
|
251 |
+
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, c, n, m, w1, w2, dataset, temp, idxs);
|
252 |
+
}
|
253 |
+
|
254 |
+
err = cudaGetLastError();
|
255 |
+
if (cudaSuccess != err) {
|
256 |
+
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
257 |
+
exit(-1);
|
258 |
+
}
|
259 |
+
}
|
pc_util/src/sampling_gpu.h
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _SAMPLING_GPU_H
|
2 |
+
#define _SAMPLING_GPU_H
|
3 |
+
|
4 |
+
#include <torch/serialize/tensor.h>
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include<vector>
|
7 |
+
|
8 |
+
|
9 |
+
int gather_points_wrapper_fast(int b, int c, int n, int npoints,
|
10 |
+
at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
|
11 |
+
|
12 |
+
void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
|
13 |
+
const float *points, const int *idx, float *out);
|
14 |
+
|
15 |
+
|
16 |
+
int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
|
17 |
+
at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
|
18 |
+
|
19 |
+
void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
|
20 |
+
const float *grad_out, const int *idx, float *grad_points);
|
21 |
+
|
22 |
+
|
23 |
+
int furthest_point_sampling_wrapper(int b, int c, int n, int m, float w1, float w2,
|
24 |
+
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
|
25 |
+
|
26 |
+
void furthest_point_sampling_kernel_launcher(int b, int c, int n, int m, float w1, float w2,
|
27 |
+
const float *dataset, float *temp, int *idxs);
|
28 |
+
|
29 |
+
#endif
|