Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import torch | |
import cv2 | |
import imageio | |
import numpy as np | |
from cotracker.datasets.utils import CoTrackerData | |
from torchvision.transforms import ColorJitter, GaussianBlur | |
from PIL import Image | |
class CoTrackerDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
data_root, | |
crop_size=(384, 512), | |
seq_len=24, | |
traj_per_sample=768, | |
sample_vis_1st_frame=False, | |
use_augs=False, | |
): | |
super(CoTrackerDataset, self).__init__() | |
np.random.seed(0) | |
torch.manual_seed(0) | |
self.data_root = data_root | |
self.seq_len = seq_len | |
self.traj_per_sample = traj_per_sample | |
self.sample_vis_1st_frame = sample_vis_1st_frame | |
self.use_augs = use_augs | |
self.crop_size = crop_size | |
# photometric augmentation | |
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14) | |
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0)) | |
self.blur_aug_prob = 0.25 | |
self.color_aug_prob = 0.25 | |
# occlusion augmentation | |
self.eraser_aug_prob = 0.5 | |
self.eraser_bounds = [2, 100] | |
self.eraser_max = 10 | |
# occlusion augmentation | |
self.replace_aug_prob = 0.5 | |
self.replace_bounds = [2, 100] | |
self.replace_max = 10 | |
# spatial augmentations | |
self.pad_bounds = [0, 100] | |
self.crop_size = crop_size | |
self.resize_lim = [0.25, 2.0] # sample resizes from here | |
self.resize_delta = 0.2 | |
self.max_crop_offset = 50 | |
self.do_flip = True | |
self.h_flip_prob = 0.5 | |
self.v_flip_prob = 0.5 | |
def getitem_helper(self, index): | |
return NotImplementedError | |
def __getitem__(self, index): | |
gotit = False | |
sample, gotit = self.getitem_helper(index) | |
if not gotit: | |
print("warning: sampling failed") | |
# fake sample, so we can still collate | |
sample = CoTrackerData( | |
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])), | |
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)), | |
visibility=torch.zeros((self.seq_len, self.traj_per_sample)), | |
valid=torch.zeros((self.seq_len, self.traj_per_sample)), | |
) | |
return sample, gotit | |
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True): | |
T, N, _ = trajs.shape | |
S = len(rgbs) | |
H, W = rgbs[0].shape[:2] | |
assert S == T | |
if eraser: | |
############ eraser transform (per image after the first) ############ | |
rgbs = [rgb.astype(np.float32) for rgb in rgbs] | |
for i in range(1, S): | |
if np.random.rand() < self.eraser_aug_prob: | |
for _ in range( | |
np.random.randint(1, self.eraser_max + 1) | |
): # number of times to occlude | |
xc = np.random.randint(0, W) | |
yc = np.random.randint(0, H) | |
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1]) | |
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1]) | |
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) | |
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) | |
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) | |
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) | |
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0) | |
rgbs[i][y0:y1, x0:x1, :] = mean_color | |
occ_inds = np.logical_and( | |
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), | |
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), | |
) | |
visibles[i, occ_inds] = 0 | |
rgbs = [rgb.astype(np.uint8) for rgb in rgbs] | |
if replace: | |
rgbs_alt = [ | |
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs | |
] | |
rgbs_alt = [ | |
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt | |
] | |
############ replace transform (per image after the first) ############ | |
rgbs = [rgb.astype(np.float32) for rgb in rgbs] | |
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt] | |
for i in range(1, S): | |
if np.random.rand() < self.replace_aug_prob: | |
for _ in range( | |
np.random.randint(1, self.replace_max + 1) | |
): # number of times to occlude | |
xc = np.random.randint(0, W) | |
yc = np.random.randint(0, H) | |
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1]) | |
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1]) | |
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) | |
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) | |
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) | |
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) | |
wid = x1 - x0 | |
hei = y1 - y0 | |
y00 = np.random.randint(0, H - hei) | |
x00 = np.random.randint(0, W - wid) | |
fr = np.random.randint(0, S) | |
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :] | |
rgbs[i][y0:y1, x0:x1, :] = rep | |
occ_inds = np.logical_and( | |
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), | |
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), | |
) | |
visibles[i, occ_inds] = 0 | |
rgbs = [rgb.astype(np.uint8) for rgb in rgbs] | |
############ photometric augmentation ############ | |
if np.random.rand() < self.color_aug_prob: | |
# random per-frame amount of aug | |
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs] | |
if np.random.rand() < self.blur_aug_prob: | |
# random per-frame amount of blur | |
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs] | |
return rgbs, trajs, visibles | |
def add_spatial_augs(self, rgbs, trajs, visibles): | |
T, N, __ = trajs.shape | |
S = len(rgbs) | |
H, W = rgbs[0].shape[:2] | |
assert S == T | |
rgbs = [rgb.astype(np.float32) for rgb in rgbs] | |
############ spatial transform ############ | |
# padding | |
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | |
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | |
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | |
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | |
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs] | |
trajs[:, :, 0] += pad_x0 | |
trajs[:, :, 1] += pad_y0 | |
H, W = rgbs[0].shape[:2] | |
# scaling + stretching | |
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1]) | |
scale_x = scale | |
scale_y = scale | |
H_new = H | |
W_new = W | |
scale_delta_x = 0.0 | |
scale_delta_y = 0.0 | |
rgbs_scaled = [] | |
for s in range(S): | |
if s == 1: | |
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta) | |
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta) | |
elif s > 1: | |
scale_delta_x = ( | |
scale_delta_x * 0.8 | |
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 | |
) | |
scale_delta_y = ( | |
scale_delta_y * 0.8 | |
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 | |
) | |
scale_x = scale_x + scale_delta_x | |
scale_y = scale_y + scale_delta_y | |
# bring h/w closer | |
scale_xy = (scale_x + scale_y) * 0.5 | |
scale_x = scale_x * 0.5 + scale_xy * 0.5 | |
scale_y = scale_y * 0.5 + scale_xy * 0.5 | |
# don't get too crazy | |
scale_x = np.clip(scale_x, 0.2, 2.0) | |
scale_y = np.clip(scale_y, 0.2, 2.0) | |
H_new = int(H * scale_y) | |
W_new = int(W * scale_x) | |
# make it at least slightly bigger than the crop area, | |
# so that the random cropping can add diversity | |
H_new = np.clip(H_new, self.crop_size[0] + 10, None) | |
W_new = np.clip(W_new, self.crop_size[1] + 10, None) | |
# recompute scale in case we clipped | |
scale_x = (W_new - 1) / float(W - 1) | |
scale_y = (H_new - 1) / float(H - 1) | |
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)) | |
trajs[s, :, 0] *= scale_x | |
trajs[s, :, 1] *= scale_y | |
rgbs = rgbs_scaled | |
ok_inds = visibles[0, :] > 0 | |
vis_trajs = trajs[:, ok_inds] # S,?,2 | |
if vis_trajs.shape[1] > 0: | |
mid_x = np.mean(vis_trajs[0, :, 0]) | |
mid_y = np.mean(vis_trajs[0, :, 1]) | |
else: | |
mid_y = self.crop_size[0] | |
mid_x = self.crop_size[1] | |
x0 = int(mid_x - self.crop_size[1] // 2) | |
y0 = int(mid_y - self.crop_size[0] // 2) | |
offset_x = 0 | |
offset_y = 0 | |
for s in range(S): | |
# on each frame, shift a bit more | |
if s == 1: | |
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset) | |
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset) | |
elif s > 1: | |
offset_x = int( | |
offset_x * 0.8 | |
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2 | |
) | |
offset_y = int( | |
offset_y * 0.8 | |
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2 | |
) | |
x0 = x0 + offset_x | |
y0 = y0 + offset_y | |
H_new, W_new = rgbs[s].shape[:2] | |
if H_new == self.crop_size[0]: | |
y0 = 0 | |
else: | |
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1) | |
if W_new == self.crop_size[1]: | |
x0 = 0 | |
else: | |
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1) | |
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | |
trajs[s, :, 0] -= x0 | |
trajs[s, :, 1] -= y0 | |
H_new = self.crop_size[0] | |
W_new = self.crop_size[1] | |
# flip | |
h_flipped = False | |
v_flipped = False | |
if self.do_flip: | |
# h flip | |
if np.random.rand() < self.h_flip_prob: | |
h_flipped = True | |
rgbs = [rgb[:, ::-1] for rgb in rgbs] | |
# v flip | |
if np.random.rand() < self.v_flip_prob: | |
v_flipped = True | |
rgbs = [rgb[::-1] for rgb in rgbs] | |
if h_flipped: | |
trajs[:, :, 0] = W_new - trajs[:, :, 0] | |
if v_flipped: | |
trajs[:, :, 1] = H_new - trajs[:, :, 1] | |
return rgbs, trajs | |
def crop(self, rgbs, trajs): | |
T, N, _ = trajs.shape | |
S = len(rgbs) | |
H, W = rgbs[0].shape[:2] | |
assert S == T | |
############ spatial transform ############ | |
H_new = H | |
W_new = W | |
# simple random crop | |
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0]) | |
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1]) | |
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] | |
trajs[:, :, 0] -= x0 | |
trajs[:, :, 1] -= y0 | |
return rgbs, trajs | |
class KubricMovifDataset(CoTrackerDataset): | |
def __init__( | |
self, | |
data_root, | |
crop_size=(384, 512), | |
seq_len=24, | |
traj_per_sample=768, | |
sample_vis_1st_frame=False, | |
use_augs=False, | |
): | |
super(KubricMovifDataset, self).__init__( | |
data_root=data_root, | |
crop_size=crop_size, | |
seq_len=seq_len, | |
traj_per_sample=traj_per_sample, | |
sample_vis_1st_frame=sample_vis_1st_frame, | |
use_augs=use_augs, | |
) | |
self.pad_bounds = [0, 25] | |
self.resize_lim = [0.75, 1.25] # sample resizes from here | |
self.resize_delta = 0.05 | |
self.max_crop_offset = 15 | |
self.seq_names = [ | |
fname | |
for fname in os.listdir(data_root) | |
if os.path.isdir(os.path.join(data_root, fname)) | |
] | |
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root)) | |
def getitem_helper(self, index): | |
gotit = True | |
seq_name = self.seq_names[index] | |
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy") | |
rgb_path = os.path.join(self.data_root, seq_name, "frames") | |
img_paths = sorted(os.listdir(rgb_path)) | |
rgbs = [] | |
for i, img_path in enumerate(img_paths): | |
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path))) | |
rgbs = np.stack(rgbs) | |
annot_dict = np.load(npy_path, allow_pickle=True).item() | |
traj_2d = annot_dict["coords"] | |
visibility = annot_dict["visibility"] | |
# random crop | |
assert self.seq_len <= len(rgbs) | |
if self.seq_len < len(rgbs): | |
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0] | |
rgbs = rgbs[start_ind : start_ind + self.seq_len] | |
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len] | |
visibility = visibility[:, start_ind : start_ind + self.seq_len] | |
traj_2d = np.transpose(traj_2d, (1, 0, 2)) | |
visibility = np.transpose(np.logical_not(visibility), (1, 0)) | |
if self.use_augs: | |
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility) | |
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility) | |
else: | |
rgbs, traj_2d = self.crop(rgbs, traj_2d) | |
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False | |
visibility[traj_2d[:, :, 0] < 0] = False | |
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False | |
visibility[traj_2d[:, :, 1] < 0] = False | |
visibility = torch.from_numpy(visibility) | |
traj_2d = torch.from_numpy(traj_2d) | |
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0] | |
if self.sample_vis_1st_frame: | |
visibile_pts_inds = visibile_pts_first_frame_inds | |
else: | |
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[ | |
:, 0 | |
] | |
visibile_pts_inds = torch.cat( | |
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0 | |
) | |
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample] | |
if len(point_inds) < self.traj_per_sample: | |
gotit = False | |
visible_inds_sampled = visibile_pts_inds[point_inds] | |
trajs = traj_2d[:, visible_inds_sampled].float() | |
visibles = visibility[:, visible_inds_sampled] | |
valids = torch.ones((self.seq_len, self.traj_per_sample)) | |
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float() | |
sample = CoTrackerData( | |
video=rgbs, | |
trajectory=trajs, | |
visibility=visibles, | |
valid=valids, | |
seq_name=seq_name, | |
) | |
return sample, gotit | |
def __len__(self): | |
return len(self.seq_names) | |