# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import gzip import torch import numpy as np import torch.utils.data as data from collections import defaultdict from dataclasses import dataclass from typing import List, Optional, Any, Dict, Tuple from cotracker.datasets.utils import CoTrackerData from cotracker.datasets.dataclass_utils import load_dataclass @dataclass class ImageAnnotation: # path to jpg file, relative w.r.t. dataset_root path: str # H x W size: Tuple[int, int] @dataclass class DynamicReplicaFrameAnnotation: """A dataclass used to load annotations from json.""" # can be used to join with `SequenceAnnotation` sequence_name: str # 0-based, continuous frame number within sequence frame_number: int # timestamp in seconds from the video start frame_timestamp: float image: ImageAnnotation meta: Optional[Dict[str, Any]] = None camera_name: Optional[str] = None trajectories: Optional[str] = None class DynamicReplicaDataset(data.Dataset): def __init__( self, root, split="valid", traj_per_sample=256, crop_size=None, sample_len=-1, only_first_n_samples=-1, rgbd_input=False, ): super(DynamicReplicaDataset, self).__init__() self.root = root self.sample_len = sample_len self.split = split self.traj_per_sample = traj_per_sample self.rgbd_input = rgbd_input self.crop_size = crop_size frame_annotations_file = f"frame_annotations_{split}.jgz" self.sample_list = [] with gzip.open( os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8" ) as zipfile: frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation]) seq_annot = defaultdict(list) for frame_annot in frame_annots_list: if frame_annot.camera_name == "left": seq_annot[frame_annot.sequence_name].append(frame_annot) for seq_name in seq_annot.keys(): seq_len = len(seq_annot[seq_name]) step = self.sample_len if self.sample_len > 0 else seq_len counter = 0 for ref_idx in range(0, seq_len, step): sample = seq_annot[seq_name][ref_idx : ref_idx + step] self.sample_list.append(sample) counter += 1 if only_first_n_samples > 0 and counter >= only_first_n_samples: break def __len__(self): return len(self.sample_list) def crop(self, rgbs, trajs): T, N, _ = trajs.shape S = len(rgbs) H, W = rgbs[0].shape[:2] assert S == T H_new = H W_new = W # simple random crop y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2 x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2 rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] trajs[:, :, 0] -= x0 trajs[:, :, 1] -= y0 return rgbs, trajs def __getitem__(self, index): sample = self.sample_list[index] T = len(sample) rgbs, visibilities, traj_2d = [], [], [] H, W = sample[0].image.size image_size = (H, W) for i in range(T): traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"]) traj = torch.load(traj_path) visibilities.append(traj["verts_inds_vis"].numpy()) rgbs.append(traj["img"].numpy()) traj_2d.append(traj["traj_2d"].numpy()[..., :2]) traj_2d = np.stack(traj_2d) visibility = np.stack(visibilities) T, N, D = traj_2d.shape # subsample trajectories for augmentations visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample] traj_2d = traj_2d[:, visible_inds_sampled] visibility = visibility[:, visible_inds_sampled] if self.crop_size is not None: rgbs, traj_2d = self.crop(rgbs, traj_2d) H, W, _ = rgbs[0].shape image_size = self.crop_size visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False visibility[traj_2d[:, :, 0] < 0] = False visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False visibility[traj_2d[:, :, 1] < 0] = False # filter out points that're visible for less than 10 frames visible_inds_resampled = visibility.sum(0) > 10 traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled]) visibility = torch.from_numpy(visibility[:, visible_inds_resampled]) rgbs = np.stack(rgbs, 0) video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float() return CoTrackerData( video=video, trajectory=traj_2d, visibility=visibility, valid=torch.ones(T, N), seq_name=sample[0].sequence_name, )