YeOldHermit's picture
Duplicate from yangheng/Super-Resolution-Anime-Diffusion
9da7c8d
raw
history blame
6.04 kB
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import trange
from Dataloader import *
from .utils import image_quality
from .utils.cls import CyclicLR
from .utils.prepare_images import *
train_folder = './dataset/train'
test_folder = "./dataset/test"
img_dataset = ImageDBData(db_file='dataset/images.db', db_table="train_images_size_128_noise_1_rgb", max_images=24)
img_data = DataLoader(img_dataset, batch_size=6, shuffle=True, num_workers=6)
total_batch = len(img_data)
print(len(img_dataset))
test_dataset = ImageDBData(db_file='dataset/test2.db', db_table="test_images_size_128_noise_1_rgb", max_images=None)
num_test = len(test_dataset)
test_data = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
criteria = nn.L1Loss()
model = CARN_V2(color_channels=3, mid_channels=64, conv=nn.Conv2d,
single_conv_size=3, single_conv_group=1,
scale=2, activation=nn.LeakyReLU(0.1),
SEBlock=True, repeat_blocks=3, atrous=(1, 1, 1))
model.total_parameters()
# model.initialize_weights_xavier_uniform()
# fp16 training is available in GPU only
model = network_to_half(model)
model = model.cuda()
model.load_state_dict(torch.load("CARN_model_checkpoint.pt"))
learning_rate = 1e-4
weight_decay = 1e-6
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, amsgrad=True)
# optimizer = optim.SGD(model.parameters(), momentum=0.9, nesterov=True, weight_decay=weight_decay, lr=learning_rate)
# optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0, verbose=False)
# optimizer.load_state_dict(torch.load("CARN_adam_checkpoint.pt"))
last_iter = -1 # torch.load("CARN_scheduler_last_iter")
scheduler = CyclicLR(optimizer, base_lr=1e-4, max_lr=1e-4,
step_size=3 * total_batch, mode="triangular",
last_batch_iteration=last_iter)
train_loss = []
train_ssim = []
train_psnr = []
test_loss = []
test_ssim = []
test_psnr = []
# train_loss = torch.load("train_loss.pt")
# train_ssim = torch.load("train_ssim.pt")
# train_psnr = torch.load("train_psnr.pt")
#
# test_loss = torch.load("test_loss.pt")
# test_ssim = torch.load("test_ssim.pt")
# test_psnr = torch.load("test_psnr.pt")
counter = 0
iteration = 2
ibar = trange(iteration, ascii=True, maxinterval=1, postfix={"avg_loss": 0, "train_ssim": 0, "test_ssim": 0})
for i in ibar:
# batch_loss = []
# insample_ssim = []
# insample_psnr = []
for index, batch in enumerate(img_data):
scheduler.batch_step()
lr_img, hr_img = batch
lr_img = lr_img.cuda().half()
hr_img = hr_img.cuda()
# model.zero_grad()
optimizer.zero_grad()
outputs = model.forward(lr_img)
outputs = outputs.float()
loss = criteria(outputs, hr_img)
# loss.backward()
optimizer.backward(loss)
# nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
counter += 1
# train_loss.append(loss.item())
ssim = image_quality.msssim(outputs, hr_img).item()
psnr = image_quality.psnr(outputs, hr_img).item()
ibar.set_postfix(ratio=index / total_batch, loss=loss.item(),
ssim=ssim, batch=index,
psnr=psnr,
lr=scheduler.current_lr
)
train_loss.append(loss.item())
train_ssim.append(ssim)
train_psnr.append(psnr)
# +++++++++++++++++++++++++++++++++++++
# save checkpoints by iterations
# -------------------------------------
if (counter + 1) % 500 == 0:
torch.save(model.state_dict(), 'CARN_model_checkpoint.pt')
torch.save(optimizer.state_dict(), 'CARN_adam_checkpoint.pt')
torch.save(train_loss, 'train_loss.pt')
torch.save(train_ssim, "train_ssim.pt")
torch.save(train_psnr, 'train_psnr.pt')
torch.save(scheduler.last_batch_iteration, "CARN_scheduler_last_iter.pt")
# +++++++++++++++++++++++++++++++++++++
# End of One Epoch
# -------------------------------------
# one_ite_loss = np.mean(batch_loss)
# one_ite_ssim = np.mean(insample_ssim)
# one_ite_psnr = np.mean(insample_psnr)
# print(f"One iteration loss {one_ite_loss}, ssim {one_ite_ssim}, psnr {one_ite_psnr}")
# train_loss.append(one_ite_loss)
# train_ssim.append(one_ite_ssim)
# train_psnr.append(one_ite_psnr)
torch.save(model.state_dict(), 'CARN_model_checkpoint.pt')
# torch.save(scheduler, "CARN_scheduler_optim.pt")
torch.save(optimizer.state_dict(), 'CARN_adam_checkpoint.pt')
torch.save(train_loss, 'train_loss.pt')
torch.save(train_ssim, "train_ssim.pt")
torch.save(train_psnr, 'train_psnr.pt')
# torch.save(scheduler.last_batch_iteration, "CARN_scheduler_last_iter.pt")
# +++++++++++++++++++++++++++++++++++++
# Test
# -------------------------------------
with torch.no_grad():
ssim = []
batch_loss = []
psnr = []
for index, test_batch in enumerate(test_data):
lr_img, hr_img = test_batch
lr_img = lr_img.cuda()
hr_img = hr_img.cuda()
lr_img_up = model(lr_img)
lr_img_up = lr_img_up.float()
loss = criteria(lr_img_up, hr_img)
save_image([lr_img_up[0], hr_img[0]], f"check_test_imgs/{index}.png")
batch_loss.append(loss.item())
ssim.append(image_quality.msssim(lr_img_up, hr_img).item())
psnr.append(image_quality.psnr(lr_img_up, hr_img).item())
test_ssim.append(np.mean(ssim))
test_loss.append(np.mean(batch_loss))
test_psnr.append(np.mean(psnr))
torch.save(test_loss, 'test_loss.pt')
torch.save(test_ssim, "test_ssim.pt")
torch.save(test_psnr, "test_psnr.pt")
# import subprocess
# subprocess.call(["shutdown", "/s"])