PoseDiffusion_MVP / util /geometry_guided_sampling.py
hugoycj
Initial commit
3d3e4e9
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from typing import Dict, List, Optional, Union
from util.camera_transform import pose_encoding_to_camera
from util.get_fundamental_matrix import get_fundamental_matrices
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
def geometry_guided_sampling(
model_mean: torch.Tensor,
t: int,
matches_dict: Dict,
GGS_cfg: Dict,
):
# pre-process matches
b, c, h, w = matches_dict["img_shape"]
device = model_mean.device
def _to_device(tensor):
return torch.from_numpy(tensor).to(device)
kp1 = _to_device(matches_dict["kp1"])
kp2 = _to_device(matches_dict["kp2"])
i12 = _to_device(matches_dict["i12"])
pair_idx = i12[:, 0] * b + i12[:, 1]
pair_idx = pair_idx.long()
def _to_homogeneous(tensor):
return torch.nn.functional.pad(tensor, [0, 1], value=1)
kp1_homo = _to_homogeneous(kp1)
kp2_homo = _to_homogeneous(kp2)
i1, i2 = [
i.reshape(-1) for i in torch.meshgrid(torch.arange(b), torch.arange(b))
]
processed_matches = {
"kp1_homo": kp1_homo,
"kp2_homo": kp2_homo,
"i1": i1,
"i2": i2,
"h": h,
"w": w,
"pair_idx": pair_idx,
}
# conduct GGS
model_mean = GGS_optimize(model_mean, t, processed_matches, **GGS_cfg)
# Optimize FL, R, and T separately
model_mean = GGS_optimize(
model_mean,
t,
processed_matches,
update_T=False,
update_R=False,
update_FL=True,
**GGS_cfg,
) # only optimize FL
model_mean = GGS_optimize(
model_mean,
t,
processed_matches,
update_T=False,
update_R=True,
update_FL=False,
**GGS_cfg,
) # only optimize R
model_mean = GGS_optimize(
model_mean,
t,
processed_matches,
update_T=True,
update_R=False,
update_FL=False,
**GGS_cfg,
) # only optimize T
model_mean = GGS_optimize(model_mean, t, processed_matches, **GGS_cfg)
return model_mean
def GGS_optimize(
model_mean: torch.Tensor,
t: int,
processed_matches: Dict,
update_R: bool = True,
update_T: bool = True,
update_FL: bool = True,
# the args below come from **GGS_cfg
alpha: float = 0.0001,
learning_rate: float = 1e-2,
iter_num: int = 100,
sampson_max: int = 10,
min_matches: int = 10,
pose_encoding_type: str = "absT_quaR_logFL",
**kwargs,
):
with torch.enable_grad():
model_mean.requires_grad_(True)
if update_R and update_T and update_FL:
iter_num = iter_num * 2
optimizer = torch.optim.SGD(
[model_mean], lr=learning_rate, momentum=0.9
)
batch_size = model_mean.shape[1]
for _ in range(iter_num):
valid_sampson, sampson_to_print = compute_sampson_distance(
model_mean,
t,
processed_matches,
update_R=update_R,
update_T=update_T,
update_FL=update_FL,
pose_encoding_type=pose_encoding_type,
sampson_max=sampson_max,
)
if min_matches > 0:
valid_match_per_frame = len(valid_sampson) / batch_size
if valid_match_per_frame < min_matches:
print(
"Drop this pair because of insufficient valid matches"
)
break
loss = valid_sampson.mean()
optimizer.zero_grad()
loss.backward()
grads = model_mean.grad
grad_norm = grads.norm()
grad_mask = (grads.abs() > 0).detach()
model_mean_norm = (model_mean * grad_mask).norm()
max_norm = alpha * model_mean_norm / learning_rate
total_norm = torch.nn.utils.clip_grad_norm_(model_mean, max_norm)
optimizer.step()
print(f"t={t:02d} | sampson={sampson_to_print:05f}")
model_mean = model_mean.detach()
return model_mean
def compute_sampson_distance(
model_mean: torch.Tensor,
t: int,
processed_matches: Dict,
update_R=True,
update_T=True,
update_FL=True,
pose_encoding_type: str = "absT_quaR_logFL",
sampson_max: int = 10,
):
camera = pose_encoding_to_camera(model_mean, pose_encoding_type)
# pick the mean of the predicted focal length
camera.focal_length = camera.focal_length.mean(dim=0).repeat(
len(camera.focal_length), 1
)
if not update_R:
camera.R = camera.R.detach()
if not update_T:
camera.T = camera.T.detach()
if not update_FL:
camera.focal_length = camera.focal_length.detach()
kp1_homo, kp2_homo, i1, i2, he, wi, pair_idx = processed_matches.values()
F_2_to_1 = get_fundamental_matrices(
camera, he, wi, i1, i2, l2_normalize_F=False
)
F = F_2_to_1.permute(0, 2, 1) # y1^T F y2 = 0
def _sampson_distance(F, kp1_homo, kp2_homo, pair_idx):
left = torch.bmm(kp1_homo[:, None], F[pair_idx])
right = torch.bmm(F[pair_idx], kp2_homo[..., None])
bottom = (
left[:, :, 0].square()
+ left[:, :, 1].square()
+ right[:, 0, :].square()
+ right[:, 1, :].square()
)
top = torch.bmm(left, kp2_homo[..., None]).square()
sampson = top[:, 0] / bottom
return sampson
sampson = _sampson_distance(
F,
kp1_homo.float(),
kp2_homo.float(),
pair_idx,
)
sampson_to_print = sampson.detach().clone().clamp(max=sampson_max).mean()
sampson = sampson[sampson < sampson_max]
return sampson, sampson_to_print