fffiloni's picture
Migrated from GitHub
c705408 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional
@dataclass(eq=False)
class CoTrackerData:
"""
Dataclass for storing video tracks data.
"""
video: torch.Tensor # B, S, C, H, W
trajectory: torch.Tensor # B, S, N, 2
visibility: torch.Tensor # B, S, N
# optional data
valid: Optional[torch.Tensor] = None # B, S, N
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
seq_name: Optional[str] = None
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
def collate_fn(batch):
"""
Collate function for video tracks data.
"""
video = torch.stack([b.video for b in batch], dim=0)
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
visibility = torch.stack([b.visibility for b in batch], dim=0)
query_points = segmentation = None
if batch[0].query_points is not None:
query_points = torch.stack([b.query_points for b in batch], dim=0)
if batch[0].segmentation is not None:
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
seq_name = [b.seq_name for b in batch]
return CoTrackerData(
video=video,
trajectory=trajectory,
visibility=visibility,
segmentation=segmentation,
seq_name=seq_name,
query_points=query_points,
)
def collate_fn_train(batch):
"""
Collate function for video tracks data during training.
"""
gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0)
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
valid = torch.stack([b.valid for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch]
return (
CoTrackerData(
video=video,
trajectory=trajectory,
visibility=visibility,
valid=valid,
seq_name=seq_name,
),
gotit,
)
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.float().cuda()
except AttributeError:
pass
return t
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj