Spaces:
Sleeping
Sleeping
""" | |
Demo for template-free reconstruction | |
python demo.py model=ho-attn run.image_path=/BS/xxie-2/work/HDM/outputs/000000017450/k1.color.jpg run.job=sample model.predict_binary=True dataset.std_coverage=3.0 | |
""" | |
import pickle as pkl | |
import sys, os | |
import os.path as osp | |
from typing import Iterable, Optional | |
import cv2 | |
from accelerate import Accelerator | |
from tqdm import tqdm | |
from glob import glob | |
sys.path.append(os.getcwd()) | |
import hydra | |
import torch | |
import numpy as np | |
import imageio | |
from torch.utils.data import DataLoader | |
from pytorch3d.datasets import R2N2, collate_batched_meshes | |
from pytorch3d.structures import Pointclouds | |
from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform | |
from pytorch3d.io import IO | |
import torchvision.transforms.functional as TVF | |
from huggingface_hub import hf_hub_download | |
import training_utils | |
from configs.structured import ProjectConfig | |
from dataset.demo_dataset import DemoDataset | |
from model import CrossAttenHODiffusionModel, ConditionalPCDiffusionSeparateSegm | |
from render.pyt3d_wrapper import PcloudRenderer | |
class DemoRunner: | |
def __init__(self, cfg: ProjectConfig): | |
cfg.model.model_name, cfg.model.predict_binary = 'pc2-diff-ho-sepsegm', True | |
model_stage1 = ConditionalPCDiffusionSeparateSegm(**cfg.model) | |
cfg.model.model_name, cfg.model.predict_binary = 'diff-ho-attn', False # stage 2 does not predict segmentation | |
model_stage2 = CrossAttenHODiffusionModel(**cfg.model) | |
# Load ckpt from hf | |
ckpt_file1 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage1_name}.pth') | |
self.load_checkpoint(ckpt_file1, model_stage1) | |
ckpt_file2 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage2_name}.pth') | |
self.load_checkpoint(ckpt_file2, model_stage2) | |
self.model_stage1, self.model_stage2 = model_stage1, model_stage2 | |
self.model_stage1.eval() | |
self.model_stage2.eval() | |
self.model_stage1.to('cuda') | |
self.model_stage2.to('cuda') | |
self.cfg = cfg | |
self.io_pc = IO() | |
# For visualization | |
self.renderer = PcloudRenderer(image_size=cfg.dataset.image_size, radius=0.0075) | |
self.rend_size = cfg.dataset.image_size | |
self.device = 'cuda' | |
def load_checkpoint(self, ckpt_file1, model_stage1, device='cpu'): | |
checkpoint = torch.load(ckpt_file1, map_location=device) | |
state_dict, key = checkpoint['model'], 'model' | |
if any(k.startswith('module.') for k in state_dict.keys()): | |
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
print('Removed "module." from checkpoint state dict') | |
missing_keys, unexpected_keys = model_stage1.load_state_dict(state_dict, strict=False) | |
print(f'Loaded model checkpoint {key} from {ckpt_file1}') | |
if len(missing_keys): | |
print(f' - Missing_keys: {missing_keys}') | |
if len(unexpected_keys): | |
print(f' - Unexpected_keys: {unexpected_keys}') | |
def reload_checkpoint(self, cat_name): | |
"load checkpoint of models fine tuned on specific categories" | |
ckpt_file1 = hf_hub_download("xiexh20/HDM-models", f'{self.cfg.run.stage1_name}-{cat_name}.pth') | |
self.load_checkpoint(ckpt_file1, self.model_stage1, device=self.device) | |
ckpt_file2 = hf_hub_download("xiexh20/HDM-models", f'{self.cfg.run.stage2_name}-{cat_name}.pth') | |
self.load_checkpoint(ckpt_file2, self.model_stage2, device=self.device) | |
def run(self): | |
"simply run the demo on given images, and save the results" | |
# Set random seed | |
training_utils.set_seed(self.cfg.run.seed) | |
outdir = osp.join(self.cfg.run.code_dir_abs, 'outputs/demo') | |
os.makedirs(outdir, exist_ok=True) | |
cfg = self.cfg | |
# Init data | |
image_files = sorted(glob(cfg.run.image_path)) | |
data = DemoDataset(image_files, | |
(cfg.dataset.image_size, cfg.dataset.image_size), | |
cfg.dataset.std_coverage) | |
dataloader = DataLoader(data, batch_size=cfg.dataloader.batch_size, | |
collate_fn=collate_batched_meshes, | |
num_workers=1, shuffle=False) | |
dataloader = dataloader | |
progress_bar = tqdm(dataloader) | |
for batch_idx, batch in enumerate(progress_bar): | |
progress_bar.set_description(f'Processing batch {batch_idx:4d} / {len(dataloader):4d}') | |
out_stage1, out_stage2 = self.forward_batch(batch, cfg) | |
bs = len(out_stage1) | |
camera_full = PerspectiveCameras( | |
R=torch.stack(batch['R']), | |
T=torch.stack(batch['T']), | |
K=torch.stack(batch['K']), | |
device='cuda', | |
in_ndc=True) | |
# save output | |
for i in range(bs): | |
image_path = str(batch['image_path']) | |
folder, fname = osp.basename(osp.dirname(image_path)), osp.splitext(osp.basename(image_path))[0] | |
out_i = osp.join(outdir, folder) | |
os.makedirs(out_i, exist_ok=True) | |
self.io_pc.save_pointcloud(data=out_stage1[i], | |
path=osp.join(out_i, f'{fname}_stage1.ply')) | |
self.io_pc.save_pointcloud(data=out_stage2[i], | |
path=osp.join(out_i, f'{fname}_stage2.ply')) | |
TVF.to_pil_image(batch['images'][i]).save(osp.join(out_i, f'{fname}_input.png')) | |
# Save metadata as well | |
metadata = dict(index=i, | |
camera=camera_full[i], | |
image_size_hw=batch['image_size_hw'][i], | |
image_path=batch['image_path'][i]) | |
torch.save(metadata, osp.join(out_i, f'{fname}_meta.pth')) | |
# Visualize | |
# front_camera = camera_full[i] | |
pc_comb = Pointclouds([out_stage1[i].points_packed(), out_stage2[i].points_packed()], | |
features=[out_stage1[i].features_packed(), out_stage2[i].features_packed()]) | |
video_file = osp.join(out_i, f'{fname}_360view.mp4') | |
video_writer = imageio.get_writer(video_file, format='FFMPEG', mode='I', fps=1) | |
# first render front view | |
rend_stage1, _ = self.renderer.render(out_stage1[i], camera_full[i], mode='mask') | |
rend_stage2, _ = self.renderer.render(out_stage2[i], camera_full[i], mode='mask') | |
comb = np.concatenate([batch['images'][i].permute(1, 2, 0).cpu().numpy(), rend_stage1, rend_stage2], 1) | |
video_writer.append_data((comb*255).astype(np.uint8)) | |
for azim in range(180, 180+360, 30): | |
R, T = look_at_view_transform(1.7, 0, azim, up=((0, -1, 0),), ) | |
side_camera = PerspectiveCameras(image_size=((self.rend_size, self.rend_size),), | |
device=self.device, | |
R=R.repeat(2, 1, 1), T=T.repeat(2, 1), | |
focal_length=self.rend_size * 1.5, | |
principal_point=((self.rend_size / 2., self.rend_size / 2.),), | |
in_ndc=False) | |
rend, mask = self.renderer.render(pc_comb, side_camera, mode='mask') | |
imgs = [batch['images'][i].permute(1, 2, 0).cpu().numpy()] | |
imgs.extend([rend[0], rend[1]]) | |
video_writer.append_data((np.concatenate(imgs, 1)*255).astype(np.uint8)) | |
print(f"Visualization saved to {out_i}") | |
def forward_batch(self, batch, cfg): | |
""" | |
forward one batch | |
:param batch: | |
:param cfg: | |
:return: predicted point clouds of stage 1 and 2 | |
""" | |
camera_full = PerspectiveCameras( | |
R=torch.stack(batch['R']), | |
T=torch.stack(batch['T']), | |
K=torch.stack(batch['K']), | |
device='cuda', | |
in_ndc=True) | |
out_stage1 = self.model_stage1.forward_sample(num_points=cfg.dataset.max_points, | |
camera=camera_full, | |
image_rgb=torch.stack(batch['images']).to('cuda'), | |
mask=torch.stack(batch['masks']).to('cuda'), | |
scheduler=cfg.run.diffusion_scheduler, | |
num_inference_steps=cfg.run.num_inference_steps, | |
eta=cfg.model.ddim_eta, | |
) | |
# segment and normalize human/object | |
bs = len(out_stage1) | |
pred_hum, pred_obj = [], [] # predicted human/object points | |
cent_hum_pred, cent_obj_pred = [], [] | |
radius_hum_pred, radius_obj_pred = [], [] | |
T_hum, T_obj = [], [] | |
num_samples = int(cfg.dataset.max_points / 2) | |
for i in range(bs): | |
pc: Pointclouds = out_stage1[i] | |
vc = pc.features_packed().cpu() # (P, 3), human is light blue [0.1, 1.0, 1.0], object light green [0.5, 1.0, 0] | |
points = pc.points_packed().cpu() # (P, 3) | |
mask_hum = vc[:, 2] > 0.5 | |
pc_hum, pc_obj = points[mask_hum], points[~mask_hum] | |
# Up/Down-sample the points | |
pc_obj = self.upsample_predicted_pc(num_samples, pc_obj) | |
pc_hum = self.upsample_predicted_pc(num_samples, pc_hum) | |
# Normalize | |
cent_hum, cent_obj = torch.mean(pc_hum, 0, keepdim=True), torch.mean(pc_obj, 0, keepdim=True) | |
scale_hum = torch.sqrt(torch.sum((pc_hum - cent_hum) ** 2, -1).max()) | |
scale_obj = torch.sqrt(torch.sum((pc_obj - cent_obj) ** 2, -1).max()) | |
pc_hum = (pc_hum - cent_hum) / (2 * scale_hum) | |
pc_obj = (pc_obj - cent_obj) / (2 * scale_obj) | |
# Also update camera parameters for separate human + object | |
T_hum_scaled = (batch['T_ho'][i] + cent_hum.squeeze(0)) / (2 * scale_hum) | |
T_obj_scaled = (batch['T_ho'][i] + cent_obj.squeeze(0)) / (2 * scale_obj) | |
pred_hum.append(pc_hum) | |
pred_obj.append(pc_obj) | |
cent_hum_pred.append(cent_hum.squeeze(0)) | |
cent_obj_pred.append(cent_obj.squeeze(0)) | |
T_hum.append(T_hum_scaled * torch.tensor([-1, -1, 1])) # apply opencv to pytorch3d transform: flip x and y | |
T_obj.append(T_obj_scaled * torch.tensor([-1, -1, 1])) | |
radius_hum_pred.append(scale_hum) | |
radius_obj_pred.append(scale_obj) | |
# Pack data into a new batch dict | |
camera_hum = PerspectiveCameras( | |
R=torch.stack(batch['R']), | |
T=torch.stack(T_hum), | |
K=torch.stack(batch['K_hum']), | |
device='cuda', | |
in_ndc=True | |
) | |
camera_obj = PerspectiveCameras( | |
R=torch.stack(batch['R']), | |
T=torch.stack(T_obj), | |
K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!! | |
device='cuda', | |
in_ndc=True | |
) | |
# use pc from predicted | |
pc_hum = Pointclouds([x.to('cuda') for x in pred_hum]) | |
pc_obj = Pointclouds([x.to('cuda') for x in pred_obj]) | |
# use center and radius from predicted | |
cent_hum = torch.stack(cent_hum_pred, 0).to('cuda') | |
cent_obj = torch.stack(cent_obj_pred, 0).to('cuda') # B, 3 | |
radius_hum = torch.stack(radius_hum_pred, 0).to('cuda') # B, 1 | |
radius_obj = torch.stack(radius_obj_pred, 0).to('cuda') | |
out_stage2: Pointclouds = self.model_stage2.forward_sample( | |
num_points=num_samples, | |
camera=camera_hum, | |
image_rgb=torch.stack(batch['images_hum'], 0).to('cuda'), | |
mask=torch.stack(batch['masks_hum'], 0).to('cuda'), | |
gt_pc=pc_hum, | |
rgb_obj=torch.stack(batch['images_obj'], 0).to('cuda'), | |
mask_obj=torch.stack(batch['masks_obj'], 0).to('cuda'), | |
pc_obj=pc_obj, | |
camera_obj=camera_obj, | |
cent_hum=cent_hum, | |
cent_obj=cent_obj, | |
radius_hum=radius_hum.unsqueeze(-1), | |
radius_obj=radius_obj.unsqueeze(-1), | |
sample_from_interm=True, | |
noise_step=cfg.run.sample_noise_step, | |
scheduler=cfg.run.diffusion_scheduler, | |
num_inference_steps=cfg.run.num_inference_steps, | |
eta=cfg.model.ddim_eta, | |
) | |
return out_stage1, out_stage2 | |
def upsample_predicted_pc(self, num_samples, pc_obj): | |
""" | |
Up/Downsample the points to given number | |
:param num_samples: the target number | |
:param pc_obj: (N, 3) | |
:return: (num_samples, 3) | |
""" | |
if len(pc_obj) > num_samples: | |
ind_obj = np.random.choice(len(pc_obj), num_samples) | |
else: | |
ind_obj = np.concatenate([np.arange(len(pc_obj)), np.random.choice(len(pc_obj), num_samples - len(pc_obj))]) | |
pc_obj = pc_obj.clone()[torch.from_numpy(ind_obj).long().to(pc_obj.device)] | |
return pc_obj | |
def main(cfg: ProjectConfig): | |
runner = DemoRunner(cfg) | |
runner.run() | |
if __name__ == '__main__': | |
main() |