Spaces:
Sleeping
Sleeping
File size: 13,709 Bytes
854728f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
"""
Train MattingRefine
Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.
Select GPUs through CUDA_VISIBLE_DEVICES environment variable.
Example:
CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
--dataset-name videomatte240k \
--model-backbone resnet50 \
--model-name mattingrefine-resnet50-videomatte240k \
--model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
--epoch-end 1
"""
import argparse
import kornia
import torch
import os
import random
from torch import nn
from torch import distributed as dist
from torch import multiprocessing as mp
from torch.nn import functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Subset
from torch.optim import Adam
from torchvision.utils import make_grid
from tqdm import tqdm
from torchvision import transforms as T
from PIL import Image
from data_path import DATA_PATH
from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
from dataset import augmentation as A
from model import MattingRefine
from model.utils import load_matched_state_dict
# --------------- Arguments ---------------
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-thresholding', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--model-last-checkpoint', type=str, default=None)
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--num-workers', type=int, default=16)
parser.add_argument('--epoch-start', type=int, default=0)
parser.add_argument('--epoch-end', type=int, required=True)
parser.add_argument('--log-train-loss-interval', type=int, default=10)
parser.add_argument('--log-train-images-interval', type=int, default=1000)
parser.add_argument('--log-valid-interval', type=int, default=2000)
parser.add_argument('--checkpoint-interval', type=int, default=2000)
args = parser.parse_args()
distributed_num_gpus = torch.cuda.device_count()
assert args.batch_size % distributed_num_gpus == 0
# --------------- Main ---------------
def train_worker(rank, addr, port):
# Distributed Setup
os.environ['MASTER_ADDR'] = addr
os.environ['MASTER_PORT'] = port
dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)
# Training DataLoader
dataset_train = ZipDataset([
ZipDataset([
ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
], transforms=A.PairCompose([
A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
A.PairRandomHorizontalFlip(),
A.PairRandomBoxBlur(0.1, 5),
A.PairRandomSharpen(0.1),
A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
A.PairApply(T.ToTensor())
]), assert_equal_length=True),
ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
T.RandomHorizontalFlip(),
A.RandomBoxBlur(0.1, 5),
A.RandomSharpen(0.1),
T.ColorJitter(0.15, 0.15, 0.15, 0.05),
T.ToTensor()
])),
])
dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)
dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
dataloader_train = DataLoader(dataset_train,
shuffle=True,
pin_memory=True,
drop_last=True,
batch_size=args.batch_size // distributed_num_gpus,
num_workers=args.num_workers // distributed_num_gpus)
# Validation DataLoader
if rank == 0:
dataset_valid = ZipDataset([
ZipDataset([
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
], transforms=A.PairCompose([
A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
A.PairApply(T.ToTensor())
]), assert_equal_length=True),
ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
T.ToTensor()
])),
])
dataset_valid = SampleDataset(dataset_valid, 50)
dataloader_valid = DataLoader(dataset_valid,
pin_memory=True,
drop_last=True,
batch_size=args.batch_size // distributed_num_gpus,
num_workers=args.num_workers // distributed_num_gpus)
# Model
model = MattingRefine(args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_thresholding,
args.model_refine_kernel_size).to(rank)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
if args.model_last_checkpoint is not None:
load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
optimizer = Adam([
{'params': model.backbone.parameters(), 'lr': 5e-5},
{'params': model.aspp.parameters(), 'lr': 5e-5},
{'params': model.decoder.parameters(), 'lr': 1e-4},
{'params': model.refiner.parameters(), 'lr': 3e-4},
])
scaler = GradScaler()
# Logging and checkpoints
if rank == 0:
if not os.path.exists(f'checkpoint/{args.model_name}'):
os.makedirs(f'checkpoint/{args.model_name}')
writer = SummaryWriter(f'log/{args.model_name}')
# Run loop
for epoch in range(args.epoch_start, args.epoch_end):
for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
step = epoch * len(dataloader_train) + i
true_pha = true_pha.to(rank, non_blocking=True)
true_fgr = true_fgr.to(rank, non_blocking=True)
true_bgr = true_bgr.to(rank, non_blocking=True)
true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
true_src = true_bgr.clone()
# Augment with shadow
aug_shadow_idx = torch.rand(len(true_src)) < 0.3
if aug_shadow_idx.any():
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
del aug_shadow
del aug_shadow_idx
# Composite foreground onto source
true_src = true_fgr * true_pha + true_src * (1 - true_pha)
# Augment with noise
aug_noise_idx = torch.rand(len(true_src)) < 0.4
if aug_noise_idx.any():
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
del aug_noise_idx
# Augment background with jitter
aug_jitter_idx = torch.rand(len(true_src)) < 0.8
if aug_jitter_idx.any():
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
del aug_jitter_idx
# Augment background with affine
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
if aug_affine_idx.any():
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
del aug_affine_idx
with autocast():
pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)
loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if rank == 0:
if (i + 1) % args.log_train_loss_interval == 0:
writer.add_scalar('loss', loss, step)
if (i + 1) % args.log_train_images_interval == 0:
writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)
writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
del true_pha, true_fgr, true_src, true_bgr
del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm
if (i + 1) % args.log_valid_interval == 0:
valid(model, dataloader_valid, writer, step)
if (step + 1) % args.checkpoint_interval == 0:
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
if rank == 0:
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
# Clean up
dist.destroy_process_group()
# --------------- Utils ---------------
def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
true_msk_lg = true_pha_lg != 0
true_msk_sm = true_pha_sm != 0
return F.l1_loss(pred_pha_lg, true_pha_lg) + \
F.l1_loss(pred_pha_sm, true_pha_sm) + \
F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \
kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())
def random_crop(*imgs):
H_src, W_src = imgs[0].shape[2:]
W_tgt = random.choice(range(1024, 2048)) // 4 * 4
H_tgt = random.choice(range(1024, 2048)) // 4 * 4
scale = max(W_tgt / W_src, H_tgt / H_src)
results = []
for img in imgs:
img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
img = kornia.center_crop(img, (H_tgt, W_tgt))
results.append(img)
return results
def valid(model, dataloader, writer, step):
model.eval()
loss_total = 0
loss_count = 0
with torch.no_grad():
for (true_pha, true_fgr), true_bgr in dataloader:
batch_size = true_pha.size(0)
true_pha = true_pha.cuda(non_blocking=True)
true_fgr = true_fgr.cuda(non_blocking=True)
true_bgr = true_bgr.cuda(non_blocking=True)
true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)
loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
loss_total += loss.cpu().item() * batch_size
loss_count += batch_size
writer.add_scalar('valid_loss', loss_total / loss_count, step)
model.train()
# --------------- Start ---------------
if __name__ == '__main__':
addr = 'localhost'
port = str(random.choice(range(12300, 12400))) # pick a random port.
mp.spawn(train_worker,
nprocs=distributed_num_gpus,
args=(addr, port),
join=True)
|