Spaces:
Running
Running
import argparse | |
import logging | |
import shutil | |
import tarfile | |
from collections.abc import Iterable | |
from pathlib import Path | |
import h5py | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import PIL.Image | |
import torch | |
from omegaconf import OmegaConf | |
from ..geometry.wrappers import Camera, Pose | |
from ..models.cache_loader import CacheLoader | |
from ..settings import DATA_PATH | |
from ..utils.image import ImagePreprocessor, load_image | |
from ..utils.tools import fork_rng | |
from ..visualization.viz2d import plot_heatmaps, plot_image_grid | |
from .base_dataset import BaseDataset | |
from .utils import rotate_intrinsics, rotate_pose_inplane, scale_intrinsics | |
logger = logging.getLogger(__name__) | |
scene_lists_path = Path(__file__).parent / "megadepth_scene_lists" | |
def sample_n(data, num, seed=None): | |
if len(data) > num: | |
selected = np.random.RandomState(seed).choice(len(data), num, replace=False) | |
return data[selected] | |
else: | |
return data | |
class MegaDepth(BaseDataset): | |
default_conf = { | |
# paths | |
"data_dir": "megadepth/", | |
"depth_subpath": "depth_undistorted/", | |
"image_subpath": "Undistorted_SfM/", | |
"info_dir": "scene_info/", # @TODO: intrinsics problem? | |
# Training | |
"train_split": "train_scenes_clean.txt", | |
"train_num_per_scene": 500, | |
# Validation | |
"val_split": "valid_scenes_clean.txt", | |
"val_num_per_scene": None, | |
"val_pairs": None, | |
# Test | |
"test_split": "test_scenes_clean.txt", | |
"test_num_per_scene": None, | |
"test_pairs": None, | |
# data sampling | |
"views": 2, | |
"min_overlap": 0.3, # only with D2-Net format | |
"max_overlap": 1.0, # only with D2-Net format | |
"num_overlap_bins": 1, | |
"sort_by_overlap": False, | |
"triplet_enforce_overlap": False, # only with views==3 | |
# image options | |
"read_depth": True, | |
"read_image": True, | |
"grayscale": False, | |
"preprocessing": ImagePreprocessor.default_conf, | |
"p_rotate": 0.0, # probability to rotate image by +/- 90° | |
"reseed": False, | |
"seed": 0, | |
# features from cache | |
"load_features": { | |
"do": False, | |
**CacheLoader.default_conf, | |
"collate": False, | |
}, | |
} | |
def _init(self, conf): | |
if not (DATA_PATH / conf.data_dir).exists(): | |
logger.info("Downloading the MegaDepth dataset.") | |
self.download() | |
def download(self): | |
data_dir = DATA_PATH / self.conf.data_dir | |
tmp_dir = data_dir.parent / "megadepth_tmp" | |
if tmp_dir.exists(): # The previous download failed. | |
shutil.rmtree(tmp_dir) | |
tmp_dir.mkdir(exist_ok=True, parents=True) | |
url_base = "https://cvg-data.inf.ethz.ch/megadepth/" | |
for tar_name, out_name in ( | |
("Undistorted_SfM.tar.gz", self.conf.image_subpath), | |
("depth_undistorted.tar.gz", self.conf.depth_subpath), | |
("scene_info.tar.gz", self.conf.info_dir), | |
): | |
tar_path = tmp_dir / tar_name | |
torch.hub.download_url_to_file(url_base + tar_name, tar_path) | |
with tarfile.open(tar_path) as tar: | |
tar.extractall(path=tmp_dir) | |
tar_path.unlink() | |
shutil.move(tmp_dir / tar_name.split(".")[0], tmp_dir / out_name) | |
shutil.move(tmp_dir, data_dir) | |
def get_dataset(self, split): | |
assert self.conf.views in [1, 2, 3] | |
if self.conf.views == 3: | |
return _TripletDataset(self.conf, split) | |
else: | |
return _PairDataset(self.conf, split) | |
class _PairDataset(torch.utils.data.Dataset): | |
def __init__(self, conf, split, load_sample=True): | |
self.root = DATA_PATH / conf.data_dir | |
assert self.root.exists(), self.root | |
self.split = split | |
self.conf = conf | |
split_conf = conf[split + "_split"] | |
if isinstance(split_conf, (str, Path)): | |
scenes_path = scene_lists_path / split_conf | |
scenes = scenes_path.read_text().rstrip("\n").split("\n") | |
elif isinstance(split_conf, Iterable): | |
scenes = list(split_conf) | |
else: | |
raise ValueError(f"Unknown split configuration: {split_conf}.") | |
scenes = sorted(set(scenes)) | |
if conf.load_features.do: | |
self.feature_loader = CacheLoader(conf.load_features) | |
self.preprocessor = ImagePreprocessor(conf.preprocessing) | |
self.images = {} | |
self.depths = {} | |
self.poses = {} | |
self.intrinsics = {} | |
self.valid = {} | |
# load metadata | |
self.info_dir = self.root / self.conf.info_dir | |
self.scenes = [] | |
for scene in scenes: | |
path = self.info_dir / (scene + ".npz") | |
try: | |
info = np.load(str(path), allow_pickle=True) | |
except Exception: | |
logger.warning( | |
"Cannot load scene info for scene %s at %s.", scene, path | |
) | |
continue | |
self.images[scene] = info["image_paths"] | |
self.depths[scene] = info["depth_paths"] | |
self.poses[scene] = info["poses"] | |
self.intrinsics[scene] = info["intrinsics"] | |
self.scenes.append(scene) | |
if load_sample: | |
self.sample_new_items(conf.seed) | |
assert len(self.items) > 0 | |
def sample_new_items(self, seed): | |
logger.info("Sampling new %s data with seed %d.", self.split, seed) | |
self.items = [] | |
split = self.split | |
num_per_scene = self.conf[self.split + "_num_per_scene"] | |
if isinstance(num_per_scene, Iterable): | |
num_pos, num_neg = num_per_scene | |
else: | |
num_pos = num_per_scene | |
num_neg = None | |
if split != "train" and self.conf[split + "_pairs"] is not None: | |
# Fixed validation or test pairs | |
assert num_pos is None | |
assert num_neg is None | |
assert self.conf.views == 2 | |
pairs_path = scene_lists_path / self.conf[split + "_pairs"] | |
for line in pairs_path.read_text().rstrip("\n").split("\n"): | |
im0, im1 = line.split(" ") | |
scene = im0.split("/")[0] | |
assert im1.split("/")[0] == scene | |
im0, im1 = [self.conf.image_subpath + im for im in [im0, im1]] | |
assert im0 in self.images[scene] | |
assert im1 in self.images[scene] | |
idx0 = np.where(self.images[scene] == im0)[0][0] | |
idx1 = np.where(self.images[scene] == im1)[0][0] | |
self.items.append((scene, idx0, idx1, 1.0)) | |
elif self.conf.views == 1: | |
for scene in self.scenes: | |
if scene not in self.images: | |
continue | |
valid = (self.images[scene] != None) | ( # noqa: E711 | |
self.depths[scene] != None # noqa: E711 | |
) | |
ids = np.where(valid)[0] | |
if num_pos and len(ids) > num_pos: | |
ids = np.random.RandomState(seed).choice( | |
ids, num_pos, replace=False | |
) | |
ids = [(scene, i) for i in ids] | |
self.items.extend(ids) | |
else: | |
for scene in self.scenes: | |
path = self.info_dir / (scene + ".npz") | |
assert path.exists(), path | |
info = np.load(str(path), allow_pickle=True) | |
valid = (self.images[scene] != None) & ( # noqa: E711 | |
self.depths[scene] != None # noqa: E711 | |
) | |
ind = np.where(valid)[0] | |
mat = info["overlap_matrix"][valid][:, valid] | |
if num_pos is not None: | |
# Sample a subset of pairs, binned by overlap. | |
num_bins = self.conf.num_overlap_bins | |
assert num_bins > 0 | |
bin_width = ( | |
self.conf.max_overlap - self.conf.min_overlap | |
) / num_bins | |
num_per_bin = num_pos // num_bins | |
pairs_all = [] | |
for k in range(num_bins): | |
bin_min = self.conf.min_overlap + k * bin_width | |
bin_max = bin_min + bin_width | |
pairs_bin = (mat > bin_min) & (mat <= bin_max) | |
pairs_bin = np.stack(np.where(pairs_bin), -1) | |
pairs_all.append(pairs_bin) | |
# Skip bins with too few samples | |
has_enough_samples = [len(p) >= num_per_bin * 2 for p in pairs_all] | |
num_per_bin_2 = num_pos // max(1, sum(has_enough_samples)) | |
pairs = [] | |
for pairs_bin, keep in zip(pairs_all, has_enough_samples): | |
if keep: | |
pairs.append(sample_n(pairs_bin, num_per_bin_2, seed)) | |
pairs = np.concatenate(pairs, 0) | |
else: | |
pairs = (mat > self.conf.min_overlap) & ( | |
mat <= self.conf.max_overlap | |
) | |
pairs = np.stack(np.where(pairs), -1) | |
pairs = [(scene, ind[i], ind[j], mat[i, j]) for i, j in pairs] | |
if num_neg is not None: | |
neg_pairs = np.stack(np.where(mat <= 0.0), -1) | |
neg_pairs = sample_n(neg_pairs, num_neg, seed) | |
pairs += [(scene, ind[i], ind[j], mat[i, j]) for i, j in neg_pairs] | |
self.items.extend(pairs) | |
if self.conf.views == 2 and self.conf.sort_by_overlap: | |
self.items.sort(key=lambda i: i[-1], reverse=True) | |
else: | |
np.random.RandomState(seed).shuffle(self.items) | |
def _read_view(self, scene, idx): | |
path = self.root / self.images[scene][idx] | |
# read pose data | |
K = self.intrinsics[scene][idx].astype(np.float32, copy=False) | |
T = self.poses[scene][idx].astype(np.float32, copy=False) | |
# read image | |
if self.conf.read_image: | |
img = load_image(self.root / self.images[scene][idx], self.conf.grayscale) | |
else: | |
size = PIL.Image.open(path).size[::-1] | |
img = torch.zeros( | |
[3 - 2 * int(self.conf.grayscale), size[0], size[1]] | |
).float() | |
# read depth | |
if self.conf.read_depth: | |
# depth_path = ( | |
# self.root / self.conf.depth_subpath / scene / (path.stem + ".h5") | |
# ) | |
depth_subpath = self.depths[scene][idx] | |
depth_id = depth_subpath.split('/')[-1][:-3] | |
assert depth_id == path.stem | |
depth_path = self.root / depth_subpath | |
with h5py.File(str(depth_path), "r") as f: | |
depth = f["/depth"].__array__().astype(np.float32, copy=False) | |
depth = torch.Tensor(depth)[None] | |
assert depth.shape[-2:] == img.shape[-2:] | |
else: | |
depth = None | |
# add random rotations | |
do_rotate = self.conf.p_rotate > 0.0 and self.split == "train" | |
if do_rotate: | |
p = self.conf.p_rotate | |
k = 0 | |
if np.random.rand() < p: | |
k = np.random.choice(2, 1, replace=False)[0] * 2 - 1 | |
img = np.rot90(img, k=-k, axes=(-2, -1)) | |
if self.conf.read_depth: | |
depth = np.rot90(depth, k=-k, axes=(-2, -1)).copy() | |
K = rotate_intrinsics(K, img.shape, k + 2) | |
T = rotate_pose_inplane(T, k + 2) | |
name = path.name | |
data = self.preprocessor(img) | |
if depth is not None: | |
data["depth"] = self.preprocessor(depth, interpolation="nearest")["image"][ | |
0 | |
] | |
K = scale_intrinsics(K, data["scales"]) | |
data = { | |
"name": name, | |
"scene": scene, | |
"T_w2cam": Pose.from_4x4mat(T), | |
"depth": depth, | |
"camera": Camera.from_calibration_matrix(K).float(), | |
**data, | |
} | |
if self.conf.load_features.do: | |
features = self.feature_loader({k: [v] for k, v in data.items()}) | |
if do_rotate and k != 0: | |
# ang = np.deg2rad(k * 90.) | |
kpts = features["keypoints"].copy() | |
x, y = kpts[:, 0].copy(), kpts[:, 1].copy() | |
w, h = data["image_size"] | |
if k == 1: | |
kpts[:, 0] = w - y | |
kpts[:, 1] = x | |
elif k == -1: | |
kpts[:, 0] = y | |
kpts[:, 1] = h - x | |
else: | |
raise ValueError | |
features["keypoints"] = kpts | |
data = {"cache": features, **data} | |
return data | |
def __getitem__(self, idx): | |
if self.conf.reseed: | |
with fork_rng(self.conf.seed + idx, False): | |
return self.getitem(idx) | |
else: | |
return self.getitem(idx) | |
def getitem(self, idx): | |
if self.conf.views == 2: | |
if isinstance(idx, list): | |
scene, idx0, idx1, overlap = idx | |
else: | |
scene, idx0, idx1, overlap = self.items[idx] | |
data0 = self._read_view(scene, idx0) | |
data1 = self._read_view(scene, idx1) | |
data = { | |
"view0": data0, | |
"view1": data1, | |
} | |
data["T_0to1"] = data1["T_w2cam"] @ data0["T_w2cam"].inv() | |
data["T_1to0"] = data0["T_w2cam"] @ data1["T_w2cam"].inv() | |
data["overlap_0to1"] = overlap | |
data["name"] = f"{scene}/{data0['name']}_{data1['name']}" | |
else: | |
assert self.conf.views == 1 | |
scene, idx0 = self.items[idx] | |
data = self._read_view(scene, idx0) | |
data["scene"] = scene | |
data["idx"] = idx | |
return data | |
def __len__(self): | |
return len(self.items) | |
class _TripletDataset(_PairDataset): | |
def sample_new_items(self, seed): | |
logging.info("Sampling new triplets with seed %d", seed) | |
self.items = [] | |
split = self.split | |
num = self.conf[self.split + "_num_per_scene"] | |
if split != "train" and self.conf[split + "_pairs"] is not None: | |
if Path(self.conf[split + "_pairs"]).exists(): | |
pairs_path = Path(self.conf[split + "_pairs"]) | |
else: | |
pairs_path = DATA_PATH / "configs" / self.conf[split + "_pairs"] | |
for line in pairs_path.read_text().rstrip("\n").split("\n"): | |
im0, im1, im2 = line.split(" ") | |
assert im0[:4] == im1[:4] | |
scene = im1[:4] | |
idx0 = np.where(self.images[scene] == im0) | |
idx1 = np.where(self.images[scene] == im1) | |
idx2 = np.where(self.images[scene] == im2) | |
self.items.append((scene, idx0, idx1, idx2, 1.0, 1.0, 1.0)) | |
else: | |
for scene in self.scenes: | |
path = self.info_dir / (scene + ".npz") | |
assert path.exists(), path | |
info = np.load(str(path), allow_pickle=True) | |
if self.conf.num_overlap_bins > 1: | |
raise NotImplementedError("TODO") | |
valid = (self.images[scene] != None) & ( # noqa: E711 | |
self.depth[scene] != None # noqa: E711 | |
) | |
ind = np.where(valid)[0] | |
mat = info["overlap_matrix"][valid][:, valid] | |
good = (mat > self.conf.min_overlap) & (mat <= self.conf.max_overlap) | |
triplets = [] | |
if self.conf.triplet_enforce_overlap: | |
pairs = np.stack(np.where(good), -1) | |
for i0, i1 in pairs: | |
for i2 in pairs[pairs[:, 0] == i0, 1]: | |
if good[i1, i2]: | |
triplets.append((i0, i1, i2)) | |
if len(triplets) > num: | |
selected = np.random.RandomState(seed).choice( | |
len(triplets), num, replace=False | |
) | |
selected = range(num) | |
triplets = np.array(triplets)[selected] | |
else: | |
# we first enforce that each row has >1 pairs | |
non_unique = good.sum(-1) > 1 | |
ind_r = np.where(non_unique)[0] | |
good = good[non_unique] | |
pairs = np.stack(np.where(good), -1) | |
if len(pairs) > num: | |
selected = np.random.RandomState(seed).choice( | |
len(pairs), num, replace=False | |
) | |
pairs = pairs[selected] | |
for idx, (k, i) in enumerate(pairs): | |
# We now sample a j from row k s.t. i != j | |
possible_j = np.where(good[k])[0] | |
possible_j = possible_j[possible_j != i] | |
selected = np.random.RandomState(seed + idx).choice( | |
len(possible_j), 1, replace=False | |
)[0] | |
triplets.append((ind_r[k], i, possible_j[selected])) | |
triplets = [ | |
(scene, ind[k], ind[i], ind[j], mat[k, i], mat[k, j], mat[i, j]) | |
for k, i, j in triplets | |
] | |
self.items.extend(triplets) | |
np.random.RandomState(seed).shuffle(self.items) | |
def __getitem__(self, idx): | |
scene, idx0, idx1, idx2, overlap01, overlap02, overlap12 = self.items[idx] | |
data0 = self._read_view(scene, idx0) | |
data1 = self._read_view(scene, idx1) | |
data2 = self._read_view(scene, idx2) | |
data = { | |
"view0": data0, | |
"view1": data1, | |
"view2": data2, | |
} | |
data["T_0to1"] = data1["T_w2cam"] @ data0["T_w2cam"].inv() | |
data["T_0to2"] = data2["T_w2cam"] @ data0["T_w2cam"].inv() | |
data["T_1to2"] = data2["T_w2cam"] @ data1["T_w2cam"].inv() | |
data["T_1to0"] = data0["T_w2cam"] @ data1["T_w2cam"].inv() | |
data["T_2to0"] = data0["T_w2cam"] @ data2["T_w2cam"].inv() | |
data["T_2to1"] = data1["T_w2cam"] @ data2["T_w2cam"].inv() | |
data["overlap_0to1"] = overlap01 | |
data["overlap_0to2"] = overlap02 | |
data["overlap_1to2"] = overlap12 | |
data["scene"] = scene | |
data["name"] = f"{scene}/{data0['name']}_{data1['name']}_{data2['name']}" | |
return data | |
def __len__(self): | |
return len(self.items) | |
def visualize(args): | |
conf = { | |
"min_overlap": 0.1, | |
"max_overlap": 0.7, | |
"num_overlap_bins": 3, | |
"sort_by_overlap": False, | |
"train_num_per_scene": 5, | |
"batch_size": 1, | |
"num_workers": 0, | |
"prefetch_factor": None, | |
"val_num_per_scene": None, | |
} | |
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist)) | |
dataset = MegaDepth(conf) | |
loader = dataset.get_data_loader(args.split) | |
logger.info("The dataset has elements.", len(loader)) | |
with fork_rng(seed=dataset.conf.seed): | |
images, depths = [], [] | |
for _, data in zip(range(args.num_items), loader): | |
images.append( | |
[ | |
data[f"view{i}"]["image"][0].permute(1, 2, 0) | |
for i in range(dataset.conf.views) | |
] | |
) | |
depths.append( | |
[data[f"view{i}"]["depth"][0] for i in range(dataset.conf.views)] | |
) | |
axes = plot_image_grid(images, dpi=args.dpi) | |
for i in range(len(images)): | |
plot_heatmaps(depths[i], axes=axes[i]) | |
plt.show() | |
if __name__ == "__main__": | |
from .. import logger # overwrite the logger | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--split", type=str, default="val") | |
parser.add_argument("--num_items", type=int, default=4) | |
parser.add_argument("--dpi", type=int, default=100) | |
parser.add_argument("dotlist", nargs="*") | |
args = parser.parse_intermixed_args() | |
visualize(args) | |