Spaces:
Runtime error
Runtime error
hugoycj
commited on
Commit
•
3d3e4e9
1
Parent(s):
7bf852f
Initial commit
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- app.py +280 -0
- cfgs/default.yaml +40 -0
- cfgs/fast.yaml +40 -0
- examples/71165193657__AED15223-1435-44B6-AFC1-884527CE1642.mp4 +3 -0
- models/__init__.py +12 -0
- models/denoiser.py +179 -0
- models/gaussian_diffuser.py +410 -0
- models/image_feature_extractor.py +108 -0
- models/pose_diffusion_model.py +126 -0
- packages.txt +0 -0
- pre-requirements.txt +2 -0
- requirements.txt +5 -0
- util/__init__.py +7 -0
- util/camera_transform.py +63 -0
- util/embedding.py +62 -0
- util/geometry_guided_sampling.py +215 -0
- util/get_fundamental_matrix.py +57 -0
- util/load_img_folder.py +157 -0
- util/match_extraction.py +175 -0
- util/metric.py +22 -0
- util/utils.py +17 -0
- weights/co3d_model_Apr16.pth +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from datetime import datetime
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import torch
|
11 |
+
from typing import Dict, List, Optional, Union
|
12 |
+
from omegaconf import OmegaConf, DictConfig
|
13 |
+
import hydra
|
14 |
+
from hydra.utils import instantiate, get_original_cwd
|
15 |
+
import time
|
16 |
+
from functools import partial
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import shutil
|
19 |
+
from util.utils import seed_all_random_engines
|
20 |
+
from util.match_extraction import extract_match
|
21 |
+
from util.load_img_folder import load_and_preprocess_images
|
22 |
+
from util.geometry_guided_sampling import geometry_guided_sampling
|
23 |
+
from pytorch3d.vis.plotly_vis import get_camera_wireframe
|
24 |
+
import subprocess
|
25 |
+
import tempfile
|
26 |
+
import gradio as gr
|
27 |
+
|
28 |
+
def plot_cameras(ax, cameras, color: str = "blue"):
|
29 |
+
"""
|
30 |
+
Plots a set of `cameras` objects into the maplotlib axis `ax` with
|
31 |
+
color `color`.
|
32 |
+
"""
|
33 |
+
cam_wires_canonical = get_camera_wireframe().cuda()[None]
|
34 |
+
cam_trans = cameras.get_world_to_view_transform().inverse()
|
35 |
+
cam_wires_trans = cam_trans.transform_points(cam_wires_canonical)
|
36 |
+
plot_handles = []
|
37 |
+
for wire in cam_wires_trans:
|
38 |
+
# the Z and Y axes are flipped intentionally here!
|
39 |
+
x_, z_, y_ = wire.detach().cpu().numpy().T.astype(float)
|
40 |
+
(h,) = ax.plot(x_, y_, z_, color=color, linewidth=0.3)
|
41 |
+
plot_handles.append(h)
|
42 |
+
return plot_handles
|
43 |
+
|
44 |
+
def create_matplotlib_figure(pred_cameras):
|
45 |
+
fig = plt.figure()
|
46 |
+
ax = fig.add_subplot(projection="3d")
|
47 |
+
ax.clear()
|
48 |
+
handle_cam = plot_cameras(ax, pred_cameras, color="#FF7D1E")
|
49 |
+
plot_radius = 3
|
50 |
+
ax.set_xlim3d([-plot_radius, plot_radius])
|
51 |
+
ax.set_ylim3d([3 - plot_radius, 3 + plot_radius])
|
52 |
+
ax.set_zlim3d([-plot_radius, plot_radius])
|
53 |
+
ax.set_xlabel("x")
|
54 |
+
ax.set_ylabel("z")
|
55 |
+
ax.set_zlabel("y")
|
56 |
+
labels_handles = {
|
57 |
+
"Estimated cameras": handle_cam[0],
|
58 |
+
}
|
59 |
+
ax.legend(
|
60 |
+
labels_handles.values(),
|
61 |
+
labels_handles.keys(),
|
62 |
+
loc="upper center",
|
63 |
+
bbox_to_anchor=(0.5, 0),
|
64 |
+
)
|
65 |
+
|
66 |
+
return plt
|
67 |
+
|
68 |
+
import os
|
69 |
+
import json
|
70 |
+
import tempfile
|
71 |
+
from PIL import Image
|
72 |
+
|
73 |
+
|
74 |
+
def convert_extrinsics_pytorch3d_to_opengl(extrinsics: torch.Tensor) -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
Convert extrinsics from PyTorch3D coordinate system to OpenGL coordinate system.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
extrinsics (torch.Tensor): a 4x4 extrinsic matrix in PyTorch3D coordinate system.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
torch.Tensor: a 4x4 extrinsic matrix in OpenGL coordinate system.
|
83 |
+
"""
|
84 |
+
# Create a transformation matrix that flips the Z-axis
|
85 |
+
flip_z = torch.eye(4)
|
86 |
+
flip_z[2, 2] = -1
|
87 |
+
flip_z[0, 0] = -1
|
88 |
+
|
89 |
+
# Multiply the extrinsic matrix by the transformation matrix
|
90 |
+
extrinsics_opengl = torch.mm(extrinsics, flip_z)
|
91 |
+
|
92 |
+
return extrinsics_opengl
|
93 |
+
|
94 |
+
import json
|
95 |
+
from typing import List, Dict, Any
|
96 |
+
|
97 |
+
def create_camera_json(extrinsics: Any, focal_length_world: float, principle_points: List[float], image_size: int) -> str:
|
98 |
+
# Initialize the dictionary
|
99 |
+
camera_dict = {
|
100 |
+
"w": image_size,
|
101 |
+
"h": image_size,
|
102 |
+
"fl_x": float(focal_length_world[0]),
|
103 |
+
"fl_y": float(focal_length_world[1]),
|
104 |
+
"cx": float(principle_points[0]),
|
105 |
+
"cy": float(principle_points[1]),
|
106 |
+
"k1": 0.0, # Assuming these values are not provided
|
107 |
+
"k2": 0.0, # Assuming these values are not provided
|
108 |
+
"p1": 0.0, # Assuming these values are not provided
|
109 |
+
"p2": 0.0, # Assuming these values are not provided
|
110 |
+
"camera_model": "OPENCV",
|
111 |
+
"frames": []
|
112 |
+
}
|
113 |
+
|
114 |
+
# Add frames to the dictionary
|
115 |
+
for i, extrinsic in enumerate(extrinsics):
|
116 |
+
frame = {
|
117 |
+
"file_path": f"images/frame_{str(i).zfill(5)}.jpg",
|
118 |
+
"transform_matrix": extrinsic.tolist(),
|
119 |
+
"colmap_im_id": i
|
120 |
+
}
|
121 |
+
# Convert numpy float32 to Python's native float
|
122 |
+
frame["transform_matrix"] = [[float(element) for element in row] for row in frame["transform_matrix"]]
|
123 |
+
camera_dict["frames"].append(frame)
|
124 |
+
|
125 |
+
return camera_dict
|
126 |
+
|
127 |
+
def archieve_images_and_transforms(images, pred_cameras, image_size):
|
128 |
+
images_array = images.permute(0, 2, 3, 1).cpu().numpy() * 255
|
129 |
+
images_pil = [Image.fromarray(image.astype('uint8')) for image in images_array]
|
130 |
+
|
131 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
132 |
+
images_dir = os.path.join(temp_dir, 'images')
|
133 |
+
os.makedirs(images_dir, exist_ok=True)
|
134 |
+
|
135 |
+
images_path = []
|
136 |
+
for i, image in enumerate(images_pil):
|
137 |
+
image_path = os.path.join(images_dir, 'frame_{:05d}.jpg'.format(i))
|
138 |
+
image.save(image_path)
|
139 |
+
images_path.append(image_path)
|
140 |
+
|
141 |
+
cam_trans = pred_cameras.get_world_to_view_transform()
|
142 |
+
extrinsics = cam_trans.inverse().get_matrix().cpu()
|
143 |
+
extrinsics = [convert_extrinsics_pytorch3d_to_opengl(extrinsic.T) for extrinsic in extrinsics]
|
144 |
+
|
145 |
+
focal_length_ndc = pred_cameras.focal_length.mean(dim=0).cpu().numpy()
|
146 |
+
focal_length_world = focal_length_ndc * image_size / 2
|
147 |
+
principle_points = [image_size / 2, image_size / 2]
|
148 |
+
camera_dict = create_camera_json(extrinsics, focal_length_world, principle_points, image_size)
|
149 |
+
|
150 |
+
json_path = os.path.join(temp_dir, 'transforms.json')
|
151 |
+
with open(json_path, 'w') as f:
|
152 |
+
json.dump(camera_dict, f, indent=4)
|
153 |
+
|
154 |
+
project_name = datetime.now().strftime("%Y%m%d-%H%M%S")
|
155 |
+
shutil.make_archive(f'/tmp/{project_name}', 'zip', temp_dir)
|
156 |
+
return f'/tmp/{project_name}.zip'
|
157 |
+
|
158 |
+
def estimate_images_pose(image_folder, mode) -> None:
|
159 |
+
print("Slected mode:", mode)
|
160 |
+
with hydra.initialize(config_path="./cfgs/"):
|
161 |
+
cfg = hydra.compose(config_name=mode)
|
162 |
+
|
163 |
+
OmegaConf.set_struct(cfg, False)
|
164 |
+
print("Model Config:")
|
165 |
+
print(OmegaConf.to_yaml(cfg))
|
166 |
+
|
167 |
+
# Check for GPU availability and set the device
|
168 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
169 |
+
|
170 |
+
# Instantiate the model
|
171 |
+
model = instantiate(cfg.MODEL, _recursive_=False)
|
172 |
+
|
173 |
+
# Load and preprocess images
|
174 |
+
images, image_info = load_and_preprocess_images(image_folder, cfg.image_size)
|
175 |
+
|
176 |
+
# Load checkpoint
|
177 |
+
ckpt_path = os.path.join(cfg.ckpt)
|
178 |
+
if os.path.isfile(ckpt_path):
|
179 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
180 |
+
model.load_state_dict(checkpoint, strict=True)
|
181 |
+
print(f"Loaded checkpoint from: {ckpt_path}")
|
182 |
+
else:
|
183 |
+
raise ValueError(f"No checkpoint found at: {ckpt_path}")
|
184 |
+
|
185 |
+
# Move model and images to the GPU
|
186 |
+
model = model.to(device)
|
187 |
+
images = images.to(device)
|
188 |
+
|
189 |
+
# Evaluation Mode
|
190 |
+
model.eval()
|
191 |
+
|
192 |
+
# Seed random engines
|
193 |
+
seed_all_random_engines(cfg.seed)
|
194 |
+
|
195 |
+
# Start the timer
|
196 |
+
start_time = time.time()
|
197 |
+
|
198 |
+
# Perform match extraction
|
199 |
+
if cfg.GGS.enable:
|
200 |
+
# Optional TODO: remove the keypoints outside the cropped region?
|
201 |
+
|
202 |
+
kp1, kp2, i12 = extract_match(image_folder, image_info)
|
203 |
+
|
204 |
+
keys = ["kp1", "kp2", "i12", "img_shape"]
|
205 |
+
values = [kp1, kp2, i12, images.shape]
|
206 |
+
matches_dict = dict(zip(keys, values))
|
207 |
+
|
208 |
+
cfg.GGS.pose_encoding_type = cfg.MODEL.pose_encoding_type
|
209 |
+
GGS_cfg = OmegaConf.to_container(cfg.GGS)
|
210 |
+
|
211 |
+
cond_fn = partial(
|
212 |
+
geometry_guided_sampling, matches_dict=matches_dict, GGS_cfg=GGS_cfg
|
213 |
+
)
|
214 |
+
print("[92m=====> Sampling with GGS <=====[0m")
|
215 |
+
else:
|
216 |
+
cond_fn = None
|
217 |
+
print("[92m=====> Sampling without GGS <=====[0m")
|
218 |
+
|
219 |
+
# Forward
|
220 |
+
with torch.no_grad():
|
221 |
+
# Obtain predicted camera parameters
|
222 |
+
# pred_cameras is a PerspectiveCameras object with attributes
|
223 |
+
# pred_cameras.R, pred_cameras.T, pred_cameras.focal_length
|
224 |
+
|
225 |
+
# The poses and focal length are defined as
|
226 |
+
# NDC coordinate system in
|
227 |
+
# https://github.com/facebookresearch/pytorch3d/blob/main/docs/notes/cameras.md
|
228 |
+
pred_cameras = model(
|
229 |
+
image=images, cond_fn=cond_fn, cond_start_step=cfg.GGS.start_step
|
230 |
+
)
|
231 |
+
|
232 |
+
# Stop the timer and calculate elapsed time
|
233 |
+
end_time = time.time()
|
234 |
+
elapsed_time = end_time - start_time
|
235 |
+
print("Time taken: {:.4f} seconds".format(elapsed_time))
|
236 |
+
|
237 |
+
zip_path = archieve_images_and_transforms(images, pred_cameras, cfg.image_size)
|
238 |
+
return create_matplotlib_figure(pred_cameras), zip_path
|
239 |
+
|
240 |
+
def extract_frames_from_video(video_path: str) -> str:
|
241 |
+
"""
|
242 |
+
Extracts frames from a video file and saves them in a temporary directory.
|
243 |
+
Returns the path to the directory containing the frames.
|
244 |
+
"""
|
245 |
+
temp_dir = tempfile.mkdtemp()
|
246 |
+
output_path = os.path.join(temp_dir, "%03d.jpg")
|
247 |
+
command = [
|
248 |
+
"ffmpeg",
|
249 |
+
"-i", video_path,
|
250 |
+
"-vf", "fps=1",
|
251 |
+
output_path
|
252 |
+
]
|
253 |
+
subprocess.run(command, check=True)
|
254 |
+
return temp_dir
|
255 |
+
|
256 |
+
def estimate_video_pose(video_path: str, mode: str) -> plt.Figure:
|
257 |
+
"""
|
258 |
+
Estimates the pose of objects in a video.
|
259 |
+
"""
|
260 |
+
# Extract frames from the video
|
261 |
+
image_folder = extract_frames_from_video(video_path)
|
262 |
+
# Estimate the pose for each frame
|
263 |
+
fig = estimate_images_pose(image_folder, mode)
|
264 |
+
return fig
|
265 |
+
|
266 |
+
if __name__ == "__main__":
|
267 |
+
examples = [["examples/" + img, 'fast'] for img in os.listdir("examples/")]
|
268 |
+
# Create a Gradio interface
|
269 |
+
iface = gr.Interface(
|
270 |
+
fn=estimate_video_pose,
|
271 |
+
inputs=[gr.inputs.Video(label='video', type='mp4'),
|
272 |
+
gr.inputs.Radio(choices=['fast', 'precise'], default='fast',
|
273 |
+
label='Estimation Model, fast is quick, usually within 1 seconds; precise has higher accuracy, but usually take several minutes')],
|
274 |
+
outputs=['plot', 'file'],
|
275 |
+
title="PoseDiffusion Demo: Solving Pose Estimation via Diffusion-aided Bundle Adjustment",
|
276 |
+
description="Upload a video for object pose estimation. The object should be centrally located within the frame.",
|
277 |
+
examples=examples,
|
278 |
+
cache_examples=True
|
279 |
+
)
|
280 |
+
iface.launch()
|
cfgs/default.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_folder: samples/apple
|
2 |
+
image_size: 224
|
3 |
+
ckpt: weights/co3d_model_Apr16.pth
|
4 |
+
seed: 0
|
5 |
+
|
6 |
+
GGS:
|
7 |
+
enable: True
|
8 |
+
start_step: 10
|
9 |
+
learning_rate: 0.01
|
10 |
+
iter_num: 100
|
11 |
+
sampson_max: 10
|
12 |
+
min_matches: 10
|
13 |
+
alpha: 0.0001
|
14 |
+
|
15 |
+
|
16 |
+
MODEL:
|
17 |
+
_target_: models.PoseDiffusionModel
|
18 |
+
|
19 |
+
pose_encoding_type: absT_quaR_logFL
|
20 |
+
|
21 |
+
IMAGE_FEATURE_EXTRACTOR:
|
22 |
+
_target_: models.MultiScaleImageFeatureExtractor
|
23 |
+
freeze: False
|
24 |
+
|
25 |
+
DENOISER:
|
26 |
+
_target_: models.Denoiser
|
27 |
+
TRANSFORMER:
|
28 |
+
_target_: models.TransformerEncoderWrapper
|
29 |
+
d_model: 512
|
30 |
+
nhead: 4
|
31 |
+
dim_feedforward: 1024
|
32 |
+
num_encoder_layers: 8
|
33 |
+
dropout: 0.1
|
34 |
+
batch_first: True
|
35 |
+
norm_first: True
|
36 |
+
|
37 |
+
|
38 |
+
DIFFUSER:
|
39 |
+
_target_: models.GaussianDiffusion
|
40 |
+
beta_schedule: custom
|
cfgs/fast.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_folder: samples/apple
|
2 |
+
image_size: 224
|
3 |
+
ckpt: weights/co3d_model_Apr16.pth
|
4 |
+
seed: 0
|
5 |
+
|
6 |
+
GGS:
|
7 |
+
enable: False
|
8 |
+
start_step: 10
|
9 |
+
learning_rate: 0.01
|
10 |
+
iter_num: 100
|
11 |
+
sampson_max: 10
|
12 |
+
min_matches: 10
|
13 |
+
alpha: 0.0001
|
14 |
+
|
15 |
+
|
16 |
+
MODEL:
|
17 |
+
_target_: models.PoseDiffusionModel
|
18 |
+
|
19 |
+
pose_encoding_type: absT_quaR_logFL
|
20 |
+
|
21 |
+
IMAGE_FEATURE_EXTRACTOR:
|
22 |
+
_target_: models.MultiScaleImageFeatureExtractor
|
23 |
+
freeze: False
|
24 |
+
|
25 |
+
DENOISER:
|
26 |
+
_target_: models.Denoiser
|
27 |
+
TRANSFORMER:
|
28 |
+
_target_: models.TransformerEncoderWrapper
|
29 |
+
d_model: 512
|
30 |
+
nhead: 4
|
31 |
+
dim_feedforward: 1024
|
32 |
+
num_encoder_layers: 8
|
33 |
+
dropout: 0.1
|
34 |
+
batch_first: True
|
35 |
+
norm_first: True
|
36 |
+
|
37 |
+
|
38 |
+
DIFFUSER:
|
39 |
+
_target_: models.GaussianDiffusion
|
40 |
+
beta_schedule: custom
|
examples/71165193657__AED15223-1435-44B6-AFC1-884527CE1642.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:691ae5feb1286b531ad974a7e3a5859bb6de7e26b3e4f21eb208afc4af8038e7
|
3 |
+
size 512683
|
models/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .pose_diffusion_model import PoseDiffusionModel
|
8 |
+
|
9 |
+
|
10 |
+
from .denoiser import Denoiser, TransformerEncoderWrapper
|
11 |
+
from .gaussian_diffuser import GaussianDiffusion
|
12 |
+
from .image_feature_extractor import MultiScaleImageFeatureExtractor
|
models/denoiser.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from collections import defaultdict
|
9 |
+
from dataclasses import field, dataclass
|
10 |
+
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
|
11 |
+
from util.embedding import TimeStepEmbedding, PoseEmbedding
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
from hydra.utils import instantiate
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class Denoiser(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
TRANSFORMER: Dict,
|
26 |
+
target_dim: int = 9, # TODO: reduce fl dim from 2 to 1
|
27 |
+
pivot_cam_onehot: bool = True,
|
28 |
+
z_dim: int = 384,
|
29 |
+
mlp_hidden_dim: bool = 128,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.pivot_cam_onehot = pivot_cam_onehot
|
34 |
+
self.target_dim = target_dim
|
35 |
+
|
36 |
+
self.time_embed = TimeStepEmbedding()
|
37 |
+
self.pose_embed = PoseEmbedding(target_dim=self.target_dim)
|
38 |
+
|
39 |
+
first_dim = (
|
40 |
+
self.time_embed.out_dim
|
41 |
+
+ self.pose_embed.out_dim
|
42 |
+
+ z_dim
|
43 |
+
+ int(self.pivot_cam_onehot)
|
44 |
+
)
|
45 |
+
|
46 |
+
d_model = TRANSFORMER.d_model
|
47 |
+
self._first = nn.Linear(first_dim, d_model)
|
48 |
+
|
49 |
+
# slightly different from the paper that
|
50 |
+
# we use 2 encoder layers and 6 decoder layers
|
51 |
+
# here we use a transformer with 8 encoder layers
|
52 |
+
# call TransformerEncoderWrapper() to build a encoder-only transformer
|
53 |
+
self._trunk = instantiate(TRANSFORMER, _recursive_=False)
|
54 |
+
|
55 |
+
# TODO: change the implementation of MLP to a more mature one
|
56 |
+
self._last = MLP(
|
57 |
+
d_model,
|
58 |
+
[mlp_hidden_dim, self.target_dim],
|
59 |
+
norm_layer=nn.LayerNorm,
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(
|
63 |
+
self,
|
64 |
+
x: torch.Tensor, # B x N x dim
|
65 |
+
t: torch.Tensor, # B
|
66 |
+
z: torch.Tensor, # B x N x dim_z
|
67 |
+
):
|
68 |
+
B, N, _ = x.shape
|
69 |
+
|
70 |
+
t_emb = self.time_embed(t)
|
71 |
+
# expand t from B x C to B x N x C
|
72 |
+
t_emb = t_emb.view(B, 1, t_emb.shape[-1]).expand(-1, N, -1)
|
73 |
+
|
74 |
+
x_emb = self.pose_embed(x)
|
75 |
+
|
76 |
+
if self.pivot_cam_onehot:
|
77 |
+
# add the one hot vector identifying the first camera as pivot
|
78 |
+
cam_pivot_id = torch.zeros_like(z[..., :1])
|
79 |
+
cam_pivot_id[:, 0, ...] = 1.0
|
80 |
+
z = torch.cat([z, cam_pivot_id], dim=-1)
|
81 |
+
|
82 |
+
feed_feats = torch.cat([x_emb, t_emb, z], dim=-1)
|
83 |
+
|
84 |
+
input_ = self._first(feed_feats)
|
85 |
+
|
86 |
+
feats_ = self._trunk(input_)
|
87 |
+
|
88 |
+
output = self._last(feats_)
|
89 |
+
|
90 |
+
return output
|
91 |
+
|
92 |
+
|
93 |
+
def TransformerEncoderWrapper(
|
94 |
+
d_model: int,
|
95 |
+
nhead: int,
|
96 |
+
num_encoder_layers: int,
|
97 |
+
dim_feedforward: int = 2048,
|
98 |
+
dropout: float = 0.1,
|
99 |
+
norm_first: bool = True,
|
100 |
+
batch_first: bool = True,
|
101 |
+
):
|
102 |
+
encoder_layer = torch.nn.TransformerEncoderLayer(
|
103 |
+
d_model=d_model,
|
104 |
+
nhead=nhead,
|
105 |
+
dim_feedforward=dim_feedforward,
|
106 |
+
dropout=dropout,
|
107 |
+
batch_first=batch_first,
|
108 |
+
norm_first=norm_first,
|
109 |
+
)
|
110 |
+
|
111 |
+
_trunk = torch.nn.TransformerEncoder(encoder_layer, num_encoder_layers)
|
112 |
+
return _trunk
|
113 |
+
|
114 |
+
|
115 |
+
class MLP(torch.nn.Sequential):
|
116 |
+
"""This block implements the multi-layer perceptron (MLP) module.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
in_channels (int): Number of channels of the input
|
120 |
+
hidden_channels (List[int]): List of the hidden channel dimensions
|
121 |
+
norm_layer (Callable[..., torch.nn.Module], optional):
|
122 |
+
Norm layer that will be stacked on top of the convolution layer.
|
123 |
+
If ``None`` this layer wont be used. Default: ``None``
|
124 |
+
activation_layer (Callable[..., torch.nn.Module], optional):
|
125 |
+
Activation function which will be stacked on top of the
|
126 |
+
normalization layer (if not None), otherwise on top of the
|
127 |
+
conv layer. If ``None`` this layer wont be used.
|
128 |
+
Default: ``torch.nn.ReLU``
|
129 |
+
inplace (bool): Parameter for the activation layer, which can
|
130 |
+
optionally do the operation in-place. Default ``True``
|
131 |
+
bias (bool): Whether to use bias in the linear layer. Default ``True``
|
132 |
+
dropout (float): The probability for the dropout layer. Default: 0.0
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
in_channels: int,
|
138 |
+
hidden_channels: List[int],
|
139 |
+
norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
|
140 |
+
activation_layer: Optional[
|
141 |
+
Callable[..., torch.nn.Module]
|
142 |
+
] = torch.nn.ReLU,
|
143 |
+
inplace: Optional[bool] = True,
|
144 |
+
bias: bool = True,
|
145 |
+
norm_first: bool = False,
|
146 |
+
dropout: float = 0.0,
|
147 |
+
):
|
148 |
+
# The addition of `norm_layer` is inspired from
|
149 |
+
# the implementation of TorchMultimodal:
|
150 |
+
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
|
151 |
+
params = {} if inplace is None else {"inplace": inplace}
|
152 |
+
|
153 |
+
layers = []
|
154 |
+
in_dim = in_channels
|
155 |
+
|
156 |
+
for hidden_dim in hidden_channels[:-1]:
|
157 |
+
if norm_first and norm_layer is not None:
|
158 |
+
layers.append(norm_layer(in_dim))
|
159 |
+
|
160 |
+
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
|
161 |
+
|
162 |
+
if not norm_first and norm_layer is not None:
|
163 |
+
layers.append(norm_layer(hidden_dim))
|
164 |
+
|
165 |
+
layers.append(activation_layer(**params))
|
166 |
+
|
167 |
+
if dropout > 0:
|
168 |
+
layers.append(torch.nn.Dropout(dropout, **params))
|
169 |
+
|
170 |
+
in_dim = hidden_dim
|
171 |
+
|
172 |
+
if norm_first and norm_layer is not None:
|
173 |
+
layers.append(norm_layer(in_dim))
|
174 |
+
|
175 |
+
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
|
176 |
+
if dropout > 0:
|
177 |
+
layers.append(torch.nn.Dropout(dropout, **params))
|
178 |
+
|
179 |
+
super().__init__(*layers)
|
models/gaussian_diffuser.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/beb2f2d8dd9b4f2bd5be4719f37082fe061ee450/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
8 |
+
|
9 |
+
import math
|
10 |
+
import copy
|
11 |
+
from pathlib import Path
|
12 |
+
from random import random
|
13 |
+
from functools import partial
|
14 |
+
from collections import namedtuple
|
15 |
+
from multiprocessing import cpu_count
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn, einsum
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch.utils.data import Dataset, DataLoader
|
21 |
+
|
22 |
+
from torch.optim import Adam
|
23 |
+
from torchvision import transforms as T, utils
|
24 |
+
|
25 |
+
from einops import rearrange, reduce
|
26 |
+
from einops.layers.torch import Rearrange
|
27 |
+
|
28 |
+
from PIL import Image
|
29 |
+
from tqdm.auto import tqdm
|
30 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
31 |
+
|
32 |
+
# constants
|
33 |
+
|
34 |
+
ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start"])
|
35 |
+
|
36 |
+
# helpers functions
|
37 |
+
|
38 |
+
|
39 |
+
def exists(x):
|
40 |
+
return x is not None
|
41 |
+
|
42 |
+
|
43 |
+
def default(val, d):
|
44 |
+
if exists(val):
|
45 |
+
return val
|
46 |
+
return d() if callable(d) else d
|
47 |
+
|
48 |
+
|
49 |
+
def extract(a, t, x_shape):
|
50 |
+
b, *_ = t.shape
|
51 |
+
out = a.gather(-1, t)
|
52 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
53 |
+
|
54 |
+
|
55 |
+
def linear_beta_schedule(timesteps):
|
56 |
+
scale = 1000 / timesteps
|
57 |
+
beta_start = scale * 0.0001
|
58 |
+
beta_end = scale * 0.02
|
59 |
+
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
60 |
+
|
61 |
+
|
62 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
63 |
+
"""
|
64 |
+
cosine schedule
|
65 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
66 |
+
"""
|
67 |
+
steps = timesteps + 1
|
68 |
+
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
|
69 |
+
alphas_cumprod = (
|
70 |
+
torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
|
71 |
+
)
|
72 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
73 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
74 |
+
return torch.clip(betas, 0, 0.999)
|
75 |
+
|
76 |
+
|
77 |
+
class GaussianDiffusion(nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
timesteps=100,
|
81 |
+
sampling_timesteps=None,
|
82 |
+
beta_1=0.0001,
|
83 |
+
beta_T=0.1,
|
84 |
+
loss_type="l1",
|
85 |
+
objective="pred_noise",
|
86 |
+
beta_schedule="custom",
|
87 |
+
p2_loss_weight_gamma=0.0,
|
88 |
+
p2_loss_weight_k=1,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.objective = objective
|
93 |
+
|
94 |
+
assert objective in {
|
95 |
+
"pred_noise",
|
96 |
+
"pred_x0",
|
97 |
+
}, "objective must be either pred_noise (predict noise) \
|
98 |
+
or pred_x0 (predict image start)"
|
99 |
+
|
100 |
+
self.timesteps = timesteps
|
101 |
+
self.sampling_timesteps = sampling_timesteps
|
102 |
+
self.beta_1 = beta_1
|
103 |
+
self.beta_T = beta_T
|
104 |
+
self.loss_type = loss_type
|
105 |
+
self.objective = objective
|
106 |
+
self.beta_schedule = beta_schedule
|
107 |
+
self.p2_loss_weight_gamma = p2_loss_weight_gamma
|
108 |
+
self.p2_loss_weight_k = p2_loss_weight_k
|
109 |
+
|
110 |
+
self.init_diff_hyper(
|
111 |
+
self.timesteps,
|
112 |
+
self.sampling_timesteps,
|
113 |
+
self.beta_1,
|
114 |
+
self.beta_T,
|
115 |
+
self.loss_type,
|
116 |
+
self.objective,
|
117 |
+
self.beta_schedule,
|
118 |
+
self.p2_loss_weight_gamma,
|
119 |
+
self.p2_loss_weight_k,
|
120 |
+
)
|
121 |
+
|
122 |
+
def init_diff_hyper(
|
123 |
+
self,
|
124 |
+
timesteps,
|
125 |
+
sampling_timesteps,
|
126 |
+
beta_1,
|
127 |
+
beta_T,
|
128 |
+
loss_type,
|
129 |
+
objective,
|
130 |
+
beta_schedule,
|
131 |
+
p2_loss_weight_gamma,
|
132 |
+
p2_loss_weight_k,
|
133 |
+
):
|
134 |
+
if beta_schedule == "linear":
|
135 |
+
betas = linear_beta_schedule(timesteps)
|
136 |
+
elif beta_schedule == "cosine":
|
137 |
+
betas = cosine_beta_schedule(timesteps)
|
138 |
+
elif beta_schedule == "custom":
|
139 |
+
betas = torch.linspace(
|
140 |
+
beta_1, beta_T, timesteps, dtype=torch.float64
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
raise ValueError(f"unknown beta schedule {beta_schedule}")
|
144 |
+
|
145 |
+
alphas = 1.0 - betas
|
146 |
+
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
147 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
148 |
+
|
149 |
+
(timesteps,) = betas.shape
|
150 |
+
self.num_timesteps = int(timesteps)
|
151 |
+
self.loss_type = loss_type
|
152 |
+
|
153 |
+
# sampling related parameters
|
154 |
+
self.sampling_timesteps = default(
|
155 |
+
sampling_timesteps, timesteps
|
156 |
+
) # default num sampling timesteps to number of timesteps at training
|
157 |
+
|
158 |
+
assert self.sampling_timesteps <= timesteps
|
159 |
+
|
160 |
+
# helper function to register buffer from float64 to float32
|
161 |
+
register_buffer = lambda name, val: self.register_buffer(
|
162 |
+
name, val.to(torch.float32)
|
163 |
+
)
|
164 |
+
|
165 |
+
register_buffer("betas", betas)
|
166 |
+
register_buffer("alphas_cumprod", alphas_cumprod)
|
167 |
+
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
|
168 |
+
|
169 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
170 |
+
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
|
171 |
+
register_buffer(
|
172 |
+
"sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
|
173 |
+
)
|
174 |
+
register_buffer(
|
175 |
+
"log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)
|
176 |
+
)
|
177 |
+
register_buffer(
|
178 |
+
"sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)
|
179 |
+
)
|
180 |
+
register_buffer(
|
181 |
+
"sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
|
182 |
+
)
|
183 |
+
|
184 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
185 |
+
posterior_variance = (
|
186 |
+
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
187 |
+
)
|
188 |
+
|
189 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
190 |
+
register_buffer("posterior_variance", posterior_variance)
|
191 |
+
|
192 |
+
# below: log calculation clipped because the posterior variance is 0
|
193 |
+
# at the beginning of the diffusion chain
|
194 |
+
register_buffer(
|
195 |
+
"posterior_log_variance_clipped",
|
196 |
+
torch.log(posterior_variance.clamp(min=1e-20)),
|
197 |
+
)
|
198 |
+
register_buffer(
|
199 |
+
"posterior_mean_coef1",
|
200 |
+
betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
|
201 |
+
)
|
202 |
+
register_buffer(
|
203 |
+
"posterior_mean_coef2",
|
204 |
+
(1.0 - alphas_cumprod_prev)
|
205 |
+
* torch.sqrt(alphas)
|
206 |
+
/ (1.0 - alphas_cumprod),
|
207 |
+
)
|
208 |
+
|
209 |
+
# calculate p2 reweighting
|
210 |
+
register_buffer(
|
211 |
+
"p2_loss_weight",
|
212 |
+
(p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
|
213 |
+
** -p2_loss_weight_gamma,
|
214 |
+
)
|
215 |
+
|
216 |
+
# helper functions
|
217 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
218 |
+
return (
|
219 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
220 |
+
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
221 |
+
)
|
222 |
+
|
223 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
224 |
+
return (
|
225 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0
|
226 |
+
) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
227 |
+
|
228 |
+
def q_posterior(self, x_start, x_t, t):
|
229 |
+
posterior_mean = (
|
230 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
231 |
+
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
232 |
+
)
|
233 |
+
|
234 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
235 |
+
posterior_log_variance_clipped = extract(
|
236 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
237 |
+
)
|
238 |
+
return (
|
239 |
+
posterior_mean,
|
240 |
+
posterior_variance,
|
241 |
+
posterior_log_variance_clipped,
|
242 |
+
)
|
243 |
+
|
244 |
+
def q_sample(self, x_start, t, noise=None):
|
245 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
246 |
+
return (
|
247 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
248 |
+
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
249 |
+
* noise
|
250 |
+
)
|
251 |
+
|
252 |
+
def model_predictions(self, x, t, z, x_self_cond=None):
|
253 |
+
model_output = self.model(x, t, z)
|
254 |
+
|
255 |
+
if self.objective == "pred_noise":
|
256 |
+
pred_noise = model_output
|
257 |
+
x_start = self.predict_start_from_noise(x, t, model_output)
|
258 |
+
|
259 |
+
elif self.objective == "pred_x0":
|
260 |
+
pred_noise = self.predict_noise_from_start(x, t, model_output)
|
261 |
+
x_start = model_output
|
262 |
+
|
263 |
+
return ModelPrediction(pred_noise, x_start)
|
264 |
+
|
265 |
+
def p_mean_variance(
|
266 |
+
self,
|
267 |
+
x: torch.Tensor, # B x N_x x dim
|
268 |
+
t: int,
|
269 |
+
z: torch.Tensor,
|
270 |
+
x_self_cond=None,
|
271 |
+
clip_denoised=False,
|
272 |
+
):
|
273 |
+
preds = self.model_predictions(x, t, z)
|
274 |
+
|
275 |
+
x_start = preds.pred_x_start
|
276 |
+
|
277 |
+
if clip_denoised:
|
278 |
+
raise NotImplementedError(
|
279 |
+
"We don't clip the output because \
|
280 |
+
pose does not have a clear bound."
|
281 |
+
)
|
282 |
+
|
283 |
+
(
|
284 |
+
model_mean,
|
285 |
+
posterior_variance,
|
286 |
+
posterior_log_variance,
|
287 |
+
) = self.q_posterior(x_start=x_start, x_t=x, t=t)
|
288 |
+
|
289 |
+
return model_mean, posterior_variance, posterior_log_variance, x_start
|
290 |
+
|
291 |
+
@torch.no_grad()
|
292 |
+
def p_sample(
|
293 |
+
self,
|
294 |
+
x: torch.Tensor, # B x N_x x dim
|
295 |
+
t: int,
|
296 |
+
z: torch.Tensor,
|
297 |
+
x_self_cond=None,
|
298 |
+
clip_denoised=False,
|
299 |
+
cond_fn=None,
|
300 |
+
cond_start_step=0,
|
301 |
+
):
|
302 |
+
b, *_, device = *x.shape, x.device
|
303 |
+
batched_times = torch.full(
|
304 |
+
(x.shape[0],), t, device=x.device, dtype=torch.long
|
305 |
+
)
|
306 |
+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(
|
307 |
+
x=x,
|
308 |
+
t=batched_times,
|
309 |
+
z=z,
|
310 |
+
x_self_cond=x_self_cond,
|
311 |
+
clip_denoised=clip_denoised,
|
312 |
+
)
|
313 |
+
|
314 |
+
if cond_fn is not None and t < cond_start_step:
|
315 |
+
model_mean = cond_fn(model_mean, t)
|
316 |
+
noise = 0.0
|
317 |
+
else:
|
318 |
+
noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0
|
319 |
+
|
320 |
+
pred = model_mean + (0.5 * model_log_variance).exp() * noise
|
321 |
+
return pred, x_start
|
322 |
+
|
323 |
+
@torch.no_grad()
|
324 |
+
def p_sample_loop(
|
325 |
+
self,
|
326 |
+
shape,
|
327 |
+
z: torch.Tensor,
|
328 |
+
cond_fn=None,
|
329 |
+
cond_start_step=0,
|
330 |
+
):
|
331 |
+
batch, device = shape[0], self.betas.device
|
332 |
+
|
333 |
+
# Init here
|
334 |
+
pose = torch.randn(shape, device=device)
|
335 |
+
|
336 |
+
x_start = None
|
337 |
+
|
338 |
+
pose_process = []
|
339 |
+
pose_process.append(pose.unsqueeze(0))
|
340 |
+
|
341 |
+
for t in reversed(range(0, self.num_timesteps)):
|
342 |
+
pose, _ = self.p_sample(
|
343 |
+
x=pose,
|
344 |
+
t=t,
|
345 |
+
z=z,
|
346 |
+
cond_fn=cond_fn,
|
347 |
+
cond_start_step=cond_start_step,
|
348 |
+
)
|
349 |
+
pose_process.append(pose.unsqueeze(0))
|
350 |
+
|
351 |
+
return pose, torch.cat(pose_process)
|
352 |
+
|
353 |
+
@torch.no_grad()
|
354 |
+
def sample(self, shape, z, cond_fn=None, cond_start_step=0):
|
355 |
+
# TODO: add more variants
|
356 |
+
sample_fn = self.p_sample_loop
|
357 |
+
return sample_fn(
|
358 |
+
shape, z=z, cond_fn=cond_fn, cond_start_step=cond_start_step
|
359 |
+
)
|
360 |
+
|
361 |
+
def p_losses(
|
362 |
+
self,
|
363 |
+
x_start,
|
364 |
+
t,
|
365 |
+
z=None,
|
366 |
+
noise=None,
|
367 |
+
):
|
368 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
369 |
+
# noise sample
|
370 |
+
x = self.q_sample(x_start=x_start, t=t, noise=noise)
|
371 |
+
|
372 |
+
model_out = self.model(x, t, z)
|
373 |
+
|
374 |
+
if self.objective == "pred_noise":
|
375 |
+
target = noise
|
376 |
+
x_0_pred = self.predict_start_from_noise(x, t, model_out)
|
377 |
+
elif self.objective == "pred_x0":
|
378 |
+
target = x_start
|
379 |
+
x_0_pred = model_out
|
380 |
+
else:
|
381 |
+
raise ValueError(f"unknown objective {self.objective}")
|
382 |
+
|
383 |
+
loss = self.loss_fn(model_out, target, reduction="none")
|
384 |
+
|
385 |
+
loss = reduce(loss, "b ... -> b (...)", "mean")
|
386 |
+
loss = loss * extract(self.p2_loss_weight, t, loss.shape)
|
387 |
+
|
388 |
+
return {
|
389 |
+
"loss": loss,
|
390 |
+
"noise": noise,
|
391 |
+
"x_0_pred": x_0_pred,
|
392 |
+
"x_t": x,
|
393 |
+
"t": t,
|
394 |
+
}
|
395 |
+
|
396 |
+
def forward(self, pose, z=None, *args, **kwargs):
|
397 |
+
b = len(pose)
|
398 |
+
t = torch.randint(
|
399 |
+
0, self.num_timesteps, (b,), device=pose.device
|
400 |
+
).long()
|
401 |
+
return self.p_losses(pose, t, z=z, *args, **kwargs)
|
402 |
+
|
403 |
+
@property
|
404 |
+
def loss_fn(self):
|
405 |
+
if self.loss_type == "l1":
|
406 |
+
return F.l1_loss
|
407 |
+
elif self.loss_type == "l2":
|
408 |
+
return F.mse_loss
|
409 |
+
else:
|
410 |
+
raise ValueError(f"invalid loss type {self.loss_type}")
|
models/image_feature_extractor.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
import warnings
|
10 |
+
from collections import defaultdict
|
11 |
+
from dataclasses import field, dataclass
|
12 |
+
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
|
13 |
+
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torchvision
|
18 |
+
|
19 |
+
import io
|
20 |
+
from PIL import Image
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
26 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
27 |
+
|
28 |
+
|
29 |
+
class MultiScaleImageFeatureExtractor(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
modelname: str = "dino_vits16",
|
33 |
+
freeze: bool = False,
|
34 |
+
scale_factors: list = [1, 1 / 2, 1 / 3],
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
self.freeze = freeze
|
38 |
+
self.scale_factors = scale_factors
|
39 |
+
|
40 |
+
if "res" in modelname:
|
41 |
+
self._net = getattr(torchvision.models, modelname)(pretrained=True)
|
42 |
+
self._output_dim = self._net.fc.weight.shape[1]
|
43 |
+
self._net.fc = nn.Identity()
|
44 |
+
elif "dino" in modelname:
|
45 |
+
self._net = torch.hub.load("facebookresearch/dino:main", modelname)
|
46 |
+
self._output_dim = self._net.norm.weight.shape[0]
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Unknown model name {modelname}")
|
49 |
+
|
50 |
+
for name, value in (
|
51 |
+
("_resnet_mean", _RESNET_MEAN),
|
52 |
+
("_resnet_std", _RESNET_STD),
|
53 |
+
):
|
54 |
+
self.register_buffer(
|
55 |
+
name,
|
56 |
+
torch.FloatTensor(value).view(1, 3, 1, 1),
|
57 |
+
persistent=False,
|
58 |
+
)
|
59 |
+
|
60 |
+
if self.freeze:
|
61 |
+
for param in self.parameters():
|
62 |
+
param.requires_grad = False
|
63 |
+
|
64 |
+
def get_output_dim(self):
|
65 |
+
return self._output_dim
|
66 |
+
|
67 |
+
def forward(self, image_rgb: torch.Tensor) -> torch.Tensor:
|
68 |
+
img_normed = self._resnet_normalize_image(image_rgb)
|
69 |
+
|
70 |
+
features = self._compute_multiscale_features(img_normed)
|
71 |
+
|
72 |
+
return features
|
73 |
+
|
74 |
+
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
75 |
+
return (img - self._resnet_mean) / self._resnet_std
|
76 |
+
|
77 |
+
def _compute_multiscale_features(
|
78 |
+
self, img_normed: torch.Tensor
|
79 |
+
) -> torch.Tensor:
|
80 |
+
multiscale_features = None
|
81 |
+
|
82 |
+
if len(self.scale_factors) <= 0:
|
83 |
+
raise ValueError(
|
84 |
+
f"Wrong format of self.scale_factors: {self.scale_factors}"
|
85 |
+
)
|
86 |
+
|
87 |
+
for scale_factor in self.scale_factors:
|
88 |
+
if scale_factor == 1:
|
89 |
+
inp = img_normed
|
90 |
+
else:
|
91 |
+
inp = self._resize_image(img_normed, scale_factor)
|
92 |
+
|
93 |
+
if multiscale_features is None:
|
94 |
+
multiscale_features = self._net(inp)
|
95 |
+
else:
|
96 |
+
multiscale_features += self._net(inp)
|
97 |
+
|
98 |
+
averaged_features = multiscale_features / len(self.scale_factors)
|
99 |
+
return averaged_features
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def _resize_image(image: torch.Tensor, scale_factor: float) -> torch.Tensor:
|
103 |
+
return nn.functional.interpolate(
|
104 |
+
image,
|
105 |
+
scale_factor=scale_factor,
|
106 |
+
mode="bilinear",
|
107 |
+
align_corners=False,
|
108 |
+
)
|
models/pose_diffusion_model.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Standard library imports
|
8 |
+
import base64
|
9 |
+
import io
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
import pickle
|
13 |
+
import warnings
|
14 |
+
from collections import defaultdict
|
15 |
+
from dataclasses import field, dataclass
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
# Third-party library imports
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
from PIL import Image
|
23 |
+
|
24 |
+
from pytorch3d.renderer.cameras import CamerasBase
|
25 |
+
from pytorch3d.transforms import (
|
26 |
+
se3_exp_map,
|
27 |
+
se3_log_map,
|
28 |
+
Transform3d,
|
29 |
+
so3_relative_angle,
|
30 |
+
)
|
31 |
+
from util.camera_transform import pose_encoding_to_camera
|
32 |
+
|
33 |
+
import models
|
34 |
+
from hydra.utils import instantiate
|
35 |
+
from pytorch3d.renderer.cameras import PerspectiveCameras
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
|
41 |
+
class PoseDiffusionModel(nn.Module):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
pose_encoding_type: str,
|
45 |
+
IMAGE_FEATURE_EXTRACTOR: Dict,
|
46 |
+
DIFFUSER: Dict,
|
47 |
+
DENOISER: Dict,
|
48 |
+
):
|
49 |
+
"""Initializes a PoseDiffusion model.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
pose_encoding_type (str):
|
53 |
+
Defines the encoding type for extrinsics and intrinsics
|
54 |
+
Currently, only `"absT_quaR_logFL"` is supported -
|
55 |
+
a concatenation of the translation vector,
|
56 |
+
rotation quaternion, and logarithm of focal length.
|
57 |
+
image_feature_extractor_cfg (Dict):
|
58 |
+
Configuration for the image feature extractor.
|
59 |
+
diffuser_cfg (Dict):
|
60 |
+
Configuration for the diffuser.
|
61 |
+
denoiser_cfg (Dict):
|
62 |
+
Configuration for the denoiser.
|
63 |
+
"""
|
64 |
+
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
self.pose_encoding_type = pose_encoding_type
|
68 |
+
|
69 |
+
self.image_feature_extractor = instantiate(
|
70 |
+
IMAGE_FEATURE_EXTRACTOR, _recursive_=False
|
71 |
+
)
|
72 |
+
self.diffuser = instantiate(DIFFUSER, _recursive_=False)
|
73 |
+
|
74 |
+
denoiser = instantiate(DENOISER, _recursive_=False)
|
75 |
+
self.diffuser.model = denoiser
|
76 |
+
|
77 |
+
self.target_dim = denoiser.target_dim
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
image: torch.Tensor,
|
82 |
+
gt_cameras: Optional[CamerasBase] = None,
|
83 |
+
sequence_name: Optional[List[str]] = None,
|
84 |
+
cond_fn=None,
|
85 |
+
cond_start_step=0,
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
Forward pass of the PoseDiffusionModel.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
image (torch.Tensor):
|
92 |
+
Input image tensor, Bx3xHxW.
|
93 |
+
gt_cameras (Optional[CamerasBase], optional):
|
94 |
+
Camera object. Defaults to None.
|
95 |
+
sequence_name (Optional[List[str]], optional):
|
96 |
+
List of sequence names. Defaults to None.
|
97 |
+
cond_fn ([type], optional):
|
98 |
+
Conditional function. Wrapper for GGS or other functions.
|
99 |
+
cond_start_step (int, optional):
|
100 |
+
The sampling step to start using conditional function.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
PerspectiveCameras: PyTorch3D camera object.
|
104 |
+
"""
|
105 |
+
|
106 |
+
z = self.image_feature_extractor(image)
|
107 |
+
|
108 |
+
z = z.unsqueeze(0)
|
109 |
+
|
110 |
+
B, N, _ = z.shape
|
111 |
+
target_shape = [B, N, self.target_dim]
|
112 |
+
|
113 |
+
# sampling
|
114 |
+
pose_encoding, pose_encoding_diffusion_samples = self.diffuser.sample(
|
115 |
+
shape=target_shape,
|
116 |
+
z=z,
|
117 |
+
cond_fn=cond_fn,
|
118 |
+
cond_start_step=cond_start_step,
|
119 |
+
)
|
120 |
+
|
121 |
+
# convert the encoded representation to PyTorch3D cameras
|
122 |
+
pred_cameras = pose_encoding_to_camera(
|
123 |
+
pose_encoding, pose_encoding_type=self.pose_encoding_type
|
124 |
+
)
|
125 |
+
|
126 |
+
return pred_cameras
|
packages.txt
ADDED
File without changes
|
pre-requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch==1.13.0
|
2 |
+
torchvision==0.14.0
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra-core
|
2 |
+
omegaconf
|
3 |
+
opencv-python
|
4 |
+
einops
|
5 |
+
git+https://github.com/facebookresearch/PoseDiffusion.git
|
util/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
util/camera_transform.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from pytorch3d.transforms.rotation_conversions import (
|
9 |
+
matrix_to_quaternion,
|
10 |
+
quaternion_to_matrix,
|
11 |
+
)
|
12 |
+
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
13 |
+
|
14 |
+
|
15 |
+
def pose_encoding_to_camera(
|
16 |
+
pose_encoding,
|
17 |
+
pose_encoding_type="absT_quaR_logFL",
|
18 |
+
log_focal_length_bias=1.8,
|
19 |
+
min_focal_length=0.1,
|
20 |
+
max_focal_length=20,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
pose_encoding: A tensor of shape `BxNxC`, containing a batch of
|
25 |
+
`BxN` `C`-dimensional pose encodings.
|
26 |
+
pose_encoding_type: The type of pose encoding,
|
27 |
+
only "absT_quaR_logFL" is supported.
|
28 |
+
"""
|
29 |
+
|
30 |
+
batch_size, num_poses, _ = pose_encoding.shape
|
31 |
+
pose_encoding_reshaped = pose_encoding.reshape(
|
32 |
+
-1, pose_encoding.shape[-1]
|
33 |
+
) # Reshape to BNxC
|
34 |
+
|
35 |
+
if pose_encoding_type == "absT_quaR_logFL":
|
36 |
+
# forced that 3 for absT, 4 for quaR, 2 logFL
|
37 |
+
# TODO: converted to 1 dim for logFL, consistent with our paper
|
38 |
+
abs_T = pose_encoding_reshaped[:, :3]
|
39 |
+
quaternion_R = pose_encoding_reshaped[:, 3:7]
|
40 |
+
R = quaternion_to_matrix(quaternion_R)
|
41 |
+
|
42 |
+
log_focal_length = pose_encoding_reshaped[:, 7:9]
|
43 |
+
|
44 |
+
# log_focal_length_bias was the hyperparameter
|
45 |
+
# to ensure the mean of logFL close to 0 during training
|
46 |
+
# Now converted back
|
47 |
+
focal_length = (log_focal_length + log_focal_length_bias).exp()
|
48 |
+
|
49 |
+
# clamp to avoid weird fl values
|
50 |
+
focal_length = torch.clamp(
|
51 |
+
focal_length, min=min_focal_length, max=max_focal_length
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
|
55 |
+
|
56 |
+
pred_cameras = PerspectiveCameras(
|
57 |
+
focal_length=focal_length,
|
58 |
+
R=R,
|
59 |
+
T=abs_T,
|
60 |
+
device=R.device,
|
61 |
+
)
|
62 |
+
|
63 |
+
return pred_cameras
|
util/embedding.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import math
|
10 |
+
from pytorch3d.renderer import HarmonicEmbedding
|
11 |
+
|
12 |
+
|
13 |
+
class TimeStepEmbedding(nn.Module):
|
14 |
+
# learned from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py
|
15 |
+
def __init__(self, dim=256, max_period=10000):
|
16 |
+
super().__init__()
|
17 |
+
self.dim = dim
|
18 |
+
self.max_period = max_period
|
19 |
+
|
20 |
+
self.linear = nn.Sequential(
|
21 |
+
nn.Linear(dim, dim // 2),
|
22 |
+
nn.SiLU(),
|
23 |
+
nn.Linear(dim // 2, dim // 2),
|
24 |
+
)
|
25 |
+
|
26 |
+
self.out_dim = dim // 2
|
27 |
+
|
28 |
+
def _compute_freqs(self, half):
|
29 |
+
freqs = torch.exp(
|
30 |
+
-math.log(self.max_period)
|
31 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
32 |
+
/ half
|
33 |
+
)
|
34 |
+
return freqs
|
35 |
+
|
36 |
+
def forward(self, timesteps):
|
37 |
+
half = self.dim // 2
|
38 |
+
freqs = self._compute_freqs(half).to(device=timesteps.device)
|
39 |
+
args = timesteps[:, None].float() * freqs[None]
|
40 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
41 |
+
if self.dim % 2:
|
42 |
+
embedding = torch.cat(
|
43 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
44 |
+
)
|
45 |
+
|
46 |
+
output = self.linear(embedding)
|
47 |
+
return output
|
48 |
+
|
49 |
+
|
50 |
+
class PoseEmbedding(nn.Module):
|
51 |
+
def __init__(self, target_dim, n_harmonic_functions=10, append_input=True):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self._emb_pose = HarmonicEmbedding(
|
55 |
+
n_harmonic_functions=n_harmonic_functions, append_input=append_input
|
56 |
+
)
|
57 |
+
|
58 |
+
self.out_dim = self._emb_pose.get_output_dim(target_dim)
|
59 |
+
|
60 |
+
def forward(self, pose_encoding):
|
61 |
+
e_pose_encoding = self._emb_pose(pose_encoding)
|
62 |
+
return e_pose_encoding
|
util/geometry_guided_sampling.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from typing import Dict, List, Optional, Union
|
9 |
+
from util.camera_transform import pose_encoding_to_camera
|
10 |
+
from util.get_fundamental_matrix import get_fundamental_matrices
|
11 |
+
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
12 |
+
|
13 |
+
|
14 |
+
def geometry_guided_sampling(
|
15 |
+
model_mean: torch.Tensor,
|
16 |
+
t: int,
|
17 |
+
matches_dict: Dict,
|
18 |
+
GGS_cfg: Dict,
|
19 |
+
):
|
20 |
+
# pre-process matches
|
21 |
+
b, c, h, w = matches_dict["img_shape"]
|
22 |
+
device = model_mean.device
|
23 |
+
|
24 |
+
def _to_device(tensor):
|
25 |
+
return torch.from_numpy(tensor).to(device)
|
26 |
+
|
27 |
+
kp1 = _to_device(matches_dict["kp1"])
|
28 |
+
kp2 = _to_device(matches_dict["kp2"])
|
29 |
+
i12 = _to_device(matches_dict["i12"])
|
30 |
+
|
31 |
+
pair_idx = i12[:, 0] * b + i12[:, 1]
|
32 |
+
pair_idx = pair_idx.long()
|
33 |
+
|
34 |
+
def _to_homogeneous(tensor):
|
35 |
+
return torch.nn.functional.pad(tensor, [0, 1], value=1)
|
36 |
+
|
37 |
+
kp1_homo = _to_homogeneous(kp1)
|
38 |
+
kp2_homo = _to_homogeneous(kp2)
|
39 |
+
|
40 |
+
i1, i2 = [
|
41 |
+
i.reshape(-1) for i in torch.meshgrid(torch.arange(b), torch.arange(b))
|
42 |
+
]
|
43 |
+
|
44 |
+
processed_matches = {
|
45 |
+
"kp1_homo": kp1_homo,
|
46 |
+
"kp2_homo": kp2_homo,
|
47 |
+
"i1": i1,
|
48 |
+
"i2": i2,
|
49 |
+
"h": h,
|
50 |
+
"w": w,
|
51 |
+
"pair_idx": pair_idx,
|
52 |
+
}
|
53 |
+
|
54 |
+
# conduct GGS
|
55 |
+
model_mean = GGS_optimize(model_mean, t, processed_matches, **GGS_cfg)
|
56 |
+
|
57 |
+
# Optimize FL, R, and T separately
|
58 |
+
model_mean = GGS_optimize(
|
59 |
+
model_mean,
|
60 |
+
t,
|
61 |
+
processed_matches,
|
62 |
+
update_T=False,
|
63 |
+
update_R=False,
|
64 |
+
update_FL=True,
|
65 |
+
**GGS_cfg,
|
66 |
+
) # only optimize FL
|
67 |
+
|
68 |
+
model_mean = GGS_optimize(
|
69 |
+
model_mean,
|
70 |
+
t,
|
71 |
+
processed_matches,
|
72 |
+
update_T=False,
|
73 |
+
update_R=True,
|
74 |
+
update_FL=False,
|
75 |
+
**GGS_cfg,
|
76 |
+
) # only optimize R
|
77 |
+
|
78 |
+
model_mean = GGS_optimize(
|
79 |
+
model_mean,
|
80 |
+
t,
|
81 |
+
processed_matches,
|
82 |
+
update_T=True,
|
83 |
+
update_R=False,
|
84 |
+
update_FL=False,
|
85 |
+
**GGS_cfg,
|
86 |
+
) # only optimize T
|
87 |
+
|
88 |
+
model_mean = GGS_optimize(model_mean, t, processed_matches, **GGS_cfg)
|
89 |
+
return model_mean
|
90 |
+
|
91 |
+
|
92 |
+
def GGS_optimize(
|
93 |
+
model_mean: torch.Tensor,
|
94 |
+
t: int,
|
95 |
+
processed_matches: Dict,
|
96 |
+
update_R: bool = True,
|
97 |
+
update_T: bool = True,
|
98 |
+
update_FL: bool = True,
|
99 |
+
# the args below come from **GGS_cfg
|
100 |
+
alpha: float = 0.0001,
|
101 |
+
learning_rate: float = 1e-2,
|
102 |
+
iter_num: int = 100,
|
103 |
+
sampson_max: int = 10,
|
104 |
+
min_matches: int = 10,
|
105 |
+
pose_encoding_type: str = "absT_quaR_logFL",
|
106 |
+
**kwargs,
|
107 |
+
):
|
108 |
+
with torch.enable_grad():
|
109 |
+
model_mean.requires_grad_(True)
|
110 |
+
|
111 |
+
if update_R and update_T and update_FL:
|
112 |
+
iter_num = iter_num * 2
|
113 |
+
|
114 |
+
optimizer = torch.optim.SGD(
|
115 |
+
[model_mean], lr=learning_rate, momentum=0.9
|
116 |
+
)
|
117 |
+
batch_size = model_mean.shape[1]
|
118 |
+
|
119 |
+
for _ in range(iter_num):
|
120 |
+
valid_sampson, sampson_to_print = compute_sampson_distance(
|
121 |
+
model_mean,
|
122 |
+
t,
|
123 |
+
processed_matches,
|
124 |
+
update_R=update_R,
|
125 |
+
update_T=update_T,
|
126 |
+
update_FL=update_FL,
|
127 |
+
pose_encoding_type=pose_encoding_type,
|
128 |
+
sampson_max=sampson_max,
|
129 |
+
)
|
130 |
+
|
131 |
+
if min_matches > 0:
|
132 |
+
valid_match_per_frame = len(valid_sampson) / batch_size
|
133 |
+
if valid_match_per_frame < min_matches:
|
134 |
+
print(
|
135 |
+
"Drop this pair because of insufficient valid matches"
|
136 |
+
)
|
137 |
+
break
|
138 |
+
|
139 |
+
loss = valid_sampson.mean()
|
140 |
+
optimizer.zero_grad()
|
141 |
+
loss.backward()
|
142 |
+
|
143 |
+
grads = model_mean.grad
|
144 |
+
grad_norm = grads.norm()
|
145 |
+
grad_mask = (grads.abs() > 0).detach()
|
146 |
+
model_mean_norm = (model_mean * grad_mask).norm()
|
147 |
+
|
148 |
+
max_norm = alpha * model_mean_norm / learning_rate
|
149 |
+
|
150 |
+
total_norm = torch.nn.utils.clip_grad_norm_(model_mean, max_norm)
|
151 |
+
optimizer.step()
|
152 |
+
|
153 |
+
print(f"t={t:02d} | sampson={sampson_to_print:05f}")
|
154 |
+
model_mean = model_mean.detach()
|
155 |
+
return model_mean
|
156 |
+
|
157 |
+
|
158 |
+
def compute_sampson_distance(
|
159 |
+
model_mean: torch.Tensor,
|
160 |
+
t: int,
|
161 |
+
processed_matches: Dict,
|
162 |
+
update_R=True,
|
163 |
+
update_T=True,
|
164 |
+
update_FL=True,
|
165 |
+
pose_encoding_type: str = "absT_quaR_logFL",
|
166 |
+
sampson_max: int = 10,
|
167 |
+
):
|
168 |
+
camera = pose_encoding_to_camera(model_mean, pose_encoding_type)
|
169 |
+
|
170 |
+
# pick the mean of the predicted focal length
|
171 |
+
camera.focal_length = camera.focal_length.mean(dim=0).repeat(
|
172 |
+
len(camera.focal_length), 1
|
173 |
+
)
|
174 |
+
|
175 |
+
if not update_R:
|
176 |
+
camera.R = camera.R.detach()
|
177 |
+
|
178 |
+
if not update_T:
|
179 |
+
camera.T = camera.T.detach()
|
180 |
+
|
181 |
+
if not update_FL:
|
182 |
+
camera.focal_length = camera.focal_length.detach()
|
183 |
+
|
184 |
+
kp1_homo, kp2_homo, i1, i2, he, wi, pair_idx = processed_matches.values()
|
185 |
+
F_2_to_1 = get_fundamental_matrices(
|
186 |
+
camera, he, wi, i1, i2, l2_normalize_F=False
|
187 |
+
)
|
188 |
+
F = F_2_to_1.permute(0, 2, 1) # y1^T F y2 = 0
|
189 |
+
|
190 |
+
def _sampson_distance(F, kp1_homo, kp2_homo, pair_idx):
|
191 |
+
left = torch.bmm(kp1_homo[:, None], F[pair_idx])
|
192 |
+
right = torch.bmm(F[pair_idx], kp2_homo[..., None])
|
193 |
+
|
194 |
+
bottom = (
|
195 |
+
left[:, :, 0].square()
|
196 |
+
+ left[:, :, 1].square()
|
197 |
+
+ right[:, 0, :].square()
|
198 |
+
+ right[:, 1, :].square()
|
199 |
+
)
|
200 |
+
top = torch.bmm(left, kp2_homo[..., None]).square()
|
201 |
+
|
202 |
+
sampson = top[:, 0] / bottom
|
203 |
+
return sampson
|
204 |
+
|
205 |
+
sampson = _sampson_distance(
|
206 |
+
F,
|
207 |
+
kp1_homo.float(),
|
208 |
+
kp2_homo.float(),
|
209 |
+
pair_idx,
|
210 |
+
)
|
211 |
+
|
212 |
+
sampson_to_print = sampson.detach().clone().clamp(max=sampson_max).mean()
|
213 |
+
sampson = sampson[sampson < sampson_max]
|
214 |
+
|
215 |
+
return sampson, sampson_to_print
|
util/get_fundamental_matrix.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import pytorch3d
|
9 |
+
from pytorch3d.utils import opencv_from_cameras_projection
|
10 |
+
from pytorch3d.transforms.so3 import hat
|
11 |
+
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
12 |
+
|
13 |
+
|
14 |
+
def get_fundamental_matrices(
|
15 |
+
camera: CamerasBase,
|
16 |
+
height: int,
|
17 |
+
width: int,
|
18 |
+
index1: torch.LongTensor,
|
19 |
+
index2: torch.LongTensor,
|
20 |
+
l2_normalize_F=False,
|
21 |
+
):
|
22 |
+
"""Compute fundamental matrices for given camera parameters."""
|
23 |
+
batch_size = camera.R.shape[0]
|
24 |
+
|
25 |
+
# Convert to opencv / colmap / Hartley&Zisserman convention
|
26 |
+
image_size_t = (
|
27 |
+
torch.LongTensor([height, width])[None]
|
28 |
+
.repeat(batch_size, 1)
|
29 |
+
.to(camera.device)
|
30 |
+
)
|
31 |
+
R, t, K = opencv_from_cameras_projection(camera, image_size=image_size_t)
|
32 |
+
|
33 |
+
F, E = get_fundamental_matrix(
|
34 |
+
K[index1], R[index1], t[index1], K[index2], R[index2], t[index2]
|
35 |
+
)
|
36 |
+
|
37 |
+
if l2_normalize_F:
|
38 |
+
F_scale = torch.norm(F, dim=(1, 2))
|
39 |
+
F_scale = F_scale.clamp(min=0.0001)
|
40 |
+
F = F / F_scale[:, None, None]
|
41 |
+
|
42 |
+
return F
|
43 |
+
|
44 |
+
|
45 |
+
def get_fundamental_matrix(K1, R1, t1, K2, R2, t2):
|
46 |
+
E = get_essential_matrix(R1, t1, R2, t2)
|
47 |
+
F = K2.inverse().permute(0, 2, 1).matmul(E).matmul(K1.inverse())
|
48 |
+
return F, E # p2^T F p1 = 0
|
49 |
+
|
50 |
+
|
51 |
+
def get_essential_matrix(R1, t1, R2, t2):
|
52 |
+
R12 = R2.matmul(R1.permute(0, 2, 1))
|
53 |
+
t12 = t2 - R12.matmul(t1[..., None])[..., 0]
|
54 |
+
E_R = R12
|
55 |
+
E_t = -E_R.permute(0, 2, 1).matmul(t12[..., None])[..., 0]
|
56 |
+
E = E_R.matmul(hat(E_t))
|
57 |
+
return E
|
util/load_img_folder.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from typing import (
|
13 |
+
Any,
|
14 |
+
ClassVar,
|
15 |
+
Dict,
|
16 |
+
Iterable,
|
17 |
+
List,
|
18 |
+
Optional,
|
19 |
+
Sequence,
|
20 |
+
Tuple,
|
21 |
+
Type,
|
22 |
+
TYPE_CHECKING,
|
23 |
+
Union,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def load_and_preprocess_images(
|
28 |
+
folder_path: str, image_size: int = 224, mode: str = "bilinear"
|
29 |
+
) -> torch.Tensor:
|
30 |
+
image_paths = [
|
31 |
+
os.path.join(folder_path, file)
|
32 |
+
for file in os.listdir(folder_path)
|
33 |
+
if file.lower().endswith((".png", ".jpg", ".jpeg"))
|
34 |
+
]
|
35 |
+
image_paths.sort()
|
36 |
+
|
37 |
+
images = []
|
38 |
+
bboxes_xyxy = []
|
39 |
+
scales = []
|
40 |
+
for path in image_paths:
|
41 |
+
image = _load_image(path)
|
42 |
+
image, bbox_xyxy, min_hw = _center_crop_square(image)
|
43 |
+
minscale = image_size / min_hw
|
44 |
+
|
45 |
+
imre = F.interpolate(
|
46 |
+
torch.from_numpy(image)[None],
|
47 |
+
size=(image_size, image_size),
|
48 |
+
mode=mode,
|
49 |
+
align_corners=False if mode == "bilinear" else None,
|
50 |
+
)[0]
|
51 |
+
|
52 |
+
images.append(imre.numpy())
|
53 |
+
bboxes_xyxy.append(bbox_xyxy.numpy())
|
54 |
+
scales.append(minscale)
|
55 |
+
|
56 |
+
images_tensor = torch.from_numpy(np.stack(images))
|
57 |
+
|
58 |
+
# assume all the images have the same shape for GGS
|
59 |
+
image_info = {
|
60 |
+
"size": (min_hw, min_hw),
|
61 |
+
"bboxes_xyxy": np.stack(bboxes_xyxy),
|
62 |
+
"resized_scales": np.stack(scales),
|
63 |
+
}
|
64 |
+
return images_tensor, image_info
|
65 |
+
|
66 |
+
|
67 |
+
# helper functions
|
68 |
+
|
69 |
+
|
70 |
+
def _load_image(path) -> np.ndarray:
|
71 |
+
with Image.open(path) as pil_im:
|
72 |
+
im = np.array(pil_im.convert("RGB"))
|
73 |
+
im = im.transpose((2, 0, 1))
|
74 |
+
im = im.astype(np.float32) / 255.0
|
75 |
+
return im
|
76 |
+
|
77 |
+
|
78 |
+
def _center_crop_square(image: np.ndarray) -> np.ndarray:
|
79 |
+
h, w = image.shape[1:]
|
80 |
+
min_dim = min(h, w)
|
81 |
+
top = (h - min_dim) // 2
|
82 |
+
left = (w - min_dim) // 2
|
83 |
+
cropped_image = image[:, top : top + min_dim, left : left + min_dim]
|
84 |
+
|
85 |
+
# bbox_xywh: the cropped region
|
86 |
+
bbox_xywh = torch.tensor([left, top, min_dim, min_dim])
|
87 |
+
|
88 |
+
# the format from xywh to xyxy
|
89 |
+
bbox_xyxy = _clamp_box_to_image_bounds_and_round(
|
90 |
+
_get_clamp_bbox(
|
91 |
+
bbox_xywh,
|
92 |
+
box_crop_context=0.0,
|
93 |
+
),
|
94 |
+
image_size_hw=(h, w),
|
95 |
+
)
|
96 |
+
return cropped_image, bbox_xyxy, min_dim
|
97 |
+
|
98 |
+
|
99 |
+
def _get_clamp_bbox(
|
100 |
+
bbox: torch.Tensor,
|
101 |
+
box_crop_context: float = 0.0,
|
102 |
+
) -> torch.Tensor:
|
103 |
+
# box_crop_context: rate of expansion for bbox
|
104 |
+
# returns possibly expanded bbox xyxy as float
|
105 |
+
|
106 |
+
bbox = bbox.clone() # do not edit bbox in place
|
107 |
+
|
108 |
+
# increase box size
|
109 |
+
if box_crop_context > 0.0:
|
110 |
+
c = box_crop_context
|
111 |
+
bbox = bbox.float()
|
112 |
+
bbox[0] -= bbox[2] * c / 2
|
113 |
+
bbox[1] -= bbox[3] * c / 2
|
114 |
+
bbox[2] += bbox[2] * c
|
115 |
+
bbox[3] += bbox[3] * c
|
116 |
+
|
117 |
+
if (bbox[2:] <= 1.0).any():
|
118 |
+
raise ValueError(
|
119 |
+
f"squashed image!! The bounding box contains no pixels."
|
120 |
+
)
|
121 |
+
|
122 |
+
bbox[2:] = torch.clamp(
|
123 |
+
bbox[2:], 2
|
124 |
+
) # set min height, width to 2 along both axes
|
125 |
+
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
|
126 |
+
|
127 |
+
return bbox_xyxy
|
128 |
+
|
129 |
+
|
130 |
+
def _bbox_xywh_to_xyxy(
|
131 |
+
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
132 |
+
) -> torch.Tensor:
|
133 |
+
xyxy = xywh.clone()
|
134 |
+
if clamp_size is not None:
|
135 |
+
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
136 |
+
xyxy[2:] += xyxy[:2]
|
137 |
+
return xyxy
|
138 |
+
|
139 |
+
|
140 |
+
def _clamp_box_to_image_bounds_and_round(
|
141 |
+
bbox_xyxy: torch.Tensor,
|
142 |
+
image_size_hw: Tuple[int, int],
|
143 |
+
) -> torch.LongTensor:
|
144 |
+
bbox_xyxy = bbox_xyxy.clone()
|
145 |
+
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
|
146 |
+
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
|
147 |
+
if not isinstance(bbox_xyxy, torch.LongTensor):
|
148 |
+
bbox_xyxy = bbox_xyxy.round().long()
|
149 |
+
return bbox_xyxy # pyre-ignore [7]
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == "__main__":
|
153 |
+
# Example usage:
|
154 |
+
folder_path = "path/to/your/folder"
|
155 |
+
image_size = 224
|
156 |
+
images_tensor = load_and_preprocess_images(folder_path, image_size)
|
157 |
+
print(images_tensor.shape)
|
util/match_extraction.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
import tempfile
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import pycolmap
|
14 |
+
from typing import Optional, List, Dict, Any
|
15 |
+
from hloc import (
|
16 |
+
extract_features,
|
17 |
+
logger,
|
18 |
+
match_features,
|
19 |
+
pairs_from_exhaustive,
|
20 |
+
)
|
21 |
+
from hloc.triangulation import (
|
22 |
+
import_features,
|
23 |
+
import_matches,
|
24 |
+
estimation_and_geometric_verification,
|
25 |
+
parse_option_args,
|
26 |
+
OutputCapture,
|
27 |
+
)
|
28 |
+
from hloc.utils.database import (
|
29 |
+
COLMAPDatabase,
|
30 |
+
image_ids_to_pair_id,
|
31 |
+
pair_id_to_image_ids,
|
32 |
+
)
|
33 |
+
from hloc.reconstruction import create_empty_db, import_images, get_image_ids
|
34 |
+
|
35 |
+
|
36 |
+
def extract_match(image_folder_path: str, image_info: Dict):
|
37 |
+
# Now only supports SPSG
|
38 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
39 |
+
tmp_mapping = os.path.join(tmpdir, "mapping")
|
40 |
+
os.makedirs(tmp_mapping)
|
41 |
+
for filename in os.listdir(image_folder_path):
|
42 |
+
if filename.lower().endswith(
|
43 |
+
(".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff")
|
44 |
+
):
|
45 |
+
shutil.copy(
|
46 |
+
os.path.join(image_folder_path, filename),
|
47 |
+
os.path.join(tmp_mapping, filename),
|
48 |
+
)
|
49 |
+
matches, keypoints = run_hloc(tmpdir)
|
50 |
+
|
51 |
+
# From the format of colmap to PyTorch3D
|
52 |
+
kp1, kp2, i12 = colmap_keypoint_to_pytorch3d(matches, keypoints, image_info)
|
53 |
+
|
54 |
+
return kp1, kp2, i12
|
55 |
+
|
56 |
+
|
57 |
+
def colmap_keypoint_to_pytorch3d(matches, keypoints, image_info):
|
58 |
+
kp1, kp2, i12 = [], [], []
|
59 |
+
bbox_xyxy, scale = image_info["bboxes_xyxy"], image_info["resized_scales"]
|
60 |
+
|
61 |
+
for idx in keypoints:
|
62 |
+
# coordinate change from COLMAP to OpenCV
|
63 |
+
cur_keypoint = keypoints[idx] - 0.5
|
64 |
+
|
65 |
+
# go to the coordiante after cropping
|
66 |
+
# use idx - 1 here because the COLMAP format starts from 1 instead of 0
|
67 |
+
cur_keypoint = cur_keypoint - [
|
68 |
+
bbox_xyxy[idx - 1][0],
|
69 |
+
bbox_xyxy[idx - 1][1],
|
70 |
+
]
|
71 |
+
cur_keypoint = cur_keypoint * scale[idx - 1]
|
72 |
+
keypoints[idx] = cur_keypoint
|
73 |
+
|
74 |
+
for (r_idx, q_idx), pair_match in matches.items():
|
75 |
+
if pair_match is not None:
|
76 |
+
kp1.append(keypoints[r_idx][pair_match[:, 0]])
|
77 |
+
kp2.append(keypoints[q_idx][pair_match[:, 1]])
|
78 |
+
|
79 |
+
i12_pair = np.array([[r_idx - 1, q_idx - 1]])
|
80 |
+
i12.append(np.repeat(i12_pair, len(pair_match), axis=0))
|
81 |
+
|
82 |
+
if kp1:
|
83 |
+
kp1, kp2, i12 = map(np.concatenate, (kp1, kp2, i12), (0, 0, 0))
|
84 |
+
else:
|
85 |
+
kp1 = kp2 = i12 = None
|
86 |
+
|
87 |
+
return kp1, kp2, i12
|
88 |
+
|
89 |
+
|
90 |
+
def run_hloc(output_dir: str):
|
91 |
+
# learned from
|
92 |
+
# https://github.com/cvg/Hierarchical-Localization/blob/master/pipeline_SfM.ipynb
|
93 |
+
|
94 |
+
images = Path(output_dir)
|
95 |
+
outputs = Path(os.path.join(output_dir, "output"))
|
96 |
+
sfm_pairs = outputs / "pairs-sfm.txt"
|
97 |
+
sfm_dir = outputs / "sfm"
|
98 |
+
features = outputs / "features.h5"
|
99 |
+
matches = outputs / "matches.h5"
|
100 |
+
|
101 |
+
feature_conf = extract_features.confs[
|
102 |
+
"superpoint_inloc"
|
103 |
+
] # or superpoint_max
|
104 |
+
matcher_conf = match_features.confs["superpoint+lightglue"]
|
105 |
+
|
106 |
+
references = [
|
107 |
+
p.relative_to(images).as_posix()
|
108 |
+
for p in (images / "mapping/").iterdir()
|
109 |
+
]
|
110 |
+
|
111 |
+
extract_features.main(
|
112 |
+
feature_conf, images, image_list=references, feature_path=features
|
113 |
+
)
|
114 |
+
pairs_from_exhaustive.main(sfm_pairs, image_list=references)
|
115 |
+
match_features.main(
|
116 |
+
matcher_conf, sfm_pairs, features=features, matches=matches
|
117 |
+
)
|
118 |
+
|
119 |
+
matches, keypoints = compute_matches_and_keypoints(
|
120 |
+
sfm_dir, images, sfm_pairs, features, matches, image_list=references
|
121 |
+
)
|
122 |
+
|
123 |
+
return matches, keypoints
|
124 |
+
|
125 |
+
|
126 |
+
def compute_matches_and_keypoints(
|
127 |
+
sfm_dir: Path,
|
128 |
+
image_dir: Path,
|
129 |
+
pairs: Path,
|
130 |
+
features: Path,
|
131 |
+
matches: Path,
|
132 |
+
camera_mode: pycolmap.CameraMode = pycolmap.CameraMode.AUTO,
|
133 |
+
verbose: bool = False,
|
134 |
+
min_match_score: Optional[float] = None,
|
135 |
+
image_list: Optional[List[str]] = None,
|
136 |
+
image_options: Optional[Dict[str, Any]] = None,
|
137 |
+
) -> pycolmap.Reconstruction:
|
138 |
+
# learned from
|
139 |
+
# https://github.com/cvg/Hierarchical-Localization/blob/master/hloc/reconstruction.py
|
140 |
+
|
141 |
+
sfm_dir.mkdir(parents=True, exist_ok=True)
|
142 |
+
database = sfm_dir / "database.db"
|
143 |
+
|
144 |
+
create_empty_db(database)
|
145 |
+
import_images(image_dir, database, camera_mode, image_list, image_options)
|
146 |
+
image_ids = get_image_ids(database)
|
147 |
+
import_features(image_ids, database, features)
|
148 |
+
import_matches(image_ids, database, pairs, matches, min_match_score)
|
149 |
+
estimation_and_geometric_verification(database, pairs, verbose)
|
150 |
+
|
151 |
+
db = COLMAPDatabase.connect(database)
|
152 |
+
|
153 |
+
matches = dict(
|
154 |
+
(
|
155 |
+
pair_id_to_image_ids(pair_id),
|
156 |
+
_blob_to_array_safe(data, np.uint32, (-1, 2)),
|
157 |
+
)
|
158 |
+
for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
|
159 |
+
)
|
160 |
+
|
161 |
+
keypoints = dict(
|
162 |
+
(image_id, _blob_to_array_safe(data, np.float32, (-1, 2)))
|
163 |
+
for image_id, data in db.execute("SELECT image_id, data FROM keypoints")
|
164 |
+
)
|
165 |
+
|
166 |
+
db.close()
|
167 |
+
|
168 |
+
return matches, keypoints
|
169 |
+
|
170 |
+
|
171 |
+
def _blob_to_array_safe(blob, dtype, shape=(-1,)):
|
172 |
+
if blob is not None:
|
173 |
+
return np.fromstring(blob, dtype=dtype).reshape(*shape)
|
174 |
+
else:
|
175 |
+
return blob
|
util/metric.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def compute_ARE(rotation1, rotation2):
|
13 |
+
if isinstance(rotation1, torch.Tensor):
|
14 |
+
rotation1 = rotation1.cpu().detach().numpy()
|
15 |
+
if isinstance(rotation2, torch.Tensor):
|
16 |
+
rotation2 = rotation2.cpu().detach().numpy()
|
17 |
+
|
18 |
+
R_rel = np.einsum("Bij,Bjk ->Bik", rotation1.transpose(0, 2, 1), rotation2)
|
19 |
+
t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2
|
20 |
+
theta = np.arccos(np.clip(t, -1, 1))
|
21 |
+
error = theta * 180 / np.pi
|
22 |
+
return np.minimum(error, np.abs(180 - error))
|
util/utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import random
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import tempfile
|
12 |
+
|
13 |
+
|
14 |
+
def seed_all_random_engines(seed: int) -> None:
|
15 |
+
np.random.seed(seed)
|
16 |
+
torch.manual_seed(seed)
|
17 |
+
random.seed(seed)
|
weights/co3d_model_Apr16.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7084b19cddce8dcc8f9197a8bbcf330fd0edf1c0c97b628c35180d8a18edbeb
|
3 |
+
size 155952931
|