Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections import defaultdict | |
import os | |
from typing import Optional | |
import torch | |
from tqdm import tqdm | |
import numpy as np | |
from torch.utils.tensorboard import SummaryWriter | |
from cotracker.datasets.utils import dataclass_to_cuda_ | |
from cotracker.utils.visualizer import Visualizer | |
from cotracker.models.core.model_utils import reduce_masked_mean | |
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics | |
import logging | |
class Evaluator: | |
""" | |
A class defining the CoTracker evaluator. | |
""" | |
def __init__(self, exp_dir) -> None: | |
# Visualization | |
self.exp_dir = exp_dir | |
os.makedirs(exp_dir, exist_ok=True) | |
self.visualization_filepaths = defaultdict(lambda: defaultdict(list)) | |
self.visualize_dir = os.path.join(exp_dir, "visualisations") | |
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name): | |
if isinstance(pred_trajectory, tuple): | |
pred_trajectory, pred_visibility = pred_trajectory | |
else: | |
pred_visibility = None | |
if "tapvid" in dataset_name: | |
B, T, N, D = sample.trajectory.shape | |
traj = sample.trajectory.clone() | |
thr = 0.9 | |
if pred_visibility is None: | |
logging.warning("visibility is NONE") | |
pred_visibility = torch.zeros_like(sample.visibility) | |
if not pred_visibility.dtype == torch.bool: | |
pred_visibility = pred_visibility > thr | |
query_points = sample.query_points.clone().cpu().numpy() | |
pred_visibility = pred_visibility[:, :, :N] | |
pred_trajectory = pred_trajectory[:, :, :N] | |
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy() | |
gt_occluded = ( | |
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy() | |
) | |
pred_occluded = ( | |
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy() | |
) | |
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy() | |
out_metrics = compute_tapvid_metrics( | |
query_points, | |
gt_occluded, | |
gt_tracks, | |
pred_occluded, | |
pred_tracks, | |
query_mode="strided" if "strided" in dataset_name else "first", | |
) | |
metrics[sample.seq_name[0]] = out_metrics | |
for metric_name in out_metrics.keys(): | |
if "avg" not in metrics: | |
metrics["avg"] = {} | |
metrics["avg"][metric_name] = np.mean( | |
[v[metric_name] for k, v in metrics.items() if k != "avg"] | |
) | |
logging.info(f"Metrics: {out_metrics}") | |
logging.info(f"avg: {metrics['avg']}") | |
print("metrics", out_metrics) | |
print("avg", metrics["avg"]) | |
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey": | |
*_, N, _ = sample.trajectory.shape | |
B, T, N = sample.visibility.shape | |
H, W = sample.video.shape[-2:] | |
device = sample.video.device | |
out_metrics = {} | |
d_vis_sum = d_occ_sum = d_sum_all = 0.0 | |
thrs = [1, 2, 4, 8, 16] | |
sx_ = (W - 1) / 255.0 | |
sy_ = (H - 1) / 255.0 | |
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2]) | |
sc_pt = torch.from_numpy(sc_py).float().to(device) | |
__, first_visible_inds = torch.max(sample.visibility, dim=1) | |
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N) | |
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1)) | |
for thr in thrs: | |
d_ = ( | |
torch.norm( | |
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt, | |
dim=-1, | |
) | |
< thr | |
).float() # B,S-1,N | |
d_occ = ( | |
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item() | |
* 100.0 | |
) | |
d_occ_sum += d_occ | |
out_metrics[f"accuracy_occ_{thr}"] = d_occ | |
d_vis = ( | |
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0 | |
) | |
d_vis_sum += d_vis | |
out_metrics[f"accuracy_vis_{thr}"] = d_vis | |
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0 | |
d_sum_all += d_all | |
out_metrics[f"accuracy_{thr}"] = d_all | |
d_occ_avg = d_occ_sum / len(thrs) | |
d_vis_avg = d_vis_sum / len(thrs) | |
d_all_avg = d_sum_all / len(thrs) | |
sur_thr = 50 | |
dists = torch.norm( | |
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt, | |
dim=-1, | |
) # B,S,N | |
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N | |
survival = torch.cumprod(dist_ok, dim=1) # B,S,N | |
out_metrics["survival"] = torch.mean(survival).item() * 100.0 | |
out_metrics["accuracy_occ"] = d_occ_avg | |
out_metrics["accuracy_vis"] = d_vis_avg | |
out_metrics["accuracy"] = d_all_avg | |
metrics[sample.seq_name[0]] = out_metrics | |
for metric_name in out_metrics.keys(): | |
if "avg" not in metrics: | |
metrics["avg"] = {} | |
metrics["avg"][metric_name] = float( | |
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"]) | |
) | |
logging.info(f"Metrics: {out_metrics}") | |
logging.info(f"avg: {metrics['avg']}") | |
print("metrics", out_metrics) | |
print("avg", metrics["avg"]) | |
def evaluate_sequence( | |
self, | |
model, | |
test_dataloader: torch.utils.data.DataLoader, | |
dataset_name: str, | |
train_mode=False, | |
visualize_every: int = 1, | |
writer: Optional[SummaryWriter] = None, | |
step: Optional[int] = 0, | |
): | |
metrics = {} | |
vis = Visualizer( | |
save_dir=self.exp_dir, | |
fps=7, | |
) | |
for ind, sample in enumerate(tqdm(test_dataloader)): | |
if isinstance(sample, tuple): | |
sample, gotit = sample | |
if not all(gotit): | |
print("batch is None") | |
continue | |
if torch.cuda.is_available(): | |
dataclass_to_cuda_(sample) | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
if ( | |
not train_mode | |
and hasattr(model, "sequence_len") | |
and (sample.visibility[:, : model.sequence_len].sum() == 0) | |
): | |
print(f"skipping batch {ind}") | |
continue | |
if "tapvid" in dataset_name: | |
queries = sample.query_points.clone().float() | |
queries = torch.stack( | |
[ | |
queries[:, :, 0], | |
queries[:, :, 2], | |
queries[:, :, 1], | |
], | |
dim=2, | |
).to(device) | |
else: | |
queries = torch.cat( | |
[ | |
torch.zeros_like(sample.trajectory[:, 0, :, :1]), | |
sample.trajectory[:, 0], | |
], | |
dim=2, | |
).to(device) | |
pred_tracks = model(sample.video, queries) | |
if "strided" in dataset_name: | |
inv_video = sample.video.flip(1).clone() | |
inv_queries = queries.clone() | |
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 | |
pred_trj, pred_vsb = pred_tracks | |
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries) | |
inv_pred_trj = inv_pred_trj.flip(1) | |
inv_pred_vsb = inv_pred_vsb.flip(1) | |
mask = pred_trj == 0 | |
pred_trj[mask] = inv_pred_trj[mask] | |
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]] | |
pred_tracks = pred_trj, pred_vsb | |
if dataset_name == "badja" or dataset_name == "fastcapture": | |
seq_name = sample.seq_name[0] | |
else: | |
seq_name = str(ind) | |
if ind % visualize_every == 0: | |
vis.visualize( | |
sample.video, | |
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks, | |
filename=dataset_name + "_" + seq_name, | |
writer=writer, | |
step=step, | |
) | |
self.compute_metrics(metrics, sample, pred_tracks, dataset_name) | |
return metrics | |