Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import gzip | |
import torch | |
import numpy as np | |
import torch.utils.data as data | |
from collections import defaultdict | |
from dataclasses import dataclass | |
from typing import List, Optional, Any, Dict, Tuple | |
from cotracker.datasets.utils import CoTrackerData | |
from cotracker.datasets.dataclass_utils import load_dataclass | |
class ImageAnnotation: | |
# path to jpg file, relative w.r.t. dataset_root | |
path: str | |
# H x W | |
size: Tuple[int, int] | |
class DynamicReplicaFrameAnnotation: | |
"""A dataclass used to load annotations from json.""" | |
# can be used to join with `SequenceAnnotation` | |
sequence_name: str | |
# 0-based, continuous frame number within sequence | |
frame_number: int | |
# timestamp in seconds from the video start | |
frame_timestamp: float | |
image: ImageAnnotation | |
meta: Optional[Dict[str, Any]] = None | |
camera_name: Optional[str] = None | |
trajectories: Optional[str] = None | |
class DynamicReplicaDataset(data.Dataset): | |
def __init__( | |
self, | |
root, | |
split="valid", | |
traj_per_sample=256, | |
crop_size=None, | |
sample_len=-1, | |
only_first_n_samples=-1, | |
rgbd_input=False, | |
): | |
super(DynamicReplicaDataset, self).__init__() | |
self.root = root | |
self.sample_len = sample_len | |
self.split = split | |
self.traj_per_sample = traj_per_sample | |
self.rgbd_input = rgbd_input | |
self.crop_size = crop_size | |
frame_annotations_file = f"frame_annotations_{split}.jgz" | |
self.sample_list = [] | |
with gzip.open( | |
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8" | |
) as zipfile: | |
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation]) | |
seq_annot = defaultdict(list) | |
for frame_annot in frame_annots_list: | |
if frame_annot.camera_name == "left": | |
seq_annot[frame_annot.sequence_name].append(frame_annot) | |
for seq_name in seq_annot.keys(): | |
seq_len = len(seq_annot[seq_name]) | |
step = self.sample_len if self.sample_len > 0 else seq_len | |
counter = 0 | |
for ref_idx in range(0, seq_len, step): | |
sample = seq_annot[seq_name][ref_idx : ref_idx + step] | |
self.sample_list.append(sample) | |
counter += 1 | |
if only_first_n_samples > 0 and counter >= only_first_n_samples: | |
break | |
def __len__(self): | |
return len(self.sample_list) | |
def crop(self, rgbs, trajs): | |
T, N, _ = trajs.shape | |
S = len(rgbs) | |
H, W = rgbs[0].shape[:2] | |
assert S == T | |
H_new = H | |
W_new = W | |
# simple random crop | |
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2 | |
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2 | |
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] | |
trajs[:, :, 0] -= x0 | |
trajs[:, :, 1] -= y0 | |
return rgbs, trajs | |
def __getitem__(self, index): | |
sample = self.sample_list[index] | |
T = len(sample) | |
rgbs, visibilities, traj_2d = [], [], [] | |
H, W = sample[0].image.size | |
image_size = (H, W) | |
for i in range(T): | |
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"]) | |
traj = torch.load(traj_path) | |
visibilities.append(traj["verts_inds_vis"].numpy()) | |
rgbs.append(traj["img"].numpy()) | |
traj_2d.append(traj["traj_2d"].numpy()[..., :2]) | |
traj_2d = np.stack(traj_2d) | |
visibility = np.stack(visibilities) | |
T, N, D = traj_2d.shape | |
# subsample trajectories for augmentations | |
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample] | |
traj_2d = traj_2d[:, visible_inds_sampled] | |
visibility = visibility[:, visible_inds_sampled] | |
if self.crop_size is not None: | |
rgbs, traj_2d = self.crop(rgbs, traj_2d) | |
H, W, _ = rgbs[0].shape | |
image_size = self.crop_size | |
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False | |
visibility[traj_2d[:, :, 0] < 0] = False | |
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False | |
visibility[traj_2d[:, :, 1] < 0] = False | |
# filter out points that're visible for less than 10 frames | |
visible_inds_resampled = visibility.sum(0) > 10 | |
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled]) | |
visibility = torch.from_numpy(visibility[:, visible_inds_resampled]) | |
rgbs = np.stack(rgbs, 0) | |
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float() | |
return CoTrackerData( | |
video=video, | |
trajectory=traj_2d, | |
visibility=visibility, | |
valid=torch.ones(T, N), | |
seq_name=sample[0].sequence_name, | |
) | |