INR-Harmon / processing.py
WindVChen's picture
Init
6710c89
raw
history blame
16.4 kB
import os
import time
import datetime
import torch
import torchvision
from utils import misc, metrics
best_psnr = 0
def train(train_loader, val_loader, model, optimizer, scheduler, loss_fn, logger, opt):
total_step = opt.epochs * len(train_loader)
step_time_log = misc.AverageMeter()
loss_log = misc.AverageMeter(':6f')
loss_fg_content_bg_appearance_construct_log = misc.AverageMeter(':6f')
loss_lut_transform_image_log = misc.AverageMeter(':6f')
loss_lut_regularize_log = misc.AverageMeter(':6f')
start_epoch = 0
"Load pretrained checkpoints"
if opt.pretrained is not None:
logger.info(f"Load pretrained weight from {opt.pretrained}")
load_state = torch.load(opt.pretrained)
model = model.cpu()
model.load_state_dict(load_state['model'])
model = model.to(opt.device)
optimizer.load_state_dict(load_state['optimizer'])
scheduler.load_state_dict(load_state['scheduler'])
start_epoch = load_state['last_epoch'] + 1
for epoch in range(start_epoch, opt.epochs):
model.train()
time_ckp = time.time()
for step, batch in enumerate(train_loader):
current_step = epoch * len(train_loader) + step + 1
if opt.INRDecode and opt.hr_train:
"List with 4 elements: [Input to Encoder, three different resolutions' crop to INR Decoder]"
composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)]
real_image = [batch[f'real_image{name}'].to(opt.device) for name in range(4)]
mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)]
coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)]
fg_INR_coordinates = coordinate_map[1:]
else:
composite_image = batch['composite_image'].to(opt.device)
real_image = batch['real_image'].to(opt.device)
mask = batch['mask'].to(opt.device)
fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device)
fg_content_bg_appearance_construct, fit_lut3d, lut_transform_image = model(
composite_image, mask, fg_INR_coordinates)
if opt.INRDecode:
loss_fg_content_bg_appearance_construct = 0
"""
Our LRIP module requires three different resolution layers, thus here
`loss_fg_content_bg_appearance_construct` is calculated in multiple layers.
Besides, when leverage `hr_train`, i.e. use RSC strategy (See Section 3.4), the `real_image`
and `mask` are list type, corresponding different resolutions' crop.
"""
if opt.hr_train:
for n in range(3):
loss_fg_content_bg_appearance_construct += loss_fn['masked_mse'] \
(fg_content_bg_appearance_construct[n], real_image[3 - n], mask[3 - n])
loss_fg_content_bg_appearance_construct /= 3
loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image[1], mask[1])
else:
for n in range(3):
loss_fg_content_bg_appearance_construct += loss_fn['MaskWeightedMSE'] \
(fg_content_bg_appearance_construct[n],
torchvision.transforms.Resize(opt.INR_input_size // 2 ** (3 - n - 1))(real_image),
torchvision.transforms.Resize(opt.INR_input_size // 2 ** (3 - n - 1))(mask))
loss_fg_content_bg_appearance_construct /= 3
loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image, mask)
loss_lut_regularize = loss_fn['regularize_LUT'](fit_lut3d)
else:
loss_fg_content_bg_appearance_construct = 0
loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image, mask)
loss_lut_regularize = 0
loss = loss_fg_content_bg_appearance_construct + loss_lut_transform_image + loss_lut_regularize
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
step_time_log.update(time.time() - time_ckp)
loss_fg_content_bg_appearance_construct_log.update(0 if isinstance(loss_fg_content_bg_appearance_construct,
int) else loss_fg_content_bg_appearance_construct.item())
loss_lut_transform_image_log.update(
0 if isinstance(loss_lut_transform_image, int) else loss_lut_transform_image.item())
loss_lut_regularize_log.update(0 if isinstance(loss_lut_regularize, int) else loss_lut_regularize.item())
loss_log.update(loss.item())
if current_step % opt.print_freq == 0:
remain_secs = (total_step - current_step) * step_time_log.avg
remain_time = datetime.timedelta(seconds=round(remain_secs))
finish_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remain_secs))
log_msg = f'Epoch: [{epoch}/{opt.epochs}]\t' \
f'Step: [{step}/{len(train_loader)}]\t' \
f'StepTime {step_time_log.val:.3f} ({step_time_log.avg:.3f})\t' \
f'lr {optimizer.param_groups[0]["lr"]}\t' \
f'Loss {loss_log.val:.4f} ({loss_log.avg:.4f})\t' \
f'Loss_fg_bg_cons {loss_fg_content_bg_appearance_construct_log.val:.4f} ({loss_fg_content_bg_appearance_construct_log.avg:.4f})\t' \
f'Loss_lut_trans {loss_lut_transform_image_log.val:.4f} ({loss_lut_transform_image_log.avg:.4f})\t' \
f'Loss_lut_reg {loss_lut_regularize_log.val:.4f} ({loss_lut_regularize_log.avg:.4f})\t' \
f'Remaining Time {remain_time} ({finish_time})'
logger.info(log_msg)
if opt.wandb:
import wandb
wandb.log(
{'Train/Epoch': epoch, 'Train/lr': optimizer.param_groups[0]['lr'], 'Train/Step': current_step,
'Train/Loss': loss_log.val,
'Train/Loss_fg_bg_cons': loss_fg_content_bg_appearance_construct_log.val,
'Train/Loss_lut_trans': loss_lut_transform_image_log.val,
'Train/Loss_lut_reg': loss_lut_regularize_log.val,
})
time_ckp = time.time()
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'last_epoch': epoch,
'scheduler': scheduler.state_dict()}
"""
As the validation of original resolution Harmonization will have no consistent resolution among images
(so fail to form a batch) and also may lead to out-of-memory problem when combined with training phase,
we here only save the model when `opt.isFullRes` is True, leaving the evaluation in `inference.py`.
"""
if opt.isFullRes and opt.hr_train:
if epoch % 5 == 0:
torch.save(state, os.path.join(opt.save_path, f"epoch{epoch}.pth"))
else:
torch.save(state, os.path.join(opt.save_path, "last.pth"))
else:
val(val_loader, model, logger, opt, state)
def val(val_loader, model, logger, opt, state):
global best_psnr
current_process = 10
model.eval()
metric_log = {
'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
}
lut_metric_log = {
'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
}
for step, batch in enumerate(val_loader):
composite_image = batch['composite_image'].to(opt.device)
real_image = batch['real_image'].to(opt.device)
mask = batch['mask'].to(opt.device)
category = batch['category']
fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device)
bg_INR_coordinates = batch['bg_INR_coordinates'].to(opt.device)
fg_transfer_INR_RGB = batch['fg_transfer_INR_RGB'].to(opt.device)
with torch.no_grad():
fg_content_bg_appearance_construct, _, lut_transform_image = model(
composite_image,
mask,
fg_INR_coordinates,
bg_INR_coordinates)
if opt.INRDecode:
pred_fg_image = fg_content_bg_appearance_construct[-1]
else:
pred_fg_image = None
fg_transfer_INR_RGB = misc.lin2img(fg_transfer_INR_RGB,
val_loader.dataset.INR_dataset.size) if fg_transfer_INR_RGB is not None else None
"For INR"
mask_INR = torchvision.transforms.Resize(opt.INR_input_size)(mask)
if not opt.INRDecode:
pred_harmonized_image = None
else:
pred_harmonized_image = pred_fg_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.))
lut_transform_image = lut_transform_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.))
"Save the output images. For every 10 epochs, save more results, otherwise, save little. Thus save storage."
if state['last_epoch'] % 10 == 0:
misc.visualize(real_image, composite_image, mask, pred_fg_image,
pred_harmonized_image, lut_transform_image, opt, state['last_epoch'], show=False,
wandb=opt.wandb, isAll=True, step=step)
elif step == 0:
misc.visualize(real_image, composite_image, mask, pred_fg_image,
pred_harmonized_image, lut_transform_image, opt, state['last_epoch'], show=False,
wandb=opt.wandb, step=step)
if opt.INRDecode:
mse, fmse, psnr, ssim = metrics.calc_metrics(misc.normalize(pred_harmonized_image, opt, 'inv'),
misc.normalize(fg_transfer_INR_RGB, opt, 'inv'), mask_INR)
lut_mse, lut_fmse, lut_psnr, lut_ssim = metrics.calc_metrics(misc.normalize(lut_transform_image, opt, 'inv'),
misc.normalize(real_image, opt, 'inv'), mask)
for idx in range(len(category)):
if opt.INRDecode:
metric_log[category[idx]]['Samples'] += 1
metric_log[category[idx]]['MSE'] += mse[idx]
metric_log[category[idx]]['fMSE'] += fmse[idx]
metric_log[category[idx]]['PSNR'] += psnr[idx]
metric_log[category[idx]]['SSIM'] += ssim[idx]
metric_log['All']['Samples'] += 1
metric_log['All']['MSE'] += mse[idx]
metric_log['All']['fMSE'] += fmse[idx]
metric_log['All']['PSNR'] += psnr[idx]
metric_log['All']['SSIM'] += ssim[idx]
lut_metric_log[category[idx]]['Samples'] += 1
lut_metric_log[category[idx]]['MSE'] += lut_mse[idx]
lut_metric_log[category[idx]]['fMSE'] += lut_fmse[idx]
lut_metric_log[category[idx]]['PSNR'] += lut_psnr[idx]
lut_metric_log[category[idx]]['SSIM'] += lut_ssim[idx]
lut_metric_log['All']['Samples'] += 1
lut_metric_log['All']['MSE'] += lut_mse[idx]
lut_metric_log['All']['fMSE'] += lut_fmse[idx]
lut_metric_log['All']['PSNR'] += lut_psnr[idx]
lut_metric_log['All']['SSIM'] += lut_ssim[idx]
if (step + 1) / len(val_loader) * 100 >= current_process:
logger.info(f'Processing: {current_process}')
current_process += 10
logger.info('=========================')
for key in metric_log.keys():
if opt.INRDecode:
msg = f"{key}-'MSE': {metric_log[key]['MSE'] / metric_log[key]['Samples']:.2f}\n" \
f"{key}-'fMSE': {metric_log[key]['fMSE'] / metric_log[key]['Samples']:.2f}\n" \
f"{key}-'PSNR': {metric_log[key]['PSNR'] / metric_log[key]['Samples']:.2f}\n" \
f"{key}-'SSIM': {metric_log[key]['SSIM'] / metric_log[key]['Samples']:.4f}\n" \
f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n"
else:
msg = f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \
f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n"
logger.info(msg)
if opt.wandb:
import wandb
if opt.INRDecode:
wandb.log(
{f'Val/{key}/Epoch': state['last_epoch'],
f'Val/{key}/MSE': metric_log[key]['MSE'] / metric_log[key]['Samples'],
f'Val/{key}/fMSE': metric_log[key]['fMSE'] / metric_log[key]['Samples'],
f'Val/{key}/PSNR': metric_log[key]['PSNR'] / metric_log[key]['Samples'],
f'Val/{key}/SSIM': metric_log[key]['SSIM'] / metric_log[key]['Samples'],
f'Val/{key}/LUT_MSE': lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples'],
f'Val/{key}/LUT_fMSE': lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples'],
f'Val/{key}/LUT_PSNR': lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples'],
f'Val/{key}/LUT_SSIM': lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']
})
else:
wandb.log(
{f'Val/{key}/Epoch': state['last_epoch'],
f'Val/{key}/LUT_MSE': lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples'],
f'Val/{key}/LUT_fMSE': lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples'],
f'Val/{key}/LUT_PSNR': lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples'],
f'Val/{key}/LUT_SSIM': lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']
})
logger.info('=========================')
if not opt.INRDecode:
if lut_metric_log['All']['PSNR'] / lut_metric_log['All']['Samples'] > best_psnr:
logger.info("Best Save!")
best_psnr = lut_metric_log['All']['PSNR'] / lut_metric_log['All']['Samples']
torch.save(state, os.path.join(opt.save_path, "best.pth"))
else:
logger.info("Last Save!")
torch.save(state, os.path.join(opt.save_path, "last.pth"))
else:
if metric_log['All']['PSNR'] / metric_log['All']['Samples'] > best_psnr:
logger.info("Best Save!")
best_psnr = metric_log['All']['PSNR'] / metric_log['All']['Samples']
torch.save(state, os.path.join(opt.save_path, "best.pth"))
else:
logger.info("Last Save!")
torch.save(state, os.path.join(opt.save_path, "last.pth"))