Spaces:
Sleeping
Sleeping
""" | |
Train MattingRefine | |
Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm. | |
Select GPUs through CUDA_VISIBLE_DEVICES environment variable. | |
Example: | |
CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \ | |
--dataset-name videomatte240k \ | |
--model-backbone resnet50 \ | |
--model-name mattingrefine-resnet50-videomatte240k \ | |
--model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \ | |
--epoch-end 1 | |
""" | |
import argparse | |
import kornia | |
import torch | |
import os | |
import random | |
from torch import nn | |
from torch import distributed as dist | |
from torch import multiprocessing as mp | |
from torch.nn import functional as F | |
from torch.cuda.amp import autocast, GradScaler | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.utils.data import DataLoader, Subset | |
from torch.optim import Adam | |
from torchvision.utils import make_grid | |
from tqdm import tqdm | |
from torchvision import transforms as T | |
from PIL import Image | |
from data_path import DATA_PATH | |
from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset | |
from dataset import augmentation as A | |
from model import MattingRefine | |
from model.utils import load_matched_state_dict | |
# --------------- Arguments --------------- | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys()) | |
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) | |
parser.add_argument('--model-backbone-scale', type=float, default=0.25) | |
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) | |
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) | |
parser.add_argument('--model-refine-thresholding', type=float, default=0.7) | |
parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3]) | |
parser.add_argument('--model-name', type=str, required=True) | |
parser.add_argument('--model-last-checkpoint', type=str, default=None) | |
parser.add_argument('--batch-size', type=int, default=4) | |
parser.add_argument('--num-workers', type=int, default=16) | |
parser.add_argument('--epoch-start', type=int, default=0) | |
parser.add_argument('--epoch-end', type=int, required=True) | |
parser.add_argument('--log-train-loss-interval', type=int, default=10) | |
parser.add_argument('--log-train-images-interval', type=int, default=1000) | |
parser.add_argument('--log-valid-interval', type=int, default=2000) | |
parser.add_argument('--checkpoint-interval', type=int, default=2000) | |
args = parser.parse_args() | |
distributed_num_gpus = torch.cuda.device_count() | |
assert args.batch_size % distributed_num_gpus == 0 | |
# --------------- Main --------------- | |
def train_worker(rank, addr, port): | |
# Distributed Setup | |
os.environ['MASTER_ADDR'] = addr | |
os.environ['MASTER_PORT'] = port | |
dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus) | |
# Training DataLoader | |
dataset_train = ZipDataset([ | |
ZipDataset([ | |
ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'), | |
ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'), | |
], transforms=A.PairCompose([ | |
A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), | |
A.PairRandomHorizontalFlip(), | |
A.PairRandomBoxBlur(0.1, 5), | |
A.PairRandomSharpen(0.1), | |
A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)), | |
A.PairApply(T.ToTensor()) | |
]), assert_equal_length=True), | |
ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ | |
A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), | |
T.RandomHorizontalFlip(), | |
A.RandomBoxBlur(0.1, 5), | |
A.RandomSharpen(0.1), | |
T.ColorJitter(0.15, 0.15, 0.15, 0.05), | |
T.ToTensor() | |
])), | |
]) | |
dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus) | |
dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker)) | |
dataloader_train = DataLoader(dataset_train, | |
shuffle=True, | |
pin_memory=True, | |
drop_last=True, | |
batch_size=args.batch_size // distributed_num_gpus, | |
num_workers=args.num_workers // distributed_num_gpus) | |
# Validation DataLoader | |
if rank == 0: | |
dataset_valid = ZipDataset([ | |
ZipDataset([ | |
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), | |
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') | |
], transforms=A.PairCompose([ | |
A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), | |
A.PairApply(T.ToTensor()) | |
]), assert_equal_length=True), | |
ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ | |
A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), | |
T.ToTensor() | |
])), | |
]) | |
dataset_valid = SampleDataset(dataset_valid, 50) | |
dataloader_valid = DataLoader(dataset_valid, | |
pin_memory=True, | |
drop_last=True, | |
batch_size=args.batch_size // distributed_num_gpus, | |
num_workers=args.num_workers // distributed_num_gpus) | |
# Model | |
model = MattingRefine(args.model_backbone, | |
args.model_backbone_scale, | |
args.model_refine_mode, | |
args.model_refine_sample_pixels, | |
args.model_refine_thresholding, | |
args.model_refine_kernel_size).to(rank) | |
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) | |
if args.model_last_checkpoint is not None: | |
load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) | |
optimizer = Adam([ | |
{'params': model.backbone.parameters(), 'lr': 5e-5}, | |
{'params': model.aspp.parameters(), 'lr': 5e-5}, | |
{'params': model.decoder.parameters(), 'lr': 1e-4}, | |
{'params': model.refiner.parameters(), 'lr': 3e-4}, | |
]) | |
scaler = GradScaler() | |
# Logging and checkpoints | |
if rank == 0: | |
if not os.path.exists(f'checkpoint/{args.model_name}'): | |
os.makedirs(f'checkpoint/{args.model_name}') | |
writer = SummaryWriter(f'log/{args.model_name}') | |
# Run loop | |
for epoch in range(args.epoch_start, args.epoch_end): | |
for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): | |
step = epoch * len(dataloader_train) + i | |
true_pha = true_pha.to(rank, non_blocking=True) | |
true_fgr = true_fgr.to(rank, non_blocking=True) | |
true_bgr = true_bgr.to(rank, non_blocking=True) | |
true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr) | |
true_src = true_bgr.clone() | |
# Augment with shadow | |
aug_shadow_idx = torch.rand(len(true_src)) < 0.3 | |
if aug_shadow_idx.any(): | |
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()) | |
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) | |
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2) | |
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1) | |
del aug_shadow | |
del aug_shadow_idx | |
# Composite foreground onto source | |
true_src = true_fgr * true_pha + true_src * (1 - true_pha) | |
# Augment with noise | |
aug_noise_idx = torch.rand(len(true_src)) < 0.4 | |
if aug_noise_idx.any(): | |
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) | |
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) | |
del aug_noise_idx | |
# Augment background with jitter | |
aug_jitter_idx = torch.rand(len(true_src)) < 0.8 | |
if aug_jitter_idx.any(): | |
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) | |
del aug_jitter_idx | |
# Augment background with affine | |
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 | |
if aug_affine_idx.any(): | |
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) | |
del aug_affine_idx | |
with autocast(): | |
pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr) | |
loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr) | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
optimizer.zero_grad() | |
if rank == 0: | |
if (i + 1) % args.log_train_loss_interval == 0: | |
writer.add_scalar('loss', loss, step) | |
if (i + 1) % args.log_train_images_interval == 0: | |
writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) | |
writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) | |
writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step) | |
writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step) | |
writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) | |
del true_pha, true_fgr, true_src, true_bgr | |
del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm | |
if (i + 1) % args.log_valid_interval == 0: | |
valid(model, dataloader_valid, writer, step) | |
if (step + 1) % args.checkpoint_interval == 0: | |
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth') | |
if rank == 0: | |
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') | |
# Clean up | |
dist.destroy_process_group() | |
# --------------- Utils --------------- | |
def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg): | |
true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:]) | |
true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:]) | |
true_msk_lg = true_pha_lg != 0 | |
true_msk_sm = true_pha_sm != 0 | |
return F.l1_loss(pred_pha_lg, true_pha_lg) + \ | |
F.l1_loss(pred_pha_sm, true_pha_sm) + \ | |
F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \ | |
F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \ | |
F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \ | |
F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \ | |
F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \ | |
kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs()) | |
def random_crop(*imgs): | |
H_src, W_src = imgs[0].shape[2:] | |
W_tgt = random.choice(range(1024, 2048)) // 4 * 4 | |
H_tgt = random.choice(range(1024, 2048)) // 4 * 4 | |
scale = max(W_tgt / W_src, H_tgt / H_src) | |
results = [] | |
for img in imgs: | |
img = kornia.resize(img, (int(H_src * scale), int(W_src * scale))) | |
img = kornia.center_crop(img, (H_tgt, W_tgt)) | |
results.append(img) | |
return results | |
def valid(model, dataloader, writer, step): | |
model.eval() | |
loss_total = 0 | |
loss_count = 0 | |
with torch.no_grad(): | |
for (true_pha, true_fgr), true_bgr in dataloader: | |
batch_size = true_pha.size(0) | |
true_pha = true_pha.cuda(non_blocking=True) | |
true_fgr = true_fgr.cuda(non_blocking=True) | |
true_bgr = true_bgr.cuda(non_blocking=True) | |
true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr | |
pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr) | |
loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr) | |
loss_total += loss.cpu().item() * batch_size | |
loss_count += batch_size | |
writer.add_scalar('valid_loss', loss_total / loss_count, step) | |
model.train() | |
# --------------- Start --------------- | |
if __name__ == '__main__': | |
addr = 'localhost' | |
port = str(random.choice(range(12300, 12400))) # pick a random port. | |
mp.spawn(train_worker, | |
nprocs=distributed_num_gpus, | |
args=(addr, port), | |
join=True) | |