AniDoc / cotracker /datasets /dr_dataset.py
fffiloni's picture
Migrated from GitHub
c705408 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import gzip
import torch
import numpy as np
import torch.utils.data as data
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Any, Dict, Tuple
from cotracker.datasets.utils import CoTrackerData
from cotracker.datasets.dataclass_utils import load_dataclass
@dataclass
class ImageAnnotation:
# path to jpg file, relative w.r.t. dataset_root
path: str
# H x W
size: Tuple[int, int]
@dataclass
class DynamicReplicaFrameAnnotation:
"""A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation`
sequence_name: str
# 0-based, continuous frame number within sequence
frame_number: int
# timestamp in seconds from the video start
frame_timestamp: float
image: ImageAnnotation
meta: Optional[Dict[str, Any]] = None
camera_name: Optional[str] = None
trajectories: Optional[str] = None
class DynamicReplicaDataset(data.Dataset):
def __init__(
self,
root,
split="valid",
traj_per_sample=256,
crop_size=None,
sample_len=-1,
only_first_n_samples=-1,
rgbd_input=False,
):
super(DynamicReplicaDataset, self).__init__()
self.root = root
self.sample_len = sample_len
self.split = split
self.traj_per_sample = traj_per_sample
self.rgbd_input = rgbd_input
self.crop_size = crop_size
frame_annotations_file = f"frame_annotations_{split}.jgz"
self.sample_list = []
with gzip.open(
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
) as zipfile:
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
seq_annot = defaultdict(list)
for frame_annot in frame_annots_list:
if frame_annot.camera_name == "left":
seq_annot[frame_annot.sequence_name].append(frame_annot)
for seq_name in seq_annot.keys():
seq_len = len(seq_annot[seq_name])
step = self.sample_len if self.sample_len > 0 else seq_len
counter = 0
for ref_idx in range(0, seq_len, step):
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
def __len__(self):
return len(self.sample_list)
def crop(self, rgbs, trajs):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
H_new = H
W_new = W
# simple random crop
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return rgbs, trajs
def __getitem__(self, index):
sample = self.sample_list[index]
T = len(sample)
rgbs, visibilities, traj_2d = [], [], []
H, W = sample[0].image.size
image_size = (H, W)
for i in range(T):
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
traj = torch.load(traj_path)
visibilities.append(traj["verts_inds_vis"].numpy())
rgbs.append(traj["img"].numpy())
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
traj_2d = np.stack(traj_2d)
visibility = np.stack(visibilities)
T, N, D = traj_2d.shape
# subsample trajectories for augmentations
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
traj_2d = traj_2d[:, visible_inds_sampled]
visibility = visibility[:, visible_inds_sampled]
if self.crop_size is not None:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
H, W, _ = rgbs[0].shape
image_size = self.crop_size
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
# filter out points that're visible for less than 10 frames
visible_inds_resampled = visibility.sum(0) > 10
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
rgbs = np.stack(rgbs, 0)
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
return CoTrackerData(
video=video,
trajectory=traj_2d,
visibility=visibility,
valid=torch.ones(T, N),
seq_name=sample[0].sequence_name,
)