hugoycj commited on
Commit
3d3e4e9
1 Parent(s): 7bf852f

Initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from datetime import datetime
8
+ import os
9
+ import time
10
+ import torch
11
+ from typing import Dict, List, Optional, Union
12
+ from omegaconf import OmegaConf, DictConfig
13
+ import hydra
14
+ from hydra.utils import instantiate, get_original_cwd
15
+ import time
16
+ from functools import partial
17
+ import matplotlib.pyplot as plt
18
+ import shutil
19
+ from util.utils import seed_all_random_engines
20
+ from util.match_extraction import extract_match
21
+ from util.load_img_folder import load_and_preprocess_images
22
+ from util.geometry_guided_sampling import geometry_guided_sampling
23
+ from pytorch3d.vis.plotly_vis import get_camera_wireframe
24
+ import subprocess
25
+ import tempfile
26
+ import gradio as gr
27
+
28
+ def plot_cameras(ax, cameras, color: str = "blue"):
29
+ """
30
+ Plots a set of `cameras` objects into the maplotlib axis `ax` with
31
+ color `color`.
32
+ """
33
+ cam_wires_canonical = get_camera_wireframe().cuda()[None]
34
+ cam_trans = cameras.get_world_to_view_transform().inverse()
35
+ cam_wires_trans = cam_trans.transform_points(cam_wires_canonical)
36
+ plot_handles = []
37
+ for wire in cam_wires_trans:
38
+ # the Z and Y axes are flipped intentionally here!
39
+ x_, z_, y_ = wire.detach().cpu().numpy().T.astype(float)
40
+ (h,) = ax.plot(x_, y_, z_, color=color, linewidth=0.3)
41
+ plot_handles.append(h)
42
+ return plot_handles
43
+
44
+ def create_matplotlib_figure(pred_cameras):
45
+ fig = plt.figure()
46
+ ax = fig.add_subplot(projection="3d")
47
+ ax.clear()
48
+ handle_cam = plot_cameras(ax, pred_cameras, color="#FF7D1E")
49
+ plot_radius = 3
50
+ ax.set_xlim3d([-plot_radius, plot_radius])
51
+ ax.set_ylim3d([3 - plot_radius, 3 + plot_radius])
52
+ ax.set_zlim3d([-plot_radius, plot_radius])
53
+ ax.set_xlabel("x")
54
+ ax.set_ylabel("z")
55
+ ax.set_zlabel("y")
56
+ labels_handles = {
57
+ "Estimated cameras": handle_cam[0],
58
+ }
59
+ ax.legend(
60
+ labels_handles.values(),
61
+ labels_handles.keys(),
62
+ loc="upper center",
63
+ bbox_to_anchor=(0.5, 0),
64
+ )
65
+
66
+ return plt
67
+
68
+ import os
69
+ import json
70
+ import tempfile
71
+ from PIL import Image
72
+
73
+
74
+ def convert_extrinsics_pytorch3d_to_opengl(extrinsics: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ Convert extrinsics from PyTorch3D coordinate system to OpenGL coordinate system.
77
+
78
+ Args:
79
+ extrinsics (torch.Tensor): a 4x4 extrinsic matrix in PyTorch3D coordinate system.
80
+
81
+ Returns:
82
+ torch.Tensor: a 4x4 extrinsic matrix in OpenGL coordinate system.
83
+ """
84
+ # Create a transformation matrix that flips the Z-axis
85
+ flip_z = torch.eye(4)
86
+ flip_z[2, 2] = -1
87
+ flip_z[0, 0] = -1
88
+
89
+ # Multiply the extrinsic matrix by the transformation matrix
90
+ extrinsics_opengl = torch.mm(extrinsics, flip_z)
91
+
92
+ return extrinsics_opengl
93
+
94
+ import json
95
+ from typing import List, Dict, Any
96
+
97
+ def create_camera_json(extrinsics: Any, focal_length_world: float, principle_points: List[float], image_size: int) -> str:
98
+ # Initialize the dictionary
99
+ camera_dict = {
100
+ "w": image_size,
101
+ "h": image_size,
102
+ "fl_x": float(focal_length_world[0]),
103
+ "fl_y": float(focal_length_world[1]),
104
+ "cx": float(principle_points[0]),
105
+ "cy": float(principle_points[1]),
106
+ "k1": 0.0, # Assuming these values are not provided
107
+ "k2": 0.0, # Assuming these values are not provided
108
+ "p1": 0.0, # Assuming these values are not provided
109
+ "p2": 0.0, # Assuming these values are not provided
110
+ "camera_model": "OPENCV",
111
+ "frames": []
112
+ }
113
+
114
+ # Add frames to the dictionary
115
+ for i, extrinsic in enumerate(extrinsics):
116
+ frame = {
117
+ "file_path": f"images/frame_{str(i).zfill(5)}.jpg",
118
+ "transform_matrix": extrinsic.tolist(),
119
+ "colmap_im_id": i
120
+ }
121
+ # Convert numpy float32 to Python's native float
122
+ frame["transform_matrix"] = [[float(element) for element in row] for row in frame["transform_matrix"]]
123
+ camera_dict["frames"].append(frame)
124
+
125
+ return camera_dict
126
+
127
+ def archieve_images_and_transforms(images, pred_cameras, image_size):
128
+ images_array = images.permute(0, 2, 3, 1).cpu().numpy() * 255
129
+ images_pil = [Image.fromarray(image.astype('uint8')) for image in images_array]
130
+
131
+ with tempfile.TemporaryDirectory() as temp_dir:
132
+ images_dir = os.path.join(temp_dir, 'images')
133
+ os.makedirs(images_dir, exist_ok=True)
134
+
135
+ images_path = []
136
+ for i, image in enumerate(images_pil):
137
+ image_path = os.path.join(images_dir, 'frame_{:05d}.jpg'.format(i))
138
+ image.save(image_path)
139
+ images_path.append(image_path)
140
+
141
+ cam_trans = pred_cameras.get_world_to_view_transform()
142
+ extrinsics = cam_trans.inverse().get_matrix().cpu()
143
+ extrinsics = [convert_extrinsics_pytorch3d_to_opengl(extrinsic.T) for extrinsic in extrinsics]
144
+
145
+ focal_length_ndc = pred_cameras.focal_length.mean(dim=0).cpu().numpy()
146
+ focal_length_world = focal_length_ndc * image_size / 2
147
+ principle_points = [image_size / 2, image_size / 2]
148
+ camera_dict = create_camera_json(extrinsics, focal_length_world, principle_points, image_size)
149
+
150
+ json_path = os.path.join(temp_dir, 'transforms.json')
151
+ with open(json_path, 'w') as f:
152
+ json.dump(camera_dict, f, indent=4)
153
+
154
+ project_name = datetime.now().strftime("%Y%m%d-%H%M%S")
155
+ shutil.make_archive(f'/tmp/{project_name}', 'zip', temp_dir)
156
+ return f'/tmp/{project_name}.zip'
157
+
158
+ def estimate_images_pose(image_folder, mode) -> None:
159
+ print("Slected mode:", mode)
160
+ with hydra.initialize(config_path="./cfgs/"):
161
+ cfg = hydra.compose(config_name=mode)
162
+
163
+ OmegaConf.set_struct(cfg, False)
164
+ print("Model Config:")
165
+ print(OmegaConf.to_yaml(cfg))
166
+
167
+ # Check for GPU availability and set the device
168
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169
+
170
+ # Instantiate the model
171
+ model = instantiate(cfg.MODEL, _recursive_=False)
172
+
173
+ # Load and preprocess images
174
+ images, image_info = load_and_preprocess_images(image_folder, cfg.image_size)
175
+
176
+ # Load checkpoint
177
+ ckpt_path = os.path.join(cfg.ckpt)
178
+ if os.path.isfile(ckpt_path):
179
+ checkpoint = torch.load(ckpt_path, map_location=device)
180
+ model.load_state_dict(checkpoint, strict=True)
181
+ print(f"Loaded checkpoint from: {ckpt_path}")
182
+ else:
183
+ raise ValueError(f"No checkpoint found at: {ckpt_path}")
184
+
185
+ # Move model and images to the GPU
186
+ model = model.to(device)
187
+ images = images.to(device)
188
+
189
+ # Evaluation Mode
190
+ model.eval()
191
+
192
+ # Seed random engines
193
+ seed_all_random_engines(cfg.seed)
194
+
195
+ # Start the timer
196
+ start_time = time.time()
197
+
198
+ # Perform match extraction
199
+ if cfg.GGS.enable:
200
+ # Optional TODO: remove the keypoints outside the cropped region?
201
+
202
+ kp1, kp2, i12 = extract_match(image_folder, image_info)
203
+
204
+ keys = ["kp1", "kp2", "i12", "img_shape"]
205
+ values = [kp1, kp2, i12, images.shape]
206
+ matches_dict = dict(zip(keys, values))
207
+
208
+ cfg.GGS.pose_encoding_type = cfg.MODEL.pose_encoding_type
209
+ GGS_cfg = OmegaConf.to_container(cfg.GGS)
210
+
211
+ cond_fn = partial(
212
+ geometry_guided_sampling, matches_dict=matches_dict, GGS_cfg=GGS_cfg
213
+ )
214
+ print("=====> Sampling with GGS <=====")
215
+ else:
216
+ cond_fn = None
217
+ print("=====> Sampling without GGS <=====")
218
+
219
+ # Forward
220
+ with torch.no_grad():
221
+ # Obtain predicted camera parameters
222
+ # pred_cameras is a PerspectiveCameras object with attributes
223
+ # pred_cameras.R, pred_cameras.T, pred_cameras.focal_length
224
+
225
+ # The poses and focal length are defined as
226
+ # NDC coordinate system in
227
+ # https://github.com/facebookresearch/pytorch3d/blob/main/docs/notes/cameras.md
228
+ pred_cameras = model(
229
+ image=images, cond_fn=cond_fn, cond_start_step=cfg.GGS.start_step
230
+ )
231
+
232
+ # Stop the timer and calculate elapsed time
233
+ end_time = time.time()
234
+ elapsed_time = end_time - start_time
235
+ print("Time taken: {:.4f} seconds".format(elapsed_time))
236
+
237
+ zip_path = archieve_images_and_transforms(images, pred_cameras, cfg.image_size)
238
+ return create_matplotlib_figure(pred_cameras), zip_path
239
+
240
+ def extract_frames_from_video(video_path: str) -> str:
241
+ """
242
+ Extracts frames from a video file and saves them in a temporary directory.
243
+ Returns the path to the directory containing the frames.
244
+ """
245
+ temp_dir = tempfile.mkdtemp()
246
+ output_path = os.path.join(temp_dir, "%03d.jpg")
247
+ command = [
248
+ "ffmpeg",
249
+ "-i", video_path,
250
+ "-vf", "fps=1",
251
+ output_path
252
+ ]
253
+ subprocess.run(command, check=True)
254
+ return temp_dir
255
+
256
+ def estimate_video_pose(video_path: str, mode: str) -> plt.Figure:
257
+ """
258
+ Estimates the pose of objects in a video.
259
+ """
260
+ # Extract frames from the video
261
+ image_folder = extract_frames_from_video(video_path)
262
+ # Estimate the pose for each frame
263
+ fig = estimate_images_pose(image_folder, mode)
264
+ return fig
265
+
266
+ if __name__ == "__main__":
267
+ examples = [["examples/" + img, 'fast'] for img in os.listdir("examples/")]
268
+ # Create a Gradio interface
269
+ iface = gr.Interface(
270
+ fn=estimate_video_pose,
271
+ inputs=[gr.inputs.Video(label='video', type='mp4'),
272
+ gr.inputs.Radio(choices=['fast', 'precise'], default='fast',
273
+ label='Estimation Model, fast is quick, usually within 1 seconds; precise has higher accuracy, but usually take several minutes')],
274
+ outputs=['plot', 'file'],
275
+ title="PoseDiffusion Demo: Solving Pose Estimation via Diffusion-aided Bundle Adjustment",
276
+ description="Upload a video for object pose estimation. The object should be centrally located within the frame.",
277
+ examples=examples,
278
+ cache_examples=True
279
+ )
280
+ iface.launch()
cfgs/default.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_folder: samples/apple
2
+ image_size: 224
3
+ ckpt: weights/co3d_model_Apr16.pth
4
+ seed: 0
5
+
6
+ GGS:
7
+ enable: True
8
+ start_step: 10
9
+ learning_rate: 0.01
10
+ iter_num: 100
11
+ sampson_max: 10
12
+ min_matches: 10
13
+ alpha: 0.0001
14
+
15
+
16
+ MODEL:
17
+ _target_: models.PoseDiffusionModel
18
+
19
+ pose_encoding_type: absT_quaR_logFL
20
+
21
+ IMAGE_FEATURE_EXTRACTOR:
22
+ _target_: models.MultiScaleImageFeatureExtractor
23
+ freeze: False
24
+
25
+ DENOISER:
26
+ _target_: models.Denoiser
27
+ TRANSFORMER:
28
+ _target_: models.TransformerEncoderWrapper
29
+ d_model: 512
30
+ nhead: 4
31
+ dim_feedforward: 1024
32
+ num_encoder_layers: 8
33
+ dropout: 0.1
34
+ batch_first: True
35
+ norm_first: True
36
+
37
+
38
+ DIFFUSER:
39
+ _target_: models.GaussianDiffusion
40
+ beta_schedule: custom
cfgs/fast.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_folder: samples/apple
2
+ image_size: 224
3
+ ckpt: weights/co3d_model_Apr16.pth
4
+ seed: 0
5
+
6
+ GGS:
7
+ enable: False
8
+ start_step: 10
9
+ learning_rate: 0.01
10
+ iter_num: 100
11
+ sampson_max: 10
12
+ min_matches: 10
13
+ alpha: 0.0001
14
+
15
+
16
+ MODEL:
17
+ _target_: models.PoseDiffusionModel
18
+
19
+ pose_encoding_type: absT_quaR_logFL
20
+
21
+ IMAGE_FEATURE_EXTRACTOR:
22
+ _target_: models.MultiScaleImageFeatureExtractor
23
+ freeze: False
24
+
25
+ DENOISER:
26
+ _target_: models.Denoiser
27
+ TRANSFORMER:
28
+ _target_: models.TransformerEncoderWrapper
29
+ d_model: 512
30
+ nhead: 4
31
+ dim_feedforward: 1024
32
+ num_encoder_layers: 8
33
+ dropout: 0.1
34
+ batch_first: True
35
+ norm_first: True
36
+
37
+
38
+ DIFFUSER:
39
+ _target_: models.GaussianDiffusion
40
+ beta_schedule: custom
examples/71165193657__AED15223-1435-44B6-AFC1-884527CE1642.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:691ae5feb1286b531ad974a7e3a5859bb6de7e26b3e4f21eb208afc4af8038e7
3
+ size 512683
models/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .pose_diffusion_model import PoseDiffusionModel
8
+
9
+
10
+ from .denoiser import Denoiser, TransformerEncoderWrapper
11
+ from .gaussian_diffuser import GaussianDiffusion
12
+ from .image_feature_extractor import MultiScaleImageFeatureExtractor
models/denoiser.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from collections import defaultdict
9
+ from dataclasses import field, dataclass
10
+ from typing import Any, Dict, List, Optional, Tuple, Union, Callable
11
+ from util.embedding import TimeStepEmbedding, PoseEmbedding
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from hydra.utils import instantiate
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class Denoiser(nn.Module):
23
+ def __init__(
24
+ self,
25
+ TRANSFORMER: Dict,
26
+ target_dim: int = 9, # TODO: reduce fl dim from 2 to 1
27
+ pivot_cam_onehot: bool = True,
28
+ z_dim: int = 384,
29
+ mlp_hidden_dim: bool = 128,
30
+ ):
31
+ super().__init__()
32
+
33
+ self.pivot_cam_onehot = pivot_cam_onehot
34
+ self.target_dim = target_dim
35
+
36
+ self.time_embed = TimeStepEmbedding()
37
+ self.pose_embed = PoseEmbedding(target_dim=self.target_dim)
38
+
39
+ first_dim = (
40
+ self.time_embed.out_dim
41
+ + self.pose_embed.out_dim
42
+ + z_dim
43
+ + int(self.pivot_cam_onehot)
44
+ )
45
+
46
+ d_model = TRANSFORMER.d_model
47
+ self._first = nn.Linear(first_dim, d_model)
48
+
49
+ # slightly different from the paper that
50
+ # we use 2 encoder layers and 6 decoder layers
51
+ # here we use a transformer with 8 encoder layers
52
+ # call TransformerEncoderWrapper() to build a encoder-only transformer
53
+ self._trunk = instantiate(TRANSFORMER, _recursive_=False)
54
+
55
+ # TODO: change the implementation of MLP to a more mature one
56
+ self._last = MLP(
57
+ d_model,
58
+ [mlp_hidden_dim, self.target_dim],
59
+ norm_layer=nn.LayerNorm,
60
+ )
61
+
62
+ def forward(
63
+ self,
64
+ x: torch.Tensor, # B x N x dim
65
+ t: torch.Tensor, # B
66
+ z: torch.Tensor, # B x N x dim_z
67
+ ):
68
+ B, N, _ = x.shape
69
+
70
+ t_emb = self.time_embed(t)
71
+ # expand t from B x C to B x N x C
72
+ t_emb = t_emb.view(B, 1, t_emb.shape[-1]).expand(-1, N, -1)
73
+
74
+ x_emb = self.pose_embed(x)
75
+
76
+ if self.pivot_cam_onehot:
77
+ # add the one hot vector identifying the first camera as pivot
78
+ cam_pivot_id = torch.zeros_like(z[..., :1])
79
+ cam_pivot_id[:, 0, ...] = 1.0
80
+ z = torch.cat([z, cam_pivot_id], dim=-1)
81
+
82
+ feed_feats = torch.cat([x_emb, t_emb, z], dim=-1)
83
+
84
+ input_ = self._first(feed_feats)
85
+
86
+ feats_ = self._trunk(input_)
87
+
88
+ output = self._last(feats_)
89
+
90
+ return output
91
+
92
+
93
+ def TransformerEncoderWrapper(
94
+ d_model: int,
95
+ nhead: int,
96
+ num_encoder_layers: int,
97
+ dim_feedforward: int = 2048,
98
+ dropout: float = 0.1,
99
+ norm_first: bool = True,
100
+ batch_first: bool = True,
101
+ ):
102
+ encoder_layer = torch.nn.TransformerEncoderLayer(
103
+ d_model=d_model,
104
+ nhead=nhead,
105
+ dim_feedforward=dim_feedforward,
106
+ dropout=dropout,
107
+ batch_first=batch_first,
108
+ norm_first=norm_first,
109
+ )
110
+
111
+ _trunk = torch.nn.TransformerEncoder(encoder_layer, num_encoder_layers)
112
+ return _trunk
113
+
114
+
115
+ class MLP(torch.nn.Sequential):
116
+ """This block implements the multi-layer perceptron (MLP) module.
117
+
118
+ Args:
119
+ in_channels (int): Number of channels of the input
120
+ hidden_channels (List[int]): List of the hidden channel dimensions
121
+ norm_layer (Callable[..., torch.nn.Module], optional):
122
+ Norm layer that will be stacked on top of the convolution layer.
123
+ If ``None`` this layer wont be used. Default: ``None``
124
+ activation_layer (Callable[..., torch.nn.Module], optional):
125
+ Activation function which will be stacked on top of the
126
+ normalization layer (if not None), otherwise on top of the
127
+ conv layer. If ``None`` this layer wont be used.
128
+ Default: ``torch.nn.ReLU``
129
+ inplace (bool): Parameter for the activation layer, which can
130
+ optionally do the operation in-place. Default ``True``
131
+ bias (bool): Whether to use bias in the linear layer. Default ``True``
132
+ dropout (float): The probability for the dropout layer. Default: 0.0
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ in_channels: int,
138
+ hidden_channels: List[int],
139
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
140
+ activation_layer: Optional[
141
+ Callable[..., torch.nn.Module]
142
+ ] = torch.nn.ReLU,
143
+ inplace: Optional[bool] = True,
144
+ bias: bool = True,
145
+ norm_first: bool = False,
146
+ dropout: float = 0.0,
147
+ ):
148
+ # The addition of `norm_layer` is inspired from
149
+ # the implementation of TorchMultimodal:
150
+ # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
151
+ params = {} if inplace is None else {"inplace": inplace}
152
+
153
+ layers = []
154
+ in_dim = in_channels
155
+
156
+ for hidden_dim in hidden_channels[:-1]:
157
+ if norm_first and norm_layer is not None:
158
+ layers.append(norm_layer(in_dim))
159
+
160
+ layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
161
+
162
+ if not norm_first and norm_layer is not None:
163
+ layers.append(norm_layer(hidden_dim))
164
+
165
+ layers.append(activation_layer(**params))
166
+
167
+ if dropout > 0:
168
+ layers.append(torch.nn.Dropout(dropout, **params))
169
+
170
+ in_dim = hidden_dim
171
+
172
+ if norm_first and norm_layer is not None:
173
+ layers.append(norm_layer(in_dim))
174
+
175
+ layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
176
+ if dropout > 0:
177
+ layers.append(torch.nn.Dropout(dropout, **params))
178
+
179
+ super().__init__(*layers)
models/gaussian_diffuser.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/beb2f2d8dd9b4f2bd5be4719f37082fe061ee450/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
8
+
9
+ import math
10
+ import copy
11
+ from pathlib import Path
12
+ from random import random
13
+ from functools import partial
14
+ from collections import namedtuple
15
+ from multiprocessing import cpu_count
16
+
17
+ import torch
18
+ from torch import nn, einsum
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import Dataset, DataLoader
21
+
22
+ from torch.optim import Adam
23
+ from torchvision import transforms as T, utils
24
+
25
+ from einops import rearrange, reduce
26
+ from einops.layers.torch import Rearrange
27
+
28
+ from PIL import Image
29
+ from tqdm.auto import tqdm
30
+ from typing import Any, Dict, List, Optional, Tuple, Union
31
+
32
+ # constants
33
+
34
+ ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start"])
35
+
36
+ # helpers functions
37
+
38
+
39
+ def exists(x):
40
+ return x is not None
41
+
42
+
43
+ def default(val, d):
44
+ if exists(val):
45
+ return val
46
+ return d() if callable(d) else d
47
+
48
+
49
+ def extract(a, t, x_shape):
50
+ b, *_ = t.shape
51
+ out = a.gather(-1, t)
52
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
53
+
54
+
55
+ def linear_beta_schedule(timesteps):
56
+ scale = 1000 / timesteps
57
+ beta_start = scale * 0.0001
58
+ beta_end = scale * 0.02
59
+ return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
60
+
61
+
62
+ def cosine_beta_schedule(timesteps, s=0.008):
63
+ """
64
+ cosine schedule
65
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
66
+ """
67
+ steps = timesteps + 1
68
+ x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
69
+ alphas_cumprod = (
70
+ torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
71
+ )
72
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
73
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
74
+ return torch.clip(betas, 0, 0.999)
75
+
76
+
77
+ class GaussianDiffusion(nn.Module):
78
+ def __init__(
79
+ self,
80
+ timesteps=100,
81
+ sampling_timesteps=None,
82
+ beta_1=0.0001,
83
+ beta_T=0.1,
84
+ loss_type="l1",
85
+ objective="pred_noise",
86
+ beta_schedule="custom",
87
+ p2_loss_weight_gamma=0.0,
88
+ p2_loss_weight_k=1,
89
+ ):
90
+ super().__init__()
91
+
92
+ self.objective = objective
93
+
94
+ assert objective in {
95
+ "pred_noise",
96
+ "pred_x0",
97
+ }, "objective must be either pred_noise (predict noise) \
98
+ or pred_x0 (predict image start)"
99
+
100
+ self.timesteps = timesteps
101
+ self.sampling_timesteps = sampling_timesteps
102
+ self.beta_1 = beta_1
103
+ self.beta_T = beta_T
104
+ self.loss_type = loss_type
105
+ self.objective = objective
106
+ self.beta_schedule = beta_schedule
107
+ self.p2_loss_weight_gamma = p2_loss_weight_gamma
108
+ self.p2_loss_weight_k = p2_loss_weight_k
109
+
110
+ self.init_diff_hyper(
111
+ self.timesteps,
112
+ self.sampling_timesteps,
113
+ self.beta_1,
114
+ self.beta_T,
115
+ self.loss_type,
116
+ self.objective,
117
+ self.beta_schedule,
118
+ self.p2_loss_weight_gamma,
119
+ self.p2_loss_weight_k,
120
+ )
121
+
122
+ def init_diff_hyper(
123
+ self,
124
+ timesteps,
125
+ sampling_timesteps,
126
+ beta_1,
127
+ beta_T,
128
+ loss_type,
129
+ objective,
130
+ beta_schedule,
131
+ p2_loss_weight_gamma,
132
+ p2_loss_weight_k,
133
+ ):
134
+ if beta_schedule == "linear":
135
+ betas = linear_beta_schedule(timesteps)
136
+ elif beta_schedule == "cosine":
137
+ betas = cosine_beta_schedule(timesteps)
138
+ elif beta_schedule == "custom":
139
+ betas = torch.linspace(
140
+ beta_1, beta_T, timesteps, dtype=torch.float64
141
+ )
142
+ else:
143
+ raise ValueError(f"unknown beta schedule {beta_schedule}")
144
+
145
+ alphas = 1.0 - betas
146
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
147
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
148
+
149
+ (timesteps,) = betas.shape
150
+ self.num_timesteps = int(timesteps)
151
+ self.loss_type = loss_type
152
+
153
+ # sampling related parameters
154
+ self.sampling_timesteps = default(
155
+ sampling_timesteps, timesteps
156
+ ) # default num sampling timesteps to number of timesteps at training
157
+
158
+ assert self.sampling_timesteps <= timesteps
159
+
160
+ # helper function to register buffer from float64 to float32
161
+ register_buffer = lambda name, val: self.register_buffer(
162
+ name, val.to(torch.float32)
163
+ )
164
+
165
+ register_buffer("betas", betas)
166
+ register_buffer("alphas_cumprod", alphas_cumprod)
167
+ register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
168
+
169
+ # calculations for diffusion q(x_t | x_{t-1}) and others
170
+ register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
171
+ register_buffer(
172
+ "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
173
+ )
174
+ register_buffer(
175
+ "log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)
176
+ )
177
+ register_buffer(
178
+ "sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)
179
+ )
180
+ register_buffer(
181
+ "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
182
+ )
183
+
184
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
185
+ posterior_variance = (
186
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
187
+ )
188
+
189
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
190
+ register_buffer("posterior_variance", posterior_variance)
191
+
192
+ # below: log calculation clipped because the posterior variance is 0
193
+ # at the beginning of the diffusion chain
194
+ register_buffer(
195
+ "posterior_log_variance_clipped",
196
+ torch.log(posterior_variance.clamp(min=1e-20)),
197
+ )
198
+ register_buffer(
199
+ "posterior_mean_coef1",
200
+ betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
201
+ )
202
+ register_buffer(
203
+ "posterior_mean_coef2",
204
+ (1.0 - alphas_cumprod_prev)
205
+ * torch.sqrt(alphas)
206
+ / (1.0 - alphas_cumprod),
207
+ )
208
+
209
+ # calculate p2 reweighting
210
+ register_buffer(
211
+ "p2_loss_weight",
212
+ (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
213
+ ** -p2_loss_weight_gamma,
214
+ )
215
+
216
+ # helper functions
217
+ def predict_start_from_noise(self, x_t, t, noise):
218
+ return (
219
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
220
+ - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
221
+ )
222
+
223
+ def predict_noise_from_start(self, x_t, t, x0):
224
+ return (
225
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0
226
+ ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
227
+
228
+ def q_posterior(self, x_start, x_t, t):
229
+ posterior_mean = (
230
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
231
+ + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
232
+ )
233
+
234
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
235
+ posterior_log_variance_clipped = extract(
236
+ self.posterior_log_variance_clipped, t, x_t.shape
237
+ )
238
+ return (
239
+ posterior_mean,
240
+ posterior_variance,
241
+ posterior_log_variance_clipped,
242
+ )
243
+
244
+ def q_sample(self, x_start, t, noise=None):
245
+ noise = default(noise, lambda: torch.randn_like(x_start))
246
+ return (
247
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
248
+ + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
249
+ * noise
250
+ )
251
+
252
+ def model_predictions(self, x, t, z, x_self_cond=None):
253
+ model_output = self.model(x, t, z)
254
+
255
+ if self.objective == "pred_noise":
256
+ pred_noise = model_output
257
+ x_start = self.predict_start_from_noise(x, t, model_output)
258
+
259
+ elif self.objective == "pred_x0":
260
+ pred_noise = self.predict_noise_from_start(x, t, model_output)
261
+ x_start = model_output
262
+
263
+ return ModelPrediction(pred_noise, x_start)
264
+
265
+ def p_mean_variance(
266
+ self,
267
+ x: torch.Tensor, # B x N_x x dim
268
+ t: int,
269
+ z: torch.Tensor,
270
+ x_self_cond=None,
271
+ clip_denoised=False,
272
+ ):
273
+ preds = self.model_predictions(x, t, z)
274
+
275
+ x_start = preds.pred_x_start
276
+
277
+ if clip_denoised:
278
+ raise NotImplementedError(
279
+ "We don't clip the output because \
280
+ pose does not have a clear bound."
281
+ )
282
+
283
+ (
284
+ model_mean,
285
+ posterior_variance,
286
+ posterior_log_variance,
287
+ ) = self.q_posterior(x_start=x_start, x_t=x, t=t)
288
+
289
+ return model_mean, posterior_variance, posterior_log_variance, x_start
290
+
291
+ @torch.no_grad()
292
+ def p_sample(
293
+ self,
294
+ x: torch.Tensor, # B x N_x x dim
295
+ t: int,
296
+ z: torch.Tensor,
297
+ x_self_cond=None,
298
+ clip_denoised=False,
299
+ cond_fn=None,
300
+ cond_start_step=0,
301
+ ):
302
+ b, *_, device = *x.shape, x.device
303
+ batched_times = torch.full(
304
+ (x.shape[0],), t, device=x.device, dtype=torch.long
305
+ )
306
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(
307
+ x=x,
308
+ t=batched_times,
309
+ z=z,
310
+ x_self_cond=x_self_cond,
311
+ clip_denoised=clip_denoised,
312
+ )
313
+
314
+ if cond_fn is not None and t < cond_start_step:
315
+ model_mean = cond_fn(model_mean, t)
316
+ noise = 0.0
317
+ else:
318
+ noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0
319
+
320
+ pred = model_mean + (0.5 * model_log_variance).exp() * noise
321
+ return pred, x_start
322
+
323
+ @torch.no_grad()
324
+ def p_sample_loop(
325
+ self,
326
+ shape,
327
+ z: torch.Tensor,
328
+ cond_fn=None,
329
+ cond_start_step=0,
330
+ ):
331
+ batch, device = shape[0], self.betas.device
332
+
333
+ # Init here
334
+ pose = torch.randn(shape, device=device)
335
+
336
+ x_start = None
337
+
338
+ pose_process = []
339
+ pose_process.append(pose.unsqueeze(0))
340
+
341
+ for t in reversed(range(0, self.num_timesteps)):
342
+ pose, _ = self.p_sample(
343
+ x=pose,
344
+ t=t,
345
+ z=z,
346
+ cond_fn=cond_fn,
347
+ cond_start_step=cond_start_step,
348
+ )
349
+ pose_process.append(pose.unsqueeze(0))
350
+
351
+ return pose, torch.cat(pose_process)
352
+
353
+ @torch.no_grad()
354
+ def sample(self, shape, z, cond_fn=None, cond_start_step=0):
355
+ # TODO: add more variants
356
+ sample_fn = self.p_sample_loop
357
+ return sample_fn(
358
+ shape, z=z, cond_fn=cond_fn, cond_start_step=cond_start_step
359
+ )
360
+
361
+ def p_losses(
362
+ self,
363
+ x_start,
364
+ t,
365
+ z=None,
366
+ noise=None,
367
+ ):
368
+ noise = default(noise, lambda: torch.randn_like(x_start))
369
+ # noise sample
370
+ x = self.q_sample(x_start=x_start, t=t, noise=noise)
371
+
372
+ model_out = self.model(x, t, z)
373
+
374
+ if self.objective == "pred_noise":
375
+ target = noise
376
+ x_0_pred = self.predict_start_from_noise(x, t, model_out)
377
+ elif self.objective == "pred_x0":
378
+ target = x_start
379
+ x_0_pred = model_out
380
+ else:
381
+ raise ValueError(f"unknown objective {self.objective}")
382
+
383
+ loss = self.loss_fn(model_out, target, reduction="none")
384
+
385
+ loss = reduce(loss, "b ... -> b (...)", "mean")
386
+ loss = loss * extract(self.p2_loss_weight, t, loss.shape)
387
+
388
+ return {
389
+ "loss": loss,
390
+ "noise": noise,
391
+ "x_0_pred": x_0_pred,
392
+ "x_t": x,
393
+ "t": t,
394
+ }
395
+
396
+ def forward(self, pose, z=None, *args, **kwargs):
397
+ b = len(pose)
398
+ t = torch.randint(
399
+ 0, self.num_timesteps, (b,), device=pose.device
400
+ ).long()
401
+ return self.p_losses(pose, t, z=z, *args, **kwargs)
402
+
403
+ @property
404
+ def loss_fn(self):
405
+ if self.loss_type == "l1":
406
+ return F.l1_loss
407
+ elif self.loss_type == "l2":
408
+ return F.mse_loss
409
+ else:
410
+ raise ValueError(f"invalid loss type {self.loss_type}")
models/image_feature_extractor.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import math
9
+ import warnings
10
+ from collections import defaultdict
11
+ from dataclasses import field, dataclass
12
+ from typing import Any, Dict, List, Optional, Tuple, Union, Callable
13
+
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torchvision
18
+
19
+ import io
20
+ from PIL import Image
21
+ import numpy as np
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
26
+ _RESNET_STD = [0.229, 0.224, 0.225]
27
+
28
+
29
+ class MultiScaleImageFeatureExtractor(nn.Module):
30
+ def __init__(
31
+ self,
32
+ modelname: str = "dino_vits16",
33
+ freeze: bool = False,
34
+ scale_factors: list = [1, 1 / 2, 1 / 3],
35
+ ):
36
+ super().__init__()
37
+ self.freeze = freeze
38
+ self.scale_factors = scale_factors
39
+
40
+ if "res" in modelname:
41
+ self._net = getattr(torchvision.models, modelname)(pretrained=True)
42
+ self._output_dim = self._net.fc.weight.shape[1]
43
+ self._net.fc = nn.Identity()
44
+ elif "dino" in modelname:
45
+ self._net = torch.hub.load("facebookresearch/dino:main", modelname)
46
+ self._output_dim = self._net.norm.weight.shape[0]
47
+ else:
48
+ raise ValueError(f"Unknown model name {modelname}")
49
+
50
+ for name, value in (
51
+ ("_resnet_mean", _RESNET_MEAN),
52
+ ("_resnet_std", _RESNET_STD),
53
+ ):
54
+ self.register_buffer(
55
+ name,
56
+ torch.FloatTensor(value).view(1, 3, 1, 1),
57
+ persistent=False,
58
+ )
59
+
60
+ if self.freeze:
61
+ for param in self.parameters():
62
+ param.requires_grad = False
63
+
64
+ def get_output_dim(self):
65
+ return self._output_dim
66
+
67
+ def forward(self, image_rgb: torch.Tensor) -> torch.Tensor:
68
+ img_normed = self._resnet_normalize_image(image_rgb)
69
+
70
+ features = self._compute_multiscale_features(img_normed)
71
+
72
+ return features
73
+
74
+ def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
75
+ return (img - self._resnet_mean) / self._resnet_std
76
+
77
+ def _compute_multiscale_features(
78
+ self, img_normed: torch.Tensor
79
+ ) -> torch.Tensor:
80
+ multiscale_features = None
81
+
82
+ if len(self.scale_factors) <= 0:
83
+ raise ValueError(
84
+ f"Wrong format of self.scale_factors: {self.scale_factors}"
85
+ )
86
+
87
+ for scale_factor in self.scale_factors:
88
+ if scale_factor == 1:
89
+ inp = img_normed
90
+ else:
91
+ inp = self._resize_image(img_normed, scale_factor)
92
+
93
+ if multiscale_features is None:
94
+ multiscale_features = self._net(inp)
95
+ else:
96
+ multiscale_features += self._net(inp)
97
+
98
+ averaged_features = multiscale_features / len(self.scale_factors)
99
+ return averaged_features
100
+
101
+ @staticmethod
102
+ def _resize_image(image: torch.Tensor, scale_factor: float) -> torch.Tensor:
103
+ return nn.functional.interpolate(
104
+ image,
105
+ scale_factor=scale_factor,
106
+ mode="bilinear",
107
+ align_corners=False,
108
+ )
models/pose_diffusion_model.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Standard library imports
8
+ import base64
9
+ import io
10
+ import logging
11
+ import math
12
+ import pickle
13
+ import warnings
14
+ from collections import defaultdict
15
+ from dataclasses import field, dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ # Third-party library imports
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from PIL import Image
23
+
24
+ from pytorch3d.renderer.cameras import CamerasBase
25
+ from pytorch3d.transforms import (
26
+ se3_exp_map,
27
+ se3_log_map,
28
+ Transform3d,
29
+ so3_relative_angle,
30
+ )
31
+ from util.camera_transform import pose_encoding_to_camera
32
+
33
+ import models
34
+ from hydra.utils import instantiate
35
+ from pytorch3d.renderer.cameras import PerspectiveCameras
36
+
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ class PoseDiffusionModel(nn.Module):
42
+ def __init__(
43
+ self,
44
+ pose_encoding_type: str,
45
+ IMAGE_FEATURE_EXTRACTOR: Dict,
46
+ DIFFUSER: Dict,
47
+ DENOISER: Dict,
48
+ ):
49
+ """Initializes a PoseDiffusion model.
50
+
51
+ Args:
52
+ pose_encoding_type (str):
53
+ Defines the encoding type for extrinsics and intrinsics
54
+ Currently, only `"absT_quaR_logFL"` is supported -
55
+ a concatenation of the translation vector,
56
+ rotation quaternion, and logarithm of focal length.
57
+ image_feature_extractor_cfg (Dict):
58
+ Configuration for the image feature extractor.
59
+ diffuser_cfg (Dict):
60
+ Configuration for the diffuser.
61
+ denoiser_cfg (Dict):
62
+ Configuration for the denoiser.
63
+ """
64
+
65
+ super().__init__()
66
+
67
+ self.pose_encoding_type = pose_encoding_type
68
+
69
+ self.image_feature_extractor = instantiate(
70
+ IMAGE_FEATURE_EXTRACTOR, _recursive_=False
71
+ )
72
+ self.diffuser = instantiate(DIFFUSER, _recursive_=False)
73
+
74
+ denoiser = instantiate(DENOISER, _recursive_=False)
75
+ self.diffuser.model = denoiser
76
+
77
+ self.target_dim = denoiser.target_dim
78
+
79
+ def forward(
80
+ self,
81
+ image: torch.Tensor,
82
+ gt_cameras: Optional[CamerasBase] = None,
83
+ sequence_name: Optional[List[str]] = None,
84
+ cond_fn=None,
85
+ cond_start_step=0,
86
+ ):
87
+ """
88
+ Forward pass of the PoseDiffusionModel.
89
+
90
+ Args:
91
+ image (torch.Tensor):
92
+ Input image tensor, Bx3xHxW.
93
+ gt_cameras (Optional[CamerasBase], optional):
94
+ Camera object. Defaults to None.
95
+ sequence_name (Optional[List[str]], optional):
96
+ List of sequence names. Defaults to None.
97
+ cond_fn ([type], optional):
98
+ Conditional function. Wrapper for GGS or other functions.
99
+ cond_start_step (int, optional):
100
+ The sampling step to start using conditional function.
101
+
102
+ Returns:
103
+ PerspectiveCameras: PyTorch3D camera object.
104
+ """
105
+
106
+ z = self.image_feature_extractor(image)
107
+
108
+ z = z.unsqueeze(0)
109
+
110
+ B, N, _ = z.shape
111
+ target_shape = [B, N, self.target_dim]
112
+
113
+ # sampling
114
+ pose_encoding, pose_encoding_diffusion_samples = self.diffuser.sample(
115
+ shape=target_shape,
116
+ z=z,
117
+ cond_fn=cond_fn,
118
+ cond_start_step=cond_start_step,
119
+ )
120
+
121
+ # convert the encoded representation to PyTorch3D cameras
122
+ pred_cameras = pose_encoding_to_camera(
123
+ pose_encoding, pose_encoding_type=self.pose_encoding_type
124
+ )
125
+
126
+ return pred_cameras
packages.txt ADDED
File without changes
pre-requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==1.13.0
2
+ torchvision==0.14.0
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ hydra-core
2
+ omegaconf
3
+ opencv-python
4
+ einops
5
+ git+https://github.com/facebookresearch/PoseDiffusion.git
util/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
util/camera_transform.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from pytorch3d.transforms.rotation_conversions import (
9
+ matrix_to_quaternion,
10
+ quaternion_to_matrix,
11
+ )
12
+ from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
13
+
14
+
15
+ def pose_encoding_to_camera(
16
+ pose_encoding,
17
+ pose_encoding_type="absT_quaR_logFL",
18
+ log_focal_length_bias=1.8,
19
+ min_focal_length=0.1,
20
+ max_focal_length=20,
21
+ ):
22
+ """
23
+ Args:
24
+ pose_encoding: A tensor of shape `BxNxC`, containing a batch of
25
+ `BxN` `C`-dimensional pose encodings.
26
+ pose_encoding_type: The type of pose encoding,
27
+ only "absT_quaR_logFL" is supported.
28
+ """
29
+
30
+ batch_size, num_poses, _ = pose_encoding.shape
31
+ pose_encoding_reshaped = pose_encoding.reshape(
32
+ -1, pose_encoding.shape[-1]
33
+ ) # Reshape to BNxC
34
+
35
+ if pose_encoding_type == "absT_quaR_logFL":
36
+ # forced that 3 for absT, 4 for quaR, 2 logFL
37
+ # TODO: converted to 1 dim for logFL, consistent with our paper
38
+ abs_T = pose_encoding_reshaped[:, :3]
39
+ quaternion_R = pose_encoding_reshaped[:, 3:7]
40
+ R = quaternion_to_matrix(quaternion_R)
41
+
42
+ log_focal_length = pose_encoding_reshaped[:, 7:9]
43
+
44
+ # log_focal_length_bias was the hyperparameter
45
+ # to ensure the mean of logFL close to 0 during training
46
+ # Now converted back
47
+ focal_length = (log_focal_length + log_focal_length_bias).exp()
48
+
49
+ # clamp to avoid weird fl values
50
+ focal_length = torch.clamp(
51
+ focal_length, min=min_focal_length, max=max_focal_length
52
+ )
53
+ else:
54
+ raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
55
+
56
+ pred_cameras = PerspectiveCameras(
57
+ focal_length=focal_length,
58
+ R=R,
59
+ T=abs_T,
60
+ device=R.device,
61
+ )
62
+
63
+ return pred_cameras
util/embedding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import math
10
+ from pytorch3d.renderer import HarmonicEmbedding
11
+
12
+
13
+ class TimeStepEmbedding(nn.Module):
14
+ # learned from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py
15
+ def __init__(self, dim=256, max_period=10000):
16
+ super().__init__()
17
+ self.dim = dim
18
+ self.max_period = max_period
19
+
20
+ self.linear = nn.Sequential(
21
+ nn.Linear(dim, dim // 2),
22
+ nn.SiLU(),
23
+ nn.Linear(dim // 2, dim // 2),
24
+ )
25
+
26
+ self.out_dim = dim // 2
27
+
28
+ def _compute_freqs(self, half):
29
+ freqs = torch.exp(
30
+ -math.log(self.max_period)
31
+ * torch.arange(start=0, end=half, dtype=torch.float32)
32
+ / half
33
+ )
34
+ return freqs
35
+
36
+ def forward(self, timesteps):
37
+ half = self.dim // 2
38
+ freqs = self._compute_freqs(half).to(device=timesteps.device)
39
+ args = timesteps[:, None].float() * freqs[None]
40
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
41
+ if self.dim % 2:
42
+ embedding = torch.cat(
43
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
44
+ )
45
+
46
+ output = self.linear(embedding)
47
+ return output
48
+
49
+
50
+ class PoseEmbedding(nn.Module):
51
+ def __init__(self, target_dim, n_harmonic_functions=10, append_input=True):
52
+ super().__init__()
53
+
54
+ self._emb_pose = HarmonicEmbedding(
55
+ n_harmonic_functions=n_harmonic_functions, append_input=append_input
56
+ )
57
+
58
+ self.out_dim = self._emb_pose.get_output_dim(target_dim)
59
+
60
+ def forward(self, pose_encoding):
61
+ e_pose_encoding = self._emb_pose(pose_encoding)
62
+ return e_pose_encoding
util/geometry_guided_sampling.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from typing import Dict, List, Optional, Union
9
+ from util.camera_transform import pose_encoding_to_camera
10
+ from util.get_fundamental_matrix import get_fundamental_matrices
11
+ from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
12
+
13
+
14
+ def geometry_guided_sampling(
15
+ model_mean: torch.Tensor,
16
+ t: int,
17
+ matches_dict: Dict,
18
+ GGS_cfg: Dict,
19
+ ):
20
+ # pre-process matches
21
+ b, c, h, w = matches_dict["img_shape"]
22
+ device = model_mean.device
23
+
24
+ def _to_device(tensor):
25
+ return torch.from_numpy(tensor).to(device)
26
+
27
+ kp1 = _to_device(matches_dict["kp1"])
28
+ kp2 = _to_device(matches_dict["kp2"])
29
+ i12 = _to_device(matches_dict["i12"])
30
+
31
+ pair_idx = i12[:, 0] * b + i12[:, 1]
32
+ pair_idx = pair_idx.long()
33
+
34
+ def _to_homogeneous(tensor):
35
+ return torch.nn.functional.pad(tensor, [0, 1], value=1)
36
+
37
+ kp1_homo = _to_homogeneous(kp1)
38
+ kp2_homo = _to_homogeneous(kp2)
39
+
40
+ i1, i2 = [
41
+ i.reshape(-1) for i in torch.meshgrid(torch.arange(b), torch.arange(b))
42
+ ]
43
+
44
+ processed_matches = {
45
+ "kp1_homo": kp1_homo,
46
+ "kp2_homo": kp2_homo,
47
+ "i1": i1,
48
+ "i2": i2,
49
+ "h": h,
50
+ "w": w,
51
+ "pair_idx": pair_idx,
52
+ }
53
+
54
+ # conduct GGS
55
+ model_mean = GGS_optimize(model_mean, t, processed_matches, **GGS_cfg)
56
+
57
+ # Optimize FL, R, and T separately
58
+ model_mean = GGS_optimize(
59
+ model_mean,
60
+ t,
61
+ processed_matches,
62
+ update_T=False,
63
+ update_R=False,
64
+ update_FL=True,
65
+ **GGS_cfg,
66
+ ) # only optimize FL
67
+
68
+ model_mean = GGS_optimize(
69
+ model_mean,
70
+ t,
71
+ processed_matches,
72
+ update_T=False,
73
+ update_R=True,
74
+ update_FL=False,
75
+ **GGS_cfg,
76
+ ) # only optimize R
77
+
78
+ model_mean = GGS_optimize(
79
+ model_mean,
80
+ t,
81
+ processed_matches,
82
+ update_T=True,
83
+ update_R=False,
84
+ update_FL=False,
85
+ **GGS_cfg,
86
+ ) # only optimize T
87
+
88
+ model_mean = GGS_optimize(model_mean, t, processed_matches, **GGS_cfg)
89
+ return model_mean
90
+
91
+
92
+ def GGS_optimize(
93
+ model_mean: torch.Tensor,
94
+ t: int,
95
+ processed_matches: Dict,
96
+ update_R: bool = True,
97
+ update_T: bool = True,
98
+ update_FL: bool = True,
99
+ # the args below come from **GGS_cfg
100
+ alpha: float = 0.0001,
101
+ learning_rate: float = 1e-2,
102
+ iter_num: int = 100,
103
+ sampson_max: int = 10,
104
+ min_matches: int = 10,
105
+ pose_encoding_type: str = "absT_quaR_logFL",
106
+ **kwargs,
107
+ ):
108
+ with torch.enable_grad():
109
+ model_mean.requires_grad_(True)
110
+
111
+ if update_R and update_T and update_FL:
112
+ iter_num = iter_num * 2
113
+
114
+ optimizer = torch.optim.SGD(
115
+ [model_mean], lr=learning_rate, momentum=0.9
116
+ )
117
+ batch_size = model_mean.shape[1]
118
+
119
+ for _ in range(iter_num):
120
+ valid_sampson, sampson_to_print = compute_sampson_distance(
121
+ model_mean,
122
+ t,
123
+ processed_matches,
124
+ update_R=update_R,
125
+ update_T=update_T,
126
+ update_FL=update_FL,
127
+ pose_encoding_type=pose_encoding_type,
128
+ sampson_max=sampson_max,
129
+ )
130
+
131
+ if min_matches > 0:
132
+ valid_match_per_frame = len(valid_sampson) / batch_size
133
+ if valid_match_per_frame < min_matches:
134
+ print(
135
+ "Drop this pair because of insufficient valid matches"
136
+ )
137
+ break
138
+
139
+ loss = valid_sampson.mean()
140
+ optimizer.zero_grad()
141
+ loss.backward()
142
+
143
+ grads = model_mean.grad
144
+ grad_norm = grads.norm()
145
+ grad_mask = (grads.abs() > 0).detach()
146
+ model_mean_norm = (model_mean * grad_mask).norm()
147
+
148
+ max_norm = alpha * model_mean_norm / learning_rate
149
+
150
+ total_norm = torch.nn.utils.clip_grad_norm_(model_mean, max_norm)
151
+ optimizer.step()
152
+
153
+ print(f"t={t:02d} | sampson={sampson_to_print:05f}")
154
+ model_mean = model_mean.detach()
155
+ return model_mean
156
+
157
+
158
+ def compute_sampson_distance(
159
+ model_mean: torch.Tensor,
160
+ t: int,
161
+ processed_matches: Dict,
162
+ update_R=True,
163
+ update_T=True,
164
+ update_FL=True,
165
+ pose_encoding_type: str = "absT_quaR_logFL",
166
+ sampson_max: int = 10,
167
+ ):
168
+ camera = pose_encoding_to_camera(model_mean, pose_encoding_type)
169
+
170
+ # pick the mean of the predicted focal length
171
+ camera.focal_length = camera.focal_length.mean(dim=0).repeat(
172
+ len(camera.focal_length), 1
173
+ )
174
+
175
+ if not update_R:
176
+ camera.R = camera.R.detach()
177
+
178
+ if not update_T:
179
+ camera.T = camera.T.detach()
180
+
181
+ if not update_FL:
182
+ camera.focal_length = camera.focal_length.detach()
183
+
184
+ kp1_homo, kp2_homo, i1, i2, he, wi, pair_idx = processed_matches.values()
185
+ F_2_to_1 = get_fundamental_matrices(
186
+ camera, he, wi, i1, i2, l2_normalize_F=False
187
+ )
188
+ F = F_2_to_1.permute(0, 2, 1) # y1^T F y2 = 0
189
+
190
+ def _sampson_distance(F, kp1_homo, kp2_homo, pair_idx):
191
+ left = torch.bmm(kp1_homo[:, None], F[pair_idx])
192
+ right = torch.bmm(F[pair_idx], kp2_homo[..., None])
193
+
194
+ bottom = (
195
+ left[:, :, 0].square()
196
+ + left[:, :, 1].square()
197
+ + right[:, 0, :].square()
198
+ + right[:, 1, :].square()
199
+ )
200
+ top = torch.bmm(left, kp2_homo[..., None]).square()
201
+
202
+ sampson = top[:, 0] / bottom
203
+ return sampson
204
+
205
+ sampson = _sampson_distance(
206
+ F,
207
+ kp1_homo.float(),
208
+ kp2_homo.float(),
209
+ pair_idx,
210
+ )
211
+
212
+ sampson_to_print = sampson.detach().clone().clamp(max=sampson_max).mean()
213
+ sampson = sampson[sampson < sampson_max]
214
+
215
+ return sampson, sampson_to_print
util/get_fundamental_matrix.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import pytorch3d
9
+ from pytorch3d.utils import opencv_from_cameras_projection
10
+ from pytorch3d.transforms.so3 import hat
11
+ from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
12
+
13
+
14
+ def get_fundamental_matrices(
15
+ camera: CamerasBase,
16
+ height: int,
17
+ width: int,
18
+ index1: torch.LongTensor,
19
+ index2: torch.LongTensor,
20
+ l2_normalize_F=False,
21
+ ):
22
+ """Compute fundamental matrices for given camera parameters."""
23
+ batch_size = camera.R.shape[0]
24
+
25
+ # Convert to opencv / colmap / Hartley&Zisserman convention
26
+ image_size_t = (
27
+ torch.LongTensor([height, width])[None]
28
+ .repeat(batch_size, 1)
29
+ .to(camera.device)
30
+ )
31
+ R, t, K = opencv_from_cameras_projection(camera, image_size=image_size_t)
32
+
33
+ F, E = get_fundamental_matrix(
34
+ K[index1], R[index1], t[index1], K[index2], R[index2], t[index2]
35
+ )
36
+
37
+ if l2_normalize_F:
38
+ F_scale = torch.norm(F, dim=(1, 2))
39
+ F_scale = F_scale.clamp(min=0.0001)
40
+ F = F / F_scale[:, None, None]
41
+
42
+ return F
43
+
44
+
45
+ def get_fundamental_matrix(K1, R1, t1, K2, R2, t2):
46
+ E = get_essential_matrix(R1, t1, R2, t2)
47
+ F = K2.inverse().permute(0, 2, 1).matmul(E).matmul(K1.inverse())
48
+ return F, E # p2^T F p1 = 0
49
+
50
+
51
+ def get_essential_matrix(R1, t1, R2, t2):
52
+ R12 = R2.matmul(R1.permute(0, 2, 1))
53
+ t12 = t2 - R12.matmul(t1[..., None])[..., 0]
54
+ E_R = R12
55
+ E_t = -E_R.permute(0, 2, 1).matmul(t12[..., None])[..., 0]
56
+ E = E_R.matmul(hat(E_t))
57
+ return E
util/load_img_folder.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import numpy as np
9
+ from PIL import Image
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from typing import (
13
+ Any,
14
+ ClassVar,
15
+ Dict,
16
+ Iterable,
17
+ List,
18
+ Optional,
19
+ Sequence,
20
+ Tuple,
21
+ Type,
22
+ TYPE_CHECKING,
23
+ Union,
24
+ )
25
+
26
+
27
+ def load_and_preprocess_images(
28
+ folder_path: str, image_size: int = 224, mode: str = "bilinear"
29
+ ) -> torch.Tensor:
30
+ image_paths = [
31
+ os.path.join(folder_path, file)
32
+ for file in os.listdir(folder_path)
33
+ if file.lower().endswith((".png", ".jpg", ".jpeg"))
34
+ ]
35
+ image_paths.sort()
36
+
37
+ images = []
38
+ bboxes_xyxy = []
39
+ scales = []
40
+ for path in image_paths:
41
+ image = _load_image(path)
42
+ image, bbox_xyxy, min_hw = _center_crop_square(image)
43
+ minscale = image_size / min_hw
44
+
45
+ imre = F.interpolate(
46
+ torch.from_numpy(image)[None],
47
+ size=(image_size, image_size),
48
+ mode=mode,
49
+ align_corners=False if mode == "bilinear" else None,
50
+ )[0]
51
+
52
+ images.append(imre.numpy())
53
+ bboxes_xyxy.append(bbox_xyxy.numpy())
54
+ scales.append(minscale)
55
+
56
+ images_tensor = torch.from_numpy(np.stack(images))
57
+
58
+ # assume all the images have the same shape for GGS
59
+ image_info = {
60
+ "size": (min_hw, min_hw),
61
+ "bboxes_xyxy": np.stack(bboxes_xyxy),
62
+ "resized_scales": np.stack(scales),
63
+ }
64
+ return images_tensor, image_info
65
+
66
+
67
+ # helper functions
68
+
69
+
70
+ def _load_image(path) -> np.ndarray:
71
+ with Image.open(path) as pil_im:
72
+ im = np.array(pil_im.convert("RGB"))
73
+ im = im.transpose((2, 0, 1))
74
+ im = im.astype(np.float32) / 255.0
75
+ return im
76
+
77
+
78
+ def _center_crop_square(image: np.ndarray) -> np.ndarray:
79
+ h, w = image.shape[1:]
80
+ min_dim = min(h, w)
81
+ top = (h - min_dim) // 2
82
+ left = (w - min_dim) // 2
83
+ cropped_image = image[:, top : top + min_dim, left : left + min_dim]
84
+
85
+ # bbox_xywh: the cropped region
86
+ bbox_xywh = torch.tensor([left, top, min_dim, min_dim])
87
+
88
+ # the format from xywh to xyxy
89
+ bbox_xyxy = _clamp_box_to_image_bounds_and_round(
90
+ _get_clamp_bbox(
91
+ bbox_xywh,
92
+ box_crop_context=0.0,
93
+ ),
94
+ image_size_hw=(h, w),
95
+ )
96
+ return cropped_image, bbox_xyxy, min_dim
97
+
98
+
99
+ def _get_clamp_bbox(
100
+ bbox: torch.Tensor,
101
+ box_crop_context: float = 0.0,
102
+ ) -> torch.Tensor:
103
+ # box_crop_context: rate of expansion for bbox
104
+ # returns possibly expanded bbox xyxy as float
105
+
106
+ bbox = bbox.clone() # do not edit bbox in place
107
+
108
+ # increase box size
109
+ if box_crop_context > 0.0:
110
+ c = box_crop_context
111
+ bbox = bbox.float()
112
+ bbox[0] -= bbox[2] * c / 2
113
+ bbox[1] -= bbox[3] * c / 2
114
+ bbox[2] += bbox[2] * c
115
+ bbox[3] += bbox[3] * c
116
+
117
+ if (bbox[2:] <= 1.0).any():
118
+ raise ValueError(
119
+ f"squashed image!! The bounding box contains no pixels."
120
+ )
121
+
122
+ bbox[2:] = torch.clamp(
123
+ bbox[2:], 2
124
+ ) # set min height, width to 2 along both axes
125
+ bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
126
+
127
+ return bbox_xyxy
128
+
129
+
130
+ def _bbox_xywh_to_xyxy(
131
+ xywh: torch.Tensor, clamp_size: Optional[int] = None
132
+ ) -> torch.Tensor:
133
+ xyxy = xywh.clone()
134
+ if clamp_size is not None:
135
+ xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
136
+ xyxy[2:] += xyxy[:2]
137
+ return xyxy
138
+
139
+
140
+ def _clamp_box_to_image_bounds_and_round(
141
+ bbox_xyxy: torch.Tensor,
142
+ image_size_hw: Tuple[int, int],
143
+ ) -> torch.LongTensor:
144
+ bbox_xyxy = bbox_xyxy.clone()
145
+ bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
146
+ bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
147
+ if not isinstance(bbox_xyxy, torch.LongTensor):
148
+ bbox_xyxy = bbox_xyxy.round().long()
149
+ return bbox_xyxy # pyre-ignore [7]
150
+
151
+
152
+ if __name__ == "__main__":
153
+ # Example usage:
154
+ folder_path = "path/to/your/folder"
155
+ image_size = 224
156
+ images_tensor = load_and_preprocess_images(folder_path, image_size)
157
+ print(images_tensor.shape)
util/match_extraction.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import shutil
9
+ import tempfile
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import pycolmap
14
+ from typing import Optional, List, Dict, Any
15
+ from hloc import (
16
+ extract_features,
17
+ logger,
18
+ match_features,
19
+ pairs_from_exhaustive,
20
+ )
21
+ from hloc.triangulation import (
22
+ import_features,
23
+ import_matches,
24
+ estimation_and_geometric_verification,
25
+ parse_option_args,
26
+ OutputCapture,
27
+ )
28
+ from hloc.utils.database import (
29
+ COLMAPDatabase,
30
+ image_ids_to_pair_id,
31
+ pair_id_to_image_ids,
32
+ )
33
+ from hloc.reconstruction import create_empty_db, import_images, get_image_ids
34
+
35
+
36
+ def extract_match(image_folder_path: str, image_info: Dict):
37
+ # Now only supports SPSG
38
+ with tempfile.TemporaryDirectory() as tmpdir:
39
+ tmp_mapping = os.path.join(tmpdir, "mapping")
40
+ os.makedirs(tmp_mapping)
41
+ for filename in os.listdir(image_folder_path):
42
+ if filename.lower().endswith(
43
+ (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff")
44
+ ):
45
+ shutil.copy(
46
+ os.path.join(image_folder_path, filename),
47
+ os.path.join(tmp_mapping, filename),
48
+ )
49
+ matches, keypoints = run_hloc(tmpdir)
50
+
51
+ # From the format of colmap to PyTorch3D
52
+ kp1, kp2, i12 = colmap_keypoint_to_pytorch3d(matches, keypoints, image_info)
53
+
54
+ return kp1, kp2, i12
55
+
56
+
57
+ def colmap_keypoint_to_pytorch3d(matches, keypoints, image_info):
58
+ kp1, kp2, i12 = [], [], []
59
+ bbox_xyxy, scale = image_info["bboxes_xyxy"], image_info["resized_scales"]
60
+
61
+ for idx in keypoints:
62
+ # coordinate change from COLMAP to OpenCV
63
+ cur_keypoint = keypoints[idx] - 0.5
64
+
65
+ # go to the coordiante after cropping
66
+ # use idx - 1 here because the COLMAP format starts from 1 instead of 0
67
+ cur_keypoint = cur_keypoint - [
68
+ bbox_xyxy[idx - 1][0],
69
+ bbox_xyxy[idx - 1][1],
70
+ ]
71
+ cur_keypoint = cur_keypoint * scale[idx - 1]
72
+ keypoints[idx] = cur_keypoint
73
+
74
+ for (r_idx, q_idx), pair_match in matches.items():
75
+ if pair_match is not None:
76
+ kp1.append(keypoints[r_idx][pair_match[:, 0]])
77
+ kp2.append(keypoints[q_idx][pair_match[:, 1]])
78
+
79
+ i12_pair = np.array([[r_idx - 1, q_idx - 1]])
80
+ i12.append(np.repeat(i12_pair, len(pair_match), axis=0))
81
+
82
+ if kp1:
83
+ kp1, kp2, i12 = map(np.concatenate, (kp1, kp2, i12), (0, 0, 0))
84
+ else:
85
+ kp1 = kp2 = i12 = None
86
+
87
+ return kp1, kp2, i12
88
+
89
+
90
+ def run_hloc(output_dir: str):
91
+ # learned from
92
+ # https://github.com/cvg/Hierarchical-Localization/blob/master/pipeline_SfM.ipynb
93
+
94
+ images = Path(output_dir)
95
+ outputs = Path(os.path.join(output_dir, "output"))
96
+ sfm_pairs = outputs / "pairs-sfm.txt"
97
+ sfm_dir = outputs / "sfm"
98
+ features = outputs / "features.h5"
99
+ matches = outputs / "matches.h5"
100
+
101
+ feature_conf = extract_features.confs[
102
+ "superpoint_inloc"
103
+ ] # or superpoint_max
104
+ matcher_conf = match_features.confs["superpoint+lightglue"]
105
+
106
+ references = [
107
+ p.relative_to(images).as_posix()
108
+ for p in (images / "mapping/").iterdir()
109
+ ]
110
+
111
+ extract_features.main(
112
+ feature_conf, images, image_list=references, feature_path=features
113
+ )
114
+ pairs_from_exhaustive.main(sfm_pairs, image_list=references)
115
+ match_features.main(
116
+ matcher_conf, sfm_pairs, features=features, matches=matches
117
+ )
118
+
119
+ matches, keypoints = compute_matches_and_keypoints(
120
+ sfm_dir, images, sfm_pairs, features, matches, image_list=references
121
+ )
122
+
123
+ return matches, keypoints
124
+
125
+
126
+ def compute_matches_and_keypoints(
127
+ sfm_dir: Path,
128
+ image_dir: Path,
129
+ pairs: Path,
130
+ features: Path,
131
+ matches: Path,
132
+ camera_mode: pycolmap.CameraMode = pycolmap.CameraMode.AUTO,
133
+ verbose: bool = False,
134
+ min_match_score: Optional[float] = None,
135
+ image_list: Optional[List[str]] = None,
136
+ image_options: Optional[Dict[str, Any]] = None,
137
+ ) -> pycolmap.Reconstruction:
138
+ # learned from
139
+ # https://github.com/cvg/Hierarchical-Localization/blob/master/hloc/reconstruction.py
140
+
141
+ sfm_dir.mkdir(parents=True, exist_ok=True)
142
+ database = sfm_dir / "database.db"
143
+
144
+ create_empty_db(database)
145
+ import_images(image_dir, database, camera_mode, image_list, image_options)
146
+ image_ids = get_image_ids(database)
147
+ import_features(image_ids, database, features)
148
+ import_matches(image_ids, database, pairs, matches, min_match_score)
149
+ estimation_and_geometric_verification(database, pairs, verbose)
150
+
151
+ db = COLMAPDatabase.connect(database)
152
+
153
+ matches = dict(
154
+ (
155
+ pair_id_to_image_ids(pair_id),
156
+ _blob_to_array_safe(data, np.uint32, (-1, 2)),
157
+ )
158
+ for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
159
+ )
160
+
161
+ keypoints = dict(
162
+ (image_id, _blob_to_array_safe(data, np.float32, (-1, 2)))
163
+ for image_id, data in db.execute("SELECT image_id, data FROM keypoints")
164
+ )
165
+
166
+ db.close()
167
+
168
+ return matches, keypoints
169
+
170
+
171
+ def _blob_to_array_safe(blob, dtype, shape=(-1,)):
172
+ if blob is not None:
173
+ return np.fromstring(blob, dtype=dtype).reshape(*shape)
174
+ else:
175
+ return blob
util/metric.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import random
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def compute_ARE(rotation1, rotation2):
13
+ if isinstance(rotation1, torch.Tensor):
14
+ rotation1 = rotation1.cpu().detach().numpy()
15
+ if isinstance(rotation2, torch.Tensor):
16
+ rotation2 = rotation2.cpu().detach().numpy()
17
+
18
+ R_rel = np.einsum("Bij,Bjk ->Bik", rotation1.transpose(0, 2, 1), rotation2)
19
+ t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2
20
+ theta = np.arccos(np.clip(t, -1, 1))
21
+ error = theta * 180 / np.pi
22
+ return np.minimum(error, np.abs(180 - error))
util/utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import random
8
+
9
+ import numpy as np
10
+ import torch
11
+ import tempfile
12
+
13
+
14
+ def seed_all_random_engines(seed: int) -> None:
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ random.seed(seed)
weights/co3d_model_Apr16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7084b19cddce8dcc8f9197a8bbcf330fd0edf1c0c97b628c35180d8a18edbeb
3
+ size 155952931