File size: 12,171 Bytes
206ce41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
import math
import scipy
import numpy as np
from scipy.ndimage import grey_dilation, grey_erosion
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
'supervised_training_iter',
'soc_adaptation_iter',
]
# ----------------------------------------------------------------------------------
# Tool Classes/Functions
# ----------------------------------------------------------------------------------
class GaussianBlurLayer(nn.Module):
""" Add Gaussian Blur to a 4D tensors
This layer takes a 4D tensor of {N, C, H, W} as input.
The Gaussian blur will be performed in given channel number (C) splitly.
"""
def __init__(self, channels, kernel_size):
"""
Arguments:
channels (int): Channel for input tensor
kernel_size (int): Size of the kernel used in blurring
"""
super(GaussianBlurLayer, self).__init__()
self.channels = channels
self.kernel_size = kernel_size
assert self.kernel_size % 2 != 0
self.op = nn.Sequential(
nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
nn.Conv2d(channels, channels, self.kernel_size,
stride=1, padding=0, bias=None, groups=channels)
)
self._init_kernel()
def forward(self, x):
"""
Arguments:
x (torch.Tensor): input 4D tensor
Returns:
torch.Tensor: Blurred version of the input
"""
if not len(list(x.shape)) == 4:
print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
exit()
elif not x.shape[1] == self.channels:
print('In \'GaussianBlurLayer\', the required channel ({0}) is'
'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
exit()
return self.op(x)
def _init_kernel(self):
sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
n = np.zeros((self.kernel_size, self.kernel_size))
i = math.floor(self.kernel_size / 2)
n[i, i] = 1
kernel = scipy.ndimage.gaussian_filter(n, sigma)
for name, param in self.named_parameters():
param.data.copy_(torch.from_numpy(kernel))
# ----------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------
# MODNet Training Functions
# ----------------------------------------------------------------------------------
blurer = GaussianBlurLayer(1, 3).cuda()
def supervised_training_iter(
modnet, optimizer, image, trimap, gt_matte,
semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0):
""" Supervised training iteration of MODNet
This function trains MODNet for one iteration in a labeled dataset.
Arguments:
modnet (torch.nn.Module): instance of MODNet
optimizer (torch.optim.Optimizer): optimizer for supervised training
image (torch.autograd.Variable): input RGB image
its pixel values should be normalized
trimap (torch.autograd.Variable): trimap used to calculate the losses
its pixel values can be 0, 0.5, or 1
(foreground=1, background=0, unknown=0.5)
gt_matte (torch.autograd.Variable): ground truth alpha matte
its pixel values are between [0, 1]
semantic_scale (float): scale of the semantic loss
NOTE: please adjust according to your dataset
detail_scale (float): scale of the detail loss
NOTE: please adjust according to your dataset
matte_scale (float): scale of the matte loss
NOTE: please adjust according to your dataset
Returns:
semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]
Example:
import torch
from src.models.modnet import MODNet
from src.trainer import supervised_training_iter
bs = 16 # batch size
lr = 0.01 # learn rate
epochs = 40 # total epochs
modnet = torch.nn.DataParallel(MODNet()).cuda()
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function
for epoch in range(0, epochs):
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss = \
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
lr_scheduler.step()
"""
global blurer
# set the model to train mode and clear the optimizer
modnet.train()
optimizer.zero_grad()
# forward the model
pred_semantic, pred_detail, pred_matte = modnet(image, False)
# calculate the boundary mask from the trimap
boundaries = (trimap < 0.5) + (trimap > 0.5)
# calculate the semantic loss
gt_semantic = F.interpolate(gt_matte, scale_factor=1/16, mode='bilinear')
gt_semantic = blurer(gt_semantic)
semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
semantic_loss = semantic_scale * semantic_loss
# calculate the detail loss
pred_boundary_detail = torch.where(boundaries, trimap, pred_detail)
gt_detail = torch.where(boundaries, trimap, gt_matte)
detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail))
detail_loss = detail_scale * detail_loss
# calculate the matte loss
pred_boundary_matte = torch.where(boundaries, trimap, pred_matte)
matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
+ 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
matte_loss = matte_scale * matte_loss
# calculate the final loss, backward the loss, and update the model
loss = semantic_loss + detail_loss + matte_loss
loss.backward()
optimizer.step()
# for test
return semantic_loss, detail_loss, matte_loss
def soc_adaptation_iter(
modnet, backup_modnet, optimizer, image,
soc_semantic_scale=100.0, soc_detail_scale=1.0):
""" Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet
This function fine-tunes MODNet for one iteration in an unlabeled dataset.
Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been
trained in a labeled dataset.
Arguments:
modnet (torch.nn.Module): instance of MODNet
backup_modnet (torch.nn.Module): backup of the trained MODNet
optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC
image (torch.autograd.Variable): input RGB image
its pixel values should be normalized
soc_semantic_scale (float): scale of the SOC semantic loss
NOTE: please adjust according to your dataset
soc_detail_scale (float): scale of the SOC detail loss
NOTE: please adjust according to your dataset
Returns:
soc_semantic_loss (torch.Tensor): loss of the semantic SOC
soc_detail_loss (torch.Tensor): loss of the detail SOC
Example:
import copy
import torch
from src.models.modnet import MODNet
from src.trainer import soc_adaptation_iter
bs = 1 # batch size
lr = 0.00001 # learn rate
epochs = 10 # total epochs
modnet = torch.nn.DataParallel(MODNet()).cuda()
modnet = LOAD_TRAINED_CKPT() # NOTE: please finish this function
optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99))
dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function
for epoch in range(0, epochs):
backup_modnet = copy.deepcopy(modnet)
for idx, (image) in enumerate(dataloader):
soc_semantic_loss, soc_detail_loss = \
soc_adaptation_iter(modnet, backup_modnet, optimizer, image)
"""
global blurer
# set the backup model to eval mode
backup_modnet.eval()
# set the main model to train mode and freeze its norm layers
modnet.train()
modnet.module.freeze_norm()
# clear the optimizer
optimizer.zero_grad()
# forward the main model
pred_semantic, pred_detail, pred_matte = modnet(image, False)
# forward the backup model
with torch.no_grad():
_, pred_backup_detail, pred_backup_matte = backup_modnet(image, False)
# calculate the boundary mask from `pred_matte` and `pred_semantic`
pred_matte_fg = (pred_matte.detach() > 0.1).float()
pred_semantic_fg = (pred_semantic.detach() > 0.1).float()
pred_semantic_fg = F.interpolate(pred_semantic_fg, scale_factor=16, mode='bilinear')
pred_fg = pred_matte_fg * pred_semantic_fg
n, c, h, w = pred_matte.shape
np_pred_fg = pred_fg.data.cpu().numpy()
np_boundaries = np.zeros([n, c, h, w])
for sdx in range(0, n):
sample_np_boundaries = np_boundaries[sdx, 0, ...]
sample_np_pred_fg = np_pred_fg[sdx, 0, ...]
side = int((h + w) / 2 * 0.05)
dilated = grey_dilation(sample_np_pred_fg, size=(side, side))
eroded = grey_erosion(sample_np_pred_fg, size=(side, side))
sample_np_boundaries[np.where(dilated - eroded != 0)] = 1
np_boundaries[sdx, 0, ...] = sample_np_boundaries
boundaries = torch.tensor(np_boundaries).float().cuda()
# sub-objectives consistency between `pred_semantic` and `pred_matte`
# generate pseudo ground truth for `pred_semantic`
downsampled_pred_matte = blurer(F.interpolate(pred_matte, scale_factor=1/16, mode='bilinear'))
pseudo_gt_semantic = downsampled_pred_matte.detach()
pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic > 0.01).float()
# generate pseudo ground truth for `pred_matte`
pseudo_gt_matte = pred_semantic.detach()
pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float()
# calculate the SOC semantic loss
soc_semantic_loss = F.mse_loss(pred_semantic, pseudo_gt_semantic) + F.mse_loss(downsampled_pred_matte, pseudo_gt_matte)
soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss)
# NOTE: using the formulas in our paper to calculate the following losses has similar results
# sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only)
backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail, reduction='none')
backup_detail_loss = torch.sum(backup_detail_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3))
backup_detail_loss = torch.mean(backup_detail_loss)
# sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only)
backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte, reduction='none')
backup_matte_loss = torch.sum(backup_matte_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3))
backup_matte_loss = torch.mean(backup_matte_loss)
soc_detail_loss = soc_detail_scale * (backup_detail_loss + backup_matte_loss)
# calculate the final loss, backward the loss, and update the model
loss = soc_semantic_loss + soc_detail_loss
loss.backward()
optimizer.step()
return soc_semantic_loss, soc_detail_loss
# ----------------------------------------------------------------------------------
|