Spaces:
Sleeping
Sleeping
add hdm demo v1
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +15 -12
- app.py +177 -0
- configs/__init__.py +0 -0
- configs/structured.py +416 -0
- dataset/__init__.py +301 -0
- dataset/base_data.py +110 -0
- dataset/behave_paths.py +228 -0
- dataset/demo_dataset.py +198 -0
- dataset/img_utils.py +149 -0
- demo.py +280 -0
- diffusion_utils.py +313 -0
- examples/017450/k1.color.jpg +0 -0
- examples/017450/k1.obj_rend_mask.png +0 -0
- examples/017450/k1.person_mask.png +0 -0
- model/__init__.py +28 -0
- model/feature_model.py +160 -0
- model/model.py +303 -0
- model/model_coloring.py +84 -0
- model/model_diff_data.py +238 -0
- model/model_hoattn.py +457 -0
- model/model_utils.py +58 -0
- model/point_cloud_model.py +67 -0
- model/point_cloud_transformer_model.py +80 -0
- model/projection_model.py +273 -0
- model/pvcnn/__init__.py +0 -0
- model/pvcnn/modules/__init__.py +8 -0
- model/pvcnn/modules/ball_query.py +69 -0
- model/pvcnn/modules/frustum.py +138 -0
- model/pvcnn/modules/functional/__init__.py +7 -0
- model/pvcnn/modules/functional/backend.py +33 -0
- model/pvcnn/modules/functional/ball_query.py +19 -0
- model/pvcnn/modules/functional/devoxelization.py +42 -0
- model/pvcnn/modules/functional/grouping.py +32 -0
- model/pvcnn/modules/functional/interpolatation.py +38 -0
- model/pvcnn/modules/functional/loss.py +17 -0
- model/pvcnn/modules/functional/sampling.py +84 -0
- model/pvcnn/modules/functional/src/ball_query/ball_query.cpp +30 -0
- model/pvcnn/modules/functional/src/ball_query/ball_query.cu +59 -0
- model/pvcnn/modules/functional/src/ball_query/ball_query.cuh +8 -0
- model/pvcnn/modules/functional/src/ball_query/ball_query.hpp +10 -0
- model/pvcnn/modules/functional/src/bindings.cpp +37 -0
- model/pvcnn/modules/functional/src/cuda_utils.cuh +39 -0
- model/pvcnn/modules/functional/src/grouping/grouping.cpp +44 -0
- model/pvcnn/modules/functional/src/grouping/grouping.cu +85 -0
- model/pvcnn/modules/functional/src/grouping/grouping.cuh +9 -0
- model/pvcnn/modules/functional/src/grouping/grouping.hpp +10 -0
- model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cpp +65 -0
- model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cu +181 -0
- model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cuh +16 -0
- model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.hpp +16 -0
README.md
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
emoji: 🌍
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: green
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.20.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: cc-by-nc-4.0
|
11 |
-
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HDM
|
2 |
+
Official implementation for Hierarachical Diffusion Model in CVPR24 Template free reconstruction of human object interaction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
[Project Page](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/)|[Code](https://github.com/xiexh20/HDM)|[Dataset](https://edmond.mpg.de/dataset.xhtml?persistentId=doi:10.17617/3.2VUEUS )|[Paper](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/paper-lowreso.pdf)
|
5 |
+
|
6 |
+
|
7 |
+
## Citation
|
8 |
+
```
|
9 |
+
@inproceedings{xie2023template_free,
|
10 |
+
title = {Template Free Reconstruction of Human-object Interaction with Procedural Interaction Generation},
|
11 |
+
author = {Xie, Xianghui and Bhatnagar, Bharat Lal and Lenssen, Jan Eric and Pons-Moll, Gerard},
|
12 |
+
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
13 |
+
month = {June},
|
14 |
+
year = {2024},
|
15 |
+
}
|
16 |
+
```
|
app.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Demo built with gradio
|
3 |
+
"""
|
4 |
+
import pickle as pkl
|
5 |
+
import sys, os
|
6 |
+
import os.path as osp
|
7 |
+
from typing import Iterable, Optional
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
import trimesh
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
import cv2
|
13 |
+
from accelerate import Accelerator
|
14 |
+
from tqdm import tqdm
|
15 |
+
from glob import glob
|
16 |
+
|
17 |
+
sys.path.append(os.getcwd())
|
18 |
+
import hydra
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
import imageio
|
22 |
+
import gradio as gr
|
23 |
+
import plotly.graph_objs as go
|
24 |
+
import training_utils
|
25 |
+
|
26 |
+
from configs.structured import ProjectConfig
|
27 |
+
from demo import DemoRunner
|
28 |
+
from dataset.demo_dataset import DemoDataset
|
29 |
+
|
30 |
+
|
31 |
+
md_description="""
|
32 |
+
# HDM Interaction Reconstruction Demo
|
33 |
+
### Official Implementation of the paper \"Template Free Reconstruction of Human Object Interaction\", CVPR'24.
|
34 |
+
[Project Page](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/)|[Code](https://github.com/xiexh20/HDM)|[Dataset](https://edmond.mpg.de/dataset.xhtml?persistentId=doi:10.17617/3.2VUEUS )|[Paper](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/paper-lowreso.pdf)
|
35 |
+
|
36 |
+
|
37 |
+
Upload your own human object interaction image and get full 3D reconstruction!
|
38 |
+
|
39 |
+
## Citation
|
40 |
+
```
|
41 |
+
@inproceedings{xie2023template_free,
|
42 |
+
title = {Template Free Reconstruction of Human-object Interaction with Procedural Interaction Generation},
|
43 |
+
author = {Xie, Xianghui and Bhatnagar, Bharat Lal and Lenssen, Jan Eric and Pons-Moll, Gerard},
|
44 |
+
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
45 |
+
month = {June},
|
46 |
+
year = {2024},
|
47 |
+
}
|
48 |
+
```
|
49 |
+
"""
|
50 |
+
|
51 |
+
def plot_points(colors, coords):
|
52 |
+
"""
|
53 |
+
use plotly to visualize 3D point with colors
|
54 |
+
"""
|
55 |
+
trace = go.Scatter3d(x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], mode='markers',
|
56 |
+
marker=dict(
|
57 |
+
size=2,
|
58 |
+
color=colors
|
59 |
+
))
|
60 |
+
layout = go.Layout(
|
61 |
+
scene=dict(
|
62 |
+
xaxis=dict(
|
63 |
+
title="",
|
64 |
+
showgrid=False,
|
65 |
+
zeroline=False,
|
66 |
+
showline=False,
|
67 |
+
ticks='',
|
68 |
+
showticklabels=False
|
69 |
+
),
|
70 |
+
yaxis=dict(
|
71 |
+
title="",
|
72 |
+
showgrid=False,
|
73 |
+
zeroline=False,
|
74 |
+
showline=False,
|
75 |
+
ticks='',
|
76 |
+
showticklabels=False
|
77 |
+
),
|
78 |
+
zaxis=dict(
|
79 |
+
title="",
|
80 |
+
showgrid=False,
|
81 |
+
zeroline=False,
|
82 |
+
showline=False,
|
83 |
+
ticks='',
|
84 |
+
showticklabels=False
|
85 |
+
),
|
86 |
+
),
|
87 |
+
margin=dict(l=0, r=0, b=0, t=0),
|
88 |
+
showlegend=False
|
89 |
+
)
|
90 |
+
fig = go.Figure(data=[trace], layout=layout)
|
91 |
+
return fig
|
92 |
+
|
93 |
+
|
94 |
+
def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed):
|
95 |
+
"""
|
96 |
+
given user input, run inference
|
97 |
+
:param runner:
|
98 |
+
:param cfg:
|
99 |
+
:param rgb: (h, w, 3), np array
|
100 |
+
:param mask_hum: (h, w, 3), np array
|
101 |
+
:param mask_obj: (h, w, 3), np array
|
102 |
+
:param std_coverage: float value, used to estimate camera translation
|
103 |
+
:param input_seed: random seed
|
104 |
+
:return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
|
105 |
+
"""
|
106 |
+
# Set random seed
|
107 |
+
training_utils.set_seed(int(input_seed))
|
108 |
+
|
109 |
+
data = DemoDataset([], (cfg.dataset.image_size, cfg.dataset.image_size),
|
110 |
+
std_coverage)
|
111 |
+
batch = data.image2batch(rgb, mask_hum, mask_obj)
|
112 |
+
|
113 |
+
out_stage1, out_stage2 = runner.forward_batch(batch, cfg)
|
114 |
+
points = out_stage2.points_packed().cpu().numpy()
|
115 |
+
colors = out_stage2.features_packed().cpu().numpy()
|
116 |
+
fig = plot_points(colors, points)
|
117 |
+
# save tmp point cloud
|
118 |
+
outdir = './results'
|
119 |
+
os.makedirs(outdir, exist_ok=True)
|
120 |
+
trimesh.PointCloud(points, colors).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2.ply")
|
121 |
+
trimesh.PointCloud(out_stage1.points_packed().cpu().numpy(),
|
122 |
+
out_stage1.features_packed().cpu().numpy()).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage1.ply")
|
123 |
+
return fig, outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2.ply"
|
124 |
+
|
125 |
+
|
126 |
+
@hydra.main(config_path='configs', config_name='configs', version_base='1.1')
|
127 |
+
def main(cfg: ProjectConfig):
|
128 |
+
# Setup model
|
129 |
+
runner = DemoRunner(cfg)
|
130 |
+
|
131 |
+
# Setup interface
|
132 |
+
demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
|
133 |
+
with demo:
|
134 |
+
gr.Markdown(md_description)
|
135 |
+
gr.HTML("""<h1 style="text-align:center; color:#10768c">HDM Demo</h1>""")
|
136 |
+
gr.HTML("""<h3 style="text-align:center; color:#10768c">Instruction: Upload RGB, human, object masks and then click reconstruct.</h1>""")
|
137 |
+
|
138 |
+
# Input data
|
139 |
+
with gr.Row():
|
140 |
+
input_rgb = gr.Image(label='Input RGB', type='numpy')
|
141 |
+
input_mask_hum = gr.Image(label='Human mask', type='numpy')
|
142 |
+
with gr.Row():
|
143 |
+
input_mask_obj = gr.Image(label='Object mask', type='numpy')
|
144 |
+
with gr.Column():
|
145 |
+
# TODO: add hint for this value here
|
146 |
+
input_std = gr.Number(label='Gaussian std coverage', value=3.5)
|
147 |
+
input_seed = gr.Number(label='Random seed', value=42)
|
148 |
+
# Output visualization
|
149 |
+
with gr.Row():
|
150 |
+
pc_plot = gr.Plot(label="Reconstructed point cloud")
|
151 |
+
out_pc_download = gr.File(label="3D reconstruction for download") # this allows downloading
|
152 |
+
|
153 |
+
gr.HTML("""<br/>""")
|
154 |
+
# Control
|
155 |
+
with gr.Row():
|
156 |
+
button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
|
157 |
+
button_recon.click(fn=partial(inference, runner, cfg),
|
158 |
+
inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],
|
159 |
+
outputs=[pc_plot, out_pc_download])
|
160 |
+
gr.HTML("""<br/>""")
|
161 |
+
# Example input
|
162 |
+
example_dir = cfg.run.code_dir_abs+"/examples"
|
163 |
+
rgb, ps, obj = 'k1.color.jpg', 'k1.person_mask.png', 'k1.obj_rend_mask.png'
|
164 |
+
example_images = gr.Examples([
|
165 |
+
[f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42],
|
166 |
+
[f"{example_dir}/002446/{rgb}", f"{example_dir}/002446/{ps}", f"{example_dir}/002446/{obj}", 3.0, 42],
|
167 |
+
[f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42],
|
168 |
+
[f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42],
|
169 |
+
|
170 |
+
], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],)
|
171 |
+
|
172 |
+
# demo.launch(share=True)
|
173 |
+
# Enabling queue for runtime>60s, see: https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
|
174 |
+
demo.queue(concurrency_count=3).launch(share=True)
|
175 |
+
|
176 |
+
if __name__ == '__main__':
|
177 |
+
main()
|
configs/__init__.py
ADDED
File without changes
|
configs/structured.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import Any, Dict, List, Optional, Iterable
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
from hydra.core.config_store import ConfigStore
|
7 |
+
from hydra.conf import RunDir
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class CustomHydraRunDir(RunDir):
|
12 |
+
dir: str = './outputs/${run.name}/single'
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class RunConfig:
|
17 |
+
name: str = 'debug'
|
18 |
+
job: str = 'train'
|
19 |
+
mixed_precision: str = 'fp16' # 'no'
|
20 |
+
cpu: bool = False
|
21 |
+
seed: int = 42
|
22 |
+
val_before_training: bool = True
|
23 |
+
vis_before_training: bool = True
|
24 |
+
limit_train_batches: Optional[int] = None
|
25 |
+
limit_val_batches: Optional[int] = None
|
26 |
+
max_steps: int = 100_000
|
27 |
+
checkpoint_freq: int = 1_000
|
28 |
+
val_freq: int = 5_000
|
29 |
+
vis_freq: int = 5_000
|
30 |
+
# vis_freq: int = 10_000
|
31 |
+
log_step_freq: int = 20
|
32 |
+
print_step_freq: int = 100
|
33 |
+
|
34 |
+
# config to run demo
|
35 |
+
stage1_name: str = 'stage1' # experiment name to the stage 1 model
|
36 |
+
stage2_name: str = 'stage2' # experiment name to the stage 2 model
|
37 |
+
image_path: str = '' # the path to the images for running demo, can be a single file or a glob pattern
|
38 |
+
|
39 |
+
# abs path to working dir
|
40 |
+
code_dir_abs: str = osp.dirname(osp.dirname(osp.abspath(__file__)))
|
41 |
+
|
42 |
+
# Inference configs
|
43 |
+
num_inference_steps: int = 1000
|
44 |
+
diffusion_scheduler: Optional[str] = 'ddpm'
|
45 |
+
num_samples: int = 1
|
46 |
+
# num_sample_batches: Optional[int] = None
|
47 |
+
num_sample_batches: Optional[int] = 2000 # XH: change to 2
|
48 |
+
sample_from_ema: bool = False
|
49 |
+
sample_save_evolutions: bool = False # temporarily set by default
|
50 |
+
save_name: str = 'sample' # XH: additional save name
|
51 |
+
redo: bool = False
|
52 |
+
|
53 |
+
# for parallel sampling in slurm
|
54 |
+
batch_start: int = 0
|
55 |
+
batch_end: Optional[int] = None
|
56 |
+
|
57 |
+
# Training configs
|
58 |
+
freeze_feature_model: bool = True
|
59 |
+
|
60 |
+
# Coloring training configs
|
61 |
+
coloring_training_noise_std: float = 0.0
|
62 |
+
coloring_sample_dir: Optional[str] = None
|
63 |
+
|
64 |
+
sample_mode: str = 'sample' # whether from noise or from some intermediate steps
|
65 |
+
sample_noise_step: int = 500 # add noise to GT up to some steps, and then denoise
|
66 |
+
sample_save_gt: bool = True
|
67 |
+
|
68 |
+
|
69 |
+
@dataclass
|
70 |
+
class LoggingConfig:
|
71 |
+
wandb: bool = True
|
72 |
+
wandb_project: str = 'pc2'
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
@dataclass
|
77 |
+
class PointCloudProjectionModelConfig:
|
78 |
+
# Feature extraction arguments
|
79 |
+
image_size: int = '${dataset.image_size}'
|
80 |
+
image_feature_model: str = 'vit_base_patch16_224_mae' # or 'vit_small_patch16_224_msn' or 'identity'
|
81 |
+
use_local_colors: bool = True
|
82 |
+
use_local_features: bool = True
|
83 |
+
use_global_features: bool = False
|
84 |
+
use_mask: bool = True
|
85 |
+
use_distance_transform: bool = True
|
86 |
+
|
87 |
+
# Point cloud data arguments. Note these are here because the processing happens
|
88 |
+
# inside the model, rather than inside the dataset.
|
89 |
+
scale_factor: float = "${dataset.scale_factor}"
|
90 |
+
colors_mean: float = 0.5
|
91 |
+
colors_std: float = 0.5
|
92 |
+
color_channels: int = 3
|
93 |
+
predict_shape: bool = True
|
94 |
+
predict_color: bool = False
|
95 |
+
|
96 |
+
# added by XH
|
97 |
+
load_sample_init: bool = False # load init samples from file
|
98 |
+
sample_init_scale: float = 1.0 # scale the initial pc samples
|
99 |
+
test_init_with_gtpc: bool = False # test time init samples with GT samples
|
100 |
+
consistent_center: bool = True # use consistent center prediction by CCD-3DR
|
101 |
+
voxel_resolution_multiplier: float = 1 # increase network voxel resolution
|
102 |
+
|
103 |
+
# predict binary segmentation
|
104 |
+
predict_binary: bool = False # True for stage 1 model, False for others
|
105 |
+
lw_binary: float = 3.0 # to have roughly the same magnitude of the binary segmentation loss
|
106 |
+
# for separate model
|
107 |
+
binary_training_noise_std: float = 0.1 # from github doc for predicting color
|
108 |
+
self_conditioning: bool = False
|
109 |
+
|
110 |
+
@dataclass
|
111 |
+
class PVCNNAEModelConfig(PointCloudProjectionModelConfig):
|
112 |
+
"my own model config, must inherit parent class"
|
113 |
+
model_name: str = 'pvcnn-ae'
|
114 |
+
latent_dim: int = 1024
|
115 |
+
num_dec_blocks: int = 6
|
116 |
+
block_dims: List[int] = field(default_factory=lambda: [512, 256])
|
117 |
+
num_points: int = 1500
|
118 |
+
bottleneck_dim: int = -1 # the input dim to the last MLP layer
|
119 |
+
|
120 |
+
@dataclass
|
121 |
+
class PointCloudDiffusionModelConfig(PointCloudProjectionModelConfig):
|
122 |
+
model_name: str = 'pc2-diff-ho' # default as behave
|
123 |
+
|
124 |
+
# Diffusion arguments
|
125 |
+
beta_start: float = 1e-5 # 0.00085
|
126 |
+
beta_end: float = 8e-3 # 0.012
|
127 |
+
beta_schedule: str = 'linear' # 'custom'
|
128 |
+
dm_pred_type: str = 'epsilon' # diffusion model prediction type, sample (x0) or noise
|
129 |
+
|
130 |
+
# Point cloud model arguments
|
131 |
+
point_cloud_model: str = 'pvcnn'
|
132 |
+
point_cloud_model_embed_dim: int = 64
|
133 |
+
|
134 |
+
dataset_type: str = '${dataset.type}'
|
135 |
+
|
136 |
+
@dataclass
|
137 |
+
class CrossAttnHOModelConfig(PointCloudDiffusionModelConfig):
|
138 |
+
model_name: str = 'diff-ho-attn'
|
139 |
+
|
140 |
+
attn_type: str = 'coord3d+posenc-learnable'
|
141 |
+
attn_weight: float = 1.0
|
142 |
+
point_visible_test: str = 'combine' # To compute point visibility: use all points or only human/object points
|
143 |
+
|
144 |
+
|
145 |
+
@dataclass
|
146 |
+
class DirectTransModelConfig(PointCloudProjectionModelConfig):
|
147 |
+
model_name: str = 'direct-transl-ho'
|
148 |
+
|
149 |
+
pooling: str = "avg"
|
150 |
+
act: str = 'gelu'
|
151 |
+
out_act: str = 'relu'
|
152 |
+
# feat_dims_transl: Iterable[Any] = (384, 256, 128, 6) # cannot use List[int] https://github.com/facebookresearch/hydra/issues/1752#issuecomment-893174197
|
153 |
+
# feat_dims_scale: Iterable[Any] = (384, 128, 64, 2)
|
154 |
+
feat_dims_transl: List[int] = field(default_factory=lambda: [384, 256, 128, 6])
|
155 |
+
feat_dims_scale: List[int] = field(default_factory=lambda: [384, 128, 64, 2])
|
156 |
+
lw_transl: float = 10000.0
|
157 |
+
lw_scale: float = 10000.0
|
158 |
+
|
159 |
+
|
160 |
+
@dataclass
|
161 |
+
class PointCloudColoringModelConfig(PointCloudProjectionModelConfig):
|
162 |
+
# Projection arguments
|
163 |
+
predict_shape: bool = False
|
164 |
+
predict_color: bool = True
|
165 |
+
|
166 |
+
# Point cloud model arguments
|
167 |
+
point_cloud_model: str = 'pvcnn'
|
168 |
+
point_cloud_model_layers: int = 1
|
169 |
+
point_cloud_model_embed_dim: int = 64
|
170 |
+
|
171 |
+
|
172 |
+
@dataclass
|
173 |
+
class DatasetConfig:
|
174 |
+
type: str
|
175 |
+
|
176 |
+
|
177 |
+
@dataclass
|
178 |
+
class PointCloudDatasetConfig(DatasetConfig):
|
179 |
+
eval_split: str = 'val'
|
180 |
+
max_points: int = 16_384
|
181 |
+
image_size: int = 224
|
182 |
+
scale_factor: float = 1.0
|
183 |
+
restrict_model_ids: Optional[List] = None # for only running on a subset of data points
|
184 |
+
|
185 |
+
|
186 |
+
@dataclass
|
187 |
+
class CO3DConfig(PointCloudDatasetConfig):
|
188 |
+
type: str = 'co3dv2'
|
189 |
+
# root: str = os.getenv('CO3DV2_DATASET_ROOT')
|
190 |
+
root: str = "/BS/xxie-2/work/co3d/hydrant"
|
191 |
+
category: str = 'hydrant'
|
192 |
+
subset_name: str = 'fewview_dev'
|
193 |
+
mask_images: bool = '${model.use_mask}'
|
194 |
+
|
195 |
+
|
196 |
+
@dataclass
|
197 |
+
class ShapeNetR2N2Config(PointCloudDatasetConfig):
|
198 |
+
# added by XH
|
199 |
+
fix_sample: bool = True
|
200 |
+
category: str = 'chair'
|
201 |
+
|
202 |
+
type: str = 'shapenet_r2n2'
|
203 |
+
root: str = "/BS/chiban2/work/data_shapenet/ShapeNetCore.v1"
|
204 |
+
r2n2_dir: str = "/BS/databases20/3d-r2n2"
|
205 |
+
shapenet_dir: str = "/BS/chiban2/work/data_shapenet/ShapeNetCore.v1"
|
206 |
+
preprocessed_r2n2_dir: str = "${dataset.root}/r2n2_preprocessed_renders"
|
207 |
+
splits_file: str = "${dataset.root}/r2n2_standard_splits_from_ShapeNet_taxonomy.json"
|
208 |
+
# splits_file: str = "${dataset.root}/pix2mesh_splits_val05.json" # <-- incorrect
|
209 |
+
scale_factor: float = 7.0
|
210 |
+
point_cloud_filename: str = 'pointcloud_r2n2.npz' # should use 'pointcloud_mesh.npz'
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
@dataclass
|
215 |
+
class BehaveDatasetConfig(PointCloudDatasetConfig):
|
216 |
+
# added by XH
|
217 |
+
type: str = 'behave'
|
218 |
+
|
219 |
+
fix_sample: bool = True
|
220 |
+
behave_dir: str = "/BS/xxie-5/static00/behave_release/sequences/"
|
221 |
+
split_file: str = "" # specify you dataset split file here
|
222 |
+
scale_factor: float = 7.0 # use the same as shapenet
|
223 |
+
sample_ratio_hum: float = 0.5
|
224 |
+
image_size: int = 224
|
225 |
+
|
226 |
+
normalize_type: str = 'comb'
|
227 |
+
smpl_type: str = 'gt' # use which SMPL mesh to obtain normalization parameters
|
228 |
+
test_transl_type: str = 'norm'
|
229 |
+
|
230 |
+
load_corr_points: bool = False # load autoencoder points for object and SMPL
|
231 |
+
uniform_obj_sample: bool = False
|
232 |
+
|
233 |
+
# configs for direct translation prediction
|
234 |
+
bkg_type: str = 'none'
|
235 |
+
bbox_params: str = 'none'
|
236 |
+
ho_segm_pred_path: Optional[str] = None
|
237 |
+
use_gt_transl: bool = False
|
238 |
+
|
239 |
+
cam_noise_std: float = 0. # add noise to the camera pose
|
240 |
+
sep_same_crop: bool = False # use same input image crop to separate models
|
241 |
+
aug_blur: float = 0. # blur augmentation
|
242 |
+
|
243 |
+
std_coverage: float=3.5 # a heuristic value to estimate translation
|
244 |
+
|
245 |
+
v2v_path: str = '' # object v2v corr path
|
246 |
+
|
247 |
+
@dataclass
|
248 |
+
class ShapeDatasetConfig(BehaveDatasetConfig):
|
249 |
+
"the dataset to train AE for aligned shapes"
|
250 |
+
type: str = 'shape'
|
251 |
+
fix_sample: bool = False
|
252 |
+
split_file: str = "/BS/xxie-2/work/pc2-diff/experiments/splits/shapes-chair.pkl"
|
253 |
+
|
254 |
+
|
255 |
+
# TODO
|
256 |
+
@dataclass
|
257 |
+
class ShapeNetNMRConfig(PointCloudDatasetConfig):
|
258 |
+
type: str = 'shapenet_nmr'
|
259 |
+
shapenet_nmr_dir: str = "/work/lukemk/machine-learning-datasets/3d-reconstruction/ShapeNet_NMR/NMR_Dataset"
|
260 |
+
synset_names: str = 'chair' # comma-separated or 'all'
|
261 |
+
augmentation: str = 'all'
|
262 |
+
scale_factor: float = 7.0
|
263 |
+
|
264 |
+
|
265 |
+
@dataclass
|
266 |
+
class AugmentationConfig:
|
267 |
+
# need to specify the variable type in order to define it properly
|
268 |
+
max_radius: int = 0 # generate a random square to mask object, this is the radius for the square in pixel size, zero means no occlusion
|
269 |
+
|
270 |
+
|
271 |
+
@dataclass
|
272 |
+
class DataloaderConfig:
|
273 |
+
# batch_size: int = 8 # 2 for debug
|
274 |
+
batch_size: int = 16
|
275 |
+
num_workers: int = 14 # 0 for debug # suggested by accelerator for gpu20
|
276 |
+
|
277 |
+
|
278 |
+
@dataclass
|
279 |
+
class LossConfig:
|
280 |
+
diffusion_weight: float = 1.0
|
281 |
+
rgb_weight: float = 1.0
|
282 |
+
consistency_weight: float = 1.0
|
283 |
+
|
284 |
+
|
285 |
+
@dataclass
|
286 |
+
class CheckpointConfig:
|
287 |
+
resume: Optional[str] = "test"
|
288 |
+
resume_training: bool = True
|
289 |
+
resume_training_optimizer: bool = True
|
290 |
+
resume_training_scheduler: bool = True
|
291 |
+
resume_training_state: bool = True
|
292 |
+
|
293 |
+
|
294 |
+
@dataclass
|
295 |
+
class ExponentialMovingAverageConfig:
|
296 |
+
use_ema: bool = False
|
297 |
+
# # From Diffusers EMA (should probably switch)
|
298 |
+
# ema_inv_gamma: float = 1.0
|
299 |
+
# ema_power: float = 0.75
|
300 |
+
# ema_max_decay: float = 0.9999
|
301 |
+
decay: float = 0.999
|
302 |
+
update_every: int = 20
|
303 |
+
|
304 |
+
|
305 |
+
@dataclass
|
306 |
+
class OptimizerConfig:
|
307 |
+
type: str
|
308 |
+
name: str
|
309 |
+
lr: float = 3e-4
|
310 |
+
weight_decay: float = 0.0
|
311 |
+
scale_learning_rate_with_batch_size: bool = False
|
312 |
+
gradient_accumulation_steps: int = 1
|
313 |
+
clip_grad_norm: Optional[float] = 50.0 # 5.0
|
314 |
+
kwargs: Dict = field(default_factory=lambda: dict())
|
315 |
+
|
316 |
+
|
317 |
+
@dataclass
|
318 |
+
class AdadeltaOptimizerConfig(OptimizerConfig):
|
319 |
+
type: str = 'torch'
|
320 |
+
name: str = 'Adadelta'
|
321 |
+
kwargs: Dict = field(default_factory=lambda: dict(
|
322 |
+
weight_decay=1e-6,
|
323 |
+
))
|
324 |
+
|
325 |
+
|
326 |
+
@dataclass
|
327 |
+
class AdamOptimizerConfig(OptimizerConfig):
|
328 |
+
type: str = 'torch'
|
329 |
+
name: str = 'AdamW'
|
330 |
+
weight_decay: float = 1e-6
|
331 |
+
kwargs: Dict = field(default_factory=lambda: dict(betas=(0.95, 0.999)))
|
332 |
+
|
333 |
+
|
334 |
+
@dataclass
|
335 |
+
class SchedulerConfig:
|
336 |
+
type: str
|
337 |
+
kwargs: Dict = field(default_factory=lambda: dict())
|
338 |
+
|
339 |
+
|
340 |
+
@dataclass
|
341 |
+
class LinearSchedulerConfig(SchedulerConfig):
|
342 |
+
type: str = 'transformers'
|
343 |
+
kwargs: Dict = field(default_factory=lambda: dict(
|
344 |
+
name='linear',
|
345 |
+
num_warmup_steps=0,
|
346 |
+
num_training_steps="${run.max_steps}",
|
347 |
+
))
|
348 |
+
|
349 |
+
|
350 |
+
@dataclass
|
351 |
+
class CosineSchedulerConfig(SchedulerConfig):
|
352 |
+
type: str = 'transformers'
|
353 |
+
kwargs: Dict = field(default_factory=lambda: dict(
|
354 |
+
name='cosine',
|
355 |
+
num_warmup_steps=2000, # 0
|
356 |
+
num_training_steps="${run.max_steps}",
|
357 |
+
))
|
358 |
+
|
359 |
+
|
360 |
+
@dataclass
|
361 |
+
class ProjectConfig:
|
362 |
+
run: RunConfig
|
363 |
+
logging: LoggingConfig
|
364 |
+
dataset: PointCloudDatasetConfig
|
365 |
+
augmentations: AugmentationConfig
|
366 |
+
dataloader: DataloaderConfig
|
367 |
+
loss: LossConfig
|
368 |
+
model: PointCloudProjectionModelConfig
|
369 |
+
ema: ExponentialMovingAverageConfig
|
370 |
+
checkpoint: CheckpointConfig
|
371 |
+
optimizer: OptimizerConfig
|
372 |
+
scheduler: SchedulerConfig
|
373 |
+
|
374 |
+
defaults: List[Any] = field(default_factory=lambda: [
|
375 |
+
'custom_hydra_run_dir',
|
376 |
+
{'run': 'default'},
|
377 |
+
{'logging': 'default'},
|
378 |
+
{'model': 'ho-attn'},
|
379 |
+
# {'dataset': 'co3d'},
|
380 |
+
{'dataset': 'behave'},
|
381 |
+
{'augmentations': 'default'},
|
382 |
+
{'dataloader': 'default'},
|
383 |
+
{'ema': 'default'},
|
384 |
+
{'loss': 'default'},
|
385 |
+
{'checkpoint': 'default'},
|
386 |
+
{'optimizer': 'adam'}, # default adamw
|
387 |
+
{'scheduler': 'linear'},
|
388 |
+
# {'scheduler': 'cosine'},
|
389 |
+
])
|
390 |
+
|
391 |
+
|
392 |
+
cs = ConfigStore.instance()
|
393 |
+
cs.store(name='custom_hydra_run_dir', node=CustomHydraRunDir, package="hydra.run")
|
394 |
+
cs.store(group='run', name='default', node=RunConfig)
|
395 |
+
cs.store(group='logging', name='default', node=LoggingConfig)
|
396 |
+
cs.store(group='model', name='diffrec', node=PointCloudDiffusionModelConfig)
|
397 |
+
cs.store(group='model', name='coloring_model', node=PointCloudColoringModelConfig)
|
398 |
+
cs.store(group='model', name='direct-transl', node=DirectTransModelConfig)
|
399 |
+
cs.store(group='model', name='ho-attn', node=CrossAttnHOModelConfig)
|
400 |
+
cs.store(group='model', name='pvcnn-ae', node=PVCNNAEModelConfig)
|
401 |
+
cs.store(group='dataset', name='co3d', node=CO3DConfig)
|
402 |
+
# TODO
|
403 |
+
cs.store(group='dataset', name='shapenet_r2n2', node=ShapeNetR2N2Config)
|
404 |
+
cs.store(group='dataset', name='behave', node=BehaveDatasetConfig)
|
405 |
+
cs.store(group='dataset', name='shape', node=ShapeDatasetConfig)
|
406 |
+
# cs.store(group='dataset', name='shapenet_nmr', node=ShapeNetNMRConfig)
|
407 |
+
cs.store(group='augmentations', name='default', node=AugmentationConfig)
|
408 |
+
cs.store(group='dataloader', name='default', node=DataloaderConfig)
|
409 |
+
cs.store(group='loss', name='default', node=LossConfig)
|
410 |
+
cs.store(group='ema', name='default', node=ExponentialMovingAverageConfig)
|
411 |
+
cs.store(group='checkpoint', name='default', node=CheckpointConfig)
|
412 |
+
cs.store(group='optimizer', name='adadelta', node=AdadeltaOptimizerConfig)
|
413 |
+
cs.store(group='optimizer', name='adam', node=AdamOptimizerConfig)
|
414 |
+
cs.store(group='scheduler', name='linear', node=LinearSchedulerConfig)
|
415 |
+
cs.store(group='scheduler', name='cosine', node=CosineSchedulerConfig)
|
416 |
+
cs.store(name='configs', node=ProjectConfig)
|
dataset/__init__.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytorch3d
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import SequentialSampler
|
8 |
+
from omegaconf import DictConfig
|
9 |
+
from pytorch3d.implicitron.dataset.data_loader_map_provider import \
|
10 |
+
SequenceDataLoaderMapProvider
|
11 |
+
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
12 |
+
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
13 |
+
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
|
14 |
+
JsonIndexDatasetMapProviderV2, registry)
|
15 |
+
from pytorch3d.implicitron.tools.config import expand_args_fields
|
16 |
+
from pytorch3d.renderer.cameras import CamerasBase
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
|
19 |
+
from configs.structured import CO3DConfig, DataloaderConfig, ProjectConfig, Optional
|
20 |
+
from .exclude_sequence import EXCLUDE_SEQUENCE, LOW_QUALITY_SEQUENCE
|
21 |
+
from .utils import DatasetMap
|
22 |
+
from .r2n2_my import R2N2Sample, collate_batched_meshes
|
23 |
+
|
24 |
+
|
25 |
+
def get_dataset(cfg: ProjectConfig):
|
26 |
+
|
27 |
+
if cfg.dataset.type == 'co3dv2':
|
28 |
+
dataset_cfg: CO3DConfig = cfg.dataset
|
29 |
+
dataloader_cfg: DataloaderConfig = cfg.dataloader
|
30 |
+
|
31 |
+
# Exclude bad and low-quality sequences, XH: why this is needed?
|
32 |
+
exclude_sequence = []
|
33 |
+
exclude_sequence.extend(EXCLUDE_SEQUENCE.get(dataset_cfg.category, []))
|
34 |
+
exclude_sequence.extend(LOW_QUALITY_SEQUENCE.get(dataset_cfg.category, []))
|
35 |
+
|
36 |
+
# Whether to load pointclouds
|
37 |
+
kwargs = dict(
|
38 |
+
remove_empty_masks=True,
|
39 |
+
n_frames_per_sequence=1,
|
40 |
+
load_point_clouds=True,
|
41 |
+
max_points=dataset_cfg.max_points,
|
42 |
+
image_height=dataset_cfg.image_size,
|
43 |
+
image_width=dataset_cfg.image_size,
|
44 |
+
mask_images=dataset_cfg.mask_images,
|
45 |
+
exclude_sequence=exclude_sequence,
|
46 |
+
pick_sequence=() if dataset_cfg.restrict_model_ids is None else dataset_cfg.restrict_model_ids,
|
47 |
+
)
|
48 |
+
|
49 |
+
# Get dataset mapper
|
50 |
+
dataset_map_provider_type = registry.get(JsonIndexDatasetMapProviderV2, "JsonIndexDatasetMapProviderV2")
|
51 |
+
expand_args_fields(dataset_map_provider_type)
|
52 |
+
dataset_map_provider = dataset_map_provider_type(
|
53 |
+
category=dataset_cfg.category,
|
54 |
+
subset_name=dataset_cfg.subset_name,
|
55 |
+
dataset_root=dataset_cfg.root,
|
56 |
+
test_on_train=False,
|
57 |
+
only_test_set=False,
|
58 |
+
load_eval_batches=True,
|
59 |
+
dataset_JsonIndexDataset_args=DictConfig(kwargs),
|
60 |
+
)
|
61 |
+
|
62 |
+
# Get datasets
|
63 |
+
datasets = dataset_map_provider.get_dataset_map() # how to select specific frames??
|
64 |
+
|
65 |
+
# PATCH BUG WITH POINT CLOUD LOCATIONS!
|
66 |
+
for dataset in (datasets["train"], datasets["val"]):
|
67 |
+
# print(dataset.seq_annots.items())
|
68 |
+
for key, ann in dataset.seq_annots.items():
|
69 |
+
correct_point_cloud_path = Path(dataset.dataset_root) / Path(*Path(ann.point_cloud.path).parts[-3:])
|
70 |
+
assert correct_point_cloud_path.is_file(), correct_point_cloud_path
|
71 |
+
ann.point_cloud.path = str(correct_point_cloud_path)
|
72 |
+
|
73 |
+
# Get dataloader mapper
|
74 |
+
data_loader_map_provider_type = registry.get(SequenceDataLoaderMapProvider, "SequenceDataLoaderMapProvider")
|
75 |
+
expand_args_fields(data_loader_map_provider_type)
|
76 |
+
data_loader_map_provider = data_loader_map_provider_type(
|
77 |
+
batch_size=dataloader_cfg.batch_size,
|
78 |
+
num_workers=dataloader_cfg.num_workers,
|
79 |
+
)
|
80 |
+
|
81 |
+
# QUICK HACK: Patch the train dataset because it is not used but it throws an error
|
82 |
+
if (len(datasets['train']) == 0 and len(datasets[dataset_cfg.eval_split]) > 0 and
|
83 |
+
dataset_cfg.restrict_model_ids is not None and cfg.run.job == 'sample'):
|
84 |
+
datasets = DatasetMap(train=datasets[dataset_cfg.eval_split], val=datasets[dataset_cfg.eval_split],
|
85 |
+
test=datasets[dataset_cfg.eval_split])
|
86 |
+
# XH: why all eval split?
|
87 |
+
print('Note: You used restrict_model_ids and there were no ids in the train set.')
|
88 |
+
|
89 |
+
# Get dataloaders
|
90 |
+
dataloaders = data_loader_map_provider.get_data_loader_map(datasets)
|
91 |
+
dataloader_train = dataloaders['train']
|
92 |
+
dataloader_val = dataloader_vis = dataloaders[dataset_cfg.eval_split]
|
93 |
+
|
94 |
+
# Replace validation dataloader sampler with SequentialSampler
|
95 |
+
# seems to be randomly sampled? with a fixed random seed? but one cannot control which image is being sampled??
|
96 |
+
dataloader_val.batch_sampler.sampler = SequentialSampler(dataloader_val.batch_sampler.sampler.data_source)
|
97 |
+
|
98 |
+
# Modify for accelerate
|
99 |
+
dataloader_train.batch_sampler.drop_last = True
|
100 |
+
dataloader_val.batch_sampler.drop_last = False
|
101 |
+
elif cfg.dataset.type == 'shapenet_r2n2':
|
102 |
+
# from ..configs.structured import ShapeNetR2N2Config
|
103 |
+
dataset_cfg: ShapeNetR2N2Config = cfg.dataset
|
104 |
+
# for k in dataset_cfg:
|
105 |
+
# print(k)
|
106 |
+
datasets = [R2N2Sample(dataset_cfg.max_points, dataset_cfg.fix_sample,
|
107 |
+
dataset_cfg.image_size, cfg.augmentations,
|
108 |
+
s, dataset_cfg.shapenet_dir,
|
109 |
+
dataset_cfg.r2n2_dir, dataset_cfg.splits_file,
|
110 |
+
load_textures=False, return_all_views=True) for s in ['train', 'val', 'test']]
|
111 |
+
dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size,
|
112 |
+
collate_fn=collate_batched_meshes,
|
113 |
+
num_workers=cfg.dataloader.num_workers, shuffle=True)
|
114 |
+
dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size,
|
115 |
+
collate_fn=collate_batched_meshes,
|
116 |
+
num_workers=cfg.dataloader.num_workers, shuffle=False)
|
117 |
+
dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size,
|
118 |
+
collate_fn=collate_batched_meshes,
|
119 |
+
num_workers=cfg.dataloader.num_workers, shuffle=False)
|
120 |
+
|
121 |
+
elif cfg.dataset.type in ['behave', 'behave-objonly', 'behave-humonly', 'behave-dtransl',
|
122 |
+
'behave-objonly-segm', 'behave-humonly-segm', 'behave-attn',
|
123 |
+
'behave-test', 'behave-attn-test', 'behave-hum-pe', 'behave-hum-noscale',
|
124 |
+
'behave-hum-surf', 'behave-objv2v']:
|
125 |
+
from .behave_dataset import BehaveDataset, NTUDataset, BehaveObjOnly, BehaveHumanOnly, BehaveHumanOnlyPosEnc
|
126 |
+
from .behave_dataset import BehaveHumanOnlySegmInput, BehaveObjOnlySegmInput, BehaveTestOnly, BehaveHumNoscale
|
127 |
+
from .behave_dataset import BehaveHumanOnlySurfSample
|
128 |
+
from .dtransl_dataset import DirectTranslDataset
|
129 |
+
from .behave_paths import DataPaths
|
130 |
+
from configs.structured import BehaveDatasetConfig
|
131 |
+
from .behave_crossattn import BehaveCrossAttnDataset, BehaveCrossAttnTest
|
132 |
+
from .behave_dataset import BehaveObjOnlyV2V
|
133 |
+
|
134 |
+
dataset_cfg: BehaveDatasetConfig = cfg.dataset
|
135 |
+
# print(dataset_cfg.behave_dir)
|
136 |
+
train_paths, val_paths = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir)
|
137 |
+
# exit(0)
|
138 |
+
|
139 |
+
# split validation paths to only consider the selected batches
|
140 |
+
bs = cfg.dataloader.batch_size
|
141 |
+
num_batches_total = int(np.ceil(len(val_paths)/cfg.dataloader.batch_size))
|
142 |
+
end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total
|
143 |
+
# print(cfg.run.batch_end, cfg.run.batch_start, end_idx)
|
144 |
+
val_paths = val_paths[cfg.run.batch_start*bs:end_idx*bs]
|
145 |
+
|
146 |
+
if cfg.dataset.type == 'behave':
|
147 |
+
train_type = BehaveDataset
|
148 |
+
val_datatype = BehaveDataset if 'ntu' not in dataset_cfg.split_file else NTUDataset
|
149 |
+
elif cfg.dataset.type == 'behave-test':
|
150 |
+
train_type = BehaveDataset
|
151 |
+
val_datatype = BehaveTestOnly
|
152 |
+
elif cfg.dataset.type == 'behave-objonly':
|
153 |
+
train_type = BehaveObjOnly
|
154 |
+
val_datatype = BehaveObjOnly
|
155 |
+
assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
|
156 |
+
elif cfg.dataset.type == 'behave-humonly':
|
157 |
+
train_type = BehaveHumanOnly
|
158 |
+
val_datatype = BehaveHumanOnly
|
159 |
+
assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
|
160 |
+
elif cfg.dataset.type == 'behave-hum-noscale':
|
161 |
+
train_type = BehaveHumNoscale
|
162 |
+
val_datatype = BehaveHumNoscale
|
163 |
+
elif cfg.dataset.type == 'behave-hum-pe':
|
164 |
+
train_type = BehaveHumanOnlyPosEnc
|
165 |
+
val_datatype = BehaveHumanOnlyPosEnc
|
166 |
+
elif cfg.dataset.type == 'behave-hum-surf':
|
167 |
+
train_type = BehaveHumanOnlySurfSample
|
168 |
+
val_datatype = BehaveHumanOnlySurfSample
|
169 |
+
elif cfg.dataset.type == 'behave-humonly-segm':
|
170 |
+
assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!'
|
171 |
+
train_type = BehaveHumanOnly
|
172 |
+
val_datatype = BehaveHumanOnlySegmInput
|
173 |
+
assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
|
174 |
+
elif cfg.dataset.type == 'behave-objonly-segm':
|
175 |
+
assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!'
|
176 |
+
train_type = BehaveObjOnly
|
177 |
+
val_datatype = BehaveObjOnlySegmInput
|
178 |
+
assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
|
179 |
+
elif cfg.dataset.type == 'behave-dtransl':
|
180 |
+
train_type = DirectTranslDataset
|
181 |
+
val_datatype = DirectTranslDataset
|
182 |
+
elif cfg.dataset.type == 'behave-attn':
|
183 |
+
train_type = BehaveCrossAttnDataset
|
184 |
+
val_datatype = BehaveCrossAttnDataset
|
185 |
+
elif cfg.dataset.type == 'behave-attn-test':
|
186 |
+
train_type = BehaveCrossAttnDataset
|
187 |
+
val_datatype = BehaveCrossAttnTest
|
188 |
+
elif cfg.dataset.type == 'behave-objv2v':
|
189 |
+
train_type = BehaveObjOnlyV2V
|
190 |
+
val_datatype = BehaveObjOnlyV2V
|
191 |
+
else:
|
192 |
+
raise NotImplementedError
|
193 |
+
|
194 |
+
dataset_train = train_type(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
|
195 |
+
(dataset_cfg.image_size, dataset_cfg.image_size),
|
196 |
+
split='train', sample_ratio_hum=dataset_cfg.sample_ratio_hum,
|
197 |
+
normalize_type=dataset_cfg.normalize_type, smpl_type='gt',
|
198 |
+
load_corr_points=dataset_cfg.load_corr_points,
|
199 |
+
uniform_obj_sample=dataset_cfg.uniform_obj_sample,
|
200 |
+
bkg_type=dataset_cfg.bkg_type,
|
201 |
+
bbox_params=dataset_cfg.bbox_params,
|
202 |
+
pred_binary=cfg.model.predict_binary,
|
203 |
+
ho_segm_pred_path=cfg.dataset.ho_segm_pred_path,
|
204 |
+
compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss',
|
205 |
+
use_gt_transl=cfg.dataset.use_gt_transl,
|
206 |
+
cam_noise_std=cfg.dataset.cam_noise_std,
|
207 |
+
sep_same_crop=cfg.dataset.sep_same_crop,
|
208 |
+
aug_blur=cfg.dataset.aug_blur,
|
209 |
+
std_coverage=cfg.dataset.std_coverage,
|
210 |
+
v2v_path=cfg.dataset.v2v_path)
|
211 |
+
|
212 |
+
dataset_val = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
|
213 |
+
(dataset_cfg.image_size, dataset_cfg.image_size),
|
214 |
+
split='val', sample_ratio_hum=dataset_cfg.sample_ratio_hum,
|
215 |
+
normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type,
|
216 |
+
load_corr_points=dataset_cfg.load_corr_points,
|
217 |
+
test_transl_type=dataset_cfg.test_transl_type,
|
218 |
+
uniform_obj_sample=dataset_cfg.uniform_obj_sample,
|
219 |
+
bkg_type=dataset_cfg.bkg_type,
|
220 |
+
bbox_params=dataset_cfg.bbox_params,
|
221 |
+
pred_binary=cfg.model.predict_binary,
|
222 |
+
ho_segm_pred_path=cfg.dataset.ho_segm_pred_path,
|
223 |
+
compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss',
|
224 |
+
use_gt_transl=cfg.dataset.use_gt_transl,
|
225 |
+
sep_same_crop=cfg.dataset.sep_same_crop,
|
226 |
+
std_coverage=cfg.dataset.std_coverage,
|
227 |
+
v2v_path=cfg.dataset.v2v_path)
|
228 |
+
# dataset_test = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
|
229 |
+
# (dataset_cfg.image_size, dataset_cfg.image_size),
|
230 |
+
# split='test', sample_ratio_hum=dataset_cfg.sample_ratio_hum,
|
231 |
+
# normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type,
|
232 |
+
# load_corr_points=dataset_cfg.load_corr_points,
|
233 |
+
# test_transl_type=dataset_cfg.test_transl_type,
|
234 |
+
# uniform_obj_sample=dataset_cfg.uniform_obj_sample,
|
235 |
+
# bkg_type=dataset_cfg.bkg_type,
|
236 |
+
# bbox_params=dataset_cfg.bbox_params,
|
237 |
+
# pred_binary=cfg.model.predict_binary,
|
238 |
+
# ho_segm_pred_path=cfg.dataset.ho_segm_pred_path,
|
239 |
+
# compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss',
|
240 |
+
# use_gt_transl=cfg.dataset.use_gt_transl,
|
241 |
+
# sep_same_crop=cfg.dataset.sep_same_crop)
|
242 |
+
dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size,
|
243 |
+
collate_fn=collate_batched_meshes,
|
244 |
+
num_workers=cfg.dataloader.num_workers, shuffle=True)
|
245 |
+
shuffle = cfg.run.job == 'train'
|
246 |
+
dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
|
247 |
+
collate_fn=collate_batched_meshes,
|
248 |
+
num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
|
249 |
+
dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
|
250 |
+
collate_fn=collate_batched_meshes,
|
251 |
+
num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
|
252 |
+
|
253 |
+
# datasets = [BehaveDataset(p, dataset_cfg.max_points, dataset_cfg.fix_sample,
|
254 |
+
# (dataset_cfg.image_size, dataset_cfg.image_size),
|
255 |
+
# split=s, sample_ratio_hum=dataset_cfg.sample_ratio_hum,
|
256 |
+
# normalize_type=dataset_cfg.normalize_type) for p, s in zip([train_paths, val_paths, val_paths],
|
257 |
+
# ['train', 'val', 'test'])]
|
258 |
+
# dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size,
|
259 |
+
# collate_fn=collate_batched_meshes,
|
260 |
+
# num_workers=cfg.dataloader.num_workers, shuffle=True)
|
261 |
+
# dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size,
|
262 |
+
# collate_fn=collate_batched_meshes,
|
263 |
+
# num_workers=cfg.dataloader.num_workers, shuffle=False)
|
264 |
+
# dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size,
|
265 |
+
# collate_fn=collate_batched_meshes,
|
266 |
+
# num_workers=cfg.dataloader.num_workers, shuffle=False)
|
267 |
+
elif cfg.dataset.type in ['shape']:
|
268 |
+
from .shape_dataset import ShapeDataset
|
269 |
+
from .behave_paths import DataPaths
|
270 |
+
from configs.structured import ShapeDatasetConfig
|
271 |
+
dataset_cfg: ShapeDatasetConfig = cfg.dataset
|
272 |
+
|
273 |
+
train_paths, _ = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir)
|
274 |
+
val_paths = train_paths # same as training, this is for overfitting
|
275 |
+
# split validation paths to only consider the selected batches
|
276 |
+
bs = cfg.dataloader.batch_size
|
277 |
+
num_batches_total = int(np.ceil(len(val_paths) / cfg.dataloader.batch_size))
|
278 |
+
end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total
|
279 |
+
# print(cfg.run.batch_end, cfg.run.batch_start, end_idx)
|
280 |
+
val_paths = val_paths[cfg.run.batch_start * bs:end_idx * bs]
|
281 |
+
|
282 |
+
dataset_train = ShapeDataset(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
|
283 |
+
(dataset_cfg.image_size, dataset_cfg.image_size),
|
284 |
+
split='train', )
|
285 |
+
dataset_val = ShapeDataset(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
|
286 |
+
(dataset_cfg.image_size, dataset_cfg.image_size),
|
287 |
+
split='train', )
|
288 |
+
dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size,
|
289 |
+
collate_fn=collate_batched_meshes,
|
290 |
+
num_workers=cfg.dataloader.num_workers, shuffle=True)
|
291 |
+
shuffle = cfg.run.job == 'train'
|
292 |
+
dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
|
293 |
+
collate_fn=collate_batched_meshes,
|
294 |
+
num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
|
295 |
+
dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
|
296 |
+
collate_fn=collate_batched_meshes,
|
297 |
+
num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
|
298 |
+
else:
|
299 |
+
raise NotImplementedError(cfg.dataset.type)
|
300 |
+
|
301 |
+
return dataloader_train, dataloader_val, dataloader_vis
|
dataset/base_data.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path as osp
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
from dataset.img_utils import masks2bbox, resize, crop
|
8 |
+
|
9 |
+
|
10 |
+
class BaseDataset(Dataset):
|
11 |
+
def __init__(self, data_paths, input_size=(224, 224)):
|
12 |
+
self.data_paths = data_paths # RGB image files
|
13 |
+
self.input_size = input_size
|
14 |
+
opencv2py3d = np.eye(4)
|
15 |
+
opencv2py3d[0, 0] = opencv2py3d[1, 1] = -1
|
16 |
+
self.opencv2py3d = opencv2py3d
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
return len(self.data_paths)
|
20 |
+
|
21 |
+
def load_masks(self, rgb_file):
|
22 |
+
person_mask_file = rgb_file.replace('.color.jpg', ".person_mask.png")
|
23 |
+
if not osp.isfile(person_mask_file):
|
24 |
+
person_mask_file = rgb_file.replace('.color.jpg', ".person_mask.jpg")
|
25 |
+
obj_mask_file = None
|
26 |
+
for pat in [".obj_rend_mask.png", ".obj_rend_mask.jpg", ".obj_mask.png", ".obj_mask.jpg", ".object_rend.png"]:
|
27 |
+
obj_mask_file = rgb_file.replace('.color.jpg', pat)
|
28 |
+
if osp.isfile(obj_mask_file):
|
29 |
+
break
|
30 |
+
person_mask = cv2.imread(person_mask_file, cv2.IMREAD_GRAYSCALE)
|
31 |
+
obj_mask = cv2.imread(obj_mask_file, cv2.IMREAD_GRAYSCALE)
|
32 |
+
|
33 |
+
return person_mask, obj_mask
|
34 |
+
|
35 |
+
def get_crop_params(self, mask_hum, mask_obj, bbox_exp=1.0):
|
36 |
+
"compute bounding box based on masks"
|
37 |
+
bmin, bmax = masks2bbox([mask_hum, mask_obj])
|
38 |
+
crop_center = (bmin + bmax) // 2
|
39 |
+
# crop_size = np.max(bmax - bmin)
|
40 |
+
crop_size = int(np.max(bmax - bmin) * bbox_exp)
|
41 |
+
if crop_size % 2 == 1:
|
42 |
+
crop_size += 1 # make sure it is an even number
|
43 |
+
return bmax, bmin, crop_center, crop_size
|
44 |
+
|
45 |
+
def is_behave_dataset(self, image_width):
|
46 |
+
assert image_width in [2048, 1920, 1024, 960], f'unknwon image width {image_width}!'
|
47 |
+
if image_width in [2048, 1024]:
|
48 |
+
is_behave = True
|
49 |
+
else:
|
50 |
+
is_behave = False
|
51 |
+
return is_behave
|
52 |
+
|
53 |
+
def compute_K_roi(self, bbox_square,
|
54 |
+
image_width=2048,
|
55 |
+
image_height=1536,
|
56 |
+
fx=979.7844, fy=979.840,
|
57 |
+
cx=1018.952, cy=779.486):
|
58 |
+
"return results in ndc coordinate, this is correct!!!"
|
59 |
+
x, y, b, w = bbox_square
|
60 |
+
assert b == w
|
61 |
+
is_behave = self.is_behave_dataset(image_width)
|
62 |
+
|
63 |
+
if is_behave:
|
64 |
+
assert image_height / image_width == 0.75, f"invalid image aspect ratio: width={image_width}, height={image_height}"
|
65 |
+
# the image might be rendered at different size
|
66 |
+
ratio = image_width/2048.
|
67 |
+
fx, fy = 979.7844*ratio, 979.840*ratio
|
68 |
+
cx, cy = 1018.952*ratio, 779.486*ratio
|
69 |
+
else:
|
70 |
+
assert image_height / image_width == 9/16, f"invalid image aspect ratio: width={image_width}, height={image_height}"
|
71 |
+
# intercap camera
|
72 |
+
ratio = image_width/1920
|
73 |
+
fx, fy = 918.457763671875*ratio, 918.4373779296875*ratio
|
74 |
+
cx, cy = 956.9661865234375*ratio, 555.944580078125*ratio
|
75 |
+
|
76 |
+
cx, cy = cx - x, cy - y
|
77 |
+
scale = b/2.
|
78 |
+
# in ndc
|
79 |
+
cx_ = (scale - cx)/scale
|
80 |
+
cy_ = (scale - cy)/scale
|
81 |
+
fx_ = fx/scale
|
82 |
+
fy_ = fy/scale
|
83 |
+
|
84 |
+
K_roi = np.array([
|
85 |
+
[fx_, 0, cx_, 0],
|
86 |
+
[0., fy_, cy_, 0, ],
|
87 |
+
[0, 0, 0, 1.],
|
88 |
+
[0, 0, 1, 0]
|
89 |
+
])
|
90 |
+
return K_roi
|
91 |
+
|
92 |
+
def crop_full_image(self, mask_hum, mask_obj, rgb_full, crop_masks, bbox_exp=1.0):
|
93 |
+
"""
|
94 |
+
crop the image based on the given masks
|
95 |
+
:param mask_hum:
|
96 |
+
:param mask_obj:
|
97 |
+
:param rgb_full:
|
98 |
+
:param crop_masks: a list of masks used to do the crop
|
99 |
+
:return: Kroi, cropped human, object mask and RGB images (background masked out).
|
100 |
+
"""
|
101 |
+
bmax, bmin, crop_center, crop_size = self.get_crop_params(*crop_masks, bbox_exp)
|
102 |
+
rgb = resize(crop(rgb_full, crop_center, crop_size), self.input_size) / 255.
|
103 |
+
person_mask = resize(crop(mask_hum, crop_center, crop_size), self.input_size) / 255.
|
104 |
+
obj_mask = resize(crop(mask_obj, crop_center, crop_size), self.input_size) / 255.
|
105 |
+
xywh = np.concatenate([crop_center - crop_size // 2, np.array([crop_size, crop_size])])
|
106 |
+
Kroi = self.compute_K_roi(xywh, rgb_full.shape[1], rgb_full.shape[0])
|
107 |
+
# mask bkg out
|
108 |
+
mask_comb = (person_mask > 0.5) | (obj_mask > 0.5)
|
109 |
+
rgb = rgb * np.expand_dims(mask_comb, -1)
|
110 |
+
return Kroi, obj_mask, person_mask, rgb
|
dataset/behave_paths.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os, re
|
3 |
+
import pickle as pkl
|
4 |
+
from os.path import join, basename, dirname, isfile
|
5 |
+
import os.path as osp
|
6 |
+
|
7 |
+
import cv2, json
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
# PROCESSED_PATH = paths['PROCESSED_PATH']
|
11 |
+
BEHAVE_PATH = "/BS/xxie-5/static00/behave_release/sequences/"
|
12 |
+
RECON_PATH = "/BS/xxie-5/static00/behave-train"
|
13 |
+
|
14 |
+
class DataPaths:
|
15 |
+
"""
|
16 |
+
class to handle path operations based on BEHAVE dataset structure
|
17 |
+
"""
|
18 |
+
def __init__(self):
|
19 |
+
pass
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def load_splits(split_file, dataset_path=None):
|
23 |
+
assert os.path.exists(dataset_path), f'the given dataset path {dataset_path} does not exist, please check if your training data are placed over there!'
|
24 |
+
train, val = DataPaths.get_train_test_from_pkl(split_file)
|
25 |
+
return train, val
|
26 |
+
# print(train[:5], val[:5])
|
27 |
+
if isinstance(train[0], list):
|
28 |
+
# video data
|
29 |
+
train_full = [[join(dataset_path, seq[x]) for x in range(len(seq))] for seq in train]
|
30 |
+
val_full = [[join(dataset_path, seq[x]) for x in range(len(seq))] for seq in val]
|
31 |
+
else:
|
32 |
+
train_full = [join(dataset_path, x) for x in train] # full path to the training data
|
33 |
+
val_full = [join(dataset_path, x) for x in val] # full path to the validation data files
|
34 |
+
# print(train_full[:5], val_full[:5])
|
35 |
+
return train_full, val_full
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def load_splits_online(split_file, dataset_path=BEHAVE_PATH):
|
39 |
+
"load rgb file, smpl and object mesh paths"
|
40 |
+
keys = ['rgb', 'smpl', 'obj']
|
41 |
+
types = ['train', 'val']
|
42 |
+
splits = {}
|
43 |
+
data = pkl.load(open(split_file, 'rb'))
|
44 |
+
for type in types:
|
45 |
+
for key in keys:
|
46 |
+
k = f'{type}_{key}'
|
47 |
+
splits[k] = [join(dataset_path, x) for x in data[k]]
|
48 |
+
return splits
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def get_train_test_from_pkl(pkl_file):
|
52 |
+
data = pkl.load(open(pkl_file, 'rb'))
|
53 |
+
return data['train'], data['test']
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def get_image_paths_seq(seq, tid=1, check_occlusion=False, pat='t*.000'):
|
57 |
+
"""
|
58 |
+
find all image paths in one sequence
|
59 |
+
:param seq: path to one behave sequence
|
60 |
+
:param tid: test on images from which camera
|
61 |
+
:param check_occlusion: whether to load full object mask and check occlusion ratio
|
62 |
+
:return: a list of paths to test image files
|
63 |
+
"""
|
64 |
+
image_files = sorted(glob.glob(seq + f"/{pat}/k{tid}.color.jpg"))
|
65 |
+
# print(image_files, seq + f"/{pat}/k{tid}.color.jpg")
|
66 |
+
if not check_occlusion:
|
67 |
+
return image_files
|
68 |
+
# check object occlusion ratio
|
69 |
+
valid_files = []
|
70 |
+
count = 0
|
71 |
+
for img_file in image_files:
|
72 |
+
mask_file = img_file.replace('.color.jpg', '.obj_rend_mask.png')
|
73 |
+
if not os.path.isfile(mask_file):
|
74 |
+
mask_file = img_file.replace('.color.jpg', '.obj_rend_mask.jpg')
|
75 |
+
full_mask_file = img_file.replace('.color.jpg', '.obj_rend_full.png')
|
76 |
+
if not os.path.isfile(full_mask_file):
|
77 |
+
full_mask_file = img_file.replace('.color.jpg', '.obj_rend_full.jpg')
|
78 |
+
if not isfile(mask_file) or not isfile(full_mask_file):
|
79 |
+
continue
|
80 |
+
|
81 |
+
mask = np.sum(cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) > 127)
|
82 |
+
mask_full = np.sum(cv2.imread(full_mask_file, cv2.IMREAD_GRAYSCALE) > 127)
|
83 |
+
if mask_full == 0:
|
84 |
+
count += 1
|
85 |
+
continue
|
86 |
+
|
87 |
+
ratio = mask / mask_full
|
88 |
+
if ratio > 0.3:
|
89 |
+
valid_files.append(img_file)
|
90 |
+
else:
|
91 |
+
count += 1
|
92 |
+
print(f'{mask_file} occluded by {1 - ratio}!')
|
93 |
+
return valid_files
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def get_kinect_id(rgb_file):
|
97 |
+
"extract kinect id from the rgb file"
|
98 |
+
filename = osp.basename(rgb_file)
|
99 |
+
try:
|
100 |
+
kid = int(filename.split('.')[0][1])
|
101 |
+
assert kid in [0, 1, 2, 3, 4, 5], f'found invalid kinect id {kid} for file {rgb_file}'
|
102 |
+
return kid
|
103 |
+
except Exception as e:
|
104 |
+
print(rgb_file)
|
105 |
+
raise ValueError()
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def get_seq_date(rgb_file):
|
109 |
+
"date for the sequence"
|
110 |
+
seq_name = str(rgb_file).split(os.sep)[-3]
|
111 |
+
date = seq_name.split('_')[0]
|
112 |
+
assert date in ['Date01', 'Date02', 'Date03', 'Date04', 'Date05', 'Date06', 'Date07',
|
113 |
+
"ICapS01", "ICapS02", "ICapS03", "Date08", "Date09"], f"invalid date for {rgb_file}"
|
114 |
+
return date
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def rgb2obj_path(rgb_file:str, save_name='fit01-smooth'):
|
118 |
+
"convert an rgb file to a obj mesh file"
|
119 |
+
ss = rgb_file.split(os.sep)
|
120 |
+
seq_name = ss[-3]
|
121 |
+
obj_name = seq_name.split('_')[2]
|
122 |
+
real_name = obj_name
|
123 |
+
if 'chair' in obj_name:
|
124 |
+
real_name = 'chair'
|
125 |
+
if 'ball' in obj_name:
|
126 |
+
real_name = 'sports ball'
|
127 |
+
|
128 |
+
frame_folder = osp.dirname(rgb_file)
|
129 |
+
mesh_file = osp.join(frame_folder, real_name, save_name, f'{real_name}_fit.ply')
|
130 |
+
|
131 |
+
if not osp.isfile(mesh_file):
|
132 |
+
# synthetic data
|
133 |
+
mesh_file = osp.join(frame_folder, obj_name, save_name, f'{obj_name}_fit.ply')
|
134 |
+
return mesh_file
|
135 |
+
|
136 |
+
@staticmethod
|
137 |
+
def rgb2smpl_path(rgb_file:str, save_name='fit03'):
|
138 |
+
frame_folder = osp.dirname(rgb_file)
|
139 |
+
real_name = 'person'
|
140 |
+
mesh_file = osp.join(frame_folder, real_name, save_name, f'{real_name}_fit.ply')
|
141 |
+
return mesh_file
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def rgb2seq_frame(rgb_file:str):
|
145 |
+
"rgb file to seq_name, frame time"
|
146 |
+
ss = rgb_file.split(os.sep)
|
147 |
+
return ss[-3], ss[-2]
|
148 |
+
|
149 |
+
@staticmethod
|
150 |
+
def rgb2recon_folder(rgb_file, save_name, recon_path):
|
151 |
+
"convert rgb file to the subfolder"
|
152 |
+
dataset_path = osp.dirname(osp.dirname(osp.dirname(rgb_file)))
|
153 |
+
recon_folder = osp.join(osp.dirname(rgb_file.replace(dataset_path, recon_path)), save_name)
|
154 |
+
return recon_folder
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def get_seq_name(rgb_file):
|
158 |
+
return osp.basename(osp.dirname(osp.dirname(rgb_file)))
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def rgb2template_path(rgb_file):
|
162 |
+
"return the path to the object template"
|
163 |
+
from recon.opt_utils import get_template_path
|
164 |
+
# seq_name = DataPaths.get_seq_name(rgb_file)
|
165 |
+
# obj_name = seq_name.split('_')[2]
|
166 |
+
obj_name = DataPaths.rgb2object_name(rgb_file)
|
167 |
+
path = get_template_path(BEHAVE_PATH+"/../objects", obj_name)
|
168 |
+
return path
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def rgb2object_name(rgb_file):
|
172 |
+
seq_name = DataPaths.get_seq_name(rgb_file)
|
173 |
+
obj_name = seq_name.split('_')[2]
|
174 |
+
return obj_name
|
175 |
+
|
176 |
+
@staticmethod
|
177 |
+
def rgb2recon_frame(rgb_file, recon_path=RECON_PATH):
|
178 |
+
"return the frame folder in recon path"
|
179 |
+
ss = rgb_file.split(os.sep)
|
180 |
+
seq_name, frame = ss[-3], ss[-2]
|
181 |
+
return osp.join(recon_path, seq_name, frame)
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def rgb2gender(rgb_file):
|
185 |
+
"find the gender of this image"
|
186 |
+
seq_name = str(rgb_file).split(os.sep)[-3]
|
187 |
+
sub = seq_name.split('_')[1]
|
188 |
+
return _sub_gender[sub]
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
def get_dataset_root(rgb_file):
|
192 |
+
"return the root path to all sequences"
|
193 |
+
from pathlib import Path
|
194 |
+
path = Path(rgb_file)
|
195 |
+
return str(path.parents[2])
|
196 |
+
|
197 |
+
@staticmethod
|
198 |
+
def seqname2gender(seq_name:str):
|
199 |
+
sub = seq_name.split('_')[1]
|
200 |
+
return _sub_gender[sub]
|
201 |
+
|
202 |
+
ICAP_PATH = "/BS/xxie-6/static00/InterCap" # assume same root folder
|
203 |
+
date_seqs = {
|
204 |
+
"Date01": BEHAVE_PATH + "/Date01_Sub01_backpack_back",
|
205 |
+
"Date02": BEHAVE_PATH + "/Date02_Sub02_backpack_back",
|
206 |
+
"Date03": BEHAVE_PATH + "/Date03_Sub03_backpack_back",
|
207 |
+
"Date04": BEHAVE_PATH + "/Date04_Sub05_backpack",
|
208 |
+
"Date05": BEHAVE_PATH + "/Date05_Sub05_backpack",
|
209 |
+
"Date06": BEHAVE_PATH + "/Date06_Sub07_backpack_back",
|
210 |
+
"Date07": BEHAVE_PATH + "/Date07_Sub04_backpack_back",
|
211 |
+
# "Date08": "/BS/xxie-6/static00/synthesize/Date08_Subxx_chairwood_synzv2-02",
|
212 |
+
"Date08": "/BS/xxie-6/static00/synz-backup/Date08_Subxx_chairwood_synzv2-02",
|
213 |
+
"Date09": "/BS/xxie-6/static00/synthesize/Date09_Subxx_obj01_icap", # InterCap sequence synz
|
214 |
+
"ICapS01": ICAP_PATH + "/ICapS01_sub01_obj01_Seg_0",
|
215 |
+
"ICapS02": ICAP_PATH + "/ICapS02_sub01_obj08_Seg_0",
|
216 |
+
"ICapS03": ICAP_PATH + "/ICapS03_sub07_obj05_Seg_0",
|
217 |
+
}
|
218 |
+
|
219 |
+
_sub_gender = {
|
220 |
+
"Sub01": 'male',
|
221 |
+
"Sub02": 'male',
|
222 |
+
"Sub03": 'male',
|
223 |
+
"Sub04": 'male',
|
224 |
+
"Sub05": 'male',
|
225 |
+
"Sub06": 'female',
|
226 |
+
"Sub07": 'female',
|
227 |
+
"Sub08": 'female',
|
228 |
+
}
|
dataset/demo_dataset.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .base_data import BaseDataset
|
7 |
+
from .behave_paths import DataPaths
|
8 |
+
from .img_utils import compute_translation, masks2bbox, crop
|
9 |
+
|
10 |
+
|
11 |
+
def padTo_4x3(rgb, person_mask, obj_mask, aspect_ratio=0.75):
|
12 |
+
"""
|
13 |
+
pad images to have 4:3 aspect ratio
|
14 |
+
:param rgb: (H, W, 3)
|
15 |
+
:param person_mask:
|
16 |
+
:param obj_mask:
|
17 |
+
:return: all images at the given aspect ratio
|
18 |
+
"""
|
19 |
+
h, w = rgb.shape[:2]
|
20 |
+
if w > h * 1/aspect_ratio:
|
21 |
+
# pad top
|
22 |
+
h_4x3 = int(w * aspect_ratio)
|
23 |
+
pad_top = h_4x3 - h
|
24 |
+
rgb_pad = np.pad(rgb, ((pad_top, 0), (0, 0), (0, 0)))
|
25 |
+
person_mask = np.pad(person_mask, ((pad_top, 0), (0, 0))) if person_mask is not None else None
|
26 |
+
obj_mask = np.pad(obj_mask, ((pad_top, 0), (0, 0))) if obj_mask is not None else None
|
27 |
+
else:
|
28 |
+
# pad two side
|
29 |
+
w_new = np.lcm.reduce([h * 2, 16]) # least common multiplier
|
30 |
+
h_4x3 = int(w_new * aspect_ratio)
|
31 |
+
pad_top = h_4x3 - h
|
32 |
+
pad_left = (w_new - w) // 2
|
33 |
+
pad_right = w_new - w - pad_left
|
34 |
+
rgb_pad = np.pad(rgb, ((pad_top, 0), (pad_left, pad_right), (0, 0)))
|
35 |
+
obj_mask = np.pad(obj_mask, ((pad_top, 0), (pad_left, pad_right))) if obj_mask is not None else None
|
36 |
+
person_mask = np.pad(person_mask, ((pad_top, 0), (pad_left, pad_right))) if person_mask is not None else None
|
37 |
+
return rgb_pad, obj_mask, person_mask
|
38 |
+
|
39 |
+
|
40 |
+
def recrop_input(rgb, person_mask, obj_mask, dataset_name='behave'):
|
41 |
+
"recrop input images"
|
42 |
+
exp_ratio = 1.42
|
43 |
+
if dataset_name == 'behave':
|
44 |
+
mean_center = np.array([1008, 995]) # mean RGB image crop center
|
45 |
+
behave_size = (2048, 1536)
|
46 |
+
new_size = (int(750 * exp_ratio), int(exp_ratio * 750))
|
47 |
+
else:
|
48 |
+
mean_center = np.array([904, 668]) # mean RGB image crop center for bottle sequences of ICAP
|
49 |
+
behave_size = (1920, 1080)
|
50 |
+
new_size = (int(593.925 * exp_ratio), int(exp_ratio * 593.925)) # mean width of bottle sequences
|
51 |
+
aspect_ratio = behave_size[1] / behave_size[0]
|
52 |
+
pad_top = mean_center[1] - new_size[0] // 2
|
53 |
+
pad_bottom = behave_size[1] - (mean_center[1] + new_size[0] // 2)
|
54 |
+
pad_left = mean_center[0] - new_size[0] // 2
|
55 |
+
pad_right = behave_size[0] - (mean_center[0] + new_size[0] // 2)
|
56 |
+
|
57 |
+
# First resize to the same aspect ratio
|
58 |
+
if rgb.shape[0] / rgb.shape[1] != aspect_ratio:
|
59 |
+
rgb, obj_mask, person_mask = padTo_4x3(rgb, person_mask, obj_mask, aspect_ratio)
|
60 |
+
|
61 |
+
# Resize to the same size as behave image, to have a comparable pixel size
|
62 |
+
rgb = cv2.resize(rgb, behave_size)
|
63 |
+
mask_ps = cv2.resize(person_mask, behave_size)
|
64 |
+
mask_obj = cv2.resize(obj_mask, behave_size)
|
65 |
+
|
66 |
+
# Crop and resize the human + object patch
|
67 |
+
bmin, bmax = masks2bbox([mask_ps, mask_obj])
|
68 |
+
center = (bmin + bmax) // 2
|
69 |
+
crop_size = int(np.max(bmax - bmin) * exp_ratio) # larger crop to have background
|
70 |
+
img_crop = cv2.resize(crop(rgb, center, crop_size), new_size)
|
71 |
+
mask_ps = cv2.resize(crop(mask_ps, center, crop_size), new_size)
|
72 |
+
mask_obj = cv2.resize(crop(mask_obj, center, crop_size), new_size)
|
73 |
+
|
74 |
+
# Pad back to have same shape as behave image
|
75 |
+
img_full = np.pad(img_crop, [[pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
|
76 |
+
mask_ps_full = np.pad(mask_ps, [[pad_top, pad_bottom], [pad_left, pad_right]])
|
77 |
+
mask_obj_full = np.pad(mask_obj, [[pad_top, pad_bottom], [pad_left, pad_right]])
|
78 |
+
|
79 |
+
# Make sure the image shape is the same
|
80 |
+
if img_full.shape[:2] != behave_size[::-1]:
|
81 |
+
img_full = cv2.resize(img_full, behave_size)
|
82 |
+
mask_ps_full = cv2.resize(mask_ps_full, behave_size)
|
83 |
+
mask_obj_full = cv2.resize(mask_obj_full, behave_size)
|
84 |
+
return img_full, mask_ps_full, mask_obj_full
|
85 |
+
|
86 |
+
|
87 |
+
class DemoDataset(BaseDataset):
|
88 |
+
def __init__(self, data_paths, input_size=(224, 224),
|
89 |
+
std_coverage=3.5, # used to estimate camera translation
|
90 |
+
):
|
91 |
+
super().__init__(data_paths, input_size)
|
92 |
+
self.std_coverage = std_coverage
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return len(self.data_paths)
|
96 |
+
|
97 |
+
def __getitem__(self, idx):
|
98 |
+
rgb_file = self.data_paths[idx]
|
99 |
+
mask_hum, mask_obj = self.load_masks(rgb_file)
|
100 |
+
rgb_full = cv2.imread(rgb_file)[:, :, ::-1]
|
101 |
+
|
102 |
+
return self.image2dict(mask_hum, mask_obj, rgb_full, rgb_file)
|
103 |
+
|
104 |
+
def image2dict(self, mask_hum, mask_obj, rgb_full, rgb_file=None):
|
105 |
+
"do all the necessary preprocessing for images"
|
106 |
+
if rgb_full.shape[:2] != mask_obj.shape[:2]:
|
107 |
+
raise ValueError(f"The given object mask shape {mask_obj.shape[:2]} does not match the RGB image shape {rgb_full.shape[:2]}")
|
108 |
+
if rgb_full.shape[:2] != mask_hum.shape[:2]:
|
109 |
+
raise ValueError(f"The given human mask shape {mask_hum.shape[:2]} does not match the RGB image shape {rgb_full.shape[:2]}")
|
110 |
+
|
111 |
+
if rgb_full.shape[:2] not in [(1080, 1920), (1536, 2048)]:
|
112 |
+
# crop and resize the image to behave image size
|
113 |
+
print(f"Recropping the input image and masks for {rgb_file}")
|
114 |
+
rgb_full, mask_hum, mask_obj = recrop_input(rgb_full, mask_hum, mask_obj)
|
115 |
+
color_h, color_w = rgb_full.shape[:2]
|
116 |
+
# Input to the first stage model: human + object crop
|
117 |
+
Kroi, objmask_fullcrop, psmask_fullcrop, rgb_fullcrop = self.crop_full_image(mask_hum.copy(),
|
118 |
+
mask_obj.copy(),
|
119 |
+
rgb_full.copy(),
|
120 |
+
[mask_hum, mask_obj],
|
121 |
+
1.00)
|
122 |
+
# Input to the second stage model: human and object crops
|
123 |
+
Kroi_h, masko_hum, maskh_hum, rgb_hum = self.crop_full_image(mask_hum.copy(),
|
124 |
+
mask_obj.copy(),
|
125 |
+
rgb_full.copy(),
|
126 |
+
[mask_hum, mask_hum], 1.05)
|
127 |
+
Kroi_o, masko_obj, maskh_obj, rgb_obj = self.crop_full_image(mask_hum.copy(),
|
128 |
+
mask_obj.copy(),
|
129 |
+
rgb_full.copy(),
|
130 |
+
[mask_obj, mask_obj], 1.5)
|
131 |
+
# Estimate camera translation
|
132 |
+
cent_transform = np.eye(4) # the transform applied to the mesh that moves it back to kinect camera frame
|
133 |
+
bmin_ho, bmax_ho = masks2bbox([mask_hum, mask_obj])
|
134 |
+
crop_size_ho = int(np.max(bmax_ho - bmin_ho) * 1.0)
|
135 |
+
if crop_size_ho % 2 == 1:
|
136 |
+
crop_size_ho += 1 # make sure it is an even number
|
137 |
+
is_behave = self.is_behave_dataset(rgb_full.shape[1])
|
138 |
+
if rgb_full.shape[1] not in [2048, 1920]:
|
139 |
+
raise ValueError('the image is not normalized to BEHAVE or ICAP size!')
|
140 |
+
indices = np.indices(rgb_full.shape[:2])
|
141 |
+
if np.sum(mask_obj > 127) < 5:
|
142 |
+
raise ValueError(f'not enough object mask found for {rgb_file}')
|
143 |
+
pts_h = np.stack([indices[1][mask_hum > 127], indices[0][mask_hum > 127]], -1)
|
144 |
+
pts_o = np.stack([indices[1][mask_obj > 127], indices[0][mask_obj > 127]], -1)
|
145 |
+
proj_cent_est = (np.mean(pts_h, 0) + np.mean(pts_o, 0)) / 2. # heuristic to obtain 2d projection center
|
146 |
+
transl_estimate = compute_translation(proj_cent_est, crop_size_ho, is_behave, self.std_coverage)
|
147 |
+
cent_transform[:3, 3] = transl_estimate / 7.0
|
148 |
+
radius = 0.5 # don't do normalization anymore
|
149 |
+
cent = transl_estimate / 7.0
|
150 |
+
comb = np.matmul(self.opencv2py3d, cent_transform)
|
151 |
+
R = torch.from_numpy(comb[:3, :3]).float()
|
152 |
+
T = torch.from_numpy(comb[:3, 3]).float() / (radius * 2)
|
153 |
+
data_dict = {
|
154 |
+
"R": R,
|
155 |
+
"T": T,
|
156 |
+
"K": torch.from_numpy(Kroi).float(),
|
157 |
+
"T_ho": torch.from_numpy(cent).float(), # translation for H+O
|
158 |
+
"image_path": rgb_file,
|
159 |
+
"image_size_hw": torch.tensor(self.input_size),
|
160 |
+
"images": torch.from_numpy(rgb_fullcrop).float().permute(2, 0, 1),
|
161 |
+
"masks": torch.from_numpy(np.stack([psmask_fullcrop, objmask_fullcrop], 0)).float(),
|
162 |
+
'orig_image_size': torch.tensor([color_h, color_w]),
|
163 |
+
|
164 |
+
# Human input to stage 2
|
165 |
+
"images_hum": torch.from_numpy(rgb_hum).float().permute(2, 0, 1),
|
166 |
+
"masks_hum": torch.from_numpy(np.stack([maskh_hum, masko_hum], 0)).float(),
|
167 |
+
"K_hum": torch.from_numpy(Kroi_h).float(),
|
168 |
+
|
169 |
+
# Object input to stage 2
|
170 |
+
"images_obj": torch.from_numpy(rgb_obj).float().permute(2, 0, 1),
|
171 |
+
"masks_obj": torch.from_numpy(np.stack([maskh_obj, masko_obj], 0)).float(),
|
172 |
+
"K_obj": torch.from_numpy(Kroi_o).float(),
|
173 |
+
|
174 |
+
# some normalization parameters
|
175 |
+
"gt_trans": cent,
|
176 |
+
'radius': radius,
|
177 |
+
"estimated_trans": transl_estimate,
|
178 |
+
}
|
179 |
+
return data_dict
|
180 |
+
|
181 |
+
def image2batch(self, rgb, mask_hum, mask_obj):
|
182 |
+
"""
|
183 |
+
given input image, convert it into a batch object ready for model inference
|
184 |
+
:param rgb: (h, w, 3), np array
|
185 |
+
:param mask_hum: (h, w, 3), np array
|
186 |
+
:param mask_obj: (h, w, 3), np array
|
187 |
+
:return:
|
188 |
+
"""
|
189 |
+
mask_hum = np.mean(mask_hum, -1)
|
190 |
+
mask_obj = np.mean(mask_obj, -1)
|
191 |
+
|
192 |
+
data_dict = self.image2dict(mask_hum, mask_obj, rgb, 'input image')
|
193 |
+
# convert dict to list
|
194 |
+
new_dict = {k:[v] for k, v in data_dict.items()}
|
195 |
+
|
196 |
+
return new_dict
|
197 |
+
|
198 |
+
|
dataset/img_utils.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
common functions for image operations
|
3 |
+
"""
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def crop(img, center, crop_size):
|
10 |
+
"""
|
11 |
+
crop image around the given center, pad zeros for borders
|
12 |
+
:param img:
|
13 |
+
:param center: np array
|
14 |
+
:param crop_size: np array or a float size of the resulting crop
|
15 |
+
:return: a square crop around the center
|
16 |
+
"""
|
17 |
+
assert isinstance(img, np.ndarray)
|
18 |
+
h, w = img.shape[:2]
|
19 |
+
topleft = np.round(center - crop_size / 2).astype(int)
|
20 |
+
bottom_right = np.round(center + crop_size / 2).astype(int)
|
21 |
+
|
22 |
+
x1 = max(0, topleft[0])
|
23 |
+
y1 = max(0, topleft[1])
|
24 |
+
x2 = min(w - 1, bottom_right[0])
|
25 |
+
y2 = min(h - 1, bottom_right[1])
|
26 |
+
cropped = img[y1:y2, x1:x2]
|
27 |
+
|
28 |
+
p1 = max(0, -topleft[0]) # padding in x, top
|
29 |
+
p2 = max(0, -topleft[1]) # padding in y, top
|
30 |
+
p3 = max(0, bottom_right[0] - w + 1) # padding in x, bottom
|
31 |
+
p4 = max(0, bottom_right[1] - h + 1) # padding in y, bottom
|
32 |
+
|
33 |
+
dim = len(img.shape)
|
34 |
+
if dim == 3:
|
35 |
+
padded = np.pad(cropped, [[p2, p4], [p1, p3], [0, 0]])
|
36 |
+
elif dim == 2:
|
37 |
+
padded = np.pad(cropped, [[p2, p4], [p1, p3]])
|
38 |
+
else:
|
39 |
+
raise NotImplemented
|
40 |
+
return padded
|
41 |
+
|
42 |
+
|
43 |
+
def resize(img, img_size, mode=cv2.INTER_LINEAR):
|
44 |
+
"""
|
45 |
+
resize image to the input
|
46 |
+
:param img:
|
47 |
+
:param img_size: (width, height) of the target image size
|
48 |
+
:param mode:
|
49 |
+
:return:
|
50 |
+
"""
|
51 |
+
h, w = img.shape[:2]
|
52 |
+
load_ratio = 1.0 * w / h
|
53 |
+
netin_ratio = 1.0 * img_size[0] / img_size[1]
|
54 |
+
assert load_ratio == netin_ratio, "image aspect ration not matching, given image: {}, net input: {}".format(
|
55 |
+
img.shape, img_size)
|
56 |
+
resized = cv2.resize(img, img_size, interpolation=mode)
|
57 |
+
return resized
|
58 |
+
|
59 |
+
|
60 |
+
def masks2bbox(masks, threshold=127):
|
61 |
+
"""
|
62 |
+
|
63 |
+
:param masks:
|
64 |
+
:param threshold:
|
65 |
+
:return: bounding box corner coordinate
|
66 |
+
"""
|
67 |
+
mask_comb = np.zeros_like(masks[0], dtype=bool)
|
68 |
+
for m in masks:
|
69 |
+
mask_comb = mask_comb | (m > threshold)
|
70 |
+
|
71 |
+
yid, xid = np.where(mask_comb)
|
72 |
+
bmin = np.array([xid.min(), yid.min()])
|
73 |
+
bmax = np.array([xid.max(), yid.max()])
|
74 |
+
return bmin, bmax
|
75 |
+
|
76 |
+
|
77 |
+
def compute_translation(crop_center, crop_size, is_behave=True, std_coverage=3.5):
|
78 |
+
"""
|
79 |
+
solve for an optimal translation that project gaussian in origin to the crop
|
80 |
+
Parameters
|
81 |
+
----------
|
82 |
+
crop_center: (x, y) of the crop center
|
83 |
+
crop_size: float, the size of the square crop
|
84 |
+
std_coverage: which edge point should be projected back to the edge of the 2d crop
|
85 |
+
|
86 |
+
Returns
|
87 |
+
-------
|
88 |
+
the estimated translation
|
89 |
+
|
90 |
+
"""
|
91 |
+
x0, y0 = crop_center
|
92 |
+
x1, y1 = x0 + crop_size/2, y0
|
93 |
+
x2, y2 = x0 - crop_size/2, y0
|
94 |
+
x3, y3 = x0, y0 + crop_size/2.
|
95 |
+
# predefined kinect intrinsics
|
96 |
+
if is_behave:
|
97 |
+
fx = 979.7844
|
98 |
+
fy = 979.840
|
99 |
+
cx = 1018.952
|
100 |
+
cy = 779.486
|
101 |
+
else:
|
102 |
+
# intercap camera
|
103 |
+
fx, fy = 918.457763671875, 918.4373779296875
|
104 |
+
cx, cy = 956.9661865234375, 555.944580078125
|
105 |
+
|
106 |
+
# construct the matrix
|
107 |
+
# A = np.array([
|
108 |
+
# [fx, 0, cx-x0, cx-x0, 0, 0],
|
109 |
+
# [0, fy, cy-y0, cy-y0, 0, 0],
|
110 |
+
# [fx, 0, cx-x1, 0, cx-x1, 0],
|
111 |
+
# [0, fy, cy-y1, 0, cy-y1, 0],
|
112 |
+
# [fx, 0, cx-x2, 0, 0, cx-x2],
|
113 |
+
# [0, fy, cy-y2, 0, 0, cy-y2]
|
114 |
+
# ]) # this matrix is low-rank because columns are linearly dependent: col3 - col4 = col5 + col6
|
115 |
+
# # find linearly dependent rows
|
116 |
+
# lambdas, V = np.linalg.eig(A)
|
117 |
+
# # print()
|
118 |
+
# # The linearly dependent row vectors
|
119 |
+
# print(lambdas == 0, np.linalg.det(A), A[lambdas == 0, :]) # some have determinant zero, some don't??
|
120 |
+
# print(np.linalg.inv(A))
|
121 |
+
|
122 |
+
# A = np.array([
|
123 |
+
# [fx, 0, cx - x0, cx - x0, 0, 0],
|
124 |
+
# [0, fy, cy - y0, cy - y0, 0, 0],
|
125 |
+
# [fx, 0, cx - x1, 0, cx - x1, 0],
|
126 |
+
# [0, fy, cy - y1, 0, cy - y1, 0],
|
127 |
+
# [fx, 0, cx - x3, 0, 0, cx - x3],
|
128 |
+
# [0, fy, cy - y3, 0, 0, cy - y3]
|
129 |
+
# ]) # this is also low rank!
|
130 |
+
# b = np.array([0, 0, -3*fx, 0, 0, -3*fy]).reshape((-1, 1))
|
131 |
+
# print("rank of the coefficient matrix:", np.linalg.matrix_rank(A)) # rank is 5! underconstrained matrix!
|
132 |
+
# x = np.matmul(np.linalg.inv(A), b)
|
133 |
+
|
134 |
+
# fix z0 as 0, then A is a full-rank matrix
|
135 |
+
# first two equations: origin (0, 0, 0) is projected to the crop center
|
136 |
+
# last two equations: edge point (3.5, 0, z) is projected to the edge of crop
|
137 |
+
A = np.array([
|
138 |
+
[fx, 0, cx-x0, cx-x0],
|
139 |
+
[0, fy, cy-y0, cy-y0],
|
140 |
+
[fx, 0, fx-x1, 0],
|
141 |
+
[0, fy, cy-y1, 0]
|
142 |
+
])
|
143 |
+
# b = np.array([0, 0, -3.5*fx, 0]).reshape((-1, 1)) # 3.5->half of 7.0
|
144 |
+
b = np.array([0, 0, -std_coverage * fx, 0]).reshape((-1, 1)) # 3.5->half of 7.0
|
145 |
+
x = np.matmul(np.linalg.inv(A), b) # use 4 or 5 does not really matter, same results
|
146 |
+
|
147 |
+
# A is always a full-rank matrix
|
148 |
+
|
149 |
+
return x.flatten()[:3]
|
demo.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Demo for template-free reconstruction
|
3 |
+
|
4 |
+
python demo.py model=ho-attn run.image_path=/BS/xxie-2/work/HDM/outputs/000000017450/k1.color.jpg run.job=sample model.predict_binary=True dataset.std_coverage=3.0
|
5 |
+
"""
|
6 |
+
import pickle as pkl
|
7 |
+
import sys, os
|
8 |
+
import os.path as osp
|
9 |
+
from typing import Iterable, Optional
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
from accelerate import Accelerator
|
13 |
+
from tqdm import tqdm
|
14 |
+
from glob import glob
|
15 |
+
|
16 |
+
sys.path.append(os.getcwd())
|
17 |
+
import hydra
|
18 |
+
import torch
|
19 |
+
import numpy as np
|
20 |
+
import imageio
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
from pytorch3d.datasets import R2N2, collate_batched_meshes
|
23 |
+
from pytorch3d.structures import Pointclouds
|
24 |
+
from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform
|
25 |
+
from pytorch3d.io import IO
|
26 |
+
import torchvision.transforms.functional as TVF
|
27 |
+
from huggingface_hub import hf_hub_download
|
28 |
+
|
29 |
+
import training_utils
|
30 |
+
from configs.structured import ProjectConfig
|
31 |
+
from dataset.demo_dataset import DemoDataset
|
32 |
+
from model import CrossAttenHODiffusionModel, ConditionalPCDiffusionSeparateSegm
|
33 |
+
from render.pyt3d_wrapper import PcloudRenderer
|
34 |
+
|
35 |
+
|
36 |
+
class DemoRunner:
|
37 |
+
def __init__(self, cfg: ProjectConfig):
|
38 |
+
cfg.model.model_name, cfg.model.predict_binary = 'pc2-diff-ho-sepsegm', True
|
39 |
+
model_stage1 = ConditionalPCDiffusionSeparateSegm(**cfg.model)
|
40 |
+
cfg.model.model_name, cfg.model.predict_binary = 'diff-ho-attn', False # stage 2 does not predict segmentation
|
41 |
+
model_stage2 = CrossAttenHODiffusionModel(**cfg.model)
|
42 |
+
|
43 |
+
# Load from checkpoint
|
44 |
+
# ckpt_file1 = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.stage1_name}/single/checkpoint-latest.pth')
|
45 |
+
# self.load_checkpoint(ckpt_file1, model_stage1)
|
46 |
+
# ckpt_file2 = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.stage2_name}/single/checkpoint-latest.pth')
|
47 |
+
# self.load_checkpoint(ckpt_file2, model_stage2)
|
48 |
+
# Load ckpt from hf
|
49 |
+
ckpt_file1 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage1_name}.pth')
|
50 |
+
self.load_checkpoint(ckpt_file1, model_stage1)
|
51 |
+
ckpt_file2 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage2_name}.pth')
|
52 |
+
self.load_checkpoint(ckpt_file2, model_stage2)
|
53 |
+
|
54 |
+
self.model_stage1, self.model_stage2 = model_stage1, model_stage2
|
55 |
+
self.model_stage1.eval()
|
56 |
+
self.model_stage2.eval()
|
57 |
+
self.model_stage1.to('cuda')
|
58 |
+
self.model_stage2.to('cuda')
|
59 |
+
|
60 |
+
self.cfg = cfg
|
61 |
+
self.io_pc = IO()
|
62 |
+
|
63 |
+
# For visualization
|
64 |
+
self.renderer = PcloudRenderer(image_size=cfg.dataset.image_size, radius=0.0075)
|
65 |
+
self.rend_size = cfg.dataset.image_size
|
66 |
+
self.device = 'cuda'
|
67 |
+
|
68 |
+
def load_checkpoint(self, ckpt_file1, model_stage1):
|
69 |
+
checkpoint = torch.load(ckpt_file1, map_location='cpu')
|
70 |
+
state_dict, key = checkpoint['model'], 'model'
|
71 |
+
if any(k.startswith('module.') for k in state_dict.keys()):
|
72 |
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
73 |
+
print('Removed "module." from checkpoint state dict')
|
74 |
+
missing_keys, unexpected_keys = model_stage1.load_state_dict(state_dict, strict=False)
|
75 |
+
print(f'Loaded model checkpoint {key} from {ckpt_file1}')
|
76 |
+
if len(missing_keys):
|
77 |
+
print(f' - Missing_keys: {missing_keys}')
|
78 |
+
if len(unexpected_keys):
|
79 |
+
print(f' - Unexpected_keys: {unexpected_keys}')
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def run(self):
|
83 |
+
"simply run the demo on given images, and save the results"
|
84 |
+
# Set random seed
|
85 |
+
training_utils.set_seed(self.cfg.run.seed)
|
86 |
+
|
87 |
+
outdir = osp.join(self.cfg.run.code_dir_abs, 'outputs/demo')
|
88 |
+
os.makedirs(outdir, exist_ok=True)
|
89 |
+
cfg = self.cfg
|
90 |
+
|
91 |
+
# Init data
|
92 |
+
image_files = sorted(glob(cfg.run.image_path))
|
93 |
+
data = DemoDataset(image_files,
|
94 |
+
(cfg.dataset.image_size, cfg.dataset.image_size),
|
95 |
+
cfg.dataset.std_coverage)
|
96 |
+
dataloader = DataLoader(data, batch_size=cfg.dataloader.batch_size,
|
97 |
+
collate_fn=collate_batched_meshes,
|
98 |
+
num_workers=1, shuffle=False)
|
99 |
+
dataloader = dataloader
|
100 |
+
progress_bar = tqdm(dataloader)
|
101 |
+
for batch_idx, batch in enumerate(progress_bar):
|
102 |
+
progress_bar.set_description(f'Processing batch {batch_idx:4d} / {len(dataloader):4d}')
|
103 |
+
|
104 |
+
out_stage1, out_stage2 = self.forward_batch(batch, cfg)
|
105 |
+
|
106 |
+
bs = len(out_stage1)
|
107 |
+
camera_full = PerspectiveCameras(
|
108 |
+
R=torch.stack(batch['R']),
|
109 |
+
T=torch.stack(batch['T']),
|
110 |
+
K=torch.stack(batch['K']),
|
111 |
+
device='cuda',
|
112 |
+
in_ndc=True)
|
113 |
+
|
114 |
+
# save output
|
115 |
+
for i in range(bs):
|
116 |
+
image_path = str(batch['image_path'])
|
117 |
+
folder, fname = osp.basename(osp.dirname(image_path)), osp.splitext(osp.basename(image_path))[0]
|
118 |
+
out_i = osp.join(outdir, folder)
|
119 |
+
os.makedirs(out_i, exist_ok=True)
|
120 |
+
self.io_pc.save_pointcloud(data=out_stage1[i],
|
121 |
+
path=osp.join(out_i, f'{fname}_stage1.ply'))
|
122 |
+
self.io_pc.save_pointcloud(data=out_stage2[i],
|
123 |
+
path=osp.join(out_i, f'{fname}_stage2.ply'))
|
124 |
+
TVF.to_pil_image(batch['images'][i]).save(osp.join(out_i, f'{fname}_input.png'))
|
125 |
+
|
126 |
+
# Save metadata as well
|
127 |
+
metadata = dict(index=i,
|
128 |
+
camera=camera_full[i],
|
129 |
+
image_size_hw=batch['image_size_hw'][i],
|
130 |
+
image_path=batch['image_path'][i])
|
131 |
+
torch.save(metadata, osp.join(out_i, f'{fname}_meta.pth'))
|
132 |
+
|
133 |
+
# Visualize
|
134 |
+
# front_camera = camera_full[i]
|
135 |
+
pc_comb = Pointclouds([out_stage1[i].points_packed(), out_stage2[i].points_packed()],
|
136 |
+
features=[out_stage1[i].features_packed(), out_stage2[i].features_packed()])
|
137 |
+
video_file = osp.join(out_i, f'{fname}_360view.mp4')
|
138 |
+
video_writer = imageio.get_writer(video_file, format='FFMPEG', mode='I', fps=1)
|
139 |
+
|
140 |
+
# first render front view
|
141 |
+
rend_stage1, _ = self.renderer.render(out_stage1[i], camera_full[i], mode='mask')
|
142 |
+
rend_stage2, _ = self.renderer.render(out_stage2[i], camera_full[i], mode='mask')
|
143 |
+
comb = np.concatenate([batch['images'][i].permute(1, 2, 0).cpu().numpy(), rend_stage1, rend_stage2], 1)
|
144 |
+
video_writer.append_data((comb*255).astype(np.uint8))
|
145 |
+
|
146 |
+
for azim in range(180, 180+360, 30):
|
147 |
+
R, T = look_at_view_transform(1.7, 0, azim, up=((0, -1, 0),), )
|
148 |
+
side_camera = PerspectiveCameras(image_size=((self.rend_size, self.rend_size),),
|
149 |
+
device=self.device,
|
150 |
+
R=R.repeat(2, 1, 1), T=T.repeat(2, 1),
|
151 |
+
focal_length=self.rend_size * 1.5,
|
152 |
+
principal_point=((self.rend_size / 2., self.rend_size / 2.),),
|
153 |
+
in_ndc=False)
|
154 |
+
rend, mask = self.renderer.render(pc_comb, side_camera, mode='mask')
|
155 |
+
|
156 |
+
imgs = [batch['images'][i].permute(1, 2, 0).cpu().numpy()]
|
157 |
+
imgs.extend([rend[0], rend[1]])
|
158 |
+
video_writer.append_data((np.concatenate(imgs, 1)*255).astype(np.uint8))
|
159 |
+
print(f"Visualization saved to {out_i}")
|
160 |
+
|
161 |
+
@torch.no_grad()
|
162 |
+
def forward_batch(self, batch, cfg):
|
163 |
+
"""
|
164 |
+
forward one batch
|
165 |
+
:param batch:
|
166 |
+
:param cfg:
|
167 |
+
:return: predicted point clouds of stage 1 and 2
|
168 |
+
"""
|
169 |
+
camera_full = PerspectiveCameras(
|
170 |
+
R=torch.stack(batch['R']),
|
171 |
+
T=torch.stack(batch['T']),
|
172 |
+
K=torch.stack(batch['K']),
|
173 |
+
device='cuda',
|
174 |
+
in_ndc=True)
|
175 |
+
out_stage1 = self.model_stage1.forward_sample(num_points=cfg.dataset.max_points,
|
176 |
+
camera=camera_full,
|
177 |
+
image_rgb=torch.stack(batch['images']).to('cuda'),
|
178 |
+
mask=torch.stack(batch['masks']).to('cuda'),
|
179 |
+
scheduler=cfg.run.diffusion_scheduler,
|
180 |
+
num_inference_steps=cfg.run.num_inference_steps,
|
181 |
+
)
|
182 |
+
# segment and normalize human/object
|
183 |
+
bs = len(out_stage1)
|
184 |
+
pred_hum, pred_obj = [], [] # predicted human/object points
|
185 |
+
cent_hum_pred, cent_obj_pred = [], []
|
186 |
+
radius_hum_pred, radius_obj_pred = [], []
|
187 |
+
T_hum, T_obj = [], []
|
188 |
+
num_samples = int(cfg.dataset.max_points / 2)
|
189 |
+
for i in range(bs):
|
190 |
+
pc: Pointclouds = out_stage1[i]
|
191 |
+
vc = pc.features_packed().cpu() # (P, 3), human is light blue [0.1, 1.0, 1.0], object light green [0.5, 1.0, 0]
|
192 |
+
points = pc.points_packed().cpu() # (P, 3)
|
193 |
+
mask_hum = vc[:, 2] > 0.5
|
194 |
+
pc_hum, pc_obj = points[mask_hum], points[~mask_hum]
|
195 |
+
# Up/Down-sample the points
|
196 |
+
pc_obj = self.upsample_predicted_pc(num_samples, pc_obj)
|
197 |
+
pc_hum = self.upsample_predicted_pc(num_samples, pc_hum)
|
198 |
+
|
199 |
+
# Normalize
|
200 |
+
cent_hum, cent_obj = torch.mean(pc_hum, 0, keepdim=True), torch.mean(pc_obj, 0, keepdim=True)
|
201 |
+
scale_hum = torch.sqrt(torch.sum((pc_hum - cent_hum) ** 2, -1).max())
|
202 |
+
scale_obj = torch.sqrt(torch.sum((pc_obj - cent_obj) ** 2, -1).max())
|
203 |
+
pc_hum = (pc_hum - cent_hum) / (2 * scale_hum)
|
204 |
+
pc_obj = (pc_obj - cent_obj) / (2 * scale_obj)
|
205 |
+
# Also update camera parameters for separate human + object
|
206 |
+
T_hum_scaled = (batch['T_ho'][i] + cent_hum.squeeze(0)) / (2 * scale_hum)
|
207 |
+
T_obj_scaled = (batch['T_ho'][i] + cent_obj.squeeze(0)) / (2 * scale_obj)
|
208 |
+
|
209 |
+
pred_hum.append(pc_hum)
|
210 |
+
pred_obj.append(pc_obj)
|
211 |
+
cent_hum_pred.append(cent_hum.squeeze(0))
|
212 |
+
cent_obj_pred.append(cent_obj.squeeze(0))
|
213 |
+
T_hum.append(T_hum_scaled * torch.tensor([-1, -1, 1])) # apply opencv to pytorch3d transform: flip x and y
|
214 |
+
T_obj.append(T_obj_scaled * torch.tensor([-1, -1, 1]))
|
215 |
+
radius_hum_pred.append(scale_hum)
|
216 |
+
radius_obj_pred.append(scale_obj)
|
217 |
+
# Pack data into a new batch dict
|
218 |
+
camera_hum = PerspectiveCameras(
|
219 |
+
R=torch.stack(batch['R']),
|
220 |
+
T=torch.stack(T_hum),
|
221 |
+
K=torch.stack(batch['K_hum']),
|
222 |
+
device='cuda',
|
223 |
+
in_ndc=True
|
224 |
+
)
|
225 |
+
camera_obj = PerspectiveCameras(
|
226 |
+
R=torch.stack(batch['R']),
|
227 |
+
T=torch.stack(T_obj),
|
228 |
+
K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!!
|
229 |
+
device='cuda',
|
230 |
+
in_ndc=True
|
231 |
+
)
|
232 |
+
# use pc from predicted
|
233 |
+
pc_hum = Pointclouds([x.to('cuda') for x in pred_hum])
|
234 |
+
pc_obj = Pointclouds([x.to('cuda') for x in pred_obj])
|
235 |
+
# use center and radius from predicted
|
236 |
+
cent_hum = torch.stack(cent_hum_pred, 0).to('cuda')
|
237 |
+
cent_obj = torch.stack(cent_obj_pred, 0).to('cuda') # B, 3
|
238 |
+
radius_hum = torch.stack(radius_hum_pred, 0).to('cuda') # B, 1
|
239 |
+
radius_obj = torch.stack(radius_obj_pred, 0).to('cuda')
|
240 |
+
out_stage2: Pointclouds = self.model_stage2.forward_sample(
|
241 |
+
num_points=num_samples,
|
242 |
+
camera=camera_hum,
|
243 |
+
image_rgb=torch.stack(batch['images_hum'], 0).to('cuda'),
|
244 |
+
mask=torch.stack(batch['masks_hum'], 0).to('cuda'),
|
245 |
+
gt_pc=pc_hum,
|
246 |
+
rgb_obj=torch.stack(batch['images_obj'], 0).to('cuda'),
|
247 |
+
mask_obj=torch.stack(batch['masks_obj'], 0).to('cuda'),
|
248 |
+
pc_obj=pc_obj,
|
249 |
+
camera_obj=camera_obj,
|
250 |
+
cent_hum=cent_hum,
|
251 |
+
cent_obj=cent_obj,
|
252 |
+
radius_hum=radius_hum.unsqueeze(-1),
|
253 |
+
radius_obj=radius_obj.unsqueeze(-1),
|
254 |
+
sample_from_interm=True,
|
255 |
+
noise_step=cfg.run.sample_noise_step)
|
256 |
+
return out_stage1, out_stage2
|
257 |
+
|
258 |
+
def upsample_predicted_pc(self, num_samples, pc_obj):
|
259 |
+
"""
|
260 |
+
Up/Downsample the points to given number
|
261 |
+
:param num_samples: the target number
|
262 |
+
:param pc_obj: (N, 3)
|
263 |
+
:return: (num_samples, 3)
|
264 |
+
"""
|
265 |
+
if len(pc_obj) > num_samples:
|
266 |
+
ind_obj = np.random.choice(len(pc_obj), num_samples)
|
267 |
+
else:
|
268 |
+
ind_obj = np.concatenate([np.arange(len(pc_obj)), np.random.choice(len(pc_obj), num_samples - len(pc_obj))])
|
269 |
+
pc_obj = pc_obj.clone()[torch.from_numpy(ind_obj).long().to(pc_obj.device)]
|
270 |
+
return pc_obj
|
271 |
+
|
272 |
+
|
273 |
+
@hydra.main(config_path='configs', config_name='configs', version_base='1.1')
|
274 |
+
def main(cfg: ProjectConfig):
|
275 |
+
runner = DemoRunner(cfg)
|
276 |
+
runner.run()
|
277 |
+
|
278 |
+
|
279 |
+
if __name__ == '__main__':
|
280 |
+
main()
|
diffusion_utils.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Optional, Sequence, Union
|
3 |
+
|
4 |
+
import imageio
|
5 |
+
import logging
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.utils.data
|
9 |
+
from PIL import Image
|
10 |
+
from torch.distributions import Normal
|
11 |
+
from torchvision.transforms.functional import to_pil_image
|
12 |
+
from torchvision.utils import make_grid
|
13 |
+
from tqdm import tqdm, trange
|
14 |
+
from pytorch3d.renderer import (
|
15 |
+
AlphaCompositor,
|
16 |
+
NormWeightedCompositor,
|
17 |
+
OrthographicCameras,
|
18 |
+
PointsRasterizationSettings,
|
19 |
+
PointsRasterizer,
|
20 |
+
PointsRenderer,
|
21 |
+
look_at_view_transform)
|
22 |
+
from pytorch3d.renderer.cameras import CamerasBase
|
23 |
+
from pytorch3d.structures import Pointclouds
|
24 |
+
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch
|
25 |
+
|
26 |
+
|
27 |
+
# Disable unnecessary imageio logging
|
28 |
+
logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)
|
29 |
+
|
30 |
+
|
31 |
+
def rotation_matrix(axis, theta):
|
32 |
+
"""
|
33 |
+
Return the rotation matrix associated with counterclockwise rotation about
|
34 |
+
the given axis by theta radians.
|
35 |
+
"""
|
36 |
+
axis = np.asarray(axis)
|
37 |
+
axis = axis / np.sqrt(np.dot(axis, axis))
|
38 |
+
a = np.cos(theta / 2.0)
|
39 |
+
b, c, d = -axis * np.sin(theta / 2.0)
|
40 |
+
aa, bb, cc, dd = a * a, b * b, c * c, d * d
|
41 |
+
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
|
42 |
+
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
|
43 |
+
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
|
44 |
+
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
|
45 |
+
|
46 |
+
|
47 |
+
def rotate(vertices, faces):
|
48 |
+
'''
|
49 |
+
vertices: [numpoints, 3]
|
50 |
+
'''
|
51 |
+
M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
|
52 |
+
N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
|
53 |
+
K = rotation_matrix([0, 0, 1], np.pi).transpose()
|
54 |
+
|
55 |
+
v, f = vertices[:, [1, 2, 0]].dot(M).dot(N).dot(K), faces[:, [1, 2, 0]]
|
56 |
+
return v, f
|
57 |
+
|
58 |
+
|
59 |
+
def norm(v, f):
|
60 |
+
v = (v - v.min()) / (v.max() - v.min()) - 0.5
|
61 |
+
|
62 |
+
return v, f
|
63 |
+
|
64 |
+
|
65 |
+
def getGradNorm(net):
|
66 |
+
pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
|
67 |
+
gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
|
68 |
+
return pNorm, gradNorm
|
69 |
+
|
70 |
+
|
71 |
+
def weights_init(m):
|
72 |
+
classname = m.__class__.__name__
|
73 |
+
if classname.find('Conv') != -1 and m.weight is not None:
|
74 |
+
torch.nn.init.xavier_normal_(m.weight)
|
75 |
+
elif classname.find('BatchNorm') != -1:
|
76 |
+
m.weight.data.normal_()
|
77 |
+
m.bias.data.fill_(0)
|
78 |
+
|
79 |
+
|
80 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
81 |
+
# Assumes data is integers [0, 1]
|
82 |
+
assert x.shape == means.shape == log_scales.shape
|
83 |
+
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
|
84 |
+
|
85 |
+
centered_x = x - means
|
86 |
+
inv_stdv = torch.exp(-log_scales)
|
87 |
+
plus_in = inv_stdv * (centered_x + 0.5)
|
88 |
+
cdf_plus = px0.cdf(plus_in)
|
89 |
+
min_in = inv_stdv * (centered_x - .5)
|
90 |
+
cdf_min = px0.cdf(min_in)
|
91 |
+
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
|
92 |
+
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min) * 1e-12))
|
93 |
+
cdf_delta = cdf_plus - cdf_min
|
94 |
+
|
95 |
+
log_probs = torch.where(
|
96 |
+
x < 0.001, log_cdf_plus,
|
97 |
+
torch.where(x > 0.999, log_one_minus_cdf_min,
|
98 |
+
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))))
|
99 |
+
assert log_probs.shape == x.shape
|
100 |
+
return log_probs
|
101 |
+
|
102 |
+
|
103 |
+
def fig2img(fig):
|
104 |
+
"""Convert a Matplotlib figure to a PIL Image and return it"""
|
105 |
+
import io
|
106 |
+
buf = io.BytesIO()
|
107 |
+
fig.savefig(buf)
|
108 |
+
buf.seek(0)
|
109 |
+
img = Image.open(buf)
|
110 |
+
return img
|
111 |
+
|
112 |
+
|
113 |
+
@torch.no_grad()
|
114 |
+
def visualize_distance_transform(
|
115 |
+
path_stem: str,
|
116 |
+
images: torch.Tensor,
|
117 |
+
) -> str:
|
118 |
+
output_file_image = f'{path_stem}.png'
|
119 |
+
if images.shape[3] in [1, 3]: # convert to (B, C, H, W)
|
120 |
+
images = images.permute(0, 3, 1, 2)
|
121 |
+
images = images[:, -1:] # (B, 1, H, W) # get only distances (not vectors for now, for simplicity)
|
122 |
+
image_grid = make_grid(images, nrow=int(math.sqrt(len(images))), pad_value=1, normalize=True)
|
123 |
+
to_pil_image(image_grid).save(output_file_image)
|
124 |
+
return output_file_image
|
125 |
+
|
126 |
+
|
127 |
+
@torch.no_grad()
|
128 |
+
def visualize_image(
|
129 |
+
path_stem: str,
|
130 |
+
images: torch.Tensor,
|
131 |
+
mean: Union[torch.Tensor, float] = 0.5,
|
132 |
+
std: Union[torch.Tensor, float] = 0.5,
|
133 |
+
) -> str:
|
134 |
+
output_file_image = f'{path_stem}.png'
|
135 |
+
if images.shape[3] in [1, 3, 4]: # convert to (B, C, H, W)
|
136 |
+
images = images.permute(0, 3, 1, 2)
|
137 |
+
if images.shape[1] in [3, 4]: # normalize (single-channel images are not normalized)
|
138 |
+
images[:, :3] = images[:, :3] * std + mean # denormalize (color channels only, not alpha channel)
|
139 |
+
if images.shape[1] == 4: # normalize (single-channel images are not normalized)
|
140 |
+
image_alpha = images[:, 3:] # (B, 1, H, W)
|
141 |
+
bg_color = torch.tensor([230, 220, 250], device=images.device).reshape(1, 3, 1, 1) / 255
|
142 |
+
images = images[:, :3] * image_alpha + bg_color * (1 - image_alpha) # (B, 3, H, W)
|
143 |
+
image_grid = make_grid(images, nrow=int(math.sqrt(len(images))), pad_value=1)
|
144 |
+
to_pil_image(image_grid).save(output_file_image)
|
145 |
+
return output_file_image
|
146 |
+
|
147 |
+
|
148 |
+
def ensure_point_cloud_has_colors(pointcloud: Pointclouds):
|
149 |
+
if pointcloud.features_padded() is None:
|
150 |
+
pointcloud = type(pointcloud)(points=pointcloud.points_padded(),
|
151 |
+
normals=pointcloud.normals_padded(), features=torch.zeros_like(pointcloud.points_padded()))
|
152 |
+
return pointcloud
|
153 |
+
|
154 |
+
|
155 |
+
@torch.no_grad()
|
156 |
+
def render_pointcloud_batch_pytorch3d(
|
157 |
+
cameras: CamerasBase,
|
158 |
+
pointclouds: Pointclouds,
|
159 |
+
image_size: int = 224,
|
160 |
+
radius: float = 0.01,
|
161 |
+
points_per_pixel: int = 10,
|
162 |
+
background_color: Sequence[float] = (0.78431373, 0.78431373, 0.78431373),
|
163 |
+
compositor: str = 'norm_weighted'
|
164 |
+
):
|
165 |
+
# Define the settings for rasterization and shading. Here we set the output image to be of size
|
166 |
+
# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1
|
167 |
+
# and blur_radius=0.0. Refer to rasterize_points.py for explanations of these parameters.
|
168 |
+
raster_settings = PointsRasterizationSettings(
|
169 |
+
image_size=image_size,
|
170 |
+
radius=radius,
|
171 |
+
points_per_pixel=points_per_pixel,
|
172 |
+
)
|
173 |
+
|
174 |
+
# Rasterizer
|
175 |
+
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
176 |
+
|
177 |
+
# Compositor
|
178 |
+
if compositor == 'alpha':
|
179 |
+
compositor = AlphaCompositor(background_color=background_color)
|
180 |
+
elif compositor == 'norm_weighted':
|
181 |
+
compositor = NormWeightedCompositor(background_color=background_color)
|
182 |
+
else:
|
183 |
+
raise ValueError(compositor)
|
184 |
+
|
185 |
+
# Create a points renderer by compositing points using an weighted compositor (3D points are
|
186 |
+
# weighted according to their distance to a pixel and accumulated using a weighted sum)
|
187 |
+
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
|
188 |
+
|
189 |
+
# We cannot render a point cloud without colors, so add them if the pointcloud does
|
190 |
+
# not already have them
|
191 |
+
pointclouds = ensure_point_cloud_has_colors(pointclouds)
|
192 |
+
|
193 |
+
# Render batch of image
|
194 |
+
images = renderer(pointclouds)
|
195 |
+
|
196 |
+
return images
|
197 |
+
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
def visualize_pointcloud_batch_pytorch3d(
|
201 |
+
pointclouds: Pointclouds,
|
202 |
+
output_file_video: Optional[str] = None,
|
203 |
+
output_file_image: Optional[str] = None,
|
204 |
+
cameras: Optional[CamerasBase] = None, # if None, we rotate
|
205 |
+
scale_factor: float = 1.0,
|
206 |
+
num_frames: int = 1, # note that it takes a while with 30 * batch_size frames
|
207 |
+
elev: int = 30,
|
208 |
+
):
|
209 |
+
"""Saves a video and a single image of a point cloud"""
|
210 |
+
assert 360 % num_frames == 0, 'please select a better number of frames'
|
211 |
+
|
212 |
+
# Sizes
|
213 |
+
B, N, C, F = *(pointclouds.points_padded().shape), num_frames
|
214 |
+
device = pointclouds.device
|
215 |
+
|
216 |
+
# If a camera has not been provided, we render from a rotating view around an image
|
217 |
+
if cameras is None:
|
218 |
+
|
219 |
+
# Create view transforms - R is (F, 3, 3) and T is (F, 3)
|
220 |
+
R, T = look_at_view_transform(dist=10.0, elev=elev, azim=list(range(0, 360, 360 // F)), degrees=True, device=device)
|
221 |
+
|
222 |
+
# Repeat
|
223 |
+
R = R.repeat_interleave(B, dim=0) # (F * B, 3, 3)
|
224 |
+
T = T.repeat_interleave(B, dim=0) # (F * B, 3)
|
225 |
+
points = pointclouds.points_padded().tile(F, 1, 1) # (F * B, num_points, 3)
|
226 |
+
colors = (torch.zeros_like(points) if pointclouds.features_padded() is None else
|
227 |
+
pointclouds.features_padded().tile(F, 1, 1)) # (F * B, num_points, 3)
|
228 |
+
|
229 |
+
# Initialize batch of cameras
|
230 |
+
cameras = OrthographicCameras(focal_length=(0.25 * scale_factor), device=device, R=R, T=T)
|
231 |
+
|
232 |
+
# Wrap in Pointclouds (with color, even if the original point cloud had no color)
|
233 |
+
pointclouds = Pointclouds(points=points, features=colors).to(device)
|
234 |
+
|
235 |
+
# Render image
|
236 |
+
images = render_pointcloud_batch_pytorch3d(cameras, pointclouds)
|
237 |
+
|
238 |
+
# Convert images into grid
|
239 |
+
image_grids = []
|
240 |
+
images_for_grids = images.reshape(F, B, *images.shape[1:]).permute(0, 1, 4, 2, 3)
|
241 |
+
for image_for_grids in images_for_grids:
|
242 |
+
image_grid = make_grid(image_for_grids, nrow=int(math.sqrt(B)), pad_value=1)
|
243 |
+
image_grids.append(image_grid)
|
244 |
+
image_grids = torch.stack(image_grids, dim=0)
|
245 |
+
image_grids = image_grids.detach().cpu()
|
246 |
+
|
247 |
+
# Save image
|
248 |
+
if output_file_image is not None:
|
249 |
+
to_pil_image(image_grids[0]).save(output_file_image)
|
250 |
+
|
251 |
+
# Save video
|
252 |
+
if output_file_video:
|
253 |
+
video = (image_grids * 255).permute(0, 2, 3, 1).to(torch.uint8).numpy()
|
254 |
+
imageio.mimwrite(output_file_video, video, fps=10)
|
255 |
+
|
256 |
+
|
257 |
+
@torch.no_grad()
|
258 |
+
def visualize_pointcloud_evolution_pytorch3d(
|
259 |
+
pointclouds: Pointclouds,
|
260 |
+
output_file_video: str,
|
261 |
+
camera: Optional[CamerasBase] = None, # if None, we rotate
|
262 |
+
scale_factor: float = 1.0,
|
263 |
+
):
|
264 |
+
|
265 |
+
# Device
|
266 |
+
B, device = len(pointclouds), pointclouds.device
|
267 |
+
|
268 |
+
# Cameras
|
269 |
+
if camera is None:
|
270 |
+
R, T = look_at_view_transform(dist=10.0, elev=30, azim=0, device=device)
|
271 |
+
camera = OrthographicCameras(focal_length=(0.25 * scale_factor), device=device, R=R, T=T)
|
272 |
+
|
273 |
+
# Render
|
274 |
+
frames = render_pointcloud_batch_pytorch3d(camera, pointclouds)
|
275 |
+
|
276 |
+
# Save video
|
277 |
+
video = (frames.detach().cpu() * 255).to(torch.uint8).numpy()
|
278 |
+
imageio.mimwrite(output_file_video, video, fps=10)
|
279 |
+
|
280 |
+
|
281 |
+
def get_camera_index(cameras: CamerasBase, index: Optional[int] = None):
|
282 |
+
if index is None:
|
283 |
+
return cameras
|
284 |
+
kwargs = dict(
|
285 |
+
R=cameras.R[index].unsqueeze(0),
|
286 |
+
T=cameras.T[index].unsqueeze(0),
|
287 |
+
K=cameras.K[index].unsqueeze(0) if cameras.K is not None else None,
|
288 |
+
)
|
289 |
+
if hasattr(cameras, 'focal_length'):
|
290 |
+
kwargs['focal_length'] = cameras.focal_length[index].unsqueeze(0)
|
291 |
+
if hasattr(cameras, 'principal_point'):
|
292 |
+
kwargs['principal_point'] = cameras.principal_point[index].unsqueeze(0)
|
293 |
+
return type(cameras)(**kwargs).to(cameras.device)
|
294 |
+
|
295 |
+
|
296 |
+
def get_metadata(item) -> str:
|
297 |
+
s = '-------------\n'
|
298 |
+
for key in item.keys():
|
299 |
+
value = item[key]
|
300 |
+
if torch.is_tensor(value) and value.numel() < 25:
|
301 |
+
value_str = value
|
302 |
+
elif torch.is_tensor(value):
|
303 |
+
value_str = value.shape
|
304 |
+
elif isinstance(value, str):
|
305 |
+
value_str = value
|
306 |
+
elif isinstance(value, list) and 0 < len(value) and len(value) < 25 and isinstance(value[0], str):
|
307 |
+
value_str = value
|
308 |
+
elif isinstance(value, dict):
|
309 |
+
value_str = str({k: type(v) for k, v in value.items()})
|
310 |
+
else:
|
311 |
+
value_str = type(value)
|
312 |
+
s += f"{key:<30} {value_str}\n"
|
313 |
+
return s
|
examples/017450/k1.color.jpg
ADDED
examples/017450/k1.obj_rend_mask.png
ADDED
examples/017450/k1.person_mask.png
ADDED
model/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs.structured import ProjectConfig
|
2 |
+
from .model import ConditionalPointCloudDiffusionModel
|
3 |
+
from .model_coloring import PointCloudColoringModel
|
4 |
+
from .model_utils import set_requires_grad
|
5 |
+
from .model_diff_data import ConditionalPCDiffusionSeparateSegm
|
6 |
+
from .model_hoattn import CrossAttenHODiffusionModel
|
7 |
+
|
8 |
+
def get_model(cfg: ProjectConfig):
|
9 |
+
if cfg.model.model_name == 'pc2-diff':
|
10 |
+
model = ConditionalPointCloudDiffusionModel(**cfg.model)
|
11 |
+
elif cfg.model.model_name == 'pc2-diff-ho-sepsegm':
|
12 |
+
model = ConditionalPCDiffusionSeparateSegm(**cfg.model)
|
13 |
+
print("Using a separate model to predict segmentation label")
|
14 |
+
elif cfg.model.model_name == 'diff-ho-attn':
|
15 |
+
model = CrossAttenHODiffusionModel(**cfg.model)
|
16 |
+
print("Using separate model for human + object with cross attention.")
|
17 |
+
else:
|
18 |
+
raise NotImplementedError
|
19 |
+
if cfg.run.freeze_feature_model:
|
20 |
+
set_requires_grad(model.feature_model, False)
|
21 |
+
return model
|
22 |
+
|
23 |
+
|
24 |
+
def get_coloring_model(cfg: ProjectConfig):
|
25 |
+
model = PointCloudColoringModel(**cfg.model)
|
26 |
+
if cfg.run.freeze_feature_model:
|
27 |
+
set_requires_grad(model.feature_model, False)
|
28 |
+
return model
|
model/feature_model.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
7 |
+
from diffusers import ModelMixin
|
8 |
+
from timm.models.vision_transformer import VisionTransformer, resize_pos_embed
|
9 |
+
from torch import Tensor
|
10 |
+
from torchvision.transforms import functional as TVF
|
11 |
+
|
12 |
+
|
13 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
14 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
15 |
+
|
16 |
+
MODEL_URLS = {
|
17 |
+
'vit_base_patch16_224_mae': 'https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
|
18 |
+
'vit_small_patch16_224_msn': 'https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar',
|
19 |
+
'vit_large_patch7_224_msn': 'https://dl.fbaipublicfiles.com/msn/vitl7_200ep.pth.tar',
|
20 |
+
}
|
21 |
+
|
22 |
+
NORMALIZATION = {
|
23 |
+
'vit_base_patch16_224_mae': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
24 |
+
'vit_small_patch16_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
25 |
+
'vit_large_patch7_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
26 |
+
}
|
27 |
+
|
28 |
+
MODEL_KWARGS = {
|
29 |
+
'vit_base_patch16_224_mae': dict(
|
30 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
31 |
+
),
|
32 |
+
'vit_small_patch16_224_msn': dict(
|
33 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=6,
|
34 |
+
),
|
35 |
+
'vit_large_patch7_224_msn': dict(
|
36 |
+
patch_size=7, embed_dim=1024, depth=24, num_heads=16,
|
37 |
+
)
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
class FeatureModel(ModelMixin, ConfigMixin):
|
42 |
+
|
43 |
+
@register_to_config
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
image_size: int = 224,
|
47 |
+
model_name: str = 'vit_small_patch16_224_mae',
|
48 |
+
global_pool: str = '', # '' or 'token'
|
49 |
+
) -> None:
|
50 |
+
super().__init__()
|
51 |
+
self.model_name = model_name
|
52 |
+
|
53 |
+
# Identity
|
54 |
+
if self.model_name == 'identity':
|
55 |
+
return
|
56 |
+
|
57 |
+
# Create model
|
58 |
+
self.model = VisionTransformer(
|
59 |
+
img_size=image_size, num_classes=0, global_pool=global_pool,
|
60 |
+
**MODEL_KWARGS[model_name])
|
61 |
+
|
62 |
+
# Model properties
|
63 |
+
self.feature_dim = self.model.embed_dim
|
64 |
+
self.mean, self.std = NORMALIZATION[model_name]
|
65 |
+
|
66 |
+
# # Modify MSN model with output head from training
|
67 |
+
# if model_name.endswith('msn'):
|
68 |
+
# use_bn = True
|
69 |
+
# emb_dim = (192 if 'tiny' in model_name else 384 if 'small' in model_name else
|
70 |
+
# 768 if 'base' in model_name else 1024 if 'large' in model_name else 1280)
|
71 |
+
# hidden_dim = 2048
|
72 |
+
# output_dim = 256
|
73 |
+
# self.model.fc = None
|
74 |
+
# fc = OrderedDict([])
|
75 |
+
# fc['fc1'] = torch.nn.Linear(emb_dim, hidden_dim)
|
76 |
+
# if use_bn:
|
77 |
+
# fc['bn1'] = torch.nn.BatchNorm1d(hidden_dim)
|
78 |
+
# fc['gelu1'] = torch.nn.GELU()
|
79 |
+
# fc['fc2'] = torch.nn.Linear(hidden_dim, hidden_dim)
|
80 |
+
# if use_bn:
|
81 |
+
# fc['bn2'] = torch.nn.BatchNorm1d(hidden_dim)
|
82 |
+
# fc['gelu2'] = torch.nn.GELU()
|
83 |
+
# fc['fc3'] = torch.nn.Linear(hidden_dim, output_dim)
|
84 |
+
# self.model.fc = torch.nn.Sequential(fc)
|
85 |
+
|
86 |
+
# Load pretrained checkpoint
|
87 |
+
checkpoint = torch.hub.load_state_dict_from_url(MODEL_URLS[model_name])
|
88 |
+
if 'model' in checkpoint:
|
89 |
+
state_dict = checkpoint['model']
|
90 |
+
elif 'target_encoder' in checkpoint:
|
91 |
+
state_dict = checkpoint['target_encoder']
|
92 |
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
93 |
+
# NOTE: Comment the line below if using the projection head, uncomment if not using it
|
94 |
+
# See https://github.com/facebookresearch/msn/blob/81cb855006f41cd993fbaad4b6a6efbb486488e6/src/msn_train.py#L490-L502
|
95 |
+
# for more info about the projection head
|
96 |
+
state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')}
|
97 |
+
else:
|
98 |
+
raise NotImplementedError()
|
99 |
+
state_dict['pos_embed'] = resize_pos_embed(state_dict['pos_embed'], self.model.pos_embed)
|
100 |
+
self.model.load_state_dict(state_dict)
|
101 |
+
self.model.eval()
|
102 |
+
|
103 |
+
# # Modify MSN model with output head from training
|
104 |
+
# if model_name.endswith('msn'):
|
105 |
+
# self.fc = self.model.fc
|
106 |
+
# del self.model.fc
|
107 |
+
# else:
|
108 |
+
# self.fc = nn.Identity()
|
109 |
+
|
110 |
+
# NOTE: I've disabled the whole projection head stuff for simplicity for now
|
111 |
+
self.fc = nn.Identity()
|
112 |
+
|
113 |
+
def denormalize(self, img: Tensor):
|
114 |
+
img = TVF.normalize(img, mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std])
|
115 |
+
return torch.clip(img, 0, 1)
|
116 |
+
|
117 |
+
def normalize(self, img: Tensor):
|
118 |
+
return TVF.normalize(img, mean=self.mean, std=self.std)
|
119 |
+
|
120 |
+
def forward(
|
121 |
+
self,
|
122 |
+
x: Tensor,
|
123 |
+
return_type: str = 'features',
|
124 |
+
return_upscaled_features: bool = True,
|
125 |
+
return_projection_head_output: bool = False,
|
126 |
+
):
|
127 |
+
"""Normalizes the input `x` and runs it through `model` to obtain features"""
|
128 |
+
assert return_type in {'cls_token', 'features', 'all'}
|
129 |
+
|
130 |
+
# Identity
|
131 |
+
if self.model_name == 'identity':
|
132 |
+
return x
|
133 |
+
|
134 |
+
# Normalize and forward
|
135 |
+
B, C, H, W = x.shape
|
136 |
+
x = self.normalize(x)
|
137 |
+
feats = self.model(x)
|
138 |
+
|
139 |
+
# Reshape to image-like size
|
140 |
+
if return_type in {'features', 'all'}:
|
141 |
+
B, T, D = feats.shape
|
142 |
+
assert math.sqrt(T - 1).is_integer()
|
143 |
+
HW_down = int(math.sqrt(T - 1)) # subtract one for CLS token
|
144 |
+
output_feats: Tensor = feats[:, 1:, :].reshape(B, HW_down, HW_down, D).permute(0, 3, 1, 2) # (B, D, H_down, W_down)
|
145 |
+
if return_upscaled_features:
|
146 |
+
output_feats = F.interpolate(output_feats, size=(H, W), mode='bilinear',
|
147 |
+
align_corners=False) # (B, D, H_orig, W_orig)
|
148 |
+
|
149 |
+
# Head for MSN
|
150 |
+
output_cls = feats[:, 0]
|
151 |
+
if return_projection_head_output and return_type in {'cls_token', 'all'}:
|
152 |
+
output_cls = self.fc(output_cls)
|
153 |
+
|
154 |
+
# Return
|
155 |
+
if return_type == 'cls_token':
|
156 |
+
return output_cls
|
157 |
+
elif return_type == 'features':
|
158 |
+
return output_feats
|
159 |
+
else:
|
160 |
+
return output_cls, output_feats
|
model/model.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import random
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
8 |
+
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
9 |
+
from diffusers.schedulers.scheduling_pndm import PNDMScheduler
|
10 |
+
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
|
11 |
+
from pytorch3d.renderer.cameras import CamerasBase
|
12 |
+
from pytorch3d.structures import Pointclouds
|
13 |
+
from torch import Tensor
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from .model_utils import get_num_points, get_custom_betas
|
17 |
+
from .point_cloud_model import PointCloudModel
|
18 |
+
from .projection_model import PointCloudProjectionModel
|
19 |
+
|
20 |
+
|
21 |
+
class ConditionalPointCloudDiffusionModel(PointCloudProjectionModel):
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
beta_start: float,
|
26 |
+
beta_end: float,
|
27 |
+
beta_schedule: str,
|
28 |
+
point_cloud_model: str,
|
29 |
+
point_cloud_model_embed_dim: int,
|
30 |
+
**kwargs, # projection arguments
|
31 |
+
):
|
32 |
+
super().__init__(**kwargs)
|
33 |
+
|
34 |
+
# Checks
|
35 |
+
if not self.predict_shape:
|
36 |
+
raise NotImplementedError('Must predict shape if performing diffusion.')
|
37 |
+
|
38 |
+
# Create diffusion model schedulers which define the sampling timesteps
|
39 |
+
self.dm_pred_type = kwargs.get('dm_pred_type', "epsilon")
|
40 |
+
assert self.dm_pred_type in ['epsilon','sample']
|
41 |
+
scheduler_kwargs = {"prediction_type": self.dm_pred_type}
|
42 |
+
if beta_schedule == 'custom':
|
43 |
+
scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end)))
|
44 |
+
else:
|
45 |
+
scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule))
|
46 |
+
self.schedulers_map = {
|
47 |
+
'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False),
|
48 |
+
'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False),
|
49 |
+
'pndm': PNDMScheduler(**scheduler_kwargs),
|
50 |
+
}
|
51 |
+
self.scheduler = self.schedulers_map['ddpm'] # this can be changed for inference
|
52 |
+
|
53 |
+
# Create point cloud model for processing point cloud at each diffusion step
|
54 |
+
self.init_pcloud_model(kwargs, point_cloud_model, point_cloud_model_embed_dim)
|
55 |
+
|
56 |
+
self.load_sample_init = kwargs.get('load_sample_init', False)
|
57 |
+
self.sample_init_scale = kwargs.get('sample_init_scale', 1.0)
|
58 |
+
self.test_init_with_gtpc = kwargs.get('test_init_with_gtpc', False)
|
59 |
+
|
60 |
+
self.consistent_center = kwargs.get('consistent_center', False)
|
61 |
+
self.cam_noise_std = kwargs.get('cam_noise_std', 0.0) # add noise to camera based on timestamps
|
62 |
+
|
63 |
+
def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim):
|
64 |
+
self.point_cloud_model = PointCloudModel(
|
65 |
+
model_type=point_cloud_model,
|
66 |
+
embed_dim=point_cloud_model_embed_dim,
|
67 |
+
in_channels=self.in_channels,
|
68 |
+
out_channels=self.out_channels, # voxel resolution multiplier is 1.
|
69 |
+
voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1)
|
70 |
+
)
|
71 |
+
|
72 |
+
def forward_train(
|
73 |
+
self,
|
74 |
+
pc: Pointclouds,
|
75 |
+
camera: Optional[CamerasBase],
|
76 |
+
image_rgb: Optional[Tensor],
|
77 |
+
mask: Optional[Tensor],
|
78 |
+
return_intermediate_steps: bool = False,
|
79 |
+
**kwargs
|
80 |
+
):
|
81 |
+
|
82 |
+
# Normalize colors and convert to tensor
|
83 |
+
x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors
|
84 |
+
B, N, D = x_0.shape
|
85 |
+
|
86 |
+
# Sample random noise
|
87 |
+
noise = torch.randn_like(x_0)
|
88 |
+
if self.consistent_center:
|
89 |
+
# modification suggested by https://arxiv.org/pdf/2308.07837.pdf
|
90 |
+
noise = noise - torch.mean(noise, dim=1, keepdim=True)
|
91 |
+
|
92 |
+
# Sample random timesteps for each point_cloud
|
93 |
+
timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,),
|
94 |
+
device=self.device, dtype=torch.long)
|
95 |
+
|
96 |
+
# Add noise to points
|
97 |
+
x_t = self.scheduler.add_noise(x_0, noise, timestep) # diffusion noisy adding, only add to the coordinate, not features
|
98 |
+
|
99 |
+
# add noise to the camera pose, based on timestamps
|
100 |
+
if self.cam_noise_std > 0.000001:
|
101 |
+
# the noise is very different
|
102 |
+
camera = camera.clone()
|
103 |
+
camT = camera.T # (B, 3)
|
104 |
+
dist = torch.sqrt(torch.sum(camT**2, -1, keepdim=True))
|
105 |
+
nratio = timestep[:, None] / self.scheduler.num_train_timesteps # time-dependent noise
|
106 |
+
tnoise = torch.randn(B, 3).to(dist.device)/3. * dist * self.cam_noise_std * nratio
|
107 |
+
camera.T = camera.T + tnoise
|
108 |
+
|
109 |
+
# Conditioning, the pixel-aligned feature is based on points with noise (new points)
|
110 |
+
x_t_input = self.get_diffu_input(camera, image_rgb, mask, timestep, x_t, **kwargs)
|
111 |
+
|
112 |
+
# Forward
|
113 |
+
loss, noise_pred = self.compute_loss(noise, timestep, x_0, x_t_input)
|
114 |
+
|
115 |
+
# Whether to return intermediate steps
|
116 |
+
if return_intermediate_steps:
|
117 |
+
return loss, (x_0, x_t, noise, noise_pred)
|
118 |
+
|
119 |
+
return loss
|
120 |
+
|
121 |
+
def compute_loss(self, noise, timestep, x_0, x_t_input):
|
122 |
+
x_pred = torch.zeros_like(x_0)
|
123 |
+
if self.self_conditioning:
|
124 |
+
# self conditioning, from https://openreview.net/pdf?id=3itjR9QxFw
|
125 |
+
if random.uniform(0, 1.) > 0.5:
|
126 |
+
with torch.no_grad():
|
127 |
+
x_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep)
|
128 |
+
noise_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep)
|
129 |
+
else:
|
130 |
+
noise_pred = self.point_cloud_model(x_t_input, timestep)
|
131 |
+
# Check
|
132 |
+
if not noise_pred.shape == noise.shape:
|
133 |
+
raise ValueError(f'{noise_pred.shape=} and {noise.shape=}')
|
134 |
+
# Loss
|
135 |
+
if self.dm_pred_type == 'epsilon':
|
136 |
+
loss = F.mse_loss(noise_pred, noise)
|
137 |
+
elif self.dm_pred_type == 'sample':
|
138 |
+
loss = F.mse_loss(noise_pred, x_0) # predicting sample
|
139 |
+
else:
|
140 |
+
raise NotImplementedError
|
141 |
+
return loss, noise_pred
|
142 |
+
|
143 |
+
def get_diffu_input(self, camera, image_rgb, mask, timestep, x_t, **kwargs):
|
144 |
+
"return: (B, N, D), the exact input to the diffusion model, x_t: (B, N, 3)"
|
145 |
+
x_t_input = self.get_input_with_conditioning(x_t, camera=camera,
|
146 |
+
image_rgb=image_rgb, mask=mask, t=timestep)
|
147 |
+
return x_t_input
|
148 |
+
|
149 |
+
@torch.no_grad()
|
150 |
+
def forward_sample(
|
151 |
+
self,
|
152 |
+
num_points: int,
|
153 |
+
camera: Optional[CamerasBase],
|
154 |
+
image_rgb: Optional[Tensor],
|
155 |
+
mask: Optional[Tensor],
|
156 |
+
# Optional overrides
|
157 |
+
scheduler: Optional[str] = 'ddpm',
|
158 |
+
# Inference parameters
|
159 |
+
num_inference_steps: Optional[int] = 1000,
|
160 |
+
eta: Optional[float] = 0.0, # for DDIM
|
161 |
+
# Whether to return all the intermediate steps in generation
|
162 |
+
return_sample_every_n_steps: int = -1,
|
163 |
+
# Whether to disable tqdm
|
164 |
+
disable_tqdm: bool = False,
|
165 |
+
gt_pc: Pointclouds = None,
|
166 |
+
**kwargs
|
167 |
+
):
|
168 |
+
|
169 |
+
# Get scheduler from mapping, or use self.scheduler if None
|
170 |
+
scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler]
|
171 |
+
|
172 |
+
# Get the size of the noise
|
173 |
+
N = num_points
|
174 |
+
B = 1 if image_rgb is None else image_rgb.shape[0]
|
175 |
+
D = self.get_x_T_channel()
|
176 |
+
device = self.device if image_rgb is None else image_rgb.device
|
177 |
+
|
178 |
+
sample_from_interm = kwargs.get('sample_from_interm', False)
|
179 |
+
interm_steps = kwargs.get('noise_step') if sample_from_interm else -1
|
180 |
+
x_t = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler)
|
181 |
+
x_pred = torch.zeros_like(x_t)
|
182 |
+
|
183 |
+
# Set timesteps
|
184 |
+
extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler)
|
185 |
+
|
186 |
+
# Loop over timesteps
|
187 |
+
all_outputs = []
|
188 |
+
return_all_outputs = (return_sample_every_n_steps > 0)
|
189 |
+
progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm)
|
190 |
+
|
191 |
+
for i, t in enumerate(progress_bar):
|
192 |
+
add_interm_output = (return_all_outputs and (
|
193 |
+
i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1))
|
194 |
+
# Conditioning
|
195 |
+
x_t_input = self.get_diffu_input(camera, image_rgb, mask, t, x_t, **kwargs)
|
196 |
+
if self.self_conditioning:
|
197 |
+
x_t_input = torch.cat([x_t_input, x_pred], -1) # add self-conditioning
|
198 |
+
inference_binary = (i == len(progress_bar) - 1) | add_interm_output
|
199 |
+
# One reverse step with conditioning
|
200 |
+
x_t = self.reverse_step(extra_step_kwargs, scheduler, t, x_t, x_t_input,
|
201 |
+
inference_binary=inference_binary) # (B, N, D), D=3 or 4
|
202 |
+
x_pred = x_t # for next iteration self conditioning
|
203 |
+
|
204 |
+
# Append to output list if desired
|
205 |
+
if add_interm_output:
|
206 |
+
all_outputs.append(x_t)
|
207 |
+
|
208 |
+
# Convert output back into a point cloud, undoing normalization and scaling
|
209 |
+
output = self.tensor_to_point_cloud(x_t, denormalize=True, unscale=True) # this convert the points back to original scale
|
210 |
+
if return_all_outputs:
|
211 |
+
all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D)
|
212 |
+
all_outputs = [self.tensor_to_point_cloud(o, denormalize=True, unscale=True) for o in all_outputs]
|
213 |
+
|
214 |
+
return (output, all_outputs) if return_all_outputs else output
|
215 |
+
|
216 |
+
def get_x_T_channel(self):
|
217 |
+
D = 3 + (self.color_channels if self.predict_color else 0)
|
218 |
+
return D
|
219 |
+
|
220 |
+
def initialize_x_T(self, device, gt_pc, shape, interm_steps:int=-1, scheduler=None):
|
221 |
+
B, N, D = shape
|
222 |
+
# Sample noise initialization
|
223 |
+
if interm_steps > 0:
|
224 |
+
# Sample from some intermediate steps
|
225 |
+
x_0 = self.point_cloud_to_tensor(gt_pc, normalize=True, scale=True)
|
226 |
+
noise = torch.randn(B, N, D, device=device)
|
227 |
+
|
228 |
+
# always make sure the noise does not change the pc center, this is important to reduce 0.1cm CD!
|
229 |
+
noise = noise - torch.mean(noise, dim=1, keepdim=True)
|
230 |
+
|
231 |
+
x_t = scheduler.add_noise(x_0, noise, torch.tensor([interm_steps - 1] * B).long().to(device)) # Add noise
|
232 |
+
else:
|
233 |
+
# Sample from random Gaussian
|
234 |
+
x_t = torch.randn(B, N, D, device=device)
|
235 |
+
|
236 |
+
x_t = x_t * self.sample_init_scale # for test
|
237 |
+
if self.consistent_center:
|
238 |
+
x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
|
239 |
+
return x_t
|
240 |
+
|
241 |
+
def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
|
242 |
+
"""
|
243 |
+
run one reverse step to compute x_t
|
244 |
+
:param extra_step_kwargs:
|
245 |
+
:param scheduler:
|
246 |
+
:param t: [1], diffusion time step
|
247 |
+
:param x_t: (B, N, 3)
|
248 |
+
:param x_t_input: conditional features (B, N, F)
|
249 |
+
:param kwargs: other configurations to run diffusion step
|
250 |
+
:return: denoised x_t
|
251 |
+
"""
|
252 |
+
B = x_t.shape[0]
|
253 |
+
# Forward
|
254 |
+
noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B))
|
255 |
+
if self.consistent_center:
|
256 |
+
assert self.dm_pred_type != 'sample', 'incompatible dm predition type for CCD!'
|
257 |
+
# suggested by the CCD-3DR paper
|
258 |
+
noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True)
|
259 |
+
# Step
|
260 |
+
x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample
|
261 |
+
if self.consistent_center:
|
262 |
+
x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
|
263 |
+
return x_t
|
264 |
+
|
265 |
+
def setup_reverse_process(self, eta, num_inference_steps, scheduler):
|
266 |
+
"""
|
267 |
+
setup diffusion chain, and others.
|
268 |
+
"""
|
269 |
+
accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
270 |
+
extra_set_kwargs = {"offset": 1} if accepts_offset else {}
|
271 |
+
scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
272 |
+
# Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
273 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
274 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
275 |
+
# and should be between [0, 1]
|
276 |
+
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
|
277 |
+
extra_step_kwargs = {"eta": eta} if accepts_eta else {}
|
278 |
+
return extra_step_kwargs
|
279 |
+
|
280 |
+
def forward(self, batch: FrameData, mode: str = 'train', **kwargs):
|
281 |
+
"""
|
282 |
+
A wrapper around the forward method for training and inference
|
283 |
+
"""
|
284 |
+
if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict
|
285 |
+
batch = FrameData(**batch) # it really makes no sense, I do not understand it
|
286 |
+
|
287 |
+
if mode == 'train':
|
288 |
+
return self.forward_train(
|
289 |
+
pc=batch.sequence_point_cloud,
|
290 |
+
camera=batch.camera,
|
291 |
+
image_rgb=batch.image_rgb,
|
292 |
+
mask=batch.fg_probability,
|
293 |
+
**kwargs)
|
294 |
+
elif mode == 'sample':
|
295 |
+
num_points = kwargs.pop('num_points', get_num_points(batch.sequence_point_cloud))
|
296 |
+
return self.forward_sample(
|
297 |
+
num_points=num_points,
|
298 |
+
camera=batch.camera,
|
299 |
+
image_rgb=batch.image_rgb,
|
300 |
+
mask=batch.fg_probability,
|
301 |
+
**kwargs)
|
302 |
+
else:
|
303 |
+
raise NotImplementedError()
|
model/model_coloring.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
|
6 |
+
from pytorch3d.renderer.cameras import CamerasBase
|
7 |
+
from pytorch3d.structures import Pointclouds
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from .point_cloud_transformer_model import PointCloudTransformerModel
|
11 |
+
from .projection_model import PointCloudProjectionModel
|
12 |
+
|
13 |
+
class PointCloudColoringModel(PointCloudProjectionModel):
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
point_cloud_model: str,
|
18 |
+
point_cloud_model_layers: int,
|
19 |
+
point_cloud_model_embed_dim: int,
|
20 |
+
**kwargs, # projection arguments
|
21 |
+
):
|
22 |
+
super().__init__(**kwargs)
|
23 |
+
|
24 |
+
# Checks
|
25 |
+
if self.predict_shape or not self.predict_color:
|
26 |
+
raise NotImplementedError('Must predict color, not shape, for coloring')
|
27 |
+
|
28 |
+
# Create point cloud model for processing point cloud
|
29 |
+
self.point_cloud_model = PointCloudTransformerModel(
|
30 |
+
num_layers=point_cloud_model_layers,
|
31 |
+
model_type=point_cloud_model,
|
32 |
+
embed_dim=point_cloud_model_embed_dim,
|
33 |
+
in_channels=self.in_channels,
|
34 |
+
out_channels=self.out_channels,
|
35 |
+
) # why use transformer instead???
|
36 |
+
|
37 |
+
def _forward(
|
38 |
+
self,
|
39 |
+
pc: Pointclouds,
|
40 |
+
camera: Optional[CamerasBase],
|
41 |
+
image_rgb: Optional[Tensor],
|
42 |
+
mask: Optional[Tensor],
|
43 |
+
return_point_cloud: bool = False,
|
44 |
+
noise_std: float = 0.0,
|
45 |
+
):
|
46 |
+
|
47 |
+
# Normalize colors and convert to tensor
|
48 |
+
x = self.point_cloud_to_tensor(pc, normalize=True, scale=True)
|
49 |
+
x_points, x_colors = x[:, :, :3], x[:, :, 3:]
|
50 |
+
|
51 |
+
# Add noise to points. TODO: Add to configs.
|
52 |
+
x_input = x_points + torch.randn_like(x_points) * noise_std # simulate noise of the predicted pc?
|
53 |
+
|
54 |
+
# Conditioning
|
55 |
+
# x_input = self.get_input_with_conditioning(x_input, camera=camera,
|
56 |
+
# image_rgb=image_rgb, mask=mask)
|
57 |
+
# XH: edit to run
|
58 |
+
x_input = self.get_input_with_conditioning(x_input, camera=camera,
|
59 |
+
image_rgb=image_rgb, mask=mask, t=None)
|
60 |
+
|
61 |
+
# Forward
|
62 |
+
pred_colors = self.point_cloud_model(x_input)
|
63 |
+
|
64 |
+
# During inference, we return the point cloud with the predicted colors
|
65 |
+
if return_point_cloud:
|
66 |
+
pred_pointcloud = self.tensor_to_point_cloud(
|
67 |
+
torch.cat((x_points, pred_colors), dim=2), denormalize=True, unscale=True)
|
68 |
+
return pred_pointcloud
|
69 |
+
|
70 |
+
# During training, we have ground truth colors and return the loss
|
71 |
+
loss = F.mse_loss(pred_colors, x_colors)
|
72 |
+
return loss
|
73 |
+
|
74 |
+
def forward(self, batch: FrameData, **kwargs):
|
75 |
+
"""A wrapper around the forward method"""
|
76 |
+
if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict
|
77 |
+
batch = FrameData(**batch) # it really makes no sense, I do not understand it
|
78 |
+
return self._forward(
|
79 |
+
pc=batch.sequence_point_cloud,
|
80 |
+
camera=batch.camera,
|
81 |
+
image_rgb=batch.image_rgb,
|
82 |
+
mask=batch.fg_probability,
|
83 |
+
**kwargs,
|
84 |
+
)
|
model/model_diff_data.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
model to deal with shapenet inputs and other datasets such as Behave and ProciGen
|
3 |
+
the model takes a different data dictionary in forward function
|
4 |
+
"""
|
5 |
+
import inspect
|
6 |
+
from typing import Optional
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
12 |
+
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
13 |
+
from diffusers.schedulers.scheduling_pndm import PNDMScheduler
|
14 |
+
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
|
15 |
+
from pytorch3d.renderer.cameras import CamerasBase
|
16 |
+
from pytorch3d.structures import Pointclouds
|
17 |
+
from torch import Tensor
|
18 |
+
from tqdm import tqdm
|
19 |
+
from pytorch3d.renderer import PerspectiveCameras
|
20 |
+
from pytorch3d.datasets.r2n2.utils import BlenderCamera
|
21 |
+
|
22 |
+
|
23 |
+
from .model import ConditionalPointCloudDiffusionModel
|
24 |
+
from .model_utils import get_num_points
|
25 |
+
|
26 |
+
|
27 |
+
class ConditionalPCDiffusionShapenet(ConditionalPointCloudDiffusionModel):
|
28 |
+
def forward(self, batch, mode: str = 'train', **kwargs):
|
29 |
+
"""
|
30 |
+
take a batch of data from ShapeNet
|
31 |
+
"""
|
32 |
+
images = torch.stack(batch['images'], 0).to('cuda')
|
33 |
+
masks = torch.stack(batch['masks'], 0).to('cuda')
|
34 |
+
pc = Pointclouds([x.to('cuda') for x in batch['pclouds']])
|
35 |
+
camera = BlenderCamera(
|
36 |
+
torch.stack(batch['R']),
|
37 |
+
torch.stack(batch['T']),
|
38 |
+
torch.stack(batch['K']), device='cuda'
|
39 |
+
)
|
40 |
+
|
41 |
+
if mode == 'train':
|
42 |
+
return self.forward_train(
|
43 |
+
pc=pc,
|
44 |
+
camera=camera,
|
45 |
+
image_rgb=images,
|
46 |
+
mask=masks,
|
47 |
+
|
48 |
+
**kwargs)
|
49 |
+
elif mode == 'sample':
|
50 |
+
num_points = kwargs.pop('num_points', get_num_points(pc))
|
51 |
+
return self.forward_sample(
|
52 |
+
num_points=num_points,
|
53 |
+
camera=camera,
|
54 |
+
image_rgb=images,
|
55 |
+
mask=masks,
|
56 |
+
gt_pc=pc,
|
57 |
+
**kwargs)
|
58 |
+
else:
|
59 |
+
raise NotImplementedError()
|
60 |
+
|
61 |
+
|
62 |
+
class ConditionalPCDiffusionBehave(ConditionalPointCloudDiffusionModel):
|
63 |
+
"diffusion model for Behave dataset"
|
64 |
+
def forward(self, batch, mode: str = 'train', **kwargs):
|
65 |
+
images = torch.stack(batch['images'], 0).to('cuda')
|
66 |
+
masks = torch.stack(batch['masks'], 0).to('cuda')
|
67 |
+
pc = self.get_input_pc(batch)
|
68 |
+
camera = PerspectiveCameras(
|
69 |
+
R=torch.stack(batch['R']),
|
70 |
+
T=torch.stack(batch['T']),
|
71 |
+
K=torch.stack(batch['K']),
|
72 |
+
device='cuda',
|
73 |
+
in_ndc=True
|
74 |
+
)
|
75 |
+
grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None
|
76 |
+
num_points = kwargs.pop('num_points', get_num_points(pc))
|
77 |
+
if mode == 'train':
|
78 |
+
return self.forward_train(
|
79 |
+
pc=pc,
|
80 |
+
camera=camera,
|
81 |
+
image_rgb=images,
|
82 |
+
mask=masks,
|
83 |
+
grid_df=grid_df,
|
84 |
+
**kwargs)
|
85 |
+
elif mode == 'sample':
|
86 |
+
return self.forward_sample(
|
87 |
+
num_points=num_points,
|
88 |
+
camera=camera,
|
89 |
+
image_rgb=images,
|
90 |
+
mask=masks,
|
91 |
+
gt_pc=pc,
|
92 |
+
**kwargs)
|
93 |
+
else:
|
94 |
+
raise NotImplementedError()
|
95 |
+
|
96 |
+
def get_input_pc(self, batch):
|
97 |
+
pc = Pointclouds([x.to('cuda') for x in batch['pclouds']])
|
98 |
+
return pc
|
99 |
+
|
100 |
+
|
101 |
+
class ConditionalPCDiffusionSeparateSegm(ConditionalPCDiffusionBehave):
|
102 |
+
"a separate model to predict binary labels, the final segmentation model"
|
103 |
+
def __init__(self,
|
104 |
+
beta_start: float,
|
105 |
+
beta_end: float,
|
106 |
+
beta_schedule: str,
|
107 |
+
point_cloud_model: str,
|
108 |
+
point_cloud_model_embed_dim: int,
|
109 |
+
**kwargs, # projection arguments
|
110 |
+
):
|
111 |
+
super(ConditionalPCDiffusionSeparateSegm, self).__init__(beta_start, beta_end, beta_schedule,
|
112 |
+
point_cloud_model,
|
113 |
+
point_cloud_model_embed_dim, **kwargs)
|
114 |
+
# add a separate model to predict binary label
|
115 |
+
from .point_cloud_transformer_model import PointCloudTransformerModel, PointCloudModel
|
116 |
+
|
117 |
+
self.binary_model = PointCloudTransformerModel(
|
118 |
+
num_layers=1, # XH: use the default color model number of layers
|
119 |
+
model_type=point_cloud_model, # pvcnn
|
120 |
+
embed_dim=point_cloud_model_embed_dim, # save as pc shape model
|
121 |
+
in_channels=self.in_channels,
|
122 |
+
out_channels=1,
|
123 |
+
)
|
124 |
+
self.binary_training_noise_std = kwargs.get("binary_training_noise_std", 0.1)
|
125 |
+
|
126 |
+
# re-initialize point cloud model
|
127 |
+
assert self.predict_binary
|
128 |
+
self.point_cloud_model = PointCloudModel(
|
129 |
+
model_type=point_cloud_model,
|
130 |
+
embed_dim=point_cloud_model_embed_dim,
|
131 |
+
in_channels=self.in_channels,
|
132 |
+
out_channels=self.out_channels - 1, # not predicting binary from this anymore
|
133 |
+
voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1)
|
134 |
+
)
|
135 |
+
|
136 |
+
def forward_train(
|
137 |
+
self,
|
138 |
+
pc: Pointclouds,
|
139 |
+
camera: Optional[CamerasBase],
|
140 |
+
image_rgb: Optional[Tensor],
|
141 |
+
mask: Optional[Tensor],
|
142 |
+
return_intermediate_steps: bool = False,
|
143 |
+
**kwargs
|
144 |
+
):
|
145 |
+
# first run shape forward, then binary label forward
|
146 |
+
assert not return_intermediate_steps
|
147 |
+
assert self.predict_binary
|
148 |
+
loss_shape = super(ConditionalPCDiffusionSeparateSegm, self).forward_train(pc,
|
149 |
+
camera,
|
150 |
+
image_rgb,
|
151 |
+
mask,
|
152 |
+
return_intermediate_steps,
|
153 |
+
**kwargs)
|
154 |
+
|
155 |
+
# binary label forward
|
156 |
+
x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True)
|
157 |
+
x_points, x_colors = x_0[:, :, :3], x_0[:, :, 3:]
|
158 |
+
|
159 |
+
# Add noise to points.
|
160 |
+
x_input = x_points + torch.randn_like(x_points) * self.binary_training_noise_std # std=0.1
|
161 |
+
x_input = self.get_input_with_conditioning(x_input, camera=camera,
|
162 |
+
image_rgb=image_rgb, mask=mask, t=None)
|
163 |
+
|
164 |
+
# Forward
|
165 |
+
pred_segm = self.binary_model(x_input)
|
166 |
+
|
167 |
+
# use compressed bits
|
168 |
+
df_grid = kwargs.get('grid_df', None).unsqueeze(1) # (B, 1, resz, resy, resx)
|
169 |
+
points = x_points.clone().detach() / self.scale_factor * 2 # , normalize to [-1, 1]
|
170 |
+
points[:, :, 0], points[:, :, 2] = points[:, :, 2].clone(), points[:, :,0].clone() # swap, make sure clone is used!
|
171 |
+
points = points.unsqueeze(1).unsqueeze(1) # (B,1, 1, N, 3)
|
172 |
+
with torch.no_grad():
|
173 |
+
df_interp = F.grid_sample(df_grid, points, padding_mode='border', align_corners=True).squeeze(1).squeeze(1) # (B, 1, 1, 1, N)
|
174 |
+
binary_label = df_interp[:, 0] > 0.5 # (B, 1, N)
|
175 |
+
|
176 |
+
binary_pred = torch.sigmoid(pred_segm.squeeze(-1)) # add a sigmoid layer
|
177 |
+
loss_binary = F.mse_loss(binary_pred, binary_label.float().squeeze(1).squeeze(1)) * self.lw_binary
|
178 |
+
loss = loss_shape + loss_binary
|
179 |
+
|
180 |
+
return loss, torch.tensor([loss_shape, loss_binary])
|
181 |
+
|
182 |
+
def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
|
183 |
+
"return (B, N, 4), the 4-th channel is binary label"
|
184 |
+
B = x_t.shape[0]
|
185 |
+
# Forward
|
186 |
+
noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B))
|
187 |
+
if self.consistent_center:
|
188 |
+
assert self.dm_pred_type != 'sample', 'incompatible dm predition type!'
|
189 |
+
# suggested by the CCD-3DR paper
|
190 |
+
noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True)
|
191 |
+
# Step: make sure only update the shape (first 3 channels)
|
192 |
+
x_t = scheduler.step(noise_pred, t, x_t[:, :, :3], **extra_step_kwargs).prev_sample
|
193 |
+
if self.consistent_center:
|
194 |
+
x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
|
195 |
+
|
196 |
+
# also add binary prediction
|
197 |
+
if kwargs.get('inference_binary', False):
|
198 |
+
pred_segm = self.binary_model(x_t_input)
|
199 |
+
else:
|
200 |
+
pred_segm = torch.zeros_like(x_t[:, :, 0:1])
|
201 |
+
|
202 |
+
x_t = torch.cat([x_t, torch.sigmoid(pred_segm)], -1)
|
203 |
+
|
204 |
+
return x_t
|
205 |
+
|
206 |
+
def get_coord_feature(self, x_t):
|
207 |
+
x_t_input = [x_t[:, :, :3]]
|
208 |
+
return x_t_input
|
209 |
+
|
210 |
+
def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False):
|
211 |
+
"""
|
212 |
+
take binary label into account
|
213 |
+
:param self:
|
214 |
+
:param x: (B, N, 4), the 4th channel is the binary segmentation, 1-human, 0-object
|
215 |
+
:param denormalize: denormalize the per-point colors, from pc2
|
216 |
+
:param unscale: undo point scaling, from pc2
|
217 |
+
:return: pc with point colors if predict binary label or per-point color
|
218 |
+
"""
|
219 |
+
points = x[:, :, :3] / (self.scale_factor if unscale else 1)
|
220 |
+
if self.predict_color:
|
221 |
+
colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:]
|
222 |
+
return Pointclouds(points=points, features=colors)
|
223 |
+
else:
|
224 |
+
if self.predict_binary:
|
225 |
+
assert x.shape[2] == 4
|
226 |
+
# add color to predicted binary labels
|
227 |
+
is_hum = x[:, :, 3] > 0.5
|
228 |
+
features = []
|
229 |
+
for mask in is_hum:
|
230 |
+
color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device)
|
231 |
+
color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green
|
232 |
+
features.append(color)
|
233 |
+
else:
|
234 |
+
assert x.shape[2] == 3
|
235 |
+
features = None
|
236 |
+
return Pointclouds(points=points, features=features)
|
237 |
+
|
238 |
+
|
model/model_hoattn.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
model that use cross attention to predict human + object
|
3 |
+
"""
|
4 |
+
|
5 |
+
import inspect
|
6 |
+
import random
|
7 |
+
from typing import Optional
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from pytorch3d.structures import Pointclouds
|
13 |
+
from pytorch3d.renderer import CamerasBase
|
14 |
+
from .model_diff_data import ConditionalPCDiffusionBehave
|
15 |
+
from .pvcnn.pvcnn_ho import PVCNN2HumObj
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from pytorch3d.renderer import PerspectiveCameras
|
18 |
+
from .model_utils import get_num_points
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
|
22 |
+
class CrossAttenHODiffusionModel(ConditionalPCDiffusionBehave):
|
23 |
+
def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim):
|
24 |
+
"""use cross attention model"""
|
25 |
+
if point_cloud_model == 'pvcnn':
|
26 |
+
self.point_cloud_model = PVCNN2HumObj(embed_dim=point_cloud_model_embed_dim,
|
27 |
+
num_classes=self.out_channels,
|
28 |
+
extra_feature_channels=(self.in_channels - 3),
|
29 |
+
voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1),
|
30 |
+
attn_type=kwargs.get('attn_type', 'simple-cross'),
|
31 |
+
attn_weight=kwargs.get("attn_weight", 1.0)
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
raise ValueError(f"Unknown point cloud model {point_cloud_model}!")
|
35 |
+
self.point_visible_test = kwargs.get("point_visible_test", 'single') # when doing point visibility test, use only human points or human + object?
|
36 |
+
assert self.point_visible_test in ['single', 'combine'], f'invalide point visible test option {self.point_visible_test}'
|
37 |
+
# print(f"Point visibility test is based on {self.point_visible_test} point clouds!")
|
38 |
+
|
39 |
+
def forward_train(
|
40 |
+
self,
|
41 |
+
pc: Pointclouds,
|
42 |
+
camera: Optional[CamerasBase],
|
43 |
+
image_rgb: Optional[Tensor],
|
44 |
+
mask: Optional[Tensor],
|
45 |
+
return_intermediate_steps: bool = False,
|
46 |
+
**kwargs
|
47 |
+
):
|
48 |
+
"additional input (RGB, mask, camera, and pc) for object is read from kwargs"
|
49 |
+
# assert not self.consistent_center
|
50 |
+
assert not self.self_conditioning
|
51 |
+
|
52 |
+
# Normalize colors and convert to tensor
|
53 |
+
x0_h = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors
|
54 |
+
x0_o = self.point_cloud_to_tensor(kwargs.get('pc_obj'), normalize=True, scale=True)
|
55 |
+
B, N, D = x0_h.shape
|
56 |
+
|
57 |
+
# Sample random noise
|
58 |
+
noise = torch.randn_like(x0_h)
|
59 |
+
if self.consistent_center:
|
60 |
+
# modification suggested by https://arxiv.org/pdf/2308.07837.pdf
|
61 |
+
noise = noise - torch.mean(noise, dim=1, keepdim=True)
|
62 |
+
|
63 |
+
# Sample random timesteps for each point_cloud
|
64 |
+
timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,),
|
65 |
+
device=self.device, dtype=torch.long)
|
66 |
+
# timestep = torch.randint(0, 1, (B,),
|
67 |
+
# device=self.device, dtype=torch.long)
|
68 |
+
|
69 |
+
# Add noise to points
|
70 |
+
xt_h = self.scheduler.add_noise(x0_h, noise, timestep)
|
71 |
+
xt_o = self.scheduler.add_noise(x0_o, noise, timestep)
|
72 |
+
norm_parms = self.pack_norm_params(kwargs) # (2, B, 4)
|
73 |
+
|
74 |
+
# get input conditioning
|
75 |
+
x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb, kwargs, mask, norm_parms, timestep,
|
76 |
+
xt_h, xt_o)
|
77 |
+
|
78 |
+
# Diffusion prediction
|
79 |
+
noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input_h, x_t_input_o, timestep, norm_parms)
|
80 |
+
|
81 |
+
# Check
|
82 |
+
if not noise_pred_h.shape == noise.shape:
|
83 |
+
raise ValueError(f'{noise_pred_h.shape=} and {noise.shape=}')
|
84 |
+
if not noise_pred_o.shape == noise.shape:
|
85 |
+
raise ValueError(f'{noise_pred_o.shape=} and {noise.shape=}')
|
86 |
+
|
87 |
+
# Loss
|
88 |
+
loss_h = F.mse_loss(noise_pred_h, noise)
|
89 |
+
loss_o = F.mse_loss(noise_pred_o, noise)
|
90 |
+
|
91 |
+
loss = loss_h + loss_o
|
92 |
+
|
93 |
+
# Whether to return intermediate steps
|
94 |
+
if return_intermediate_steps:
|
95 |
+
return loss, (x0_h, xt_h, noise, noise_pred_h)
|
96 |
+
|
97 |
+
return loss, torch.tensor([loss_h, loss_o])
|
98 |
+
|
99 |
+
def get_image_conditioning(self, camera, image_rgb, kwargs, mask, norm_parms, timestep, xt_h, xt_o):
|
100 |
+
"""
|
101 |
+
compute image features for each point
|
102 |
+
:param camera:
|
103 |
+
:param image_rgb:
|
104 |
+
:param kwargs:
|
105 |
+
:param mask:
|
106 |
+
:param norm_parms:
|
107 |
+
:param timestep:
|
108 |
+
:param xt_h:
|
109 |
+
:param xt_o:
|
110 |
+
:return:
|
111 |
+
"""
|
112 |
+
if self.point_visible_test == 'single':
|
113 |
+
# Visibility test is down independently for human and object
|
114 |
+
x_t_input_h = self.get_input_with_conditioning(xt_h, camera=camera,
|
115 |
+
image_rgb=image_rgb, mask=mask, t=timestep)
|
116 |
+
x_t_input_o = self.get_input_with_conditioning(xt_o, camera=kwargs.get('camera_obj'),
|
117 |
+
image_rgb=kwargs.get('rgb_obj'),
|
118 |
+
mask=kwargs.get('mask_obj'), t=timestep)
|
119 |
+
elif self.point_visible_test == 'combine':
|
120 |
+
# Combine human + object points to do visibility test and obtain features
|
121 |
+
B, N = xt_h.shape[:2] # (B, N, 3)
|
122 |
+
# for human: transform object points first to H+O space, then to human space
|
123 |
+
xt_o_in_ho = xt_o * 2 * norm_parms[1, :, 3:].unsqueeze(1) + norm_parms[1, :, :3].unsqueeze(1)
|
124 |
+
xt_o_in_hum = (xt_o_in_ho - norm_parms[0, :, :3].unsqueeze(1)) / (2 * norm_parms[0, :, 3:].unsqueeze(1))
|
125 |
+
# compute features for all points, take only first half feature for human
|
126 |
+
x_t_input_h = self.get_input_with_conditioning(torch.cat([xt_h, xt_o_in_hum], 1), camera=camera,
|
127 |
+
image_rgb=image_rgb, mask=mask, t=timestep)[:,:N]
|
128 |
+
# for object: transform human points to H+O space, then to object space
|
129 |
+
xt_h_in_ho = xt_h * 2 * norm_parms[0, :, 3:].unsqueeze(1) + norm_parms[0, :, :3].unsqueeze(1)
|
130 |
+
xt_h_in_obj = (xt_h_in_ho - norm_parms[1, :, :3].unsqueeze(1)) / (2 * norm_parms[1, :, 3:].unsqueeze(1))
|
131 |
+
x_t_input_o = self.get_input_with_conditioning(torch.cat([xt_o, xt_h_in_obj], 1),
|
132 |
+
camera=kwargs.get('camera_obj'),
|
133 |
+
image_rgb=kwargs.get('rgb_obj'),
|
134 |
+
mask=kwargs.get('mask_obj'), t=timestep)[:, :N]
|
135 |
+
else:
|
136 |
+
raise NotImplementedError
|
137 |
+
return x_t_input_h, x_t_input_o
|
138 |
+
|
139 |
+
def forward(self, batch, mode: str = 'train', **kwargs):
|
140 |
+
""""""
|
141 |
+
images = torch.stack(batch['images'], 0).to('cuda')
|
142 |
+
masks = torch.stack(batch['masks'], 0).to('cuda')
|
143 |
+
pc = self.get_input_pc(batch)
|
144 |
+
camera = PerspectiveCameras(
|
145 |
+
R=torch.stack(batch['R']),
|
146 |
+
T=torch.stack(batch['T_hum']),
|
147 |
+
K=torch.stack(batch['K_hum']),
|
148 |
+
device='cuda',
|
149 |
+
in_ndc=True
|
150 |
+
)
|
151 |
+
grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None
|
152 |
+
num_points = kwargs.pop('num_points', get_num_points(pc))
|
153 |
+
|
154 |
+
rgb_obj = torch.stack(batch['images_obj'], 0).to('cuda')
|
155 |
+
masks_obj = torch.stack(batch['masks_obj'], 0).to('cuda')
|
156 |
+
pc_obj = Pointclouds([x.to('cuda') for x in batch['pclouds_obj']])
|
157 |
+
camera_obj = PerspectiveCameras(
|
158 |
+
R=torch.stack(batch['R']),
|
159 |
+
T=torch.stack(batch['T_obj']),
|
160 |
+
K=torch.stack(batch['K_obj']),
|
161 |
+
device='cuda',
|
162 |
+
in_ndc=True
|
163 |
+
)
|
164 |
+
|
165 |
+
# normalization parameters
|
166 |
+
cent_hum = torch.stack(batch['cent_hum'], 0).to('cuda')
|
167 |
+
cent_obj = torch.stack(batch['cent_obj'], 0).to('cuda') # B, 3
|
168 |
+
radius_hum = torch.stack(batch['radius_hum'], 0).to('cuda') # B, 1
|
169 |
+
radius_obj = torch.stack(batch['radius_obj'], 0).to('cuda')
|
170 |
+
|
171 |
+
# print(batch['image_path'])
|
172 |
+
|
173 |
+
if mode == 'train':
|
174 |
+
return self.forward_train(
|
175 |
+
pc=pc,
|
176 |
+
camera=camera,
|
177 |
+
image_rgb=images,
|
178 |
+
mask=masks,
|
179 |
+
grid_df=grid_df,
|
180 |
+
rgb_obj=rgb_obj,
|
181 |
+
mask_obj=masks_obj,
|
182 |
+
pc_obj=pc_obj,
|
183 |
+
camera_obj=camera_obj,
|
184 |
+
cent_hum=cent_hum,
|
185 |
+
cent_obj=cent_obj,
|
186 |
+
radius_hum=radius_hum,
|
187 |
+
radius_obj=radius_obj,
|
188 |
+
)
|
189 |
+
elif mode == 'sample':
|
190 |
+
# this use GT centers to do projection
|
191 |
+
return self.forward_sample(
|
192 |
+
num_points=num_points,
|
193 |
+
camera=camera,
|
194 |
+
image_rgb=images,
|
195 |
+
mask=masks,
|
196 |
+
gt_pc=pc,
|
197 |
+
rgb_obj=rgb_obj,
|
198 |
+
mask_obj=masks_obj,
|
199 |
+
pc_obj=pc_obj,
|
200 |
+
camera_obj=camera_obj,
|
201 |
+
cent_hum=cent_hum,
|
202 |
+
cent_obj=cent_obj,
|
203 |
+
radius_hum=radius_hum,
|
204 |
+
radius_obj=radius_obj,
|
205 |
+
**kwargs)
|
206 |
+
elif mode == 'interm-gt':
|
207 |
+
return self.forward_sample(
|
208 |
+
num_points=num_points,
|
209 |
+
camera=camera,
|
210 |
+
image_rgb=images,
|
211 |
+
mask=masks,
|
212 |
+
gt_pc=pc,
|
213 |
+
rgb_obj=rgb_obj,
|
214 |
+
mask_obj=masks_obj,
|
215 |
+
pc_obj=pc_obj,
|
216 |
+
camera_obj=camera_obj,
|
217 |
+
cent_hum=cent_hum,
|
218 |
+
cent_obj=cent_obj,
|
219 |
+
radius_hum=radius_hum,
|
220 |
+
radius_obj=radius_obj,
|
221 |
+
sample_from_interm=True,
|
222 |
+
**kwargs)
|
223 |
+
elif mode == 'interm-pred':
|
224 |
+
# use camera from predicted
|
225 |
+
camera = PerspectiveCameras(
|
226 |
+
R=torch.stack(batch['R']),
|
227 |
+
T=torch.stack(batch['T_hum_scaled']),
|
228 |
+
K=torch.stack(batch['K_hum']),
|
229 |
+
device='cuda',
|
230 |
+
in_ndc=True
|
231 |
+
)
|
232 |
+
camera_obj = PerspectiveCameras(
|
233 |
+
R=torch.stack(batch['R']),
|
234 |
+
T=torch.stack(batch['T_obj_scaled']),
|
235 |
+
K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!!
|
236 |
+
device='cuda',
|
237 |
+
in_ndc=True
|
238 |
+
)
|
239 |
+
# use pc from predicted
|
240 |
+
pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']])
|
241 |
+
pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']])
|
242 |
+
# use center and radius from predicted
|
243 |
+
cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda')
|
244 |
+
cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3
|
245 |
+
radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1
|
246 |
+
radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda')
|
247 |
+
|
248 |
+
return self.forward_sample(
|
249 |
+
num_points=num_points,
|
250 |
+
camera=camera,
|
251 |
+
image_rgb=images,
|
252 |
+
mask=masks,
|
253 |
+
gt_pc=pc,
|
254 |
+
rgb_obj=rgb_obj,
|
255 |
+
mask_obj=masks_obj,
|
256 |
+
pc_obj=pc_obj,
|
257 |
+
camera_obj=camera_obj,
|
258 |
+
cent_hum=cent_hum,
|
259 |
+
cent_obj=cent_obj,
|
260 |
+
radius_hum=radius_hum,
|
261 |
+
radius_obj=radius_obj,
|
262 |
+
sample_from_interm=True,
|
263 |
+
**kwargs)
|
264 |
+
elif mode == 'interm-pred-ts':
|
265 |
+
# use only estimate translation and scale, but sample from gaussian
|
266 |
+
# this works, the camera is GT!!!
|
267 |
+
pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']])
|
268 |
+
pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']])
|
269 |
+
# use center and radius from predicted
|
270 |
+
cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda')
|
271 |
+
cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3
|
272 |
+
radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1
|
273 |
+
radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda')
|
274 |
+
# print(cent_hum[0], radius_hum[0], cent_obj[0], radius_obj[0])
|
275 |
+
|
276 |
+
return self.forward_sample(
|
277 |
+
num_points=num_points,
|
278 |
+
camera=camera,
|
279 |
+
image_rgb=images,
|
280 |
+
mask=masks,
|
281 |
+
gt_pc=pc,
|
282 |
+
rgb_obj=rgb_obj,
|
283 |
+
mask_obj=masks_obj,
|
284 |
+
pc_obj=pc_obj,
|
285 |
+
camera_obj=camera_obj,
|
286 |
+
cent_hum=cent_hum,
|
287 |
+
cent_obj=cent_obj,
|
288 |
+
radius_hum=radius_hum,
|
289 |
+
radius_obj=radius_obj,
|
290 |
+
sample_from_interm=False,
|
291 |
+
**kwargs)
|
292 |
+
else:
|
293 |
+
raise NotImplementedError
|
294 |
+
|
295 |
+
def forward_sample(
|
296 |
+
self,
|
297 |
+
num_points: int,
|
298 |
+
camera: Optional[CamerasBase],
|
299 |
+
image_rgb: Optional[Tensor],
|
300 |
+
mask: Optional[Tensor],
|
301 |
+
# Optional overrides
|
302 |
+
scheduler: Optional[str] = 'ddpm',
|
303 |
+
# Inference parameters
|
304 |
+
num_inference_steps: Optional[int] = 1000,
|
305 |
+
eta: Optional[float] = 0.0, # for DDIM
|
306 |
+
# Whether to return all the intermediate steps in generation
|
307 |
+
return_sample_every_n_steps: int = -1,
|
308 |
+
# Whether to disable tqdm
|
309 |
+
disable_tqdm: bool = False,
|
310 |
+
gt_pc: Pointclouds = None,
|
311 |
+
**kwargs
|
312 |
+
):
|
313 |
+
"use two models to run diffusion forward, and also use translation and scale to put them back"
|
314 |
+
assert not self.self_conditioning
|
315 |
+
# Get scheduler from mapping, or use self.scheduler if None
|
316 |
+
scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler]
|
317 |
+
|
318 |
+
# Get the size of the noise
|
319 |
+
N = num_points
|
320 |
+
B = 1 if image_rgb is None else image_rgb.shape[0]
|
321 |
+
D = self.get_x_T_channel()
|
322 |
+
device = self.device if image_rgb is None else image_rgb.device
|
323 |
+
|
324 |
+
# sample from full steps or only a few steps
|
325 |
+
sample_from_interm = kwargs.get('sample_from_interm', False)
|
326 |
+
interm_steps = kwargs.get('noise_step') if sample_from_interm else -1
|
327 |
+
|
328 |
+
xt_h = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler)
|
329 |
+
xt_o = self.initialize_x_T(device, kwargs.get('pc_obj', None), (B, N, D), interm_steps, scheduler)
|
330 |
+
|
331 |
+
# the segmentation mask
|
332 |
+
segm_mask = torch.zeros(B, 2*N, 1).to(device)
|
333 |
+
segm_mask[:, :N] = 1.0
|
334 |
+
|
335 |
+
# Set timesteps
|
336 |
+
extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler)
|
337 |
+
|
338 |
+
# Loop over timesteps
|
339 |
+
all_outputs = []
|
340 |
+
return_all_outputs = (return_sample_every_n_steps > 0)
|
341 |
+
progress_bar = tqdm(self.get_reverse_timesteps(scheduler, interm_steps),
|
342 |
+
desc=f'Sampling ({xt_h.shape})', disable=disable_tqdm)
|
343 |
+
|
344 |
+
# print("Camera T:", camera.T[0], camera.R[0])
|
345 |
+
# print("Camera_obj T:", kwargs.get('camera_obj').T[0], kwargs.get('camera_obj').R[0])
|
346 |
+
|
347 |
+
norm_parms = self.pack_norm_params(kwargs)
|
348 |
+
for i, t in enumerate(progress_bar):
|
349 |
+
x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb,
|
350 |
+
kwargs, mask,
|
351 |
+
norm_parms,
|
352 |
+
t,
|
353 |
+
xt_h, xt_o)
|
354 |
+
|
355 |
+
# One reverse step with conditioning
|
356 |
+
xt_h, xt_o = self.reverse_step(extra_step_kwargs, scheduler, t, torch.stack([xt_h, xt_o], 0),
|
357 |
+
torch.stack([x_t_input_h, x_t_input_o], 0), **kwargs) # (B, N, D), D=3
|
358 |
+
|
359 |
+
if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)):
|
360 |
+
# print(xt_h.shape, kwargs.get('cent_hum').shape, kwargs.get('radius_hum').shape)
|
361 |
+
x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')),
|
362 |
+
self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1)
|
363 |
+
# print(x_t.shape, xt_o.shape)
|
364 |
+
all_outputs.append(torch.cat([x_t, segm_mask], -1))
|
365 |
+
# print("Updating intermediate...")
|
366 |
+
|
367 |
+
# Convert output back into a point cloud, undoing normalization and scaling
|
368 |
+
x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')),
|
369 |
+
self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1)
|
370 |
+
x_t = torch.cat([x_t, segm_mask], -1)
|
371 |
+
output = self.tensor_to_point_cloud(x_t, denormalize=False, unscale=False) # this convert the points back to original scale
|
372 |
+
if return_all_outputs:
|
373 |
+
all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D)
|
374 |
+
all_outputs = [self.tensor_to_point_cloud(o, denormalize=False, unscale=False) for o in all_outputs]
|
375 |
+
|
376 |
+
return (output, all_outputs) if return_all_outputs else output
|
377 |
+
|
378 |
+
def get_reverse_timesteps(self, scheduler, interm_steps:int):
|
379 |
+
"""
|
380 |
+
|
381 |
+
:param scheduler:
|
382 |
+
:param interm_steps: start from some intermediate steps
|
383 |
+
:return:
|
384 |
+
"""
|
385 |
+
if interm_steps > 0:
|
386 |
+
timesteps = torch.from_numpy(np.arange(0, interm_steps)[::-1].copy()).to(self.device)
|
387 |
+
else:
|
388 |
+
timesteps = scheduler.timesteps.to(self.device)
|
389 |
+
return timesteps
|
390 |
+
|
391 |
+
def pack_norm_params(self, kwargs:dict, scale=True):
|
392 |
+
scale_factor = self.scale_factor if scale else 1.0
|
393 |
+
hum = torch.cat([kwargs.get('cent_hum')*scale_factor, kwargs.get('radius_hum')], -1)
|
394 |
+
obj = torch.cat([kwargs.get('cent_obj')*scale_factor, kwargs.get('radius_obj')], -1)
|
395 |
+
return torch.stack([hum, obj], 0) # (2, B, 4)
|
396 |
+
|
397 |
+
def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
|
398 |
+
"x_t: (2, B, D, N), x_t_input: (2, B, D, N)"
|
399 |
+
norm_parms = self.pack_norm_params(kwargs) # (2, B, 4)
|
400 |
+
B = x_t.shape[1]
|
401 |
+
# print(f"Step {t} Norm params:", norm_parms[:, 0, :])
|
402 |
+
noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input[0], x_t_input[1], t.reshape(1).expand(B),
|
403 |
+
norm_parms)
|
404 |
+
if self.consistent_center:
|
405 |
+
assert self.dm_pred_type != 'sample', 'incompatible dm predition type!'
|
406 |
+
noise_pred_h = noise_pred_h - torch.mean(noise_pred_h, dim=1, keepdim=True)
|
407 |
+
noise_pred_o = noise_pred_o - torch.mean(noise_pred_o, dim=1, keepdim=True)
|
408 |
+
|
409 |
+
xt_h = scheduler.step(noise_pred_h, t, x_t[0], **extra_step_kwargs).prev_sample
|
410 |
+
xt_o = scheduler.step(noise_pred_o, t, x_t[1], **extra_step_kwargs).prev_sample
|
411 |
+
|
412 |
+
if self.consistent_center:
|
413 |
+
xt_h = xt_h - torch.mean(xt_h, dim=1, keepdim=True)
|
414 |
+
xt_o = xt_o - torch.mean(xt_o, dim=1, keepdim=True)
|
415 |
+
|
416 |
+
return xt_h, xt_o
|
417 |
+
|
418 |
+
def denormalize_pclouds(self, x: Tensor, cent, radius, unscale: bool = True):
|
419 |
+
"""
|
420 |
+
first denormalize, then apply center and scale to original H+O coordinate
|
421 |
+
:param x:
|
422 |
+
:param cent: (B, 3)
|
423 |
+
:param radius: (B, 1)
|
424 |
+
:param unscale:
|
425 |
+
:return:
|
426 |
+
"""
|
427 |
+
# denormalize: scale down.
|
428 |
+
points = x[:, :, :3] / (self.scale_factor if unscale else 1)
|
429 |
+
# translation and scale back to H+O coordinate
|
430 |
+
points = points * 2 * radius.unsqueeze(-1) + cent.unsqueeze(1)
|
431 |
+
return points
|
432 |
+
|
433 |
+
def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False):
|
434 |
+
"""
|
435 |
+
take binary into account
|
436 |
+
:param self:
|
437 |
+
:param x: (B, N, 4)
|
438 |
+
:param denormalize:
|
439 |
+
:param unscale:
|
440 |
+
:return:
|
441 |
+
"""
|
442 |
+
points = x[:, :, :3] / (self.scale_factor if unscale else 1)
|
443 |
+
if self.predict_color:
|
444 |
+
colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:]
|
445 |
+
return Pointclouds(points=points, features=colors)
|
446 |
+
else:
|
447 |
+
assert x.shape[2] == 4
|
448 |
+
# add color to predicted binary labels
|
449 |
+
is_hum = x[:, :, 3] > 0.5
|
450 |
+
features = []
|
451 |
+
for mask in is_hum:
|
452 |
+
color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device)
|
453 |
+
color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green
|
454 |
+
features.append(color)
|
455 |
+
return Pointclouds(points=points, features=features)
|
456 |
+
|
457 |
+
|
model/model_utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from pytorch3d.structures import Pointclouds
|
6 |
+
|
7 |
+
|
8 |
+
def set_requires_grad(module: nn.Module, requires_grad: bool):
|
9 |
+
for p in module.parameters():
|
10 |
+
p.requires_grad_(requires_grad)
|
11 |
+
|
12 |
+
|
13 |
+
def compute_distance_transform(mask: torch.Tensor):
|
14 |
+
"""
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
mask (B, 1, H, W) or (B, 2, H, W) true for foreground
|
19 |
+
|
20 |
+
Returns
|
21 |
+
-------
|
22 |
+
the vector to the closest foreground pixel, zero if inside mask
|
23 |
+
|
24 |
+
"""
|
25 |
+
C = mask.shape[1]
|
26 |
+
assert C in [1, 2], f'invalid mask shape {mask.shape} found!'
|
27 |
+
|
28 |
+
image_size = mask.shape[-1]
|
29 |
+
|
30 |
+
dts = []
|
31 |
+
for i in range(C):
|
32 |
+
distance_transform = torch.stack([
|
33 |
+
torch.from_numpy(cv2.distanceTransform(
|
34 |
+
(1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3
|
35 |
+
) / (image_size / 2))
|
36 |
+
for m in mask[:, i:i+1].squeeze(1).detach().cpu().numpy().astype(np.uint8)
|
37 |
+
]).unsqueeze(1).clip(0, 1).to(mask.device)
|
38 |
+
dts.append(distance_transform)
|
39 |
+
return torch.cat(dts, 1)
|
40 |
+
|
41 |
+
|
42 |
+
def default(x, d):
|
43 |
+
return d if x is None else x
|
44 |
+
|
45 |
+
|
46 |
+
def get_num_points(x: Pointclouds, /):
|
47 |
+
return x.points_padded().shape[1]
|
48 |
+
|
49 |
+
|
50 |
+
def get_custom_betas(beta_start: float, beta_end: float, warmup_frac: float = 0.3, num_train_timesteps: int = 1000):
|
51 |
+
"""Custom beta schedule"""
|
52 |
+
betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
53 |
+
warmup_frac = 0.3
|
54 |
+
warmup_time = int(num_train_timesteps * warmup_frac)
|
55 |
+
warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
56 |
+
warmup_time = min(warmup_time, num_train_timesteps)
|
57 |
+
betas[:warmup_time] = warmup_steps[:warmup_time]
|
58 |
+
return betas
|
model/point_cloud_model.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import nullcontext
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
5 |
+
from diffusers import ModelMixin
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from .pvcnn.pvcnn import PVCNN2
|
9 |
+
from .pvcnn.pvcnn_plus_plus import PVCNN2PlusPlus
|
10 |
+
from .simple.simple_model import SimplePointModel
|
11 |
+
|
12 |
+
|
13 |
+
class PointCloudModel(ModelMixin, ConfigMixin):
|
14 |
+
@register_to_config
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
model_type: str = 'pvcnn',
|
18 |
+
in_channels: int = 3,
|
19 |
+
out_channels: int = 3,
|
20 |
+
embed_dim: int = 64,
|
21 |
+
dropout: float = 0.1,
|
22 |
+
width_multiplier: int = 1,
|
23 |
+
voxel_resolution_multiplier: int = 1,
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.model_type = model_type
|
27 |
+
if self.model_type == 'pvcnn':
|
28 |
+
self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
|
29 |
+
self.model = PVCNN2(
|
30 |
+
embed_dim=embed_dim,
|
31 |
+
num_classes=out_channels,
|
32 |
+
extra_feature_channels=(in_channels - 3),
|
33 |
+
dropout=dropout, width_multiplier=width_multiplier,
|
34 |
+
voxel_resolution_multiplier=voxel_resolution_multiplier
|
35 |
+
)
|
36 |
+
self.model.classifier[-1].bias.data.normal_(0, 1e-6)
|
37 |
+
self.model.classifier[-1].weight.data.normal_(0, 1e-6)
|
38 |
+
elif self.model_type == 'pvcnnplusplus':
|
39 |
+
self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
|
40 |
+
self.model = PVCNN2PlusPlus(
|
41 |
+
embed_dim=embed_dim,
|
42 |
+
num_classes=out_channels,
|
43 |
+
extra_feature_channels=(in_channels - 3),
|
44 |
+
)
|
45 |
+
self.model.output_projection[-1].bias.data.normal_(0, 1e-6)
|
46 |
+
self.model.output_projection[-1].weight.data.normal_(0, 1e-6)
|
47 |
+
elif self.model_type == 'simple':
|
48 |
+
self.autocast_context = nullcontext()
|
49 |
+
self.model = SimplePointModel(
|
50 |
+
embed_dim=embed_dim,
|
51 |
+
num_classes=out_channels,
|
52 |
+
extra_feature_channels=(in_channels - 3),
|
53 |
+
)
|
54 |
+
self.model.output_projection.bias.data.normal_(0, 1e-6)
|
55 |
+
self.model.output_projection.weight.data.normal_(0, 1e-6)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError()
|
58 |
+
|
59 |
+
def forward(self, inputs: Tensor, t: Tensor, ret_feats=False) -> Tensor:
|
60 |
+
""" Receives input of shape (B, N, in_channels) and returns output
|
61 |
+
of shape (B, N, out_channels) """
|
62 |
+
with self.autocast_context:
|
63 |
+
if not ret_feats:
|
64 |
+
return self.model(inputs.transpose(1, 2), t, ret_feats=False).transpose(1, 2)
|
65 |
+
else:
|
66 |
+
pred, feats = self.model(inputs.transpose(1, 2), t, ret_feats=True)
|
67 |
+
return pred.transpose(1, 2), feats
|
model/point_cloud_transformer_model.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
6 |
+
from diffusers import ModelMixin
|
7 |
+
from torch import Tensor
|
8 |
+
from timm.models.vision_transformer import Attention, LayerScale, DropPath, Mlp
|
9 |
+
|
10 |
+
from .point_cloud_model import PointCloudModel
|
11 |
+
|
12 |
+
|
13 |
+
class PointCloudModelBlock(nn.Module):
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
*,
|
18 |
+
# Point cloud model
|
19 |
+
dim: int,
|
20 |
+
model_type: str = 'pvcnn',
|
21 |
+
dropout: float = 0.1,
|
22 |
+
width_multiplier: int = 1,
|
23 |
+
voxel_resolution_multiplier: int = 1,
|
24 |
+
# Transformer model
|
25 |
+
num_heads=6, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
|
26 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_attn=False
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
# Point cloud model
|
31 |
+
self.norm0 = norm_layer(dim)
|
32 |
+
self.point_cloud_model = PointCloudModel(model_type=model_type,
|
33 |
+
in_channels=dim, out_channels=dim, embed_dim=dim, dropout=dropout,
|
34 |
+
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier)
|
35 |
+
self.ls0 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
36 |
+
self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
37 |
+
|
38 |
+
# Attention
|
39 |
+
self.use_attn = use_attn
|
40 |
+
if self.use_attn:
|
41 |
+
self.norm1 = norm_layer(dim)
|
42 |
+
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
43 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
44 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
45 |
+
|
46 |
+
# MLP
|
47 |
+
self.norm2 = norm_layer(dim)
|
48 |
+
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
49 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
50 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
51 |
+
|
52 |
+
def apply_point_cloud_model(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor:
|
53 |
+
t = t if t is not None else torch.zeros(len(x), device=x.device, dtype=torch.long)
|
54 |
+
return self.point_cloud_model(x, t)
|
55 |
+
|
56 |
+
def forward(self, x: Tensor):
|
57 |
+
x = x + self.drop_path0(self.ls0(self.apply_point_cloud_model(self.norm0(x))))
|
58 |
+
if self.use_attn:
|
59 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
60 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class PointCloudTransformerModel(ModelMixin, ConfigMixin):
|
65 |
+
@register_to_config
|
66 |
+
def __init__(self, num_layers: int, in_channels: int = 3, out_channels: int = 3, embed_dim: int = 64, **kwargs):
|
67 |
+
super().__init__()
|
68 |
+
self.num_layers = num_layers
|
69 |
+
self.input_projection = nn.Linear(in_channels, embed_dim)
|
70 |
+
self.blocks = nn.Sequential(*[PointCloudModelBlock(dim=embed_dim, **kwargs) for i in range(self.num_layers)])
|
71 |
+
self.norm = nn.LayerNorm(embed_dim)
|
72 |
+
self.output_projection = nn.Linear(embed_dim, out_channels)
|
73 |
+
|
74 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
75 |
+
""" Receives input of shape (B, N, in_channels) and returns output
|
76 |
+
of shape (B, N, out_channels) """
|
77 |
+
x = self.input_projection(inputs)
|
78 |
+
x = self.blocks(x)
|
79 |
+
x = self.output_projection(x)
|
80 |
+
return x
|
model/projection_model.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler
|
5 |
+
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
6 |
+
from diffusers import ModelMixin
|
7 |
+
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
|
8 |
+
from pytorch3d.renderer import PointsRasterizationSettings, PointsRasterizer
|
9 |
+
from pytorch3d.renderer.cameras import CamerasBase
|
10 |
+
from pytorch3d.structures import Pointclouds
|
11 |
+
from torch import Tensor
|
12 |
+
|
13 |
+
from .feature_model import FeatureModel
|
14 |
+
from .model_utils import compute_distance_transform
|
15 |
+
|
16 |
+
SchedulerClass = Union[DDPMScheduler, DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
17 |
+
|
18 |
+
|
19 |
+
class PointCloudProjectionModel(ModelMixin):
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
image_size: int,
|
24 |
+
image_feature_model: str,
|
25 |
+
use_local_colors: bool = True,
|
26 |
+
use_local_features: bool = True,
|
27 |
+
use_global_features: bool = False,
|
28 |
+
use_mask: bool = True,
|
29 |
+
use_distance_transform: bool = True,
|
30 |
+
predict_shape: bool = True,
|
31 |
+
predict_color: bool = False,
|
32 |
+
process_color: bool = False,
|
33 |
+
image_color_channels: int = 3, # for the input image, not the points
|
34 |
+
color_channels: int = 3, # for the points, not the input image
|
35 |
+
colors_mean: float = 0.5,
|
36 |
+
colors_std: float = 0.5,
|
37 |
+
scale_factor: float = 1.0,
|
38 |
+
# Rasterization settings
|
39 |
+
raster_point_radius: float = 0.0075, # point size
|
40 |
+
raster_points_per_pixel: int = 1, # a single point per pixel, for now
|
41 |
+
bin_size: int = 0,
|
42 |
+
model_name=None,
|
43 |
+
# additional arguments added by XH
|
44 |
+
load_sample_init=False,
|
45 |
+
sample_init_scale=1.0,
|
46 |
+
test_init_with_gtpc=False,
|
47 |
+
consistent_center=False, # from https://arxiv.org/pdf/2308.07837.pdf
|
48 |
+
voxel_resolution_multiplier: int=1,
|
49 |
+
predict_binary: bool=False, # predict a binary class label
|
50 |
+
lw_binary: float=1.0,
|
51 |
+
binary_training_noise_std: float=0.1,
|
52 |
+
dm_pred_type: str='epsilon', # diffusion prediction type
|
53 |
+
self_conditioning=False,
|
54 |
+
**kwargs,
|
55 |
+
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.image_size = image_size
|
59 |
+
self.scale_factor = scale_factor
|
60 |
+
self.use_local_colors = use_local_colors
|
61 |
+
self.use_local_features = use_local_features
|
62 |
+
self.use_global_features = use_global_features
|
63 |
+
self.use_mask = use_mask
|
64 |
+
self.use_distance_transform = use_distance_transform
|
65 |
+
self.predict_shape = predict_shape # default False
|
66 |
+
self.predict_color = predict_color # default True
|
67 |
+
self.process_color = process_color
|
68 |
+
self.image_color_channels = image_color_channels
|
69 |
+
self.color_channels = color_channels
|
70 |
+
self.colors_mean = colors_mean
|
71 |
+
self.colors_std = colors_std
|
72 |
+
self.model_name = model_name
|
73 |
+
print("PointCloud Model scale factor:", self.scale_factor, 'Model name:', self.model_name)
|
74 |
+
self.predict_binary = predict_binary
|
75 |
+
self.lw_binary = lw_binary
|
76 |
+
self.self_conditioning = self_conditioning
|
77 |
+
|
78 |
+
# Types of conditioning that are used
|
79 |
+
self.use_local_conditioning = self.use_local_colors or self.use_local_features or self.use_mask
|
80 |
+
self.use_global_conditioning = self.use_global_features
|
81 |
+
self.kwargs = kwargs
|
82 |
+
|
83 |
+
# Create feature model
|
84 |
+
self.feature_model = FeatureModel(image_size, image_feature_model)
|
85 |
+
|
86 |
+
# Input size
|
87 |
+
self.in_channels = 3 # 3 for 3D point positions
|
88 |
+
if self.use_local_colors: # whether color should be an input
|
89 |
+
self.in_channels += self.image_color_channels
|
90 |
+
if self.use_local_features:
|
91 |
+
self.in_channels += self.feature_model.feature_dim
|
92 |
+
if self.use_global_features:
|
93 |
+
self.in_channels += self.feature_model.feature_dim
|
94 |
+
if self.use_mask:
|
95 |
+
self.in_channels += 2 if self.use_distance_transform else 1
|
96 |
+
if self.process_color:
|
97 |
+
self.in_channels += self.color_channels # point color added to input or not, default False
|
98 |
+
if self.self_conditioning:
|
99 |
+
self.in_channels += 3 # add self conditioning
|
100 |
+
|
101 |
+
self.in_channels = self.add_extra_input_chennels(self.in_channels)
|
102 |
+
|
103 |
+
if self.model_name in ['pc2-diff-ho-sepsegm', 'diff-ho-attn']:
|
104 |
+
self.in_channels += 2 if self.use_distance_transform else 1
|
105 |
+
|
106 |
+
# Output size
|
107 |
+
self.out_channels = 0
|
108 |
+
if self.predict_shape:
|
109 |
+
self.out_channels += 3
|
110 |
+
if self.predict_color:
|
111 |
+
self.out_channels += self.color_channels
|
112 |
+
if self.predict_binary:
|
113 |
+
print("Output binary classification score!")
|
114 |
+
self.out_channels += 1
|
115 |
+
|
116 |
+
# Save rasterization settings
|
117 |
+
self.raster_settings = PointsRasterizationSettings(
|
118 |
+
image_size=(image_size, image_size),
|
119 |
+
radius=raster_point_radius,
|
120 |
+
points_per_pixel=raster_points_per_pixel,
|
121 |
+
bin_size=bin_size,
|
122 |
+
)
|
123 |
+
|
124 |
+
def add_extra_input_chennels(self, input_channels):
|
125 |
+
return input_channels
|
126 |
+
|
127 |
+
def denormalize(self, x: Tensor, /, clamp: bool = True):
|
128 |
+
x = x * self.colors_std + self.colors_mean
|
129 |
+
return torch.clamp(x, 0, 1) if clamp else x
|
130 |
+
|
131 |
+
def normalize(self, x: Tensor, /):
|
132 |
+
x = (x - self.colors_mean) / self.colors_std
|
133 |
+
return x
|
134 |
+
|
135 |
+
def get_global_conditioning(self, image_rgb: Tensor):
|
136 |
+
global_conditioning = []
|
137 |
+
if self.use_global_features:
|
138 |
+
global_conditioning.append(self.feature_model(image_rgb,
|
139 |
+
return_cls_token_only=True)) # (B, D)
|
140 |
+
global_conditioning = torch.cat(global_conditioning, dim=1) # (B, D_cond)
|
141 |
+
return global_conditioning
|
142 |
+
|
143 |
+
def get_local_conditioning(self, image_rgb: Tensor, mask: Tensor):
|
144 |
+
"""
|
145 |
+
compute per-point conditioning
|
146 |
+
Parameters
|
147 |
+
----------
|
148 |
+
image_rgb: (B, 3, 224, 224), values normalized to 0-1, background is masked by the given mask
|
149 |
+
mask: (B, 1, 224, 224), or (B, 2, 224, 224) for h+o
|
150 |
+
"""
|
151 |
+
local_conditioning = []
|
152 |
+
# import pdb; pdb.set_trace()
|
153 |
+
|
154 |
+
if self.use_local_colors: # XH: default True
|
155 |
+
local_conditioning.append(self.normalize(image_rgb))
|
156 |
+
if self.use_local_features: # XH: default True
|
157 |
+
local_conditioning.append(self.feature_model(image_rgb)) # I guess no mask here? feature model: 'vit_small_patch16_224_mae'
|
158 |
+
if self.use_mask: # default True
|
159 |
+
local_conditioning.append(mask.float())
|
160 |
+
if self.use_distance_transform: # default True
|
161 |
+
if not self.use_mask:
|
162 |
+
raise ValueError('No mask for distance transform?')
|
163 |
+
if mask.is_floating_point():
|
164 |
+
mask = mask > 0.5
|
165 |
+
local_conditioning.append(compute_distance_transform(mask))
|
166 |
+
local_conditioning = torch.cat(local_conditioning, dim=1) # (B, D_cond, H, W)
|
167 |
+
return local_conditioning
|
168 |
+
|
169 |
+
@torch.autocast('cuda', dtype=torch.float32)
|
170 |
+
def surface_projection(
|
171 |
+
self, points: Tensor, camera: CamerasBase, local_features: Tensor,
|
172 |
+
):
|
173 |
+
B, C, H, W, device = *local_features.shape, local_features.device
|
174 |
+
R = self.raster_settings.points_per_pixel
|
175 |
+
N = points.shape[1]
|
176 |
+
|
177 |
+
# Scale camera by scaling T. ASSUMES CAMERA IS LOOKING AT ORIGIN!
|
178 |
+
camera = camera.clone()
|
179 |
+
camera.T = camera.T * self.scale_factor
|
180 |
+
|
181 |
+
# Create rasterizer
|
182 |
+
rasterizer = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings)
|
183 |
+
|
184 |
+
# Associate points with features via rasterization
|
185 |
+
fragments = rasterizer(Pointclouds(points)) # (B, H, W, R)
|
186 |
+
fragments_idx: Tensor = fragments.idx.long()
|
187 |
+
visible_pixels = (fragments_idx > -1) # (B, H, W, R)
|
188 |
+
points_to_visible_pixels = fragments_idx[visible_pixels]
|
189 |
+
|
190 |
+
# Reshape local features to (B, H, W, R, C)
|
191 |
+
local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C)
|
192 |
+
|
193 |
+
# Get local features corresponding to visible points
|
194 |
+
local_features_proj = torch.zeros(B * N, C, device=device)
|
195 |
+
# local feature includes: raw RGB color, image features, mask, distance transform
|
196 |
+
local_features_proj[points_to_visible_pixels] = local_features[visible_pixels]
|
197 |
+
local_features_proj = local_features_proj.reshape(B, N, C)
|
198 |
+
|
199 |
+
return local_features_proj
|
200 |
+
|
201 |
+
def point_cloud_to_tensor(self, pc: Pointclouds, /, normalize: bool = False, scale: bool = False):
|
202 |
+
"""Converts a point cloud to a tensor, with color if and only if self.predict_color"""
|
203 |
+
points = pc.points_padded() * (self.scale_factor if scale else 1)
|
204 |
+
if self.predict_color and pc.features_padded() is not None: # normalize color, not point locations
|
205 |
+
colors = self.normalize(pc.features_padded()) if normalize else pc.features_padded()
|
206 |
+
return torch.cat((points, colors), dim=2)
|
207 |
+
else:
|
208 |
+
return points
|
209 |
+
|
210 |
+
def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False):
|
211 |
+
points = x[:, :, :3] / (self.scale_factor if unscale else 1)
|
212 |
+
if self.predict_color:
|
213 |
+
colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:]
|
214 |
+
return Pointclouds(points=points, features=colors)
|
215 |
+
else:
|
216 |
+
assert x.shape[2] == 3
|
217 |
+
return Pointclouds(points=points)
|
218 |
+
|
219 |
+
def get_input_with_conditioning(
|
220 |
+
self,
|
221 |
+
x_t: Tensor,
|
222 |
+
camera: Optional[CamerasBase],
|
223 |
+
image_rgb: Optional[Tensor],
|
224 |
+
mask: Optional[Tensor],
|
225 |
+
t: Optional[Tensor],
|
226 |
+
):
|
227 |
+
""" Extracts local features from the input image and projects them onto the points
|
228 |
+
in the point cloud to obtain the input to the model. Then extracts global
|
229 |
+
features, replicates them across points, and concats them to the input.
|
230 |
+
image_rgb: masked background
|
231 |
+
XH: why there is no positional encoding as described by the supp??
|
232 |
+
"""
|
233 |
+
B, N = x_t.shape[:2]
|
234 |
+
|
235 |
+
# Initial input is the point locations (and colors if and only if predicting color)
|
236 |
+
x_t_input = self.get_coord_feature(x_t)
|
237 |
+
|
238 |
+
# Local conditioning
|
239 |
+
if self.use_local_conditioning:
|
240 |
+
|
241 |
+
# Get local features and check that they are the same size as the input image
|
242 |
+
local_features = self.get_local_conditioning(image_rgb=image_rgb, mask=mask) # concatenate RGB + mask + RGB feature + distance transform
|
243 |
+
if local_features.shape[-2:] != image_rgb.shape[-2:]:
|
244 |
+
raise ValueError(f'{local_features.shape=} and {image_rgb.shape=}')
|
245 |
+
|
246 |
+
# Project local features. Here that we only need the point locations, not colors
|
247 |
+
local_features_proj = self.surface_projection(points=x_t[:, :, :3],
|
248 |
+
camera=camera, local_features=local_features) # (B, N, D_local)
|
249 |
+
|
250 |
+
x_t_input.append(local_features_proj)
|
251 |
+
|
252 |
+
# Global conditioning
|
253 |
+
if self.use_global_conditioning: # False
|
254 |
+
|
255 |
+
# Get and repeat global features
|
256 |
+
global_features = self.get_global_conditioning(image_rgb=image_rgb) # (B, D_global)
|
257 |
+
global_features = global_features.unsqueeze(1).expand(-1, N, -1) # (B, D_global, N)
|
258 |
+
|
259 |
+
x_t_input.append(global_features)
|
260 |
+
|
261 |
+
# Concatenate together all the pointwise features
|
262 |
+
x_t_input = torch.cat(x_t_input, dim=2) # (B, N, D)
|
263 |
+
|
264 |
+
return x_t_input
|
265 |
+
|
266 |
+
def get_coord_feature(self, x_t):
|
267 |
+
"""get coordinate feature, for model that uses separate model to predict binary, we use first 3 channels only"""
|
268 |
+
x_t_input = [x_t]
|
269 |
+
return x_t_input
|
270 |
+
|
271 |
+
def forward(self, batch: FrameData, mode: str = 'train', **kwargs):
|
272 |
+
""" The forward method may be defined differently for different models. """
|
273 |
+
raise NotImplementedError()
|
model/pvcnn/__init__.py
ADDED
File without changes
|
model/pvcnn/modules/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ball_query import BallQuery, BallQueryHO
|
2 |
+
from .frustum import FrustumPointNetLoss
|
3 |
+
from .loss import KLLoss
|
4 |
+
from .pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule
|
5 |
+
from .pvconv import PVConv, Attention, Swish, PVConvReLU
|
6 |
+
from .se import SE3d
|
7 |
+
from .shared_mlp import SharedMLP
|
8 |
+
from .voxelization import Voxelization
|
model/pvcnn/modules/ball_query.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from . import functional as F
|
5 |
+
|
6 |
+
__all__ = ['BallQuery']
|
7 |
+
|
8 |
+
|
9 |
+
class BallQuery(nn.Module):
|
10 |
+
def __init__(self, radius, num_neighbors, include_coordinates=True):
|
11 |
+
super().__init__()
|
12 |
+
self.radius = radius
|
13 |
+
self.num_neighbors = num_neighbors
|
14 |
+
self.include_coordinates = include_coordinates
|
15 |
+
|
16 |
+
def forward(self, points_coords, centers_coords, temb, points_features=None):
|
17 |
+
points_coords = points_coords.contiguous()
|
18 |
+
centers_coords = centers_coords.contiguous()
|
19 |
+
neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
|
20 |
+
neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
|
21 |
+
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
|
22 |
+
|
23 |
+
if points_features is None:
|
24 |
+
assert self.include_coordinates, 'No Features For Grouping'
|
25 |
+
neighbor_features = neighbor_coordinates
|
26 |
+
else:
|
27 |
+
neighbor_features = F.grouping(points_features, neighbor_indices) # return [B, C, M, U] C=feat dim, M=# centers, U=# neighbours
|
28 |
+
if self.include_coordinates:
|
29 |
+
neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
|
30 |
+
return neighbor_features, F.grouping(temb, neighbor_indices)
|
31 |
+
|
32 |
+
def extra_repr(self):
|
33 |
+
return 'radius={}, num_neighbors={}{}'.format(
|
34 |
+
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
|
35 |
+
|
36 |
+
|
37 |
+
class BallQueryHO(nn.Module):
|
38 |
+
"no point feature, but only relative and abs coordinate"
|
39 |
+
def __init__(self, radius, num_neighbors, include_relative=False):
|
40 |
+
super().__init__()
|
41 |
+
self.radius = radius
|
42 |
+
self.num_neighbors = num_neighbors
|
43 |
+
self.include_relative = include_relative
|
44 |
+
|
45 |
+
def forward(self, points_coords, centers_coords, points_features=None):
|
46 |
+
"""
|
47 |
+
if not enough points inside the given radius, the entries will be zero
|
48 |
+
if too many points inside the radius, the order is random??? (not sure)
|
49 |
+
:param points_coords: (B, 3, N)
|
50 |
+
:param centers_coords: (B, 3, M)
|
51 |
+
:param points_features: None
|
52 |
+
:return:
|
53 |
+
"""
|
54 |
+
points_coords = points_coords.contiguous()
|
55 |
+
centers_coords = centers_coords.contiguous()
|
56 |
+
neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
|
57 |
+
neighbor_coordinates = F.grouping(points_coords, neighbor_indices) # (B, 3, M, U)
|
58 |
+
if self.include_relative:
|
59 |
+
neighbor_coordinates_rela = neighbor_coordinates - centers_coords.unsqueeze(-1)
|
60 |
+
neighbor_coordinates = torch.cat([neighbor_coordinates, neighbor_coordinates_rela], 1) # (B, 6, M, U)
|
61 |
+
# flatten the coordinate
|
62 |
+
neighbor_coordinates = neighbor_coordinates.permute(0, 1, 3, 2) # (B, 3/6, U, M)
|
63 |
+
neighbor_coordinates = torch.flatten(neighbor_coordinates, 1, 2) # (B, 3*U, M)
|
64 |
+
return neighbor_coordinates
|
65 |
+
|
66 |
+
def extra_repr(self):
|
67 |
+
return 'radius={}, num_neighbors={}{}'.format(
|
68 |
+
self.radius, self.num_neighbors, ', include relative' if self.include_relative else '')
|
69 |
+
|
model/pvcnn/modules/frustum.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from . import functional as F
|
7 |
+
|
8 |
+
__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
|
9 |
+
|
10 |
+
|
11 |
+
class FrustumPointNetLoss(nn.Module):
|
12 |
+
def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
|
13 |
+
corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
|
14 |
+
super().__init__()
|
15 |
+
self.box_loss_weight = box_loss_weight
|
16 |
+
self.corners_loss_weight = corners_loss_weight
|
17 |
+
self.heading_residual_loss_weight = heading_residual_loss_weight
|
18 |
+
self.size_residual_loss_weight = size_residual_loss_weight
|
19 |
+
|
20 |
+
self.num_heading_angle_bins = num_heading_angle_bins
|
21 |
+
self.num_size_templates = num_size_templates
|
22 |
+
self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
|
23 |
+
self.register_buffer(
|
24 |
+
'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, inputs, targets):
|
28 |
+
mask_logits = inputs['mask_logits'] # (B, 2, N)
|
29 |
+
center_reg = inputs['center_reg'] # (B, 3)
|
30 |
+
center = inputs['center'] # (B, 3)
|
31 |
+
heading_scores = inputs['heading_scores'] # (B, NH)
|
32 |
+
heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
|
33 |
+
heading_residuals = inputs['heading_residuals'] # (B, NH)
|
34 |
+
size_scores = inputs['size_scores'] # (B, NS)
|
35 |
+
size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
|
36 |
+
size_residuals = inputs['size_residuals'] # (B, NS, 3)
|
37 |
+
|
38 |
+
mask_logits_target = targets['mask_logits'] # (B, N)
|
39 |
+
center_target = targets['center'] # (B, 3)
|
40 |
+
heading_bin_id_target = targets['heading_bin_id'] # (B, )
|
41 |
+
heading_residual_target = targets['heading_residual'] # (B, )
|
42 |
+
size_template_id_target = targets['size_template_id'] # (B, )
|
43 |
+
size_residual_target = targets['size_residual'] # (B, 3)
|
44 |
+
|
45 |
+
batch_size = center.size(0)
|
46 |
+
batch_id = torch.arange(batch_size, device=center.device)
|
47 |
+
|
48 |
+
# Basic Classification and Regression losses
|
49 |
+
mask_loss = F.cross_entropy(mask_logits, mask_logits_target)
|
50 |
+
heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target)
|
51 |
+
size_loss = F.cross_entropy(size_scores, size_template_id_target)
|
52 |
+
center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0)
|
53 |
+
center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0)
|
54 |
+
|
55 |
+
# Refinement losses for size/heading
|
56 |
+
heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, )
|
57 |
+
heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins)
|
58 |
+
heading_residual_normalized_loss = PF.huber_loss(
|
59 |
+
heading_residuals_normalized - heading_residual_normalized_target, delta=1.0
|
60 |
+
)
|
61 |
+
size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3)
|
62 |
+
size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target]
|
63 |
+
size_residual_normalized_loss = PF.huber_loss(
|
64 |
+
torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0
|
65 |
+
)
|
66 |
+
|
67 |
+
# Bounding box losses
|
68 |
+
heading = (heading_residuals[batch_id, heading_bin_id_target]
|
69 |
+
+ self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
|
70 |
+
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
|
71 |
+
size = (size_residuals[batch_id, size_template_id_target]
|
72 |
+
+ self.size_templates[size_template_id_target]) # (B, 3)
|
73 |
+
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
|
74 |
+
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
|
75 |
+
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
|
76 |
+
corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target,
|
77 |
+
sizes=size_target, with_flip=True) # (B, 3, 8)
|
78 |
+
corners_loss = PF.huber_loss(torch.min(
|
79 |
+
torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
|
80 |
+
), delta=1.0)
|
81 |
+
# Summing up
|
82 |
+
loss = mask_loss + self.box_loss_weight * (
|
83 |
+
center_loss + center_reg_loss + heading_loss + size_loss
|
84 |
+
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
|
85 |
+
+ self.size_residual_loss_weight * size_residual_normalized_loss
|
86 |
+
+ self.corners_loss_weight * corners_loss
|
87 |
+
)
|
88 |
+
|
89 |
+
return loss
|
90 |
+
|
91 |
+
|
92 |
+
def get_box_corners_3d(centers, headings, sizes, with_flip=False):
|
93 |
+
"""
|
94 |
+
:param centers: coords of box centers, FloatTensor[N, 3]
|
95 |
+
:param headings: heading angles, FloatTensor[N, ]
|
96 |
+
:param sizes: box sizes, FloatTensor[N, 3]
|
97 |
+
:param with_flip: bool, whether to return flipped box (headings + np.pi)
|
98 |
+
:return:
|
99 |
+
coords of box corners, FloatTensor[N, 3, 8]
|
100 |
+
NOTE: corner points are in counter clockwise order, e.g.,
|
101 |
+
2--1
|
102 |
+
3--0 5
|
103 |
+
7--4
|
104 |
+
"""
|
105 |
+
l = sizes[:, 0] # (N,)
|
106 |
+
w = sizes[:, 1] # (N,)
|
107 |
+
h = sizes[:, 2] # (N,)
|
108 |
+
x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8)
|
109 |
+
y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8)
|
110 |
+
z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8)
|
111 |
+
|
112 |
+
c = torch.cos(headings) # (N,)
|
113 |
+
s = torch.sin(headings) # (N,)
|
114 |
+
o = torch.ones_like(headings) # (N,)
|
115 |
+
z = torch.zeros_like(headings) # (N,)
|
116 |
+
|
117 |
+
centers = centers.unsqueeze(-1) # (B, 3, 1)
|
118 |
+
corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
|
119 |
+
R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3)
|
120 |
+
if with_flip:
|
121 |
+
R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3)
|
122 |
+
return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers
|
123 |
+
else:
|
124 |
+
return torch.matmul(R, corners) + centers
|
125 |
+
|
126 |
+
# centers = centers.unsqueeze(1) # (B, 1, 3)
|
127 |
+
# corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3)
|
128 |
+
# RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
129 |
+
# if with_flip:
|
130 |
+
# RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
131 |
+
# return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3)
|
132 |
+
# else:
|
133 |
+
# return torch.matmul(corners, RT) + centers # (N, 8, 3)
|
134 |
+
|
135 |
+
# corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
|
136 |
+
# R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
137 |
+
# corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8)
|
138 |
+
# corners = corners.transpose(1, 2) # (N, 8, 3)
|
model/pvcnn/modules/functional/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ball_query import ball_query
|
2 |
+
from .devoxelization import trilinear_devoxelize
|
3 |
+
from .grouping import grouping
|
4 |
+
from .interpolatation import nearest_neighbor_interpolate
|
5 |
+
from .loss import kl_loss, huber_loss
|
6 |
+
from .sampling import gather, furthest_point_sample, logits_mask
|
7 |
+
from .voxelization import avg_voxelize
|
model/pvcnn/modules/functional/backend.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from torch.utils.cpp_extension import load
|
5 |
+
|
6 |
+
|
7 |
+
gcc_path = os.getenv('CC', default='/usr/bin/gcc')
|
8 |
+
if not Path(gcc_path).is_file():
|
9 |
+
raise ValueError('Could not find your gcc, please replace it here.')
|
10 |
+
|
11 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
_backend = load(
|
13 |
+
name='_pvcnn_backend',
|
14 |
+
extra_cflags=['-O3', '-std=c++17'],
|
15 |
+
extra_cuda_cflags=[f'--compiler-bindir={gcc_path}'],
|
16 |
+
sources=[os.path.join(_src_path,'src', f) for f in [
|
17 |
+
'ball_query/ball_query.cpp',
|
18 |
+
'ball_query/ball_query.cu',
|
19 |
+
'grouping/grouping.cpp',
|
20 |
+
'grouping/grouping.cu',
|
21 |
+
'interpolate/neighbor_interpolate.cpp',
|
22 |
+
'interpolate/neighbor_interpolate.cu',
|
23 |
+
'interpolate/trilinear_devox.cpp',
|
24 |
+
'interpolate/trilinear_devox.cu',
|
25 |
+
'sampling/sampling.cpp',
|
26 |
+
'sampling/sampling.cu',
|
27 |
+
'voxelization/vox.cpp',
|
28 |
+
'voxelization/vox.cu',
|
29 |
+
'bindings.cpp',
|
30 |
+
]]
|
31 |
+
)
|
32 |
+
|
33 |
+
__all__ = ['_backend']
|
model/pvcnn/modules/functional/ball_query.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.autograd import Function
|
2 |
+
|
3 |
+
from .backend import _backend
|
4 |
+
|
5 |
+
__all__ = ['ball_query']
|
6 |
+
|
7 |
+
|
8 |
+
def ball_query(centers_coords, points_coords, radius, num_neighbors):
|
9 |
+
"""
|
10 |
+
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
|
11 |
+
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
|
12 |
+
:param radius: float, radius of ball query
|
13 |
+
:param num_neighbors: int, maximum number of neighbors
|
14 |
+
:return:
|
15 |
+
neighbor_indices: indices of neighbors, IntTensor[B, M, U]
|
16 |
+
"""
|
17 |
+
centers_coords = centers_coords.contiguous()
|
18 |
+
points_coords = points_coords.contiguous()
|
19 |
+
return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)
|
model/pvcnn/modules/functional/devoxelization.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.autograd import Function
|
2 |
+
|
3 |
+
from .backend import _backend
|
4 |
+
|
5 |
+
__all__ = ['trilinear_devoxelize']
|
6 |
+
|
7 |
+
|
8 |
+
class TrilinearDevoxelization(Function):
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, features, coords, resolution, is_training=True):
|
11 |
+
"""
|
12 |
+
:param ctx:
|
13 |
+
:param coords: the coordinates of points, FloatTensor[B, 3, N]
|
14 |
+
:param features: FloatTensor[B, C, R, R, R]
|
15 |
+
:param resolution: int, the voxel resolution
|
16 |
+
:param is_training: bool, training mode
|
17 |
+
:return:
|
18 |
+
FloatTensor[B, C, N]
|
19 |
+
"""
|
20 |
+
B, C = features.shape[:2]
|
21 |
+
features = features.contiguous().view(B, C, -1)
|
22 |
+
coords = coords.contiguous()
|
23 |
+
outs, inds, wgts = _backend.trilinear_devoxelize_forward(resolution, is_training, coords, features)
|
24 |
+
if is_training:
|
25 |
+
ctx.save_for_backward(inds, wgts)
|
26 |
+
ctx.r = resolution
|
27 |
+
return outs
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def backward(ctx, grad_output):
|
31 |
+
"""
|
32 |
+
:param ctx:
|
33 |
+
:param grad_output: gradient of outputs, FloatTensor[B, C, N]
|
34 |
+
:return:
|
35 |
+
gradient of inputs, FloatTensor[B, C, R, R, R]
|
36 |
+
"""
|
37 |
+
inds, wgts = ctx.saved_tensors
|
38 |
+
grad_inputs = _backend.trilinear_devoxelize_backward(grad_output.contiguous(), inds, wgts, ctx.r)
|
39 |
+
return grad_inputs.view(grad_output.size(0), grad_output.size(1), ctx.r, ctx.r, ctx.r), None, None, None
|
40 |
+
|
41 |
+
|
42 |
+
trilinear_devoxelize = TrilinearDevoxelization.apply
|
model/pvcnn/modules/functional/grouping.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.autograd import Function
|
2 |
+
|
3 |
+
from .backend import _backend
|
4 |
+
|
5 |
+
__all__ = ['grouping']
|
6 |
+
|
7 |
+
|
8 |
+
class Grouping(Function):
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, features, indices):
|
11 |
+
"""
|
12 |
+
:param ctx:
|
13 |
+
:param features: features of points, FloatTensor[B, C, N]
|
14 |
+
:param indices: neighbor indices of centers, IntTensor[B, M, U], M is #centers, U is #neighbors
|
15 |
+
:return:
|
16 |
+
grouped_features: grouped features, FloatTensor[B, C, M, U]
|
17 |
+
"""
|
18 |
+
features = features.contiguous()
|
19 |
+
indices = indices.contiguous()
|
20 |
+
ctx.save_for_backward(indices)
|
21 |
+
ctx.num_points = features.size(-1)
|
22 |
+
# print(features.dtype, features.shape)
|
23 |
+
return _backend.grouping_forward(features, indices)
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def backward(ctx, grad_output):
|
27 |
+
indices, = ctx.saved_tensors
|
28 |
+
grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points)
|
29 |
+
return grad_features, None
|
30 |
+
|
31 |
+
|
32 |
+
grouping = Grouping.apply
|
model/pvcnn/modules/functional/interpolatation.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.autograd import Function
|
2 |
+
|
3 |
+
from .backend import _backend
|
4 |
+
|
5 |
+
__all__ = ['nearest_neighbor_interpolate']
|
6 |
+
|
7 |
+
|
8 |
+
class NeighborInterpolation(Function):
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, points_coords, centers_coords, centers_features):
|
11 |
+
"""
|
12 |
+
:param ctx:
|
13 |
+
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
|
14 |
+
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
|
15 |
+
:param centers_features: features of centers, FloatTensor[B, C, M]
|
16 |
+
:return:
|
17 |
+
points_features: features of points, FloatTensor[B, C, N]
|
18 |
+
"""
|
19 |
+
centers_coords = centers_coords.contiguous()
|
20 |
+
points_coords = points_coords.contiguous()
|
21 |
+
centers_features = centers_features.contiguous()
|
22 |
+
points_features, indices, weights = _backend.three_nearest_neighbors_interpolate_forward(
|
23 |
+
points_coords, centers_coords, centers_features
|
24 |
+
)
|
25 |
+
ctx.save_for_backward(indices, weights)
|
26 |
+
ctx.num_centers = centers_coords.size(-1)
|
27 |
+
return points_features
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def backward(ctx, grad_output):
|
31 |
+
indices, weights = ctx.saved_tensors
|
32 |
+
grad_centers_features = _backend.three_nearest_neighbors_interpolate_backward(
|
33 |
+
grad_output.contiguous(), indices, weights, ctx.num_centers
|
34 |
+
)
|
35 |
+
return None, None, grad_centers_features
|
36 |
+
|
37 |
+
|
38 |
+
nearest_neighbor_interpolate = NeighborInterpolation.apply
|
model/pvcnn/modules/functional/loss.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
__all__ = ['kl_loss', 'huber_loss']
|
5 |
+
|
6 |
+
|
7 |
+
def kl_loss(x, y):
|
8 |
+
x = F.softmax(x.detach(), dim=1)
|
9 |
+
y = F.log_softmax(y, dim=1)
|
10 |
+
return torch.mean(torch.sum(x * (torch.log(x) - y), dim=1))
|
11 |
+
|
12 |
+
|
13 |
+
def huber_loss(error, delta):
|
14 |
+
abs_error = torch.abs(error)
|
15 |
+
quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta))
|
16 |
+
losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic)
|
17 |
+
return torch.mean(losses)
|
model/pvcnn/modules/functional/sampling.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.autograd import Function
|
4 |
+
|
5 |
+
from .backend import _backend
|
6 |
+
|
7 |
+
__all__ = ['gather', 'furthest_point_sample', 'logits_mask']
|
8 |
+
|
9 |
+
|
10 |
+
class Gather(Function):
|
11 |
+
@staticmethod
|
12 |
+
def forward(ctx, features, indices):
|
13 |
+
"""
|
14 |
+
Gather
|
15 |
+
:param ctx:
|
16 |
+
:param features: features of points, FloatTensor[B, C, N]
|
17 |
+
:param indices: centers' indices in points, IntTensor[b, m]
|
18 |
+
:return:
|
19 |
+
centers_coords: coordinates of sampled centers, FloatTensor[B, C, M]
|
20 |
+
"""
|
21 |
+
features = features.contiguous()
|
22 |
+
indices = indices.int().contiguous()
|
23 |
+
ctx.save_for_backward(indices)
|
24 |
+
ctx.num_points = features.size(-1)
|
25 |
+
return _backend.gather_features_forward(features, indices)
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def backward(ctx, grad_output):
|
29 |
+
indices, = ctx.saved_tensors
|
30 |
+
grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points)
|
31 |
+
return grad_features, None
|
32 |
+
|
33 |
+
|
34 |
+
gather = Gather.apply
|
35 |
+
|
36 |
+
|
37 |
+
def furthest_point_sample(coords, num_samples):
|
38 |
+
"""
|
39 |
+
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
40 |
+
minimum distance to the sampled point set
|
41 |
+
:param coords: coordinates of points, FloatTensor[B, 3, N]
|
42 |
+
:param num_samples: int, M
|
43 |
+
:return:
|
44 |
+
centers_coords: coordinates of sampled centers, FloatTensor[B, 3, M]
|
45 |
+
"""
|
46 |
+
coords = coords.contiguous()
|
47 |
+
indices = _backend.furthest_point_sampling(coords, num_samples)
|
48 |
+
return gather(coords, indices)
|
49 |
+
|
50 |
+
|
51 |
+
def logits_mask(coords, logits, num_points_per_object):
|
52 |
+
"""
|
53 |
+
Use logits to sample points
|
54 |
+
:param coords: coords of points, FloatTensor[B, 3, N]
|
55 |
+
:param logits: binary classification logits, FloatTensor[B, 2, N]
|
56 |
+
:param num_points_per_object: M, #points per object after masking, int
|
57 |
+
:return:
|
58 |
+
selected_coords: FloatTensor[B, 3, M]
|
59 |
+
masked_coords_mean: mean coords of selected points, FloatTensor[B, 3]
|
60 |
+
mask: mask to select points, BoolTensor[B, N]
|
61 |
+
"""
|
62 |
+
batch_size, _, num_points = coords.shape
|
63 |
+
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
|
64 |
+
num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
|
65 |
+
masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
|
66 |
+
masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates,
|
67 |
+
torch.ones_like(num_candidates)).float() # [B, C]
|
68 |
+
selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32)
|
69 |
+
for i in range(batch_size):
|
70 |
+
current_mask = mask[i] # [N]
|
71 |
+
current_candidates = current_mask.nonzero().view(-1)
|
72 |
+
current_num_candidates = current_candidates.numel()
|
73 |
+
if current_num_candidates >= num_points_per_object:
|
74 |
+
choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False)
|
75 |
+
selected_indices[i] = current_candidates[choices]
|
76 |
+
elif current_num_candidates > 0:
|
77 |
+
choices = np.concatenate([
|
78 |
+
np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
|
79 |
+
np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False)
|
80 |
+
])
|
81 |
+
np.random.shuffle(choices)
|
82 |
+
selected_indices[i] = current_candidates[choices]
|
83 |
+
selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices)
|
84 |
+
return selected_coords, masked_coords_mean, mask
|
model/pvcnn/modules/functional/src/ball_query/ball_query.cpp
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "ball_query.hpp"
|
2 |
+
#include "ball_query.cuh"
|
3 |
+
|
4 |
+
#include "../utils.hpp"
|
5 |
+
|
6 |
+
at::Tensor ball_query_forward(at::Tensor centers_coords,
|
7 |
+
at::Tensor points_coords, const float radius,
|
8 |
+
const int num_neighbors) {
|
9 |
+
CHECK_CUDA(centers_coords);
|
10 |
+
CHECK_CUDA(points_coords);
|
11 |
+
CHECK_CONTIGUOUS(centers_coords);
|
12 |
+
CHECK_CONTIGUOUS(points_coords);
|
13 |
+
CHECK_IS_FLOAT(centers_coords);
|
14 |
+
CHECK_IS_FLOAT(points_coords);
|
15 |
+
|
16 |
+
int b = centers_coords.size(0);
|
17 |
+
int m = centers_coords.size(2);
|
18 |
+
int n = points_coords.size(2);
|
19 |
+
|
20 |
+
at::Tensor neighbors_indices = torch::zeros(
|
21 |
+
{b, m, num_neighbors},
|
22 |
+
at::device(centers_coords.device()).dtype(at::ScalarType::Int));
|
23 |
+
|
24 |
+
ball_query(b, n, m, radius * radius, num_neighbors,
|
25 |
+
centers_coords.data_ptr<float>(),
|
26 |
+
points_coords.data_ptr<float>(),
|
27 |
+
neighbors_indices.data_ptr<int>());
|
28 |
+
|
29 |
+
return neighbors_indices;
|
30 |
+
}
|
model/pvcnn/modules/functional/src/ball_query/ball_query.cu
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <math.h>
|
2 |
+
#include <stdio.h>
|
3 |
+
#include <stdlib.h>
|
4 |
+
|
5 |
+
#include "../cuda_utils.cuh"
|
6 |
+
|
7 |
+
/*
|
8 |
+
Function: ball query
|
9 |
+
Args:
|
10 |
+
b : batch size
|
11 |
+
n : number of points in point clouds
|
12 |
+
m : number of query centers
|
13 |
+
r2 : ball query radius ** 2
|
14 |
+
u : maximum number of neighbors
|
15 |
+
centers_coords: coordinates of centers, FloatTensor[b, 3, m]
|
16 |
+
points_coords : coordinates of points, FloatTensor[b, 3, n]
|
17 |
+
neighbors_indices : neighbor indices in points, IntTensor[b, m, u]
|
18 |
+
*/
|
19 |
+
__global__ void ball_query_kernel(int b, int n, int m, float r2, int u,
|
20 |
+
const float *__restrict__ centers_coords,
|
21 |
+
const float *__restrict__ points_coords,
|
22 |
+
int *__restrict__ neighbors_indices) {
|
23 |
+
int batch_index = blockIdx.x;
|
24 |
+
int index = threadIdx.x;
|
25 |
+
int stride = blockDim.x;
|
26 |
+
points_coords += batch_index * n * 3;
|
27 |
+
centers_coords += batch_index * m * 3;
|
28 |
+
neighbors_indices += batch_index * m * u;
|
29 |
+
|
30 |
+
for (int j = index; j < m; j += stride) {
|
31 |
+
float center_x = centers_coords[j];
|
32 |
+
float center_y = centers_coords[j + m];
|
33 |
+
float center_z = centers_coords[j + m + m];
|
34 |
+
for (int k = 0, cnt = 0; k < n && cnt < u; ++k) {
|
35 |
+
float dx = center_x - points_coords[k];
|
36 |
+
float dy = center_y - points_coords[k + n];
|
37 |
+
float dz = center_z - points_coords[k + n + n];
|
38 |
+
float d2 = dx * dx + dy * dy + dz * dz;
|
39 |
+
if (d2 < r2) {
|
40 |
+
if (cnt == 0) {
|
41 |
+
for (int v = 0; v < u; ++v) {
|
42 |
+
neighbors_indices[j * u + v] = k;
|
43 |
+
}
|
44 |
+
}
|
45 |
+
neighbors_indices[j * u + cnt] = k;
|
46 |
+
++cnt;
|
47 |
+
}
|
48 |
+
}
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
void ball_query(int b, int n, int m, float r2, int u,
|
53 |
+
const float *centers_coords, const float *points_coords,
|
54 |
+
int *neighbors_indices) {
|
55 |
+
ball_query_kernel<<<b, optimal_num_threads(m), 0,
|
56 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
57 |
+
b, n, m, r2, u, centers_coords, points_coords, neighbors_indices);
|
58 |
+
CUDA_CHECK_ERRORS();
|
59 |
+
}
|
model/pvcnn/modules/functional/src/ball_query/ball_query.cuh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _BALL_QUERY_CUH
|
2 |
+
#define _BALL_QUERY_CUH
|
3 |
+
|
4 |
+
void ball_query(int b, int n, int m, float r2, int u,
|
5 |
+
const float *centers_coords, const float *points_coords,
|
6 |
+
int *neighbors_indices);
|
7 |
+
|
8 |
+
#endif
|
model/pvcnn/modules/functional/src/ball_query/ball_query.hpp
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _BALL_QUERY_HPP
|
2 |
+
#define _BALL_QUERY_HPP
|
3 |
+
|
4 |
+
#include <torch/extension.h>
|
5 |
+
|
6 |
+
at::Tensor ball_query_forward(at::Tensor centers_coords,
|
7 |
+
at::Tensor points_coords, const float radius,
|
8 |
+
const int num_neighbors);
|
9 |
+
|
10 |
+
#endif
|
model/pvcnn/modules/functional/src/bindings.cpp
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <pybind11/pybind11.h>
|
2 |
+
|
3 |
+
#include "ball_query/ball_query.hpp"
|
4 |
+
#include "grouping/grouping.hpp"
|
5 |
+
#include "interpolate/neighbor_interpolate.hpp"
|
6 |
+
#include "interpolate/trilinear_devox.hpp"
|
7 |
+
#include "sampling/sampling.hpp"
|
8 |
+
#include "voxelization/vox.hpp"
|
9 |
+
|
10 |
+
PYBIND11_MODULE(_pvcnn_backend, m) {
|
11 |
+
m.def("gather_features_forward", &gather_features_forward,
|
12 |
+
"Gather Centers' Features forward (CUDA)");
|
13 |
+
m.def("gather_features_backward", &gather_features_backward,
|
14 |
+
"Gather Centers' Features backward (CUDA)");
|
15 |
+
m.def("furthest_point_sampling", &furthest_point_sampling_forward,
|
16 |
+
"Furthest Point Sampling (CUDA)");
|
17 |
+
m.def("ball_query", &ball_query_forward, "Ball Query (CUDA)");
|
18 |
+
m.def("grouping_forward", &grouping_forward,
|
19 |
+
"Grouping Features forward (CUDA)");
|
20 |
+
m.def("grouping_backward", &grouping_backward,
|
21 |
+
"Grouping Features backward (CUDA)");
|
22 |
+
m.def("three_nearest_neighbors_interpolate_forward",
|
23 |
+
&three_nearest_neighbors_interpolate_forward,
|
24 |
+
"3 Nearest Neighbors Interpolate forward (CUDA)");
|
25 |
+
m.def("three_nearest_neighbors_interpolate_backward",
|
26 |
+
&three_nearest_neighbors_interpolate_backward,
|
27 |
+
"3 Nearest Neighbors Interpolate backward (CUDA)");
|
28 |
+
|
29 |
+
m.def("trilinear_devoxelize_forward", &trilinear_devoxelize_forward,
|
30 |
+
"Trilinear Devoxelization forward (CUDA)");
|
31 |
+
m.def("trilinear_devoxelize_backward", &trilinear_devoxelize_backward,
|
32 |
+
"Trilinear Devoxelization backward (CUDA)");
|
33 |
+
m.def("avg_voxelize_forward", &avg_voxelize_forward,
|
34 |
+
"Voxelization forward with average pooling (CUDA)");
|
35 |
+
m.def("avg_voxelize_backward", &avg_voxelize_backward,
|
36 |
+
"Voxelization backward (CUDA)");
|
37 |
+
}
|
model/pvcnn/modules/functional/src/cuda_utils.cuh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _CUDA_UTILS_H
|
2 |
+
#define _CUDA_UTILS_H
|
3 |
+
|
4 |
+
#include <ATen/ATen.h>
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <cmath>
|
7 |
+
|
8 |
+
#include <cuda.h>
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
|
11 |
+
#include <vector>
|
12 |
+
|
13 |
+
#define MAXIMUM_THREADS 512
|
14 |
+
|
15 |
+
inline int optimal_num_threads(int work_size) {
|
16 |
+
const int pow_2 = std::log2(static_cast<double>(work_size));
|
17 |
+
return max(min(1 << pow_2, MAXIMUM_THREADS), 1);
|
18 |
+
}
|
19 |
+
|
20 |
+
inline dim3 optimal_block_config(int x, int y) {
|
21 |
+
const int x_threads = optimal_num_threads(x);
|
22 |
+
const int y_threads =
|
23 |
+
max(min(optimal_num_threads(y), MAXIMUM_THREADS / x_threads), 1);
|
24 |
+
dim3 block_config(x_threads, y_threads, 1);
|
25 |
+
return block_config;
|
26 |
+
}
|
27 |
+
|
28 |
+
#define CUDA_CHECK_ERRORS() \
|
29 |
+
{ \
|
30 |
+
cudaError_t err = cudaGetLastError(); \
|
31 |
+
if (cudaSuccess != err) { \
|
32 |
+
fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
|
33 |
+
cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
|
34 |
+
__FILE__); \
|
35 |
+
exit(-1); \
|
36 |
+
} \
|
37 |
+
}
|
38 |
+
|
39 |
+
#endif
|
model/pvcnn/modules/functional/src/grouping/grouping.cpp
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "grouping.hpp"
|
2 |
+
#include "grouping.cuh"
|
3 |
+
|
4 |
+
#include "../utils.hpp"
|
5 |
+
|
6 |
+
at::Tensor grouping_forward(at::Tensor features, at::Tensor indices) {
|
7 |
+
CHECK_CUDA(features);
|
8 |
+
CHECK_CUDA(indices);
|
9 |
+
CHECK_CONTIGUOUS(features);
|
10 |
+
CHECK_CONTIGUOUS(indices);
|
11 |
+
CHECK_IS_FLOAT(features);
|
12 |
+
CHECK_IS_INT(indices);
|
13 |
+
|
14 |
+
int b = features.size(0);
|
15 |
+
int c = features.size(1);
|
16 |
+
int n = features.size(2);
|
17 |
+
int m = indices.size(1);
|
18 |
+
int u = indices.size(2);
|
19 |
+
at::Tensor output = torch::zeros(
|
20 |
+
{b, c, m, u}, at::device(features.device()).dtype(at::ScalarType::Float));
|
21 |
+
grouping(b, c, n, m, u, features.data_ptr<float>(), indices.data_ptr<int>(),
|
22 |
+
output.data_ptr<float>());
|
23 |
+
return output;
|
24 |
+
}
|
25 |
+
|
26 |
+
at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
|
27 |
+
const int n) {
|
28 |
+
CHECK_CUDA(grad_y);
|
29 |
+
CHECK_CUDA(indices);
|
30 |
+
CHECK_CONTIGUOUS(grad_y);
|
31 |
+
CHECK_CONTIGUOUS(indices);
|
32 |
+
CHECK_IS_FLOAT(grad_y);
|
33 |
+
CHECK_IS_INT(indices);
|
34 |
+
|
35 |
+
int b = grad_y.size(0);
|
36 |
+
int c = grad_y.size(1);
|
37 |
+
int m = indices.size(1);
|
38 |
+
int u = indices.size(2);
|
39 |
+
at::Tensor grad_x = torch::zeros(
|
40 |
+
{b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
|
41 |
+
grouping_grad(b, c, n, m, u, grad_y.data_ptr<float>(),
|
42 |
+
indices.data_ptr<int>(), grad_x.data_ptr<float>());
|
43 |
+
return grad_x;
|
44 |
+
}
|
model/pvcnn/modules/functional/src/grouping/grouping.cu
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <stdlib.h>
|
3 |
+
|
4 |
+
#include "../cuda_utils.cuh"
|
5 |
+
|
6 |
+
/*
|
7 |
+
Function: grouping features of neighbors (forward)
|
8 |
+
Args:
|
9 |
+
b : batch size
|
10 |
+
c : #channles of features
|
11 |
+
n : number of points in point clouds
|
12 |
+
m : number of query centers
|
13 |
+
u : maximum number of neighbors
|
14 |
+
features: points' features, FloatTensor[b, c, n]
|
15 |
+
indices : neighbor indices in points, IntTensor[b, m, u]
|
16 |
+
out : gathered features, FloatTensor[b, c, m, u]
|
17 |
+
*/
|
18 |
+
__global__ void grouping_kernel(int b, int c, int n, int m, int u,
|
19 |
+
const float *__restrict__ features,
|
20 |
+
const int *__restrict__ indices,
|
21 |
+
float *__restrict__ out) {
|
22 |
+
int batch_index = blockIdx.x;
|
23 |
+
features += batch_index * n * c;
|
24 |
+
indices += batch_index * m * u;
|
25 |
+
out += batch_index * m * u * c;
|
26 |
+
|
27 |
+
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
28 |
+
const int stride = blockDim.y * blockDim.x;
|
29 |
+
for (int i = index; i < c * m; i += stride) {
|
30 |
+
const int l = i / m;
|
31 |
+
const int j = i % m;
|
32 |
+
for (int k = 0; k < u; ++k) {
|
33 |
+
out[(l * m + j) * u + k] = features[l * n + indices[j * u + k]];
|
34 |
+
}
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
38 |
+
void grouping(int b, int c, int n, int m, int u, const float *features,
|
39 |
+
const int *indices, float *out) {
|
40 |
+
grouping_kernel<<<b, optimal_block_config(m, c), 0,
|
41 |
+
at::cuda::getCurrentCUDAStream()>>>(b, c, n, m, u, features,
|
42 |
+
indices, out);
|
43 |
+
CUDA_CHECK_ERRORS();
|
44 |
+
}
|
45 |
+
|
46 |
+
/*
|
47 |
+
Function: grouping features of neighbors (backward)
|
48 |
+
Args:
|
49 |
+
b : batch size
|
50 |
+
c : #channles of features
|
51 |
+
n : number of points in point clouds
|
52 |
+
m : number of query centers
|
53 |
+
u : maximum number of neighbors
|
54 |
+
grad_y : grad of gathered features, FloatTensor[b, c, m, u]
|
55 |
+
indices : neighbor indices in points, IntTensor[b, m, u]
|
56 |
+
grad_x: grad of points' features, FloatTensor[b, c, n]
|
57 |
+
*/
|
58 |
+
__global__ void grouping_grad_kernel(int b, int c, int n, int m, int u,
|
59 |
+
const float *__restrict__ grad_y,
|
60 |
+
const int *__restrict__ indices,
|
61 |
+
float *__restrict__ grad_x) {
|
62 |
+
int batch_index = blockIdx.x;
|
63 |
+
grad_y += batch_index * m * u * c;
|
64 |
+
indices += batch_index * m * u;
|
65 |
+
grad_x += batch_index * n * c;
|
66 |
+
|
67 |
+
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
68 |
+
const int stride = blockDim.y * blockDim.x;
|
69 |
+
for (int i = index; i < c * m; i += stride) {
|
70 |
+
const int l = i / m;
|
71 |
+
const int j = i % m;
|
72 |
+
for (int k = 0; k < u; ++k) {
|
73 |
+
atomicAdd(grad_x + l * n + indices[j * u + k],
|
74 |
+
grad_y[(l * m + j) * u + k]);
|
75 |
+
}
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
+
void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
|
80 |
+
const int *indices, float *grad_x) {
|
81 |
+
grouping_grad_kernel<<<b, optimal_block_config(m, c), 0,
|
82 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
83 |
+
b, c, n, m, u, grad_y, indices, grad_x);
|
84 |
+
CUDA_CHECK_ERRORS();
|
85 |
+
}
|
model/pvcnn/modules/functional/src/grouping/grouping.cuh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _GROUPING_CUH
|
2 |
+
#define _GROUPING_CUH
|
3 |
+
|
4 |
+
void grouping(int b, int c, int n, int m, int u, const float *features,
|
5 |
+
const int *indices, float *out);
|
6 |
+
void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
|
7 |
+
const int *indices, float *grad_x);
|
8 |
+
|
9 |
+
#endif
|
model/pvcnn/modules/functional/src/grouping/grouping.hpp
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _GROUPING_HPP
|
2 |
+
#define _GROUPING_HPP
|
3 |
+
|
4 |
+
#include <torch/extension.h>
|
5 |
+
|
6 |
+
at::Tensor grouping_forward(at::Tensor features, at::Tensor indices);
|
7 |
+
at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
|
8 |
+
const int n);
|
9 |
+
|
10 |
+
#endif
|
model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cpp
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "neighbor_interpolate.hpp"
|
2 |
+
#include "neighbor_interpolate.cuh"
|
3 |
+
|
4 |
+
#include "../utils.hpp"
|
5 |
+
|
6 |
+
std::vector<at::Tensor>
|
7 |
+
three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
|
8 |
+
at::Tensor centers_coords,
|
9 |
+
at::Tensor centers_features) {
|
10 |
+
CHECK_CUDA(points_coords);
|
11 |
+
CHECK_CUDA(centers_coords);
|
12 |
+
CHECK_CUDA(centers_features);
|
13 |
+
CHECK_CONTIGUOUS(points_coords);
|
14 |
+
CHECK_CONTIGUOUS(centers_coords);
|
15 |
+
CHECK_CONTIGUOUS(centers_features);
|
16 |
+
CHECK_IS_FLOAT(points_coords);
|
17 |
+
CHECK_IS_FLOAT(centers_coords);
|
18 |
+
CHECK_IS_FLOAT(centers_features);
|
19 |
+
|
20 |
+
int b = centers_features.size(0);
|
21 |
+
int c = centers_features.size(1);
|
22 |
+
int m = centers_features.size(2);
|
23 |
+
int n = points_coords.size(2);
|
24 |
+
|
25 |
+
at::Tensor indices = torch::zeros(
|
26 |
+
{b, 3, n}, at::device(points_coords.device()).dtype(at::ScalarType::Int));
|
27 |
+
at::Tensor weights = torch::zeros(
|
28 |
+
{b, 3, n},
|
29 |
+
at::device(points_coords.device()).dtype(at::ScalarType::Float));
|
30 |
+
at::Tensor output = torch::zeros(
|
31 |
+
{b, c, n},
|
32 |
+
at::device(centers_features.device()).dtype(at::ScalarType::Float));
|
33 |
+
|
34 |
+
three_nearest_neighbors_interpolate(
|
35 |
+
b, c, m, n, points_coords.data_ptr<float>(),
|
36 |
+
centers_coords.data_ptr<float>(), centers_features.data_ptr<float>(),
|
37 |
+
indices.data_ptr<int>(), weights.data_ptr<float>(),
|
38 |
+
output.data_ptr<float>());
|
39 |
+
return {output, indices, weights};
|
40 |
+
}
|
41 |
+
|
42 |
+
at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
|
43 |
+
at::Tensor indices,
|
44 |
+
at::Tensor weights,
|
45 |
+
const int m) {
|
46 |
+
CHECK_CUDA(grad_y);
|
47 |
+
CHECK_CUDA(indices);
|
48 |
+
CHECK_CUDA(weights);
|
49 |
+
CHECK_CONTIGUOUS(grad_y);
|
50 |
+
CHECK_CONTIGUOUS(indices);
|
51 |
+
CHECK_CONTIGUOUS(weights);
|
52 |
+
CHECK_IS_FLOAT(grad_y);
|
53 |
+
CHECK_IS_INT(indices);
|
54 |
+
CHECK_IS_FLOAT(weights);
|
55 |
+
|
56 |
+
int b = grad_y.size(0);
|
57 |
+
int c = grad_y.size(1);
|
58 |
+
int n = grad_y.size(2);
|
59 |
+
at::Tensor grad_x = torch::zeros(
|
60 |
+
{b, c, m}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
|
61 |
+
three_nearest_neighbors_interpolate_grad(
|
62 |
+
b, c, n, m, grad_y.data_ptr<float>(), indices.data_ptr<int>(),
|
63 |
+
weights.data_ptr<float>(), grad_x.data_ptr<float>());
|
64 |
+
return grad_x;
|
65 |
+
}
|
model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cu
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <math.h>
|
2 |
+
#include <stdio.h>
|
3 |
+
#include <stdlib.h>
|
4 |
+
|
5 |
+
#include "../cuda_utils.cuh"
|
6 |
+
|
7 |
+
/*
|
8 |
+
Function: three nearest neighbors
|
9 |
+
Args:
|
10 |
+
b : batch size
|
11 |
+
n : number of points in point clouds
|
12 |
+
m : number of query centers
|
13 |
+
points_coords : coordinates of points, FloatTensor[b, 3, n]
|
14 |
+
centers_coords: coordinates of centers, FloatTensor[b, 3, m]
|
15 |
+
weights : weights of nearest 3 centers to the point,
|
16 |
+
FloatTensor[b, 3, n]
|
17 |
+
indices : indices of nearest 3 centers to the point,
|
18 |
+
IntTensor[b, 3, n]
|
19 |
+
*/
|
20 |
+
__global__ void three_nearest_neighbors_kernel(
|
21 |
+
int b, int n, int m, const float *__restrict__ points_coords,
|
22 |
+
const float *__restrict__ centers_coords, float *__restrict__ weights,
|
23 |
+
int *__restrict__ indices) {
|
24 |
+
int batch_index = blockIdx.x;
|
25 |
+
int index = threadIdx.x;
|
26 |
+
int stride = blockDim.x;
|
27 |
+
points_coords += batch_index * 3 * n;
|
28 |
+
weights += batch_index * 3 * n;
|
29 |
+
indices += batch_index * 3 * n;
|
30 |
+
centers_coords += batch_index * 3 * m;
|
31 |
+
|
32 |
+
for (int j = index; j < n; j += stride) {
|
33 |
+
float ux = points_coords[j];
|
34 |
+
float uy = points_coords[j + n];
|
35 |
+
float uz = points_coords[j + n + n];
|
36 |
+
|
37 |
+
double best0 = 1e40, best1 = 1e40, best2 = 1e40;
|
38 |
+
int besti0 = 0, besti1 = 0, besti2 = 0;
|
39 |
+
for (int k = 0; k < m; ++k) {
|
40 |
+
float x = centers_coords[k];
|
41 |
+
float y = centers_coords[k + m];
|
42 |
+
float z = centers_coords[k + m + m];
|
43 |
+
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
|
44 |
+
if (d < best2) {
|
45 |
+
best2 = d;
|
46 |
+
besti2 = k;
|
47 |
+
if (d < best1) {
|
48 |
+
best2 = best1;
|
49 |
+
besti2 = besti1;
|
50 |
+
best1 = d;
|
51 |
+
besti1 = k;
|
52 |
+
if (d < best0) {
|
53 |
+
best1 = best0;
|
54 |
+
besti1 = besti0;
|
55 |
+
best0 = d;
|
56 |
+
besti0 = k;
|
57 |
+
}
|
58 |
+
}
|
59 |
+
}
|
60 |
+
}
|
61 |
+
best0 = max(min(1e10f, best0), 1e-10f);
|
62 |
+
best1 = max(min(1e10f, best1), 1e-10f);
|
63 |
+
best2 = max(min(1e10f, best2), 1e-10f);
|
64 |
+
float d0d1 = best0 * best1;
|
65 |
+
float d0d2 = best0 * best2;
|
66 |
+
float d1d2 = best1 * best2;
|
67 |
+
float d0d1d2 = 1.0f / (d0d1 + d0d2 + d1d2);
|
68 |
+
weights[j] = d1d2 * d0d1d2;
|
69 |
+
indices[j] = besti0;
|
70 |
+
weights[j + n] = d0d2 * d0d1d2;
|
71 |
+
indices[j + n] = besti1;
|
72 |
+
weights[j + n + n] = d0d1 * d0d1d2;
|
73 |
+
indices[j + n + n] = besti2;
|
74 |
+
}
|
75 |
+
}
|
76 |
+
|
77 |
+
/*
|
78 |
+
Function: interpolate three nearest neighbors (forward)
|
79 |
+
Args:
|
80 |
+
b : batch size
|
81 |
+
c : #channels of features
|
82 |
+
m : number of query centers
|
83 |
+
n : number of points in point clouds
|
84 |
+
centers_features: features of centers, FloatTensor[b, c, m]
|
85 |
+
indices : indices of nearest 3 centers to the point,
|
86 |
+
IntTensor[b, 3, n]
|
87 |
+
weights : weights for interpolation, FloatTensor[b, 3, n]
|
88 |
+
out : features of points, FloatTensor[b, c, n]
|
89 |
+
*/
|
90 |
+
__global__ void three_nearest_neighbors_interpolate_kernel(
|
91 |
+
int b, int c, int m, int n, const float *__restrict__ centers_features,
|
92 |
+
const int *__restrict__ indices, const float *__restrict__ weights,
|
93 |
+
float *__restrict__ out) {
|
94 |
+
int batch_index = blockIdx.x;
|
95 |
+
centers_features += batch_index * m * c;
|
96 |
+
indices += batch_index * n * 3;
|
97 |
+
weights += batch_index * n * 3;
|
98 |
+
out += batch_index * n * c;
|
99 |
+
|
100 |
+
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
101 |
+
const int stride = blockDim.y * blockDim.x;
|
102 |
+
for (int i = index; i < c * n; i += stride) {
|
103 |
+
const int l = i / n;
|
104 |
+
const int j = i % n;
|
105 |
+
float w1 = weights[j];
|
106 |
+
float w2 = weights[j + n];
|
107 |
+
float w3 = weights[j + n + n];
|
108 |
+
int i1 = indices[j];
|
109 |
+
int i2 = indices[j + n];
|
110 |
+
int i3 = indices[j + n + n];
|
111 |
+
|
112 |
+
out[i] = centers_features[l * m + i1] * w1 +
|
113 |
+
centers_features[l * m + i2] * w2 +
|
114 |
+
centers_features[l * m + i3] * w3;
|
115 |
+
}
|
116 |
+
}
|
117 |
+
|
118 |
+
void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
|
119 |
+
const float *points_coords,
|
120 |
+
const float *centers_coords,
|
121 |
+
const float *centers_features,
|
122 |
+
int *indices, float *weights,
|
123 |
+
float *out) {
|
124 |
+
three_nearest_neighbors_kernel<<<b, optimal_num_threads(n), 0,
|
125 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
126 |
+
b, n, m, points_coords, centers_coords, weights, indices);
|
127 |
+
three_nearest_neighbors_interpolate_kernel<<<
|
128 |
+
b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
|
129 |
+
b, c, m, n, centers_features, indices, weights, out);
|
130 |
+
CUDA_CHECK_ERRORS();
|
131 |
+
}
|
132 |
+
|
133 |
+
/*
|
134 |
+
Function: interpolate three nearest neighbors (backward)
|
135 |
+
Args:
|
136 |
+
b : batch size
|
137 |
+
c : #channels of features
|
138 |
+
m : number of query centers
|
139 |
+
n : number of points in point clouds
|
140 |
+
grad_y : grad of features of points, FloatTensor[b, c, n]
|
141 |
+
indices : indices of nearest 3 centers to the point, IntTensor[b, 3, n]
|
142 |
+
weights : weights for interpolation, FloatTensor[b, 3, n]
|
143 |
+
grad_x : grad of features of centers, FloatTensor[b, c, m]
|
144 |
+
*/
|
145 |
+
__global__ void three_nearest_neighbors_interpolate_grad_kernel(
|
146 |
+
int b, int c, int n, int m, const float *__restrict__ grad_y,
|
147 |
+
const int *__restrict__ indices, const float *__restrict__ weights,
|
148 |
+
float *__restrict__ grad_x) {
|
149 |
+
int batch_index = blockIdx.x;
|
150 |
+
grad_y += batch_index * n * c;
|
151 |
+
indices += batch_index * n * 3;
|
152 |
+
weights += batch_index * n * 3;
|
153 |
+
grad_x += batch_index * m * c;
|
154 |
+
|
155 |
+
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
156 |
+
const int stride = blockDim.y * blockDim.x;
|
157 |
+
for (int i = index; i < c * n; i += stride) {
|
158 |
+
const int l = i / n;
|
159 |
+
const int j = i % n;
|
160 |
+
float w1 = weights[j];
|
161 |
+
float w2 = weights[j + n];
|
162 |
+
float w3 = weights[j + n + n];
|
163 |
+
int i1 = indices[j];
|
164 |
+
int i2 = indices[j + n];
|
165 |
+
int i3 = indices[j + n + n];
|
166 |
+
atomicAdd(grad_x + l * m + i1, grad_y[i] * w1);
|
167 |
+
atomicAdd(grad_x + l * m + i2, grad_y[i] * w2);
|
168 |
+
atomicAdd(grad_x + l * m + i3, grad_y[i] * w3);
|
169 |
+
}
|
170 |
+
}
|
171 |
+
|
172 |
+
void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
|
173 |
+
const float *grad_y,
|
174 |
+
const int *indices,
|
175 |
+
const float *weights,
|
176 |
+
float *grad_x) {
|
177 |
+
three_nearest_neighbors_interpolate_grad_kernel<<<
|
178 |
+
b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
|
179 |
+
b, c, n, m, grad_y, indices, weights, grad_x);
|
180 |
+
CUDA_CHECK_ERRORS();
|
181 |
+
}
|
model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cuh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _NEIGHBOR_INTERPOLATE_CUH
|
2 |
+
#define _NEIGHBOR_INTERPOLATE_CUH
|
3 |
+
|
4 |
+
void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
|
5 |
+
const float *points_coords,
|
6 |
+
const float *centers_coords,
|
7 |
+
const float *centers_features,
|
8 |
+
int *indices, float *weights,
|
9 |
+
float *out);
|
10 |
+
void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
|
11 |
+
const float *grad_y,
|
12 |
+
const int *indices,
|
13 |
+
const float *weights,
|
14 |
+
float *grad_x);
|
15 |
+
|
16 |
+
#endif
|
model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.hpp
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _NEIGHBOR_INTERPOLATE_HPP
|
2 |
+
#define _NEIGHBOR_INTERPOLATE_HPP
|
3 |
+
|
4 |
+
#include <torch/extension.h>
|
5 |
+
#include <vector>
|
6 |
+
|
7 |
+
std::vector<at::Tensor>
|
8 |
+
three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
|
9 |
+
at::Tensor centers_coords,
|
10 |
+
at::Tensor centers_features);
|
11 |
+
at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
|
12 |
+
at::Tensor indices,
|
13 |
+
at::Tensor weights,
|
14 |
+
const int m);
|
15 |
+
|
16 |
+
#endif
|