Spaces:
Sleeping
Sleeping
""" | |
Train MattingBase | |
You can download pretrained DeepLabV3 weights from <https://github.com/VainF/DeepLabV3Plus-Pytorch> | |
Example: | |
CUDA_VISIBLE_DEVICES=0 python train_base.py \ | |
--dataset-name videomatte240k \ | |
--model-backbone resnet50 \ | |
--model-name mattingbase-resnet50-videomatte240k \ | |
--model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \ | |
--epoch-end 8 | |
""" | |
import argparse | |
import kornia | |
import torch | |
import os | |
import random | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.cuda.amp import autocast, GradScaler | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.utils.data import DataLoader | |
from torch.optim import Adam | |
from torchvision.utils import make_grid | |
from tqdm import tqdm | |
from torchvision import transforms as T | |
from PIL import Image | |
from data_path import DATA_PATH | |
from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset | |
from dataset import augmentation as A | |
from model import MattingBase | |
from model.utils import load_matched_state_dict | |
# --------------- Arguments --------------- | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys()) | |
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) | |
parser.add_argument('--model-name', type=str, required=True) | |
parser.add_argument('--model-pretrain-initialization', type=str, default=None) | |
parser.add_argument('--model-last-checkpoint', type=str, default=None) | |
parser.add_argument('--batch-size', type=int, default=8) | |
parser.add_argument('--num-workers', type=int, default=16) | |
parser.add_argument('--epoch-start', type=int, default=0) | |
parser.add_argument('--epoch-end', type=int, required=True) | |
parser.add_argument('--log-train-loss-interval', type=int, default=10) | |
parser.add_argument('--log-train-images-interval', type=int, default=2000) | |
parser.add_argument('--log-valid-interval', type=int, default=5000) | |
parser.add_argument('--checkpoint-interval', type=int, default=5000) | |
args = parser.parse_args() | |
# --------------- Loading --------------- | |
def train(): | |
# Training DataLoader | |
dataset_train = ZipDataset([ | |
ZipDataset([ | |
ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'), | |
ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'), | |
], transforms=A.PairCompose([ | |
A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)), | |
A.PairRandomHorizontalFlip(), | |
A.PairRandomBoxBlur(0.1, 5), | |
A.PairRandomSharpen(0.1), | |
A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)), | |
A.PairApply(T.ToTensor()) | |
]), assert_equal_length=True), | |
ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ | |
A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), | |
T.RandomHorizontalFlip(), | |
A.RandomBoxBlur(0.1, 5), | |
A.RandomSharpen(0.1), | |
T.ColorJitter(0.15, 0.15, 0.15, 0.05), | |
T.ToTensor() | |
])), | |
]) | |
dataloader_train = DataLoader(dataset_train, | |
shuffle=True, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers, | |
pin_memory=True) | |
# Validation DataLoader | |
dataset_valid = ZipDataset([ | |
ZipDataset([ | |
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), | |
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') | |
], transforms=A.PairCompose([ | |
A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), | |
A.PairApply(T.ToTensor()) | |
]), assert_equal_length=True), | |
ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ | |
A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), | |
T.ToTensor() | |
])), | |
]) | |
dataset_valid = SampleDataset(dataset_valid, 50) | |
dataloader_valid = DataLoader(dataset_valid, | |
pin_memory=True, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers) | |
# Model | |
model = MattingBase(args.model_backbone).cuda() | |
if args.model_last_checkpoint is not None: | |
load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) | |
elif args.model_pretrain_initialization is not None: | |
model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state']) | |
optimizer = Adam([ | |
{'params': model.backbone.parameters(), 'lr': 1e-4}, | |
{'params': model.aspp.parameters(), 'lr': 5e-4}, | |
{'params': model.decoder.parameters(), 'lr': 5e-4} | |
]) | |
scaler = GradScaler() | |
# Logging and checkpoints | |
if not os.path.exists(f'checkpoint/{args.model_name}'): | |
os.makedirs(f'checkpoint/{args.model_name}') | |
writer = SummaryWriter(f'log/{args.model_name}') | |
# Run loop | |
for epoch in range(args.epoch_start, args.epoch_end): | |
for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): | |
step = epoch * len(dataloader_train) + i | |
true_pha = true_pha.cuda(non_blocking=True) | |
true_fgr = true_fgr.cuda(non_blocking=True) | |
true_bgr = true_bgr.cuda(non_blocking=True) | |
true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr) | |
true_src = true_bgr.clone() | |
# Augment with shadow | |
aug_shadow_idx = torch.rand(len(true_src)) < 0.3 | |
if aug_shadow_idx.any(): | |
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()) | |
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) | |
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2) | |
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1) | |
del aug_shadow | |
del aug_shadow_idx | |
# Composite foreground onto source | |
true_src = true_fgr * true_pha + true_src * (1 - true_pha) | |
# Augment with noise | |
aug_noise_idx = torch.rand(len(true_src)) < 0.4 | |
if aug_noise_idx.any(): | |
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) | |
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) | |
del aug_noise_idx | |
# Augment background with jitter | |
aug_jitter_idx = torch.rand(len(true_src)) < 0.8 | |
if aug_jitter_idx.any(): | |
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) | |
del aug_jitter_idx | |
# Augment background with affine | |
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 | |
if aug_affine_idx.any(): | |
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) | |
del aug_affine_idx | |
with autocast(): | |
pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3] | |
loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr) | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
optimizer.zero_grad() | |
if (i + 1) % args.log_train_loss_interval == 0: | |
writer.add_scalar('loss', loss, step) | |
if (i + 1) % args.log_train_images_interval == 0: | |
writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) | |
writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) | |
writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step) | |
writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step) | |
writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) | |
writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step) | |
del true_pha, true_fgr, true_bgr | |
del pred_pha, pred_fgr, pred_err | |
if (i + 1) % args.log_valid_interval == 0: | |
valid(model, dataloader_valid, writer, step) | |
if (step + 1) % args.checkpoint_interval == 0: | |
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth') | |
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') | |
# --------------- Utils --------------- | |
def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr): | |
true_err = torch.abs(pred_pha.detach() - true_pha) | |
true_msk = true_pha != 0 | |
return F.l1_loss(pred_pha, true_pha) + \ | |
F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \ | |
F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \ | |
F.mse_loss(pred_err, true_err) | |
def random_crop(*imgs): | |
w = random.choice(range(256, 512)) | |
h = random.choice(range(256, 512)) | |
results = [] | |
for img in imgs: | |
img = kornia.resize(img, (max(h, w), max(h, w))) | |
img = kornia.center_crop(img, (h, w)) | |
results.append(img) | |
return results | |
def valid(model, dataloader, writer, step): | |
model.eval() | |
loss_total = 0 | |
loss_count = 0 | |
with torch.no_grad(): | |
for (true_pha, true_fgr), true_bgr in dataloader: | |
batch_size = true_pha.size(0) | |
true_pha = true_pha.cuda(non_blocking=True) | |
true_fgr = true_fgr.cuda(non_blocking=True) | |
true_bgr = true_bgr.cuda(non_blocking=True) | |
true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr | |
pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3] | |
loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr) | |
loss_total += loss.cpu().item() * batch_size | |
loss_count += batch_size | |
writer.add_scalar('valid_loss', loss_total / loss_count, step) | |
model.train() | |
# --------------- Start --------------- | |
if __name__ == '__main__': | |
train() | |