Spaces:
Running
Running
import os | |
import time | |
import datetime | |
import torch | |
import torchvision | |
from utils import misc, metrics | |
best_psnr = 0 | |
def train(train_loader, val_loader, model, optimizer, scheduler, loss_fn, logger, opt): | |
total_step = opt.epochs * len(train_loader) | |
step_time_log = misc.AverageMeter() | |
loss_log = misc.AverageMeter(':6f') | |
loss_fg_content_bg_appearance_construct_log = misc.AverageMeter(':6f') | |
loss_lut_transform_image_log = misc.AverageMeter(':6f') | |
loss_lut_regularize_log = misc.AverageMeter(':6f') | |
start_epoch = 0 | |
"Load pretrained checkpoints" | |
if opt.pretrained is not None: | |
logger.info(f"Load pretrained weight from {opt.pretrained}") | |
load_state = torch.load(opt.pretrained) | |
model = model.cpu() | |
model.load_state_dict(load_state['model']) | |
model = model.to(opt.device) | |
optimizer.load_state_dict(load_state['optimizer']) | |
scheduler.load_state_dict(load_state['scheduler']) | |
start_epoch = load_state['last_epoch'] + 1 | |
for epoch in range(start_epoch, opt.epochs): | |
model.train() | |
time_ckp = time.time() | |
for step, batch in enumerate(train_loader): | |
current_step = epoch * len(train_loader) + step + 1 | |
if opt.INRDecode and opt.hr_train: | |
"List with 4 elements: [Input to Encoder, three different resolutions' crop to INR Decoder]" | |
composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)] | |
real_image = [batch[f'real_image{name}'].to(opt.device) for name in range(4)] | |
mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)] | |
coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)] | |
fg_INR_coordinates = coordinate_map[1:] | |
else: | |
composite_image = batch['composite_image'].to(opt.device) | |
real_image = batch['real_image'].to(opt.device) | |
mask = batch['mask'].to(opt.device) | |
fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device) | |
fg_content_bg_appearance_construct, fit_lut3d, lut_transform_image = model( | |
composite_image, mask, fg_INR_coordinates) | |
if opt.INRDecode: | |
loss_fg_content_bg_appearance_construct = 0 | |
""" | |
Our LRIP module requires three different resolution layers, thus here | |
`loss_fg_content_bg_appearance_construct` is calculated in multiple layers. | |
Besides, when leverage `hr_train`, i.e. use RSC strategy (See Section 3.4), the `real_image` | |
and `mask` are list type, corresponding different resolutions' crop. | |
""" | |
if opt.hr_train: | |
for n in range(3): | |
loss_fg_content_bg_appearance_construct += loss_fn['masked_mse'] \ | |
(fg_content_bg_appearance_construct[n], real_image[3 - n], mask[3 - n]) | |
loss_fg_content_bg_appearance_construct /= 3 | |
loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image[1], mask[1]) | |
else: | |
for n in range(3): | |
loss_fg_content_bg_appearance_construct += loss_fn['MaskWeightedMSE'] \ | |
(fg_content_bg_appearance_construct[n], | |
torchvision.transforms.Resize(opt.INR_input_size // 2 ** (3 - n - 1))(real_image), | |
torchvision.transforms.Resize(opt.INR_input_size // 2 ** (3 - n - 1))(mask)) | |
loss_fg_content_bg_appearance_construct /= 3 | |
loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image, mask) | |
loss_lut_regularize = loss_fn['regularize_LUT'](fit_lut3d) | |
else: | |
loss_fg_content_bg_appearance_construct = 0 | |
loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image, mask) | |
loss_lut_regularize = 0 | |
loss = loss_fg_content_bg_appearance_construct + loss_lut_transform_image + loss_lut_regularize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
step_time_log.update(time.time() - time_ckp) | |
loss_fg_content_bg_appearance_construct_log.update(0 if isinstance(loss_fg_content_bg_appearance_construct, | |
int) else loss_fg_content_bg_appearance_construct.item()) | |
loss_lut_transform_image_log.update( | |
0 if isinstance(loss_lut_transform_image, int) else loss_lut_transform_image.item()) | |
loss_lut_regularize_log.update(0 if isinstance(loss_lut_regularize, int) else loss_lut_regularize.item()) | |
loss_log.update(loss.item()) | |
if current_step % opt.print_freq == 0: | |
remain_secs = (total_step - current_step) * step_time_log.avg | |
remain_time = datetime.timedelta(seconds=round(remain_secs)) | |
finish_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remain_secs)) | |
log_msg = f'Epoch: [{epoch}/{opt.epochs}]\t' \ | |
f'Step: [{step}/{len(train_loader)}]\t' \ | |
f'StepTime {step_time_log.val:.3f} ({step_time_log.avg:.3f})\t' \ | |
f'lr {optimizer.param_groups[0]["lr"]}\t' \ | |
f'Loss {loss_log.val:.4f} ({loss_log.avg:.4f})\t' \ | |
f'Loss_fg_bg_cons {loss_fg_content_bg_appearance_construct_log.val:.4f} ({loss_fg_content_bg_appearance_construct_log.avg:.4f})\t' \ | |
f'Loss_lut_trans {loss_lut_transform_image_log.val:.4f} ({loss_lut_transform_image_log.avg:.4f})\t' \ | |
f'Loss_lut_reg {loss_lut_regularize_log.val:.4f} ({loss_lut_regularize_log.avg:.4f})\t' \ | |
f'Remaining Time {remain_time} ({finish_time})' | |
logger.info(log_msg) | |
if opt.wandb: | |
import wandb | |
wandb.log( | |
{'Train/Epoch': epoch, 'Train/lr': optimizer.param_groups[0]['lr'], 'Train/Step': current_step, | |
'Train/Loss': loss_log.val, | |
'Train/Loss_fg_bg_cons': loss_fg_content_bg_appearance_construct_log.val, | |
'Train/Loss_lut_trans': loss_lut_transform_image_log.val, | |
'Train/Loss_lut_reg': loss_lut_regularize_log.val, | |
}) | |
time_ckp = time.time() | |
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'last_epoch': epoch, | |
'scheduler': scheduler.state_dict()} | |
""" | |
As the validation of original resolution Harmonization will have no consistent resolution among images | |
(so fail to form a batch) and also may lead to out-of-memory problem when combined with training phase, | |
we here only save the model when `opt.isFullRes` is True, leaving the evaluation in `inference.py`. | |
""" | |
if opt.isFullRes and opt.hr_train: | |
if epoch % 5 == 0: | |
torch.save(state, os.path.join(opt.save_path, f"epoch{epoch}.pth")) | |
else: | |
torch.save(state, os.path.join(opt.save_path, "last.pth")) | |
else: | |
val(val_loader, model, logger, opt, state) | |
def val(val_loader, model, logger, opt, state): | |
global best_psnr | |
current_process = 10 | |
model.eval() | |
metric_log = { | |
'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, | |
'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, | |
'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, | |
'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, | |
'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, | |
} | |
lut_metric_log = { | |
'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, | |
'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, | |
'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, | |
'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, | |
'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, | |
} | |
for step, batch in enumerate(val_loader): | |
composite_image = batch['composite_image'].to(opt.device) | |
real_image = batch['real_image'].to(opt.device) | |
mask = batch['mask'].to(opt.device) | |
category = batch['category'] | |
fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device) | |
bg_INR_coordinates = batch['bg_INR_coordinates'].to(opt.device) | |
fg_transfer_INR_RGB = batch['fg_transfer_INR_RGB'].to(opt.device) | |
with torch.no_grad(): | |
fg_content_bg_appearance_construct, _, lut_transform_image = model( | |
composite_image, | |
mask, | |
fg_INR_coordinates, | |
bg_INR_coordinates) | |
if opt.INRDecode: | |
pred_fg_image = fg_content_bg_appearance_construct[-1] | |
else: | |
pred_fg_image = None | |
fg_transfer_INR_RGB = misc.lin2img(fg_transfer_INR_RGB, | |
val_loader.dataset.INR_dataset.size) if fg_transfer_INR_RGB is not None else None | |
"For INR" | |
mask_INR = torchvision.transforms.Resize(opt.INR_input_size)(mask) | |
if not opt.INRDecode: | |
pred_harmonized_image = None | |
else: | |
pred_harmonized_image = pred_fg_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.)) | |
lut_transform_image = lut_transform_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.)) | |
"Save the output images. For every 10 epochs, save more results, otherwise, save little. Thus save storage." | |
if state['last_epoch'] % 10 == 0: | |
misc.visualize(real_image, composite_image, mask, pred_fg_image, | |
pred_harmonized_image, lut_transform_image, opt, state['last_epoch'], show=False, | |
wandb=opt.wandb, isAll=True, step=step) | |
elif step == 0: | |
misc.visualize(real_image, composite_image, mask, pred_fg_image, | |
pred_harmonized_image, lut_transform_image, opt, state['last_epoch'], show=False, | |
wandb=opt.wandb, step=step) | |
if opt.INRDecode: | |
mse, fmse, psnr, ssim = metrics.calc_metrics(misc.normalize(pred_harmonized_image, opt, 'inv'), | |
misc.normalize(fg_transfer_INR_RGB, opt, 'inv'), mask_INR) | |
lut_mse, lut_fmse, lut_psnr, lut_ssim = metrics.calc_metrics(misc.normalize(lut_transform_image, opt, 'inv'), | |
misc.normalize(real_image, opt, 'inv'), mask) | |
for idx in range(len(category)): | |
if opt.INRDecode: | |
metric_log[category[idx]]['Samples'] += 1 | |
metric_log[category[idx]]['MSE'] += mse[idx] | |
metric_log[category[idx]]['fMSE'] += fmse[idx] | |
metric_log[category[idx]]['PSNR'] += psnr[idx] | |
metric_log[category[idx]]['SSIM'] += ssim[idx] | |
metric_log['All']['Samples'] += 1 | |
metric_log['All']['MSE'] += mse[idx] | |
metric_log['All']['fMSE'] += fmse[idx] | |
metric_log['All']['PSNR'] += psnr[idx] | |
metric_log['All']['SSIM'] += ssim[idx] | |
lut_metric_log[category[idx]]['Samples'] += 1 | |
lut_metric_log[category[idx]]['MSE'] += lut_mse[idx] | |
lut_metric_log[category[idx]]['fMSE'] += lut_fmse[idx] | |
lut_metric_log[category[idx]]['PSNR'] += lut_psnr[idx] | |
lut_metric_log[category[idx]]['SSIM'] += lut_ssim[idx] | |
lut_metric_log['All']['Samples'] += 1 | |
lut_metric_log['All']['MSE'] += lut_mse[idx] | |
lut_metric_log['All']['fMSE'] += lut_fmse[idx] | |
lut_metric_log['All']['PSNR'] += lut_psnr[idx] | |
lut_metric_log['All']['SSIM'] += lut_ssim[idx] | |
if (step + 1) / len(val_loader) * 100 >= current_process: | |
logger.info(f'Processing: {current_process}') | |
current_process += 10 | |
logger.info('=========================') | |
for key in metric_log.keys(): | |
if opt.INRDecode: | |
msg = f"{key}-'MSE': {metric_log[key]['MSE'] / metric_log[key]['Samples']:.2f}\n" \ | |
f"{key}-'fMSE': {metric_log[key]['fMSE'] / metric_log[key]['Samples']:.2f}\n" \ | |
f"{key}-'PSNR': {metric_log[key]['PSNR'] / metric_log[key]['Samples']:.2f}\n" \ | |
f"{key}-'SSIM': {metric_log[key]['SSIM'] / metric_log[key]['Samples']:.4f}\n" \ | |
f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ | |
f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ | |
f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \ | |
f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n" | |
else: | |
msg = f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ | |
f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ | |
f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \ | |
f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n" | |
logger.info(msg) | |
if opt.wandb: | |
import wandb | |
if opt.INRDecode: | |
wandb.log( | |
{f'Val/{key}/Epoch': state['last_epoch'], | |
f'Val/{key}/MSE': metric_log[key]['MSE'] / metric_log[key]['Samples'], | |
f'Val/{key}/fMSE': metric_log[key]['fMSE'] / metric_log[key]['Samples'], | |
f'Val/{key}/PSNR': metric_log[key]['PSNR'] / metric_log[key]['Samples'], | |
f'Val/{key}/SSIM': metric_log[key]['SSIM'] / metric_log[key]['Samples'], | |
f'Val/{key}/LUT_MSE': lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples'], | |
f'Val/{key}/LUT_fMSE': lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples'], | |
f'Val/{key}/LUT_PSNR': lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples'], | |
f'Val/{key}/LUT_SSIM': lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples'] | |
}) | |
else: | |
wandb.log( | |
{f'Val/{key}/Epoch': state['last_epoch'], | |
f'Val/{key}/LUT_MSE': lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples'], | |
f'Val/{key}/LUT_fMSE': lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples'], | |
f'Val/{key}/LUT_PSNR': lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples'], | |
f'Val/{key}/LUT_SSIM': lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples'] | |
}) | |
logger.info('=========================') | |
if not opt.INRDecode: | |
if lut_metric_log['All']['PSNR'] / lut_metric_log['All']['Samples'] > best_psnr: | |
logger.info("Best Save!") | |
best_psnr = lut_metric_log['All']['PSNR'] / lut_metric_log['All']['Samples'] | |
torch.save(state, os.path.join(opt.save_path, "best.pth")) | |
else: | |
logger.info("Last Save!") | |
torch.save(state, os.path.join(opt.save_path, "last.pth")) | |
else: | |
if metric_log['All']['PSNR'] / metric_log['All']['Samples'] > best_psnr: | |
logger.info("Best Save!") | |
best_psnr = metric_log['All']['PSNR'] / metric_log['All']['Samples'] | |
torch.save(state, os.path.join(opt.save_path, "best.pth")) | |
else: | |
logger.info("Last Save!") | |
torch.save(state, os.path.join(opt.save_path, "last.pth")) | |