Spaces:
Sleeping
Sleeping
from typing import Optional, Union | |
import torch | |
from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler | |
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler | |
from diffusers import ModelMixin | |
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData | |
from pytorch3d.renderer import PointsRasterizationSettings, PointsRasterizer | |
from pytorch3d.renderer.cameras import CamerasBase | |
from pytorch3d.structures import Pointclouds | |
from torch import Tensor | |
from .feature_model import FeatureModel | |
from .model_utils import compute_distance_transform | |
SchedulerClass = Union[DDPMScheduler, DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] | |
class PointCloudProjectionModel(ModelMixin): | |
def __init__( | |
self, | |
image_size: int, | |
image_feature_model: str, | |
use_local_colors: bool = True, | |
use_local_features: bool = True, | |
use_global_features: bool = False, | |
use_mask: bool = True, | |
use_distance_transform: bool = True, | |
predict_shape: bool = True, | |
predict_color: bool = False, | |
process_color: bool = False, | |
image_color_channels: int = 3, # for the input image, not the points | |
color_channels: int = 3, # for the points, not the input image | |
colors_mean: float = 0.5, | |
colors_std: float = 0.5, | |
scale_factor: float = 1.0, | |
# Rasterization settings | |
raster_point_radius: float = 0.0075, # point size | |
raster_points_per_pixel: int = 1, # a single point per pixel, for now | |
bin_size: int = 0, | |
model_name=None, | |
# additional arguments added by XH | |
load_sample_init=False, | |
sample_init_scale=1.0, | |
test_init_with_gtpc=False, | |
consistent_center=False, # from https://arxiv.org/pdf/2308.07837.pdf | |
voxel_resolution_multiplier: int=1, | |
predict_binary: bool=False, # predict a binary class label | |
lw_binary: float=1.0, | |
binary_training_noise_std: float=0.1, | |
dm_pred_type: str='epsilon', # diffusion prediction type | |
self_conditioning=False, | |
**kwargs, | |
): | |
super().__init__() | |
self.image_size = image_size | |
self.scale_factor = scale_factor | |
self.use_local_colors = use_local_colors | |
self.use_local_features = use_local_features | |
self.use_global_features = use_global_features | |
self.use_mask = use_mask | |
self.use_distance_transform = use_distance_transform | |
self.predict_shape = predict_shape # default False | |
self.predict_color = predict_color # default True | |
self.process_color = process_color | |
self.image_color_channels = image_color_channels | |
self.color_channels = color_channels | |
self.colors_mean = colors_mean | |
self.colors_std = colors_std | |
self.model_name = model_name | |
print("PointCloud Model scale factor:", self.scale_factor, 'Model name:', self.model_name) | |
self.predict_binary = predict_binary | |
self.lw_binary = lw_binary | |
self.self_conditioning = self_conditioning | |
# Types of conditioning that are used | |
self.use_local_conditioning = self.use_local_colors or self.use_local_features or self.use_mask | |
self.use_global_conditioning = self.use_global_features | |
self.kwargs = kwargs | |
# Create feature model | |
self.feature_model = FeatureModel(image_size, image_feature_model) | |
# Input size | |
self.in_channels = 3 # 3 for 3D point positions | |
if self.use_local_colors: # whether color should be an input | |
self.in_channels += self.image_color_channels | |
if self.use_local_features: | |
self.in_channels += self.feature_model.feature_dim | |
if self.use_global_features: | |
self.in_channels += self.feature_model.feature_dim | |
if self.use_mask: | |
self.in_channels += 2 if self.use_distance_transform else 1 | |
if self.process_color: | |
self.in_channels += self.color_channels # point color added to input or not, default False | |
if self.self_conditioning: | |
self.in_channels += 3 # add self conditioning | |
self.in_channels = self.add_extra_input_chennels(self.in_channels) | |
if self.model_name in ['pc2-diff-ho-sepsegm', 'diff-ho-attn']: | |
self.in_channels += 2 if self.use_distance_transform else 1 | |
# Output size | |
self.out_channels = 0 | |
if self.predict_shape: | |
self.out_channels += 3 | |
if self.predict_color: | |
self.out_channels += self.color_channels | |
if self.predict_binary: | |
print("Output binary classification score!") | |
self.out_channels += 1 | |
# Save rasterization settings | |
self.raster_settings = PointsRasterizationSettings( | |
image_size=(image_size, image_size), | |
radius=raster_point_radius, | |
points_per_pixel=raster_points_per_pixel, | |
bin_size=bin_size, | |
) | |
def add_extra_input_chennels(self, input_channels): | |
return input_channels | |
def denormalize(self, x: Tensor, /, clamp: bool = True): | |
x = x * self.colors_std + self.colors_mean | |
return torch.clamp(x, 0, 1) if clamp else x | |
def normalize(self, x: Tensor, /): | |
x = (x - self.colors_mean) / self.colors_std | |
return x | |
def get_global_conditioning(self, image_rgb: Tensor): | |
global_conditioning = [] | |
if self.use_global_features: | |
global_conditioning.append(self.feature_model(image_rgb, | |
return_cls_token_only=True)) # (B, D) | |
global_conditioning = torch.cat(global_conditioning, dim=1) # (B, D_cond) | |
return global_conditioning | |
def get_local_conditioning(self, image_rgb: Tensor, mask: Tensor): | |
""" | |
compute per-point conditioning | |
Parameters | |
---------- | |
image_rgb: (B, 3, 224, 224), values normalized to 0-1, background is masked by the given mask | |
mask: (B, 1, 224, 224), or (B, 2, 224, 224) for h+o | |
""" | |
local_conditioning = [] | |
# import pdb; pdb.set_trace() | |
if self.use_local_colors: # XH: default True | |
local_conditioning.append(self.normalize(image_rgb)) | |
if self.use_local_features: # XH: default True | |
local_conditioning.append(self.feature_model(image_rgb)) # I guess no mask here? feature model: 'vit_small_patch16_224_mae' | |
if self.use_mask: # default True | |
local_conditioning.append(mask.float()) | |
if self.use_distance_transform: # default True | |
if not self.use_mask: | |
raise ValueError('No mask for distance transform?') | |
if mask.is_floating_point(): | |
mask = mask > 0.5 | |
local_conditioning.append(compute_distance_transform(mask)) | |
local_conditioning = torch.cat(local_conditioning, dim=1) # (B, D_cond, H, W) | |
return local_conditioning | |
def surface_projection( | |
self, points: Tensor, camera: CamerasBase, local_features: Tensor, | |
): | |
B, C, H, W, device = *local_features.shape, local_features.device | |
R = self.raster_settings.points_per_pixel | |
N = points.shape[1] | |
# Scale camera by scaling T. ASSUMES CAMERA IS LOOKING AT ORIGIN! | |
camera = camera.clone() | |
camera.T = camera.T * self.scale_factor | |
# Create rasterizer | |
rasterizer = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings) | |
# Associate points with features via rasterization | |
fragments = rasterizer(Pointclouds(points)) # (B, H, W, R) | |
fragments_idx: Tensor = fragments.idx.long() | |
visible_pixels = (fragments_idx > -1) # (B, H, W, R) | |
points_to_visible_pixels = fragments_idx[visible_pixels] | |
# Reshape local features to (B, H, W, R, C) | |
local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C) | |
# Get local features corresponding to visible points | |
local_features_proj = torch.zeros(B * N, C, device=device) | |
# local feature includes: raw RGB color, image features, mask, distance transform | |
local_features_proj[points_to_visible_pixels] = local_features[visible_pixels] | |
local_features_proj = local_features_proj.reshape(B, N, C) | |
return local_features_proj | |
def point_cloud_to_tensor(self, pc: Pointclouds, /, normalize: bool = False, scale: bool = False): | |
"""Converts a point cloud to a tensor, with color if and only if self.predict_color""" | |
points = pc.points_padded() * (self.scale_factor if scale else 1) | |
if self.predict_color and pc.features_padded() is not None: # normalize color, not point locations | |
colors = self.normalize(pc.features_padded()) if normalize else pc.features_padded() | |
return torch.cat((points, colors), dim=2) | |
else: | |
return points | |
def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False): | |
points = x[:, :, :3] / (self.scale_factor if unscale else 1) | |
if self.predict_color: | |
colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:] | |
return Pointclouds(points=points, features=colors) | |
else: | |
assert x.shape[2] == 3 | |
return Pointclouds(points=points) | |
def get_input_with_conditioning( | |
self, | |
x_t: Tensor, | |
camera: Optional[CamerasBase], | |
image_rgb: Optional[Tensor], | |
mask: Optional[Tensor], | |
t: Optional[Tensor], | |
): | |
""" Extracts local features from the input image and projects them onto the points | |
in the point cloud to obtain the input to the model. Then extracts global | |
features, replicates them across points, and concats them to the input. | |
image_rgb: masked background | |
XH: why there is no positional encoding as described by the supp?? | |
""" | |
B, N = x_t.shape[:2] | |
# Initial input is the point locations (and colors if and only if predicting color) | |
x_t_input = self.get_coord_feature(x_t) | |
# Local conditioning | |
if self.use_local_conditioning: | |
# Get local features and check that they are the same size as the input image | |
local_features = self.get_local_conditioning(image_rgb=image_rgb, mask=mask) # concatenate RGB + mask + RGB feature + distance transform | |
if local_features.shape[-2:] != image_rgb.shape[-2:]: | |
raise ValueError(f'{local_features.shape=} and {image_rgb.shape=}') | |
# Project local features. Here that we only need the point locations, not colors | |
local_features_proj = self.surface_projection(points=x_t[:, :, :3], | |
camera=camera, local_features=local_features) # (B, N, D_local) | |
x_t_input.append(local_features_proj) | |
# Global conditioning | |
if self.use_global_conditioning: # False | |
# Get and repeat global features | |
global_features = self.get_global_conditioning(image_rgb=image_rgb) # (B, D_global) | |
global_features = global_features.unsqueeze(1).expand(-1, N, -1) # (B, D_global, N) | |
x_t_input.append(global_features) | |
# Concatenate together all the pointwise features | |
x_t_input = torch.cat(x_t_input, dim=2) # (B, N, D) | |
return x_t_input | |
def get_coord_feature(self, x_t): | |
"""get coordinate feature, for model that uses separate model to predict binary, we use first 3 channels only""" | |
x_t_input = [x_t] | |
return x_t_input | |
def forward(self, batch: FrameData, mode: str = 'train', **kwargs): | |
""" The forward method may be defined differently for different models. """ | |
raise NotImplementedError() | |