Spaces:
Running
Running
import sys | |
import os.path | |
import math | |
import argparse | |
import time | |
import random | |
import cv2 | |
import numpy as np | |
from collections import OrderedDict | |
import logging | |
import torch | |
from torch.utils.data import DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from utils import utils_logger | |
from utils import utils_image as util | |
from utils import utils_option as option | |
from utils.utils_dist import get_dist_info, init_dist | |
from data.select_dataset import define_Dataset | |
from models.select_model import define_Model | |
''' | |
# -------------------------------------------- | |
# training code for VRT | |
# -------------------------------------------- | |
''' | |
def main(json_path='options/vrt/001_train_vrt_videosr_bi_reds_6frames.json'): | |
''' | |
# ---------------------------------------- | |
# Step--1 (prepare opt) | |
# ---------------------------------------- | |
''' | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--opt', type=str, default=json_path, help='Path to option JSON file.') | |
parser.add_argument('--launcher', default='pytorch', help='job launcher') | |
parser.add_argument('--local_rank', type=int, default=0) | |
parser.add_argument('--dist', default=False) | |
opt = option.parse(parser.parse_args().opt, is_train=True) | |
opt['dist'] = parser.parse_args().dist | |
# ---------------------------------------- | |
# distributed settings | |
# ---------------------------------------- | |
if opt['dist']: | |
init_dist('pytorch') | |
opt['rank'], opt['world_size'] = get_dist_info() | |
if opt['rank'] == 0: | |
util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key)) | |
# ---------------------------------------- | |
# update opt | |
# ---------------------------------------- | |
# -->-->-->-->-->-->-->-->-->-->-->-->-->- | |
init_iter_G, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G', | |
pretrained_path=opt['path']['pretrained_netG']) | |
init_iter_E, init_path_E = option.find_last_checkpoint(opt['path']['models'], net_type='E', | |
pretrained_path=opt['path']['pretrained_netE']) | |
opt['path']['pretrained_netG'] = init_path_G | |
opt['path']['pretrained_netE'] = init_path_E | |
init_iter_optimizerG, init_path_optimizerG = option.find_last_checkpoint(opt['path']['models'], | |
net_type='optimizerG') | |
opt['path']['pretrained_optimizerG'] = init_path_optimizerG | |
current_step = max(init_iter_G, init_iter_E, init_iter_optimizerG) | |
# --<--<--<--<--<--<--<--<--<--<--<--<--<- | |
# ---------------------------------------- | |
# save opt to a '../option.json' file | |
# ---------------------------------------- | |
if opt['rank'] == 0: | |
option.save(opt) | |
# ---------------------------------------- | |
# return None for missing key | |
# ---------------------------------------- | |
opt = option.dict_to_nonedict(opt) | |
# ---------------------------------------- | |
# configure logger | |
# ---------------------------------------- | |
if opt['rank'] == 0: | |
logger_name = 'train' | |
utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) | |
logger = logging.getLogger(logger_name) | |
logger.info(option.dict2str(opt)) | |
# ---------------------------------------- | |
# seed | |
# ---------------------------------------- | |
seed = opt['train']['manual_seed'] | |
if seed is None: | |
seed = random.randint(1, 10000) | |
print('Random seed: {}'.format(seed)) | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
''' | |
# ---------------------------------------- | |
# Step--2 (creat dataloader) | |
# ---------------------------------------- | |
''' | |
# ---------------------------------------- | |
# 1) create_dataset | |
# 2) creat_dataloader for train and test | |
# ---------------------------------------- | |
for phase, dataset_opt in opt['datasets'].items(): | |
if phase == 'train': | |
train_set = define_Dataset(dataset_opt) | |
train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) | |
if opt['rank'] == 0: | |
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) | |
if opt['dist']: | |
train_sampler = DistributedSampler(train_set, shuffle=dataset_opt['dataloader_shuffle'], | |
drop_last=True, seed=seed) | |
train_loader = DataLoader(train_set, | |
batch_size=dataset_opt['dataloader_batch_size']//opt['num_gpu'], | |
shuffle=False, | |
num_workers=dataset_opt['dataloader_num_workers']//opt['num_gpu'], | |
drop_last=True, | |
pin_memory=True, | |
sampler=train_sampler) | |
else: | |
train_loader = DataLoader(train_set, | |
batch_size=dataset_opt['dataloader_batch_size'], | |
shuffle=dataset_opt['dataloader_shuffle'], | |
num_workers=dataset_opt['dataloader_num_workers'], | |
drop_last=True, | |
pin_memory=True) | |
elif phase == 'test': | |
test_set = define_Dataset(dataset_opt) | |
test_loader = DataLoader(test_set, batch_size=1, | |
shuffle=False, num_workers=1, | |
drop_last=False, pin_memory=True) | |
else: | |
raise NotImplementedError("Phase [%s] is not recognized." % phase) | |
''' | |
# ---------------------------------------- | |
# Step--3 (initialize model) | |
# ---------------------------------------- | |
''' | |
model = define_Model(opt) | |
model.init_train() | |
if opt['rank'] == 0: | |
logger.info(model.info_network()) | |
logger.info(model.info_params()) | |
''' | |
# ---------------------------------------- | |
# Step--4 (main training) | |
# ---------------------------------------- | |
''' | |
for epoch in range(1000000): # keep running | |
for i, train_data in enumerate(train_loader): | |
current_step += 1 | |
# ------------------------------- | |
# 1) update learning rate | |
# ------------------------------- | |
model.update_learning_rate(current_step) | |
# ------------------------------- | |
# 2) feed patch pairs | |
# ------------------------------- | |
model.feed_data(train_data) | |
# ------------------------------- | |
# 3) optimize parameters | |
# ------------------------------- | |
model.optimize_parameters(current_step) | |
# ------------------------------- | |
# 4) training information | |
# ------------------------------- | |
if current_step % opt['train']['checkpoint_print'] == 0 and opt['rank'] == 0: | |
logs = model.current_log() # such as loss | |
message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(epoch, current_step, | |
model.current_learning_rate()) | |
for k, v in logs.items(): # merge log information into message | |
message += '{:s}: {:.3e} '.format(k, v) | |
logger.info(message) | |
# ------------------------------- | |
# 5) save model | |
# ------------------------------- | |
if current_step % opt['train']['checkpoint_save'] == 0 and opt['rank'] == 0: | |
logger.info('Saving the model.') | |
model.save(current_step) | |
if opt['use_static_graph'] and (current_step == opt['train']['fix_iter'] - 1): | |
current_step += 1 | |
model.update_learning_rate(current_step) | |
model.save(current_step) | |
current_step -= 1 | |
logger.info('Saving models ahead of time when changing the computation graph with use_static_graph=True' | |
' (we need it due to a bug with use_checkpoint=True in distributed training). The training ' | |
'will be terminated by PyTorch in the next iteration. Just resume training with the same ' | |
'.json config file.') | |
# ------------------------------- | |
# 6) testing | |
# ------------------------------- | |
if current_step % opt['train']['checkpoint_test'] == 0 and opt['rank'] == 0: | |
test_results = OrderedDict() | |
test_results['psnr'] = [] | |
test_results['ssim'] = [] | |
test_results['psnr_y'] = [] | |
test_results['ssim_y'] = [] | |
for idx, test_data in enumerate(test_loader): | |
model.feed_data(test_data) | |
model.test() | |
visuals = model.current_visuals() | |
output = visuals['E'] | |
gt = visuals['H'] if 'H' in visuals else None | |
folder = test_data['folder'] | |
test_results_folder = OrderedDict() | |
test_results_folder['psnr'] = [] | |
test_results_folder['ssim'] = [] | |
test_results_folder['psnr_y'] = [] | |
test_results_folder['ssim_y'] = [] | |
for i in range(output.shape[0]): | |
# ----------------------- | |
# save estimated image E | |
# ----------------------- | |
img = output[i, ...].clamp_(0, 1).numpy() | |
if img.ndim == 3: | |
img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR | |
img = (img * 255.0).round().astype(np.uint8) # float32 to uint8 | |
if opt['val']['save_img']: | |
save_dir = opt['path']['images'] | |
util.mkdir(save_dir) | |
seq_ = os.path.basename(test_data['lq_path'][i][0]).split('.')[0] | |
os.makedirs(f'{save_dir}/{folder[0]}', exist_ok=True) | |
cv2.imwrite(f'{save_dir}/{folder[0]}/{seq_}_{current_step:d}.png', img) | |
# ----------------------- | |
# calculate PSNR | |
# ----------------------- | |
img_gt = gt[i, ...].clamp_(0, 1).numpy() | |
if img_gt.ndim == 3: | |
img_gt = np.transpose(img_gt[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR | |
img_gt = (img_gt * 255.0).round().astype(np.uint8) # float32 to uint8 | |
img_gt = np.squeeze(img_gt) | |
test_results_folder['psnr'].append(util.calculate_psnr(img, img_gt, border=0)) | |
test_results_folder['ssim'].append(util.calculate_ssim(img, img_gt, border=0)) | |
if img_gt.ndim == 3: # RGB image | |
img = util.bgr2ycbcr(img.astype(np.float32) / 255.) * 255. | |
img_gt = util.bgr2ycbcr(img_gt.astype(np.float32) / 255.) * 255. | |
test_results_folder['psnr_y'].append(util.calculate_psnr(img, img_gt, border=0)) | |
test_results_folder['ssim_y'].append(util.calculate_ssim(img, img_gt, border=0)) | |
else: | |
test_results_folder['psnr_y'] = test_results_folder['psnr'] | |
test_results_folder['ssim_y'] = test_results_folder['ssim'] | |
psnr = sum(test_results_folder['psnr']) / len(test_results_folder['psnr']) | |
ssim = sum(test_results_folder['ssim']) / len(test_results_folder['ssim']) | |
psnr_y = sum(test_results_folder['psnr_y']) / len(test_results_folder['psnr_y']) | |
ssim_y = sum(test_results_folder['ssim_y']) / len(test_results_folder['ssim_y']) | |
if gt is not None: | |
logger.info('Testing {:20s} ({:2d}/{}) - PSNR: {:.2f} dB; SSIM: {:.4f}; ' | |
'PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}'. | |
format(folder[0], idx, len(test_loader), psnr, ssim, psnr_y, ssim_y)) | |
test_results['psnr'].append(psnr) | |
test_results['ssim'].append(ssim) | |
test_results['psnr_y'].append(psnr_y) | |
test_results['ssim_y'].append(ssim_y) | |
else: | |
logger.info('Testing {:20s} ({:2d}/{})'.format(folder[0], idx, len(test_loader))) | |
# summarize psnr/ssim | |
if gt is not None: | |
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) | |
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) | |
ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) | |
ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) | |
logger.info('<epoch:{:3d}, iter:{:8,d} Average PSNR: {:.2f} dB; SSIM: {:.4f}; ' | |
'PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}'.format( | |
epoch, current_step, ave_psnr, ave_ssim, ave_psnr_y, ave_ssim_y)) | |
if current_step > opt['train']['total_iter']: | |
logger.info('Finish training.') | |
model.save(current_step) | |
sys.exit() | |
if __name__ == '__main__': | |
main() | |