# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from pytorch3d.transforms.rotation_conversions import ( matrix_to_quaternion, quaternion_to_matrix, ) from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras def pose_encoding_to_camera( pose_encoding, pose_encoding_type="absT_quaR_logFL", log_focal_length_bias=1.8, min_focal_length=0.1, max_focal_length=20, ): """ Args: pose_encoding: A tensor of shape `BxNxC`, containing a batch of `BxN` `C`-dimensional pose encodings. pose_encoding_type: The type of pose encoding, only "absT_quaR_logFL" is supported. """ batch_size, num_poses, _ = pose_encoding.shape pose_encoding_reshaped = pose_encoding.reshape( -1, pose_encoding.shape[-1] ) # Reshape to BNxC if pose_encoding_type == "absT_quaR_logFL": # forced that 3 for absT, 4 for quaR, 2 logFL # TODO: converted to 1 dim for logFL, consistent with our paper abs_T = pose_encoding_reshaped[:, :3] quaternion_R = pose_encoding_reshaped[:, 3:7] R = quaternion_to_matrix(quaternion_R) log_focal_length = pose_encoding_reshaped[:, 7:9] # log_focal_length_bias was the hyperparameter # to ensure the mean of logFL close to 0 during training # Now converted back focal_length = (log_focal_length + log_focal_length_bias).exp() # clamp to avoid weird fl values focal_length = torch.clamp( focal_length, min=min_focal_length, max=max_focal_length ) else: raise ValueError(f"Unknown pose encoding {pose_encoding_type}") pred_cameras = PerspectiveCameras( focal_length=focal_length, R=R, T=abs_T, device=R.device, ) return pred_cameras