# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import random import torch import signal import socket import sys import json import numpy as np import argparse import logging from pathlib import Path from tqdm import tqdm import torch.optim as optim from torch.utils.data import DataLoader from torch.cuda.amp import GradScaler from torch.utils.tensorboard import SummaryWriter from pytorch_lightning.lite import LightningLite from cotracker.models.evaluation_predictor import EvaluationPredictor from cotracker.models.core.cotracker.cotracker import CoTracker2 from cotracker.utils.visualizer import Visualizer from cotracker.datasets.tap_vid_datasets import TapVidDataset from cotracker.datasets.dr_dataset import DynamicReplicaDataset from cotracker.evaluation.core.evaluator import Evaluator from cotracker.datasets import kubric_movif_dataset from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_ from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss # define the handler function # for training on a slurm cluster def sig_handler(signum, frame): print("caught signal", signum) print(socket.gethostname(), "USR1 signal caught.") # do other stuff to cleanup here print("requeuing job " + os.environ["SLURM_JOB_ID"]) os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) sys.exit(-1) def term_handler(signum, frame): print("bypassing sigterm", flush=True) def fetch_optimizer(args, model): """Create the optimizer and learning rate scheduler""" optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, args.lr, args.num_steps + 100, pct_start=0.05, cycle_momentum=False, anneal_strategy="linear", ) return optimizer, scheduler def forward_batch(batch, model, args): video = batch.video trajs_g = batch.trajectory vis_g = batch.visibility valids = batch.valid B, T, C, H, W = video.shape assert C == 3 B, T, N, D = trajs_g.shape device = video.device __, first_positive_inds = torch.max(vis_g, dim=1) # We want to make sure that during training the model sees visible points # that it does not need to track just yet: they are visible but queried from a later frame N_rand = N // 4 # inds of visible points in the 1st frame nonzero_inds = [[torch.nonzero(vis_g[b, :, i]) for i in range(N)] for b in range(B)] for b in range(B): rand_vis_inds = torch.cat( [ nonzero_row[torch.randint(len(nonzero_row), size=(1,))] for nonzero_row in nonzero_inds[b] ], dim=1, ) first_positive_inds[b] = torch.cat( [rand_vis_inds[:, :N_rand], first_positive_inds[b : b + 1, N_rand:]], dim=1 ) ind_array_ = torch.arange(T, device=device) ind_array_ = ind_array_[None, :, None].repeat(B, 1, N) assert torch.allclose( vis_g[ind_array_ == first_positive_inds[:, None, :]], torch.ones(1, device=device), ) gather = torch.gather(trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, D)) xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1) queries = torch.cat([first_positive_inds[:, :, None], xys[:, :, :2]], dim=2) predictions, visibility, train_data = model( video=video, queries=queries, iters=args.train_iters, is_train=True ) coord_predictions, vis_predictions, valid_mask = train_data vis_gts = [] traj_gts = [] valids_gts = [] S = args.sliding_window_len for ind in range(0, args.sequence_len - S // 2, S // 2): vis_gts.append(vis_g[:, ind : ind + S]) traj_gts.append(trajs_g[:, ind : ind + S]) valids_gts.append(valids[:, ind : ind + S] * valid_mask[:, ind : ind + S]) seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8) vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts) output = {"flow": {"predictions": predictions[0].detach()}} output["flow"]["loss"] = seq_loss.mean() output["visibility"] = { "loss": vis_loss.mean() * 10.0, "predictions": visibility[0].detach(), } return output def run_test_eval(evaluator, model, dataloaders, writer, step): model.eval() for ds_name, dataloader in dataloaders: visualize_every = 1 grid_size = 5 if ds_name == "dynamic_replica": visualize_every = 8 grid_size = 0 elif "tapvid" in ds_name: visualize_every = 5 predictor = EvaluationPredictor( model.module.module, grid_size=grid_size, local_grid_size=0, single_point=False, n_iters=6, ) if torch.cuda.is_available(): predictor.model = predictor.model.cuda() metrics = evaluator.evaluate_sequence( model=predictor, test_dataloader=dataloader, dataset_name=ds_name, train_mode=True, writer=writer, step=step, visualize_every=visualize_every, ) if ds_name == "dynamic_replica" or ds_name == "kubric": metrics = {f"{ds_name}_avg_{k}": v for k, v in metrics["avg"].items()} if "tapvid" in ds_name: metrics = { f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"], f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"], f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"], } writer.add_scalars(f"Eval_{ds_name}", metrics, step) class Logger: SUM_FREQ = 100 def __init__(self, model, scheduler): self.model = model self.scheduler = scheduler self.total_steps = 0 self.running_loss = {} self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) def _print_training_status(self): metrics_data = [ self.running_loss[k] / Logger.SUM_FREQ for k in sorted(self.running_loss.keys()) ] training_str = "[{:6d}] ".format(self.total_steps + 1) metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data) # print the training status logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}") if self.writer is None: self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) for k in self.running_loss: self.writer.add_scalar(k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps) self.running_loss[k] = 0.0 def push(self, metrics, task): self.total_steps += 1 for key in metrics: task_key = str(key) + "_" + task if task_key not in self.running_loss: self.running_loss[task_key] = 0.0 self.running_loss[task_key] += metrics[key] if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1: self._print_training_status() self.running_loss = {} def write_dict(self, results): if self.writer is None: self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) for key in results: self.writer.add_scalar(key, results[key], self.total_steps) def close(self): self.writer.close() class Lite(LightningLite): def run(self, args): def seed_everything(seed: int): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False seed_everything(0) def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) g = torch.Generator() g.manual_seed(0) if self.global_rank == 0: eval_dataloaders = [] if "dynamic_replica" in args.eval_datasets: eval_dataset = DynamicReplicaDataset( sample_len=60, only_first_n_samples=1, rgbd_input=False ) eval_dataloader_dr = torch.utils.data.DataLoader( eval_dataset, batch_size=1, shuffle=False, num_workers=1, collate_fn=collate_fn, ) eval_dataloaders.append(("dynamic_replica", eval_dataloader_dr)) if "tapvid_davis_first" in args.eval_datasets: data_root = os.path.join(args.dataset_root, "tapvid/tapvid_davis/tapvid_davis.pkl") eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root) eval_dataloader_tapvid_davis = torch.utils.data.DataLoader( eval_dataset, batch_size=1, shuffle=False, num_workers=1, collate_fn=collate_fn, ) eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis)) evaluator = Evaluator(args.ckpt_path) visualizer = Visualizer( save_dir=args.ckpt_path, pad_value=80, fps=1, show_first_frame=0, tracks_leave_trace=0, ) if args.model_name == "cotracker": model = CoTracker2( stride=args.model_stride, window_len=args.sliding_window_len, add_space_attn=not args.remove_space_attn, num_virtual_tracks=args.num_virtual_tracks, model_resolution=args.crop_size, ) else: raise ValueError(f"Model {args.model_name} doesn't exist") with open(args.ckpt_path + "/meta.json", "w") as file: json.dump(vars(args), file, sort_keys=True, indent=4) model.cuda() train_dataset = kubric_movif_dataset.KubricMovifDataset( data_root=os.path.join(args.dataset_root, "kubric", "kubric_movi_f_tracks"), crop_size=args.crop_size, seq_len=args.sequence_len, traj_per_sample=args.traj_per_sample, sample_vis_1st_frame=args.sample_vis_1st_frame, use_augs=not args.dont_use_augs, ) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=seed_worker, generator=g, pin_memory=True, collate_fn=collate_fn_train, drop_last=True, ) train_loader = self.setup_dataloaders(train_loader, move_to_device=False) print("LEN TRAIN LOADER", len(train_loader)) optimizer, scheduler = fetch_optimizer(args, model) total_steps = 0 if self.global_rank == 0: logger = Logger(model, scheduler) folder_ckpts = [ f for f in os.listdir(args.ckpt_path) if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f ] if len(folder_ckpts) > 0: ckpt_path = sorted(folder_ckpts)[-1] ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path)) logging.info(f"Loading checkpoint {ckpt_path}") if "model" in ckpt: model.load_state_dict(ckpt["model"]) else: model.load_state_dict(ckpt) if "optimizer" in ckpt: logging.info("Load optimizer") optimizer.load_state_dict(ckpt["optimizer"]) if "scheduler" in ckpt: logging.info("Load scheduler") scheduler.load_state_dict(ckpt["scheduler"]) if "total_steps" in ckpt: total_steps = ckpt["total_steps"] logging.info(f"Load total_steps {total_steps}") elif args.restore_ckpt is not None: assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(".pt") logging.info("Loading checkpoint...") strict = True state_dict = self.load(args.restore_ckpt) if "model" in state_dict: state_dict = state_dict["model"] if list(state_dict.keys())[0].startswith("module."): state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict, strict=strict) logging.info(f"Done loading checkpoint") model, optimizer = self.setup(model, optimizer, move_to_device=False) # model.cuda() model.train() save_freq = args.save_freq scaler = GradScaler(enabled=args.mixed_precision) should_keep_training = True global_batch_num = 0 epoch = -1 while should_keep_training: epoch += 1 for i_batch, batch in enumerate(tqdm(train_loader)): batch, gotit = batch if not all(gotit): print("batch is None") continue dataclass_to_cuda_(batch) optimizer.zero_grad() assert model.training output = forward_batch(batch, model, args) loss = 0 for k, v in output.items(): if "loss" in v: loss += v["loss"] if self.global_rank == 0: for k, v in output.items(): if "loss" in v: logger.writer.add_scalar( f"live_{k}_loss", v["loss"].item(), total_steps ) if "metrics" in v: logger.push(v["metrics"], k) if total_steps % save_freq == save_freq - 1: visualizer.visualize( video=batch.video.clone(), tracks=batch.trajectory.clone(), filename="train_gt_traj", writer=logger.writer, step=total_steps, ) visualizer.visualize( video=batch.video.clone(), tracks=output["flow"]["predictions"][None], filename="train_pred_traj", writer=logger.writer, step=total_steps, ) if len(output) > 1: logger.writer.add_scalar(f"live_total_loss", loss.item(), total_steps) logger.writer.add_scalar( f"learning_rate", optimizer.param_groups[0]["lr"], total_steps ) global_batch_num += 1 self.barrier() self.backward(scaler.scale(loss)) scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) scaler.step(optimizer) scheduler.step() scaler.update() total_steps += 1 if self.global_rank == 0: if (i_batch >= len(train_loader) - 1) or ( total_steps == 1 and args.validate_at_start ): if (epoch + 1) % args.save_every_n_epoch == 0: ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps) save_path = Path( f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth" ) save_dict = { "model": model.module.module.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "total_steps": total_steps, } logging.info(f"Saving file {save_path}") self.save(save_dict, save_path) if (epoch + 1) % args.evaluate_every_n_epoch == 0 or ( args.validate_at_start and epoch == 0 ): run_test_eval( evaluator, model, eval_dataloaders, logger.writer, total_steps, ) model.train() torch.cuda.empty_cache() self.barrier() if total_steps > args.num_steps: should_keep_training = False break if self.global_rank == 0: print("FINISHED TRAINING") PATH = f"{args.ckpt_path}/{args.model_name}_final.pth" torch.save(model.module.module.state_dict(), PATH) run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps) logger.close() if __name__ == "__main__": signal.signal(signal.SIGUSR1, sig_handler) signal.signal(signal.SIGTERM, term_handler) parser = argparse.ArgumentParser() parser.add_argument("--model_name", default="cotracker", help="model name") parser.add_argument("--restore_ckpt", help="path to restore a checkpoint") parser.add_argument("--ckpt_path", help="path to save checkpoints") parser.add_argument( "--batch_size", type=int, default=4, help="batch size used during training." ) parser.add_argument("--num_nodes", type=int, default=1) parser.add_argument("--num_workers", type=int, default=10, help="number of dataloader workers") parser.add_argument("--mixed_precision", action="store_true", help="use mixed precision") parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.") parser.add_argument("--wdecay", type=float, default=0.00001, help="Weight decay in optimizer.") parser.add_argument( "--num_steps", type=int, default=200000, help="length of training schedule." ) parser.add_argument( "--evaluate_every_n_epoch", type=int, default=1, help="evaluate during training after every n epochs, after every epoch by default", ) parser.add_argument( "--save_every_n_epoch", type=int, default=1, help="save checkpoints during training after every n epochs, after every epoch by default", ) parser.add_argument( "--validate_at_start", action="store_true", help="whether to run evaluation before training starts", ) parser.add_argument( "--save_freq", type=int, default=100, help="frequency of trajectory visualization during training", ) parser.add_argument( "--traj_per_sample", type=int, default=768, help="the number of trajectories to sample for training", ) parser.add_argument( "--dataset_root", type=str, help="path lo all the datasets (train and eval)" ) parser.add_argument( "--train_iters", type=int, default=4, help="number of updates to the disparity field in each forward pass.", ) parser.add_argument("--sequence_len", type=int, default=8, help="train sequence length") parser.add_argument( "--eval_datasets", nargs="+", default=["tapvid_davis_first"], help="what datasets to use for evaluation", ) parser.add_argument( "--remove_space_attn", action="store_true", help="remove space attention from CoTracker", ) parser.add_argument( "--num_virtual_tracks", type=int, default=None, help="stride of the CoTracker feature network", ) parser.add_argument( "--dont_use_augs", action="store_true", help="don't apply augmentations during training", ) parser.add_argument( "--sample_vis_1st_frame", action="store_true", help="only sample trajectories with points visible on the first frame", ) parser.add_argument( "--sliding_window_len", type=int, default=8, help="length of the CoTracker sliding window", ) parser.add_argument( "--model_stride", type=int, default=8, help="stride of the CoTracker feature network", ) parser.add_argument( "--crop_size", type=int, nargs="+", default=[384, 512], help="crop videos to this resolution during training", ) parser.add_argument( "--eval_max_seq_len", type=int, default=1000, help="maximum length of evaluation videos", ) args = parser.parse_args() logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", ) Path(args.ckpt_path).mkdir(exist_ok=True, parents=True) from pytorch_lightning.strategies import DDPStrategy Lite( strategy=DDPStrategy(find_unused_parameters=False), devices="auto", accelerator="gpu", precision=32, num_nodes=args.num_nodes, ).run(args)