fffiloni's picture
Migrated from GitHub
c705408 verified
raw
history blame
22.1 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import random
import torch
import signal
import socket
import sys
import json
import numpy as np
import argparse
import logging
from pathlib import Path
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.lite import LightningLite
from cotracker.models.evaluation_predictor import EvaluationPredictor
from cotracker.models.core.cotracker.cotracker import CoTracker2
from cotracker.utils.visualizer import Visualizer
from cotracker.datasets.tap_vid_datasets import TapVidDataset
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
from cotracker.evaluation.core.evaluator import Evaluator
from cotracker.datasets import kubric_movif_dataset
from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
# define the handler function
# for training on a slurm cluster
def sig_handler(signum, frame):
print("caught signal", signum)
print(socket.gethostname(), "USR1 signal caught.")
# do other stuff to cleanup here
print("requeuing job " + os.environ["SLURM_JOB_ID"])
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
sys.exit(-1)
def term_handler(signum, frame):
print("bypassing sigterm", flush=True)
def fetch_optimizer(args, model):
"""Create the optimizer and learning rate scheduler"""
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
args.lr,
args.num_steps + 100,
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
return optimizer, scheduler
def forward_batch(batch, model, args):
video = batch.video
trajs_g = batch.trajectory
vis_g = batch.visibility
valids = batch.valid
B, T, C, H, W = video.shape
assert C == 3
B, T, N, D = trajs_g.shape
device = video.device
__, first_positive_inds = torch.max(vis_g, dim=1)
# We want to make sure that during training the model sees visible points
# that it does not need to track just yet: they are visible but queried from a later frame
N_rand = N // 4
# inds of visible points in the 1st frame
nonzero_inds = [[torch.nonzero(vis_g[b, :, i]) for i in range(N)] for b in range(B)]
for b in range(B):
rand_vis_inds = torch.cat(
[
nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
for nonzero_row in nonzero_inds[b]
],
dim=1,
)
first_positive_inds[b] = torch.cat(
[rand_vis_inds[:, :N_rand], first_positive_inds[b : b + 1, N_rand:]], dim=1
)
ind_array_ = torch.arange(T, device=device)
ind_array_ = ind_array_[None, :, None].repeat(B, 1, N)
assert torch.allclose(
vis_g[ind_array_ == first_positive_inds[:, None, :]],
torch.ones(1, device=device),
)
gather = torch.gather(trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, D))
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
queries = torch.cat([first_positive_inds[:, :, None], xys[:, :, :2]], dim=2)
predictions, visibility, train_data = model(
video=video, queries=queries, iters=args.train_iters, is_train=True
)
coord_predictions, vis_predictions, valid_mask = train_data
vis_gts = []
traj_gts = []
valids_gts = []
S = args.sliding_window_len
for ind in range(0, args.sequence_len - S // 2, S // 2):
vis_gts.append(vis_g[:, ind : ind + S])
traj_gts.append(trajs_g[:, ind : ind + S])
valids_gts.append(valids[:, ind : ind + S] * valid_mask[:, ind : ind + S])
seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8)
vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts)
output = {"flow": {"predictions": predictions[0].detach()}}
output["flow"]["loss"] = seq_loss.mean()
output["visibility"] = {
"loss": vis_loss.mean() * 10.0,
"predictions": visibility[0].detach(),
}
return output
def run_test_eval(evaluator, model, dataloaders, writer, step):
model.eval()
for ds_name, dataloader in dataloaders:
visualize_every = 1
grid_size = 5
if ds_name == "dynamic_replica":
visualize_every = 8
grid_size = 0
elif "tapvid" in ds_name:
visualize_every = 5
predictor = EvaluationPredictor(
model.module.module,
grid_size=grid_size,
local_grid_size=0,
single_point=False,
n_iters=6,
)
if torch.cuda.is_available():
predictor.model = predictor.model.cuda()
metrics = evaluator.evaluate_sequence(
model=predictor,
test_dataloader=dataloader,
dataset_name=ds_name,
train_mode=True,
writer=writer,
step=step,
visualize_every=visualize_every,
)
if ds_name == "dynamic_replica" or ds_name == "kubric":
metrics = {f"{ds_name}_avg_{k}": v for k, v in metrics["avg"].items()}
if "tapvid" in ds_name:
metrics = {
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"],
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"],
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"],
}
writer.add_scalars(f"Eval_{ds_name}", metrics, step)
class Logger:
SUM_FREQ = 100
def __init__(self, model, scheduler):
self.model = model
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
def _print_training_status(self):
metrics_data = [
self.running_loss[k] / Logger.SUM_FREQ for k in sorted(self.running_loss.keys())
]
training_str = "[{:6d}] ".format(self.total_steps + 1)
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
# print the training status
logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics, task):
self.total_steps += 1
for key in metrics:
task_key = str(key) + "_" + task
if task_key not in self.running_loss:
self.running_loss[task_key] = 0.0
self.running_loss[task_key] += metrics[key]
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:
self._print_training_status()
self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
class Lite(LightningLite):
def run(self, args):
def seed_everything(seed: int):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(0)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)
if self.global_rank == 0:
eval_dataloaders = []
if "dynamic_replica" in args.eval_datasets:
eval_dataset = DynamicReplicaDataset(
sample_len=60, only_first_n_samples=1, rgbd_input=False
)
eval_dataloader_dr = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("dynamic_replica", eval_dataloader_dr))
if "tapvid_davis_first" in args.eval_datasets:
data_root = os.path.join(args.dataset_root, "tapvid/tapvid_davis/tapvid_davis.pkl")
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
evaluator = Evaluator(args.ckpt_path)
visualizer = Visualizer(
save_dir=args.ckpt_path,
pad_value=80,
fps=1,
show_first_frame=0,
tracks_leave_trace=0,
)
if args.model_name == "cotracker":
model = CoTracker2(
stride=args.model_stride,
window_len=args.sliding_window_len,
add_space_attn=not args.remove_space_attn,
num_virtual_tracks=args.num_virtual_tracks,
model_resolution=args.crop_size,
)
else:
raise ValueError(f"Model {args.model_name} doesn't exist")
with open(args.ckpt_path + "/meta.json", "w") as file:
json.dump(vars(args), file, sort_keys=True, indent=4)
model.cuda()
train_dataset = kubric_movif_dataset.KubricMovifDataset(
data_root=os.path.join(args.dataset_root, "kubric", "kubric_movi_f_tracks"),
crop_size=args.crop_size,
seq_len=args.sequence_len,
traj_per_sample=args.traj_per_sample,
sample_vis_1st_frame=args.sample_vis_1st_frame,
use_augs=not args.dont_use_augs,
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
worker_init_fn=seed_worker,
generator=g,
pin_memory=True,
collate_fn=collate_fn_train,
drop_last=True,
)
train_loader = self.setup_dataloaders(train_loader, move_to_device=False)
print("LEN TRAIN LOADER", len(train_loader))
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
if self.global_rank == 0:
logger = Logger(model, scheduler)
folder_ckpts = [
f
for f in os.listdir(args.ckpt_path)
if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f
]
if len(folder_ckpts) > 0:
ckpt_path = sorted(folder_ckpts)[-1]
ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path))
logging.info(f"Loading checkpoint {ckpt_path}")
if "model" in ckpt:
model.load_state_dict(ckpt["model"])
else:
model.load_state_dict(ckpt)
if "optimizer" in ckpt:
logging.info("Load optimizer")
optimizer.load_state_dict(ckpt["optimizer"])
if "scheduler" in ckpt:
logging.info("Load scheduler")
scheduler.load_state_dict(ckpt["scheduler"])
if "total_steps" in ckpt:
total_steps = ckpt["total_steps"]
logging.info(f"Load total_steps {total_steps}")
elif args.restore_ckpt is not None:
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(".pt")
logging.info("Loading checkpoint...")
strict = True
state_dict = self.load(args.restore_ckpt)
if "model" in state_dict:
state_dict = state_dict["model"]
if list(state_dict.keys())[0].startswith("module."):
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=strict)
logging.info(f"Done loading checkpoint")
model, optimizer = self.setup(model, optimizer, move_to_device=False)
# model.cuda()
model.train()
save_freq = args.save_freq
scaler = GradScaler(enabled=args.mixed_precision)
should_keep_training = True
global_batch_num = 0
epoch = -1
while should_keep_training:
epoch += 1
for i_batch, batch in enumerate(tqdm(train_loader)):
batch, gotit = batch
if not all(gotit):
print("batch is None")
continue
dataclass_to_cuda_(batch)
optimizer.zero_grad()
assert model.training
output = forward_batch(batch, model, args)
loss = 0
for k, v in output.items():
if "loss" in v:
loss += v["loss"]
if self.global_rank == 0:
for k, v in output.items():
if "loss" in v:
logger.writer.add_scalar(
f"live_{k}_loss", v["loss"].item(), total_steps
)
if "metrics" in v:
logger.push(v["metrics"], k)
if total_steps % save_freq == save_freq - 1:
visualizer.visualize(
video=batch.video.clone(),
tracks=batch.trajectory.clone(),
filename="train_gt_traj",
writer=logger.writer,
step=total_steps,
)
visualizer.visualize(
video=batch.video.clone(),
tracks=output["flow"]["predictions"][None],
filename="train_pred_traj",
writer=logger.writer,
step=total_steps,
)
if len(output) > 1:
logger.writer.add_scalar(f"live_total_loss", loss.item(), total_steps)
logger.writer.add_scalar(
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
)
global_batch_num += 1
self.barrier()
self.backward(scaler.scale(loss))
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
scaler.step(optimizer)
scheduler.step()
scaler.update()
total_steps += 1
if self.global_rank == 0:
if (i_batch >= len(train_loader) - 1) or (
total_steps == 1 and args.validate_at_start
):
if (epoch + 1) % args.save_every_n_epoch == 0:
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps)
save_path = Path(
f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth"
)
save_dict = {
"model": model.module.module.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"total_steps": total_steps,
}
logging.info(f"Saving file {save_path}")
self.save(save_dict, save_path)
if (epoch + 1) % args.evaluate_every_n_epoch == 0 or (
args.validate_at_start and epoch == 0
):
run_test_eval(
evaluator,
model,
eval_dataloaders,
logger.writer,
total_steps,
)
model.train()
torch.cuda.empty_cache()
self.barrier()
if total_steps > args.num_steps:
should_keep_training = False
break
if self.global_rank == 0:
print("FINISHED TRAINING")
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
torch.save(model.module.module.state_dict(), PATH)
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
logger.close()
if __name__ == "__main__":
signal.signal(signal.SIGUSR1, sig_handler)
signal.signal(signal.SIGTERM, term_handler)
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="cotracker", help="model name")
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
parser.add_argument("--ckpt_path", help="path to save checkpoints")
parser.add_argument(
"--batch_size", type=int, default=4, help="batch size used during training."
)
parser.add_argument("--num_nodes", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=10, help="number of dataloader workers")
parser.add_argument("--mixed_precision", action="store_true", help="use mixed precision")
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
parser.add_argument("--wdecay", type=float, default=0.00001, help="Weight decay in optimizer.")
parser.add_argument(
"--num_steps", type=int, default=200000, help="length of training schedule."
)
parser.add_argument(
"--evaluate_every_n_epoch",
type=int,
default=1,
help="evaluate during training after every n epochs, after every epoch by default",
)
parser.add_argument(
"--save_every_n_epoch",
type=int,
default=1,
help="save checkpoints during training after every n epochs, after every epoch by default",
)
parser.add_argument(
"--validate_at_start",
action="store_true",
help="whether to run evaluation before training starts",
)
parser.add_argument(
"--save_freq",
type=int,
default=100,
help="frequency of trajectory visualization during training",
)
parser.add_argument(
"--traj_per_sample",
type=int,
default=768,
help="the number of trajectories to sample for training",
)
parser.add_argument(
"--dataset_root", type=str, help="path lo all the datasets (train and eval)"
)
parser.add_argument(
"--train_iters",
type=int,
default=4,
help="number of updates to the disparity field in each forward pass.",
)
parser.add_argument("--sequence_len", type=int, default=8, help="train sequence length")
parser.add_argument(
"--eval_datasets",
nargs="+",
default=["tapvid_davis_first"],
help="what datasets to use for evaluation",
)
parser.add_argument(
"--remove_space_attn",
action="store_true",
help="remove space attention from CoTracker",
)
parser.add_argument(
"--num_virtual_tracks",
type=int,
default=None,
help="stride of the CoTracker feature network",
)
parser.add_argument(
"--dont_use_augs",
action="store_true",
help="don't apply augmentations during training",
)
parser.add_argument(
"--sample_vis_1st_frame",
action="store_true",
help="only sample trajectories with points visible on the first frame",
)
parser.add_argument(
"--sliding_window_len",
type=int,
default=8,
help="length of the CoTracker sliding window",
)
parser.add_argument(
"--model_stride",
type=int,
default=8,
help="stride of the CoTracker feature network",
)
parser.add_argument(
"--crop_size",
type=int,
nargs="+",
default=[384, 512],
help="crop videos to this resolution during training",
)
parser.add_argument(
"--eval_max_seq_len",
type=int,
default=1000,
help="maximum length of evaluation videos",
)
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
)
Path(args.ckpt_path).mkdir(exist_ok=True, parents=True)
from pytorch_lightning.strategies import DDPStrategy
Lite(
strategy=DDPStrategy(find_unused_parameters=False),
devices="auto",
accelerator="gpu",
precision=32,
num_nodes=args.num_nodes,
).run(args)