Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from typing import Dict, List, Optional, Union | |
from util.camera_transform import pose_encoding_to_camera | |
from util.get_fundamental_matrix import get_fundamental_matrices | |
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras | |
def geometry_guided_sampling( | |
model_mean: torch.Tensor, | |
t: int, | |
matches_dict: Dict, | |
GGS_cfg: Dict, | |
): | |
# pre-process matches | |
b, c, h, w = matches_dict["img_shape"] | |
device = model_mean.device | |
def _to_device(tensor): | |
return torch.from_numpy(tensor).to(device) | |
kp1 = _to_device(matches_dict["kp1"]) | |
kp2 = _to_device(matches_dict["kp2"]) | |
i12 = _to_device(matches_dict["i12"]) | |
pair_idx = i12[:, 0] * b + i12[:, 1] | |
pair_idx = pair_idx.long() | |
def _to_homogeneous(tensor): | |
return torch.nn.functional.pad(tensor, [0, 1], value=1) | |
kp1_homo = _to_homogeneous(kp1) | |
kp2_homo = _to_homogeneous(kp2) | |
i1, i2 = [ | |
i.reshape(-1) for i in torch.meshgrid(torch.arange(b), torch.arange(b)) | |
] | |
processed_matches = { | |
"kp1_homo": kp1_homo, | |
"kp2_homo": kp2_homo, | |
"i1": i1, | |
"i2": i2, | |
"h": h, | |
"w": w, | |
"pair_idx": pair_idx, | |
} | |
# conduct GGS | |
model_mean = GGS_optimize(model_mean, t, processed_matches, **GGS_cfg) | |
# Optimize FL, R, and T separately | |
model_mean = GGS_optimize( | |
model_mean, | |
t, | |
processed_matches, | |
update_T=False, | |
update_R=False, | |
update_FL=True, | |
**GGS_cfg, | |
) # only optimize FL | |
model_mean = GGS_optimize( | |
model_mean, | |
t, | |
processed_matches, | |
update_T=False, | |
update_R=True, | |
update_FL=False, | |
**GGS_cfg, | |
) # only optimize R | |
model_mean = GGS_optimize( | |
model_mean, | |
t, | |
processed_matches, | |
update_T=True, | |
update_R=False, | |
update_FL=False, | |
**GGS_cfg, | |
) # only optimize T | |
model_mean = GGS_optimize(model_mean, t, processed_matches, **GGS_cfg) | |
return model_mean | |
def GGS_optimize( | |
model_mean: torch.Tensor, | |
t: int, | |
processed_matches: Dict, | |
update_R: bool = True, | |
update_T: bool = True, | |
update_FL: bool = True, | |
# the args below come from **GGS_cfg | |
alpha: float = 0.0001, | |
learning_rate: float = 1e-2, | |
iter_num: int = 100, | |
sampson_max: int = 10, | |
min_matches: int = 10, | |
pose_encoding_type: str = "absT_quaR_logFL", | |
**kwargs, | |
): | |
with torch.enable_grad(): | |
model_mean.requires_grad_(True) | |
if update_R and update_T and update_FL: | |
iter_num = iter_num * 2 | |
optimizer = torch.optim.SGD( | |
[model_mean], lr=learning_rate, momentum=0.9 | |
) | |
batch_size = model_mean.shape[1] | |
for _ in range(iter_num): | |
valid_sampson, sampson_to_print = compute_sampson_distance( | |
model_mean, | |
t, | |
processed_matches, | |
update_R=update_R, | |
update_T=update_T, | |
update_FL=update_FL, | |
pose_encoding_type=pose_encoding_type, | |
sampson_max=sampson_max, | |
) | |
if min_matches > 0: | |
valid_match_per_frame = len(valid_sampson) / batch_size | |
if valid_match_per_frame < min_matches: | |
print( | |
"Drop this pair because of insufficient valid matches" | |
) | |
break | |
loss = valid_sampson.mean() | |
optimizer.zero_grad() | |
loss.backward() | |
grads = model_mean.grad | |
grad_norm = grads.norm() | |
grad_mask = (grads.abs() > 0).detach() | |
model_mean_norm = (model_mean * grad_mask).norm() | |
max_norm = alpha * model_mean_norm / learning_rate | |
total_norm = torch.nn.utils.clip_grad_norm_(model_mean, max_norm) | |
optimizer.step() | |
print(f"t={t:02d} | sampson={sampson_to_print:05f}") | |
model_mean = model_mean.detach() | |
return model_mean | |
def compute_sampson_distance( | |
model_mean: torch.Tensor, | |
t: int, | |
processed_matches: Dict, | |
update_R=True, | |
update_T=True, | |
update_FL=True, | |
pose_encoding_type: str = "absT_quaR_logFL", | |
sampson_max: int = 10, | |
): | |
camera = pose_encoding_to_camera(model_mean, pose_encoding_type) | |
# pick the mean of the predicted focal length | |
camera.focal_length = camera.focal_length.mean(dim=0).repeat( | |
len(camera.focal_length), 1 | |
) | |
if not update_R: | |
camera.R = camera.R.detach() | |
if not update_T: | |
camera.T = camera.T.detach() | |
if not update_FL: | |
camera.focal_length = camera.focal_length.detach() | |
kp1_homo, kp2_homo, i1, i2, he, wi, pair_idx = processed_matches.values() | |
F_2_to_1 = get_fundamental_matrices( | |
camera, he, wi, i1, i2, l2_normalize_F=False | |
) | |
F = F_2_to_1.permute(0, 2, 1) # y1^T F y2 = 0 | |
def _sampson_distance(F, kp1_homo, kp2_homo, pair_idx): | |
left = torch.bmm(kp1_homo[:, None], F[pair_idx]) | |
right = torch.bmm(F[pair_idx], kp2_homo[..., None]) | |
bottom = ( | |
left[:, :, 0].square() | |
+ left[:, :, 1].square() | |
+ right[:, 0, :].square() | |
+ right[:, 1, :].square() | |
) | |
top = torch.bmm(left, kp2_homo[..., None]).square() | |
sampson = top[:, 0] / bottom | |
return sampson | |
sampson = _sampson_distance( | |
F, | |
kp1_homo.float(), | |
kp2_homo.float(), | |
pair_idx, | |
) | |
sampson_to_print = sampson.detach().clone().clamp(max=sampson_max).mean() | |
sampson = sampson[sampson < sampson_max] | |
return sampson, sampson_to_print | |