""" Train MattingBase You can download pretrained DeepLabV3 weights from Example: CUDA_VISIBLE_DEVICES=0 python train_base.py \ --dataset-name videomatte240k \ --model-backbone resnet50 \ --model-name mattingbase-resnet50-videomatte240k \ --model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \ --epoch-end 8 """ import argparse import kornia import torch import os import random from torch import nn from torch.nn import functional as F from torch.cuda.amp import autocast, GradScaler from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader from torch.optim import Adam from torchvision.utils import make_grid from tqdm import tqdm from torchvision import transforms as T from PIL import Image from data_path import DATA_PATH from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset from dataset import augmentation as A from model import MattingBase from model.utils import load_matched_state_dict # --------------- Arguments --------------- parser = argparse.ArgumentParser() parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys()) parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-name', type=str, required=True) parser.add_argument('--model-pretrain-initialization', type=str, default=None) parser.add_argument('--model-last-checkpoint', type=str, default=None) parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--num-workers', type=int, default=16) parser.add_argument('--epoch-start', type=int, default=0) parser.add_argument('--epoch-end', type=int, required=True) parser.add_argument('--log-train-loss-interval', type=int, default=10) parser.add_argument('--log-train-images-interval', type=int, default=2000) parser.add_argument('--log-valid-interval', type=int, default=5000) parser.add_argument('--checkpoint-interval', type=int, default=5000) args = parser.parse_args() # --------------- Loading --------------- def train(): # Training DataLoader dataset_train = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'), ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)), A.PairRandomHorizontalFlip(), A.PairRandomBoxBlur(0.1, 5), A.PairRandomSharpen(0.1), A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), T.RandomHorizontalFlip(), A.RandomBoxBlur(0.1, 5), A.RandomSharpen(0.1), T.ColorJitter(0.15, 0.15, 0.15, 0.05), T.ToTensor() ])), ]) dataloader_train = DataLoader(dataset_train, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) # Validation DataLoader dataset_valid = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), T.ToTensor() ])), ]) dataset_valid = SampleDataset(dataset_valid, 50) dataloader_valid = DataLoader(dataset_valid, pin_memory=True, batch_size=args.batch_size, num_workers=args.num_workers) # Model model = MattingBase(args.model_backbone).cuda() if args.model_last_checkpoint is not None: load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) elif args.model_pretrain_initialization is not None: model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state']) optimizer = Adam([ {'params': model.backbone.parameters(), 'lr': 1e-4}, {'params': model.aspp.parameters(), 'lr': 5e-4}, {'params': model.decoder.parameters(), 'lr': 5e-4} ]) scaler = GradScaler() # Logging and checkpoints if not os.path.exists(f'checkpoint/{args.model_name}'): os.makedirs(f'checkpoint/{args.model_name}') writer = SummaryWriter(f'log/{args.model_name}') # Run loop for epoch in range(args.epoch_start, args.epoch_end): for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): step = epoch * len(dataloader_train) + i true_pha = true_pha.cuda(non_blocking=True) true_fgr = true_fgr.cuda(non_blocking=True) true_bgr = true_bgr.cuda(non_blocking=True) true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr) true_src = true_bgr.clone() # Augment with shadow aug_shadow_idx = torch.rand(len(true_src)) < 0.3 if aug_shadow_idx.any(): aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()) aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2) true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1) del aug_shadow del aug_shadow_idx # Composite foreground onto source true_src = true_fgr * true_pha + true_src * (1 - true_pha) # Augment with noise aug_noise_idx = torch.rand(len(true_src)) < 0.4 if aug_noise_idx.any(): true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) del aug_noise_idx # Augment background with jitter aug_jitter_idx = torch.rand(len(true_src)) < 0.8 if aug_jitter_idx.any(): true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) del aug_jitter_idx # Augment background with affine aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 if aug_affine_idx.any(): true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) del aug_affine_idx with autocast(): pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3] loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() if (i + 1) % args.log_train_loss_interval == 0: writer.add_scalar('loss', loss, step) if (i + 1) % args.log_train_images_interval == 0: writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step) writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step) writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step) del true_pha, true_fgr, true_bgr del pred_pha, pred_fgr, pred_err if (i + 1) % args.log_valid_interval == 0: valid(model, dataloader_valid, writer, step) if (step + 1) % args.checkpoint_interval == 0: torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth') torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') # --------------- Utils --------------- def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr): true_err = torch.abs(pred_pha.detach() - true_pha) true_msk = true_pha != 0 return F.l1_loss(pred_pha, true_pha) + \ F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \ F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \ F.mse_loss(pred_err, true_err) def random_crop(*imgs): w = random.choice(range(256, 512)) h = random.choice(range(256, 512)) results = [] for img in imgs: img = kornia.resize(img, (max(h, w), max(h, w))) img = kornia.center_crop(img, (h, w)) results.append(img) return results def valid(model, dataloader, writer, step): model.eval() loss_total = 0 loss_count = 0 with torch.no_grad(): for (true_pha, true_fgr), true_bgr in dataloader: batch_size = true_pha.size(0) true_pha = true_pha.cuda(non_blocking=True) true_fgr = true_fgr.cuda(non_blocking=True) true_bgr = true_bgr.cuda(non_blocking=True) true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3] loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr) loss_total += loss.cpu().item() * batch_size loss_count += batch_size writer.add_scalar('valid_loss', loss_total / loss_count, step) model.train() # --------------- Start --------------- if __name__ == '__main__': train()