Spaces:
Sleeping
Sleeping
xxie
commited on
Commit
•
531dfb5
1
Parent(s):
a6d435a
add ddim support
Browse files- app.py +15 -8
- configs/structured.py +1 -0
- demo.py +6 -1
- model/model_hoattn.py +20 -6
app.py
CHANGED
@@ -127,7 +127,7 @@ def plot_points(colors, coords):
|
|
127 |
return fig
|
128 |
|
129 |
|
130 |
-
def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed, input_cls):
|
131 |
"""
|
132 |
given user input, run inference
|
133 |
:param runner:
|
@@ -138,6 +138,7 @@ def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, s
|
|
138 |
:param std_coverage: float value, used to estimate camera translation
|
139 |
:param input_seed: random seed
|
140 |
:param input_cls: the object category of the input image
|
|
|
141 |
:return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
|
142 |
"""
|
143 |
log = ""
|
@@ -153,6 +154,8 @@ def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, s
|
|
153 |
log += f"Reloading fine-tuned checkpoint of category {input_cls}\n"
|
154 |
runner.reload_checkpoint(input_cls)
|
155 |
|
|
|
|
|
156 |
out_stage1, out_stage2 = runner.forward_batch(batch, cfg)
|
157 |
points = out_stage2.points_packed().cpu().numpy()
|
158 |
colors = out_stage2.features_packed().cpu().numpy()
|
@@ -204,6 +207,10 @@ def main(cfg: ProjectConfig):
|
|
204 |
'chair', 'skateboard', 'suitcase', 'table'],
|
205 |
value='general')
|
206 |
input_seed = gr.Number(label='Random seed', value=42)
|
|
|
|
|
|
|
|
|
207 |
# Output visualization
|
208 |
with gr.Row():
|
209 |
pc_plot = gr.Plot(label="Reconstructed point cloud")
|
@@ -217,20 +224,20 @@ def main(cfg: ProjectConfig):
|
|
217 |
with gr.Row():
|
218 |
button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
|
219 |
button_recon.click(fn=partial(inference, runner, cfg),
|
220 |
-
inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls],
|
221 |
outputs=[pc_plot, out_pc_download, out_log])
|
222 |
gr.HTML("""<br/>""")
|
223 |
# Example input
|
224 |
example_dir = cfg.run.code_dir_abs+"/examples"
|
225 |
rgb, ps, obj = 'k1_color.jpg', 'k1_person_mask.png', 'k1_obj_rend_mask.png'
|
226 |
example_images = gr.Examples([
|
227 |
-
[f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42, 'skateboard'],
|
228 |
-
[f"{example_dir}/205904/{rgb}", f"{example_dir}/205904/{ps}", f"{example_dir}/205904/{obj}", 3.2, 42, 'suitcase'],
|
229 |
-
[f"{example_dir}/066241/{rgb}", f"{example_dir}/066241/{ps}", f"{example_dir}/066241/{obj}", 3.5, 42, 'backpack'],
|
230 |
-
[f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42, 'chair'],
|
231 |
-
[f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42, 'chair'],
|
232 |
|
233 |
-
], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls],)
|
234 |
|
235 |
gr.Markdown(citation_str)
|
236 |
|
|
|
127 |
return fig
|
128 |
|
129 |
|
130 |
+
def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed, input_cls, input_scheduler):
|
131 |
"""
|
132 |
given user input, run inference
|
133 |
:param runner:
|
|
|
138 |
:param std_coverage: float value, used to estimate camera translation
|
139 |
:param input_seed: random seed
|
140 |
:param input_cls: the object category of the input image
|
141 |
+
:param input_scheduler: reverse sampling scheduler, ddim or ddpm
|
142 |
:return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
|
143 |
"""
|
144 |
log = ""
|
|
|
154 |
log += f"Reloading fine-tuned checkpoint of category {input_cls}\n"
|
155 |
runner.reload_checkpoint(input_cls)
|
156 |
|
157 |
+
cfg.run.diffusion_scheduler = input_scheduler
|
158 |
+
cfg.run.num_inference_steps = 1000 if input_scheduler == 'ddpm' else 100
|
159 |
out_stage1, out_stage2 = runner.forward_batch(batch, cfg)
|
160 |
points = out_stage2.points_packed().cpu().numpy()
|
161 |
colors = out_stage2.features_packed().cpu().numpy()
|
|
|
207 |
'chair', 'skateboard', 'suitcase', 'table'],
|
208 |
value='general')
|
209 |
input_seed = gr.Number(label='Random seed', value=42)
|
210 |
+
input_scheduler = gr.Dropdown(label='Diffusion scheduler',
|
211 |
+
info='Reverse diffusion scheduler: DDIM is 10x faster',
|
212 |
+
choices=['ddpm', 'ddim'],
|
213 |
+
value='ddim')
|
214 |
# Output visualization
|
215 |
with gr.Row():
|
216 |
pc_plot = gr.Plot(label="Reconstructed point cloud")
|
|
|
224 |
with gr.Row():
|
225 |
button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
|
226 |
button_recon.click(fn=partial(inference, runner, cfg),
|
227 |
+
inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls, input_scheduler],
|
228 |
outputs=[pc_plot, out_pc_download, out_log])
|
229 |
gr.HTML("""<br/>""")
|
230 |
# Example input
|
231 |
example_dir = cfg.run.code_dir_abs+"/examples"
|
232 |
rgb, ps, obj = 'k1_color.jpg', 'k1_person_mask.png', 'k1_obj_rend_mask.png'
|
233 |
example_images = gr.Examples([
|
234 |
+
[f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42, 'skateboard', 'ddim'],
|
235 |
+
[f"{example_dir}/205904/{rgb}", f"{example_dir}/205904/{ps}", f"{example_dir}/205904/{obj}", 3.2, 42, 'suitcase', 'ddim'],
|
236 |
+
[f"{example_dir}/066241/{rgb}", f"{example_dir}/066241/{ps}", f"{example_dir}/066241/{obj}", 3.5, 42, 'backpack', 'ddim'],
|
237 |
+
[f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42, 'chair', 'ddim'],
|
238 |
+
[f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42, 'chair', 'ddim'],
|
239 |
|
240 |
+
], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls, input_scheduler],)
|
241 |
|
242 |
gr.Markdown(citation_str)
|
243 |
|
configs/structured.py
CHANGED
@@ -127,6 +127,7 @@ class PointCloudDiffusionModelConfig(PointCloudProjectionModelConfig):
|
|
127 |
beta_end: float = 8e-3 # 0.012
|
128 |
beta_schedule: str = 'linear' # 'custom'
|
129 |
dm_pred_type: str = 'epsilon' # diffusion model prediction type, sample (x0) or noise
|
|
|
130 |
|
131 |
# Point cloud model arguments
|
132 |
point_cloud_model: str = 'pvcnn'
|
|
|
127 |
beta_end: float = 8e-3 # 0.012
|
128 |
beta_schedule: str = 'linear' # 'custom'
|
129 |
dm_pred_type: str = 'epsilon' # diffusion model prediction type, sample (x0) or noise
|
130 |
+
ddim_eta: float = 1.0 # DDIM eta parameter: 0 is the default one which does deterministic generation
|
131 |
|
132 |
# Point cloud model arguments
|
133 |
point_cloud_model: str = 'pvcnn'
|
demo.py
CHANGED
@@ -180,6 +180,7 @@ class DemoRunner:
|
|
180 |
mask=torch.stack(batch['masks']).to('cuda'),
|
181 |
scheduler=cfg.run.diffusion_scheduler,
|
182 |
num_inference_steps=cfg.run.num_inference_steps,
|
|
|
183 |
)
|
184 |
# segment and normalize human/object
|
185 |
bs = len(out_stage1)
|
@@ -254,7 +255,11 @@ class DemoRunner:
|
|
254 |
radius_hum=radius_hum.unsqueeze(-1),
|
255 |
radius_obj=radius_obj.unsqueeze(-1),
|
256 |
sample_from_interm=True,
|
257 |
-
noise_step=cfg.run.sample_noise_step
|
|
|
|
|
|
|
|
|
258 |
return out_stage1, out_stage2
|
259 |
|
260 |
def upsample_predicted_pc(self, num_samples, pc_obj):
|
|
|
180 |
mask=torch.stack(batch['masks']).to('cuda'),
|
181 |
scheduler=cfg.run.diffusion_scheduler,
|
182 |
num_inference_steps=cfg.run.num_inference_steps,
|
183 |
+
eta=cfg.model.ddim_eta,
|
184 |
)
|
185 |
# segment and normalize human/object
|
186 |
bs = len(out_stage1)
|
|
|
255 |
radius_hum=radius_hum.unsqueeze(-1),
|
256 |
radius_obj=radius_obj.unsqueeze(-1),
|
257 |
sample_from_interm=True,
|
258 |
+
noise_step=cfg.run.sample_noise_step,
|
259 |
+
scheduler=cfg.run.diffusion_scheduler,
|
260 |
+
num_inference_steps=cfg.run.num_inference_steps,
|
261 |
+
eta=cfg.model.ddim_eta,
|
262 |
+
)
|
263 |
return out_stage1, out_stage2
|
264 |
|
265 |
def upsample_predicted_pc(self, num_samples, pc_obj):
|
model/model_hoattn.py
CHANGED
@@ -11,6 +11,7 @@ import numpy as np
|
|
11 |
|
12 |
from pytorch3d.structures import Pointclouds
|
13 |
from pytorch3d.renderer import CamerasBase
|
|
|
14 |
from .model_diff_data import ConditionalPCDiffusionBehave
|
15 |
from .pvcnn.pvcnn_ho import PVCNN2HumObj
|
16 |
import torch.nn.functional as F
|
@@ -375,17 +376,30 @@ class CrossAttenHODiffusionModel(ConditionalPCDiffusionBehave):
|
|
375 |
|
376 |
return (output, all_outputs) if return_all_outputs else output
|
377 |
|
378 |
-
def get_reverse_timesteps(self, scheduler, interm_steps:int):
|
379 |
"""
|
380 |
-
|
381 |
:param scheduler:
|
382 |
-
:param interm_steps: start from some intermediate steps
|
|
|
383 |
:return:
|
384 |
"""
|
385 |
-
if
|
386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
else:
|
388 |
-
|
389 |
return timesteps
|
390 |
|
391 |
def pack_norm_params(self, kwargs:dict, scale=True):
|
|
|
11 |
|
12 |
from pytorch3d.structures import Pointclouds
|
13 |
from pytorch3d.renderer import CamerasBase
|
14 |
+
from diffusers.schedulers import DDPMScheduler, DDIMScheduler
|
15 |
from .model_diff_data import ConditionalPCDiffusionBehave
|
16 |
from .pvcnn.pvcnn_ho import PVCNN2HumObj
|
17 |
import torch.nn.functional as F
|
|
|
376 |
|
377 |
return (output, all_outputs) if return_all_outputs else output
|
378 |
|
379 |
+
def get_reverse_timesteps(self, scheduler, interm_steps: int):
|
380 |
"""
|
381 |
+
get the timesteps to run reverse diffusion
|
382 |
:param scheduler:
|
383 |
+
:param interm_steps: start from some intermediate steps, the step number is for DDPM scheduler
|
384 |
+
if DDIM, will be recomputed accordingly
|
385 |
:return:
|
386 |
"""
|
387 |
+
if isinstance(scheduler, DDPMScheduler):
|
388 |
+
# DDPM, directly reverse N steps from interm_steps
|
389 |
+
if interm_steps > 0:
|
390 |
+
timesteps = torch.from_numpy(np.arange(0, interm_steps)[::-1].copy()).to(self.device)
|
391 |
+
else:
|
392 |
+
timesteps = scheduler.timesteps.to(self.device)
|
393 |
+
elif isinstance(scheduler, DDIMScheduler):
|
394 |
+
if interm_steps > 0:
|
395 |
+
# compute a step ratio, and find the intermediate steps for DDIM
|
396 |
+
step_ratio = scheduler.config.num_train_timesteps // scheduler.num_inference_steps
|
397 |
+
timesteps = (np.arange(0, interm_steps, step_ratio)).round()[::-1].copy().astype(np.int64)
|
398 |
+
timesteps = torch.from_numpy(timesteps).to(self.device)
|
399 |
+
else:
|
400 |
+
timesteps = scheduler.timesteps.to(self.device)
|
401 |
else:
|
402 |
+
raise NotImplementedError
|
403 |
return timesteps
|
404 |
|
405 |
def pack_norm_params(self, kwargs:dict, scale=True):
|