""" Train MattingRefine Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm. Select GPUs through CUDA_VISIBLE_DEVICES environment variable. Example: CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \ --dataset-name videomatte240k \ --model-backbone resnet50 \ --model-name mattingrefine-resnet50-videomatte240k \ --model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \ --epoch-end 1 """ import argparse import kornia import torch import os import random from torch import nn from torch import distributed as dist from torch import multiprocessing as mp from torch.nn import functional as F from torch.cuda.amp import autocast, GradScaler from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader, Subset from torch.optim import Adam from torchvision.utils import make_grid from tqdm import tqdm from torchvision import transforms as T from PIL import Image from data_path import DATA_PATH from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset from dataset import augmentation as A from model import MattingRefine from model.utils import load_matched_state_dict # --------------- Arguments --------------- parser = argparse.ArgumentParser() parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys()) parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-backbone-scale', type=float, default=0.25) parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) parser.add_argument('--model-refine-thresholding', type=float, default=0.7) parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3]) parser.add_argument('--model-name', type=str, required=True) parser.add_argument('--model-last-checkpoint', type=str, default=None) parser.add_argument('--batch-size', type=int, default=4) parser.add_argument('--num-workers', type=int, default=16) parser.add_argument('--epoch-start', type=int, default=0) parser.add_argument('--epoch-end', type=int, required=True) parser.add_argument('--log-train-loss-interval', type=int, default=10) parser.add_argument('--log-train-images-interval', type=int, default=1000) parser.add_argument('--log-valid-interval', type=int, default=2000) parser.add_argument('--checkpoint-interval', type=int, default=2000) args = parser.parse_args() distributed_num_gpus = torch.cuda.device_count() assert args.batch_size % distributed_num_gpus == 0 # --------------- Main --------------- def train_worker(rank, addr, port): # Distributed Setup os.environ['MASTER_ADDR'] = addr os.environ['MASTER_PORT'] = port dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus) # Training DataLoader dataset_train = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'), ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairRandomHorizontalFlip(), A.PairRandomBoxBlur(0.1, 5), A.PairRandomSharpen(0.1), A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), T.RandomHorizontalFlip(), A.RandomBoxBlur(0.1, 5), A.RandomSharpen(0.1), T.ColorJitter(0.15, 0.15, 0.15, 0.05), T.ToTensor() ])), ]) dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus) dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker)) dataloader_train = DataLoader(dataset_train, shuffle=True, pin_memory=True, drop_last=True, batch_size=args.batch_size // distributed_num_gpus, num_workers=args.num_workers // distributed_num_gpus) # Validation DataLoader if rank == 0: dataset_valid = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), T.ToTensor() ])), ]) dataset_valid = SampleDataset(dataset_valid, 50) dataloader_valid = DataLoader(dataset_valid, pin_memory=True, drop_last=True, batch_size=args.batch_size // distributed_num_gpus, num_workers=args.num_workers // distributed_num_gpus) # Model model = MattingRefine(args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_thresholding, args.model_refine_kernel_size).to(rank) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) if args.model_last_checkpoint is not None: load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) optimizer = Adam([ {'params': model.backbone.parameters(), 'lr': 5e-5}, {'params': model.aspp.parameters(), 'lr': 5e-5}, {'params': model.decoder.parameters(), 'lr': 1e-4}, {'params': model.refiner.parameters(), 'lr': 3e-4}, ]) scaler = GradScaler() # Logging and checkpoints if rank == 0: if not os.path.exists(f'checkpoint/{args.model_name}'): os.makedirs(f'checkpoint/{args.model_name}') writer = SummaryWriter(f'log/{args.model_name}') # Run loop for epoch in range(args.epoch_start, args.epoch_end): for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): step = epoch * len(dataloader_train) + i true_pha = true_pha.to(rank, non_blocking=True) true_fgr = true_fgr.to(rank, non_blocking=True) true_bgr = true_bgr.to(rank, non_blocking=True) true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr) true_src = true_bgr.clone() # Augment with shadow aug_shadow_idx = torch.rand(len(true_src)) < 0.3 if aug_shadow_idx.any(): aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()) aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2) true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1) del aug_shadow del aug_shadow_idx # Composite foreground onto source true_src = true_fgr * true_pha + true_src * (1 - true_pha) # Augment with noise aug_noise_idx = torch.rand(len(true_src)) < 0.4 if aug_noise_idx.any(): true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) del aug_noise_idx # Augment background with jitter aug_jitter_idx = torch.rand(len(true_src)) < 0.8 if aug_jitter_idx.any(): true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) del aug_jitter_idx # Augment background with affine aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 if aug_affine_idx.any(): true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) del aug_affine_idx with autocast(): pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr) loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() if rank == 0: if (i + 1) % args.log_train_loss_interval == 0: writer.add_scalar('loss', loss, step) if (i + 1) % args.log_train_images_interval == 0: writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step) writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step) writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) del true_pha, true_fgr, true_src, true_bgr del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm if (i + 1) % args.log_valid_interval == 0: valid(model, dataloader_valid, writer, step) if (step + 1) % args.checkpoint_interval == 0: torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth') if rank == 0: torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') # Clean up dist.destroy_process_group() # --------------- Utils --------------- def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg): true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:]) true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:]) true_msk_lg = true_pha_lg != 0 true_msk_sm = true_pha_sm != 0 return F.l1_loss(pred_pha_lg, true_pha_lg) + \ F.l1_loss(pred_pha_sm, true_pha_sm) + \ F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \ F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \ F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \ F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \ F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \ kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs()) def random_crop(*imgs): H_src, W_src = imgs[0].shape[2:] W_tgt = random.choice(range(1024, 2048)) // 4 * 4 H_tgt = random.choice(range(1024, 2048)) // 4 * 4 scale = max(W_tgt / W_src, H_tgt / H_src) results = [] for img in imgs: img = kornia.resize(img, (int(H_src * scale), int(W_src * scale))) img = kornia.center_crop(img, (H_tgt, W_tgt)) results.append(img) return results def valid(model, dataloader, writer, step): model.eval() loss_total = 0 loss_count = 0 with torch.no_grad(): for (true_pha, true_fgr), true_bgr in dataloader: batch_size = true_pha.size(0) true_pha = true_pha.cuda(non_blocking=True) true_fgr = true_fgr.cuda(non_blocking=True) true_bgr = true_bgr.cuda(non_blocking=True) true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr) loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr) loss_total += loss.cpu().item() * batch_size loss_count += batch_size writer.add_scalar('valid_loss', loss_total / loss_count, step) model.train() # --------------- Start --------------- if __name__ == '__main__': addr = 'localhost' port = str(random.choice(range(12300, 12400))) # pick a random port. mp.spawn(train_worker, nprocs=distributed_num_gpus, args=(addr, port), join=True)