Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn.functional as F | |
from typing import Tuple | |
from cotracker.models.core.cotracker.cotracker import CoTracker2 | |
from cotracker.models.core.model_utils import get_points_on_a_grid | |
class EvaluationPredictor(torch.nn.Module): | |
def __init__( | |
self, | |
cotracker_model: CoTracker2, | |
interp_shape: Tuple[int, int] = (384, 512), | |
grid_size: int = 5, | |
local_grid_size: int = 8, | |
single_point: bool = True, | |
n_iters: int = 6, | |
) -> None: | |
super(EvaluationPredictor, self).__init__() | |
self.grid_size = grid_size | |
self.local_grid_size = local_grid_size | |
self.single_point = single_point | |
self.interp_shape = interp_shape | |
self.n_iters = n_iters | |
self.model = cotracker_model | |
self.model.eval() | |
def forward(self, video, queries): | |
queries = queries.clone() | |
B, T, C, H, W = video.shape | |
B, N, D = queries.shape | |
assert D == 3 | |
video = video.reshape(B * T, C, H, W) | |
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True) | |
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) | |
device = video.device | |
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1) | |
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1) | |
if self.single_point: | |
traj_e = torch.zeros((B, T, N, 2), device=device) | |
vis_e = torch.zeros((B, T, N), device=device) | |
for pind in range((N)): | |
query = queries[:, pind : pind + 1] | |
t = query[0, 0, 0].long() | |
traj_e_pind, vis_e_pind = self._process_one_point(video, query) | |
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1] | |
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] | |
else: | |
if self.grid_size > 0: | |
xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) | |
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # | |
queries = torch.cat([queries, xy], dim=1) # | |
traj_e, vis_e, __ = self.model( | |
video=video, | |
queries=queries, | |
iters=self.n_iters, | |
) | |
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1) | |
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1) | |
return traj_e, vis_e | |
def _process_one_point(self, video, query): | |
t = query[0, 0, 0].long() | |
device = query.device | |
if self.local_grid_size > 0: | |
xy_target = get_points_on_a_grid( | |
self.local_grid_size, | |
(50, 50), | |
[query[0, 0, 2].item(), query[0, 0, 1].item()], | |
) | |
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to( | |
device | |
) # | |
query = torch.cat([query, xy_target], dim=1) # | |
if self.grid_size > 0: | |
xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) | |
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # | |
query = torch.cat([query, xy], dim=1) # | |
# crop the video to start from the queried frame | |
query[0, 0, 0] = 0 | |
traj_e_pind, vis_e_pind, __ = self.model( | |
video=video[:, t:], queries=query, iters=self.n_iters | |
) | |
return traj_e_pind, vis_e_pind | |