|
import itertools |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torch |
|
from torch.nn import functional as F |
|
import distutils.util |
|
|
|
def show_result(num_epoch, G_net, imgs_lr, imgs_hr): |
|
with torch.no_grad(): |
|
test_images = G_net(imgs_lr) |
|
|
|
fig, ax = plt.subplots(1, 3) |
|
|
|
for j in itertools.product(range(3)): |
|
ax[j].get_xaxis().set_visible(False) |
|
ax[j].get_yaxis().set_visible(False) |
|
ax[0].cla() |
|
ax[0].imshow(np.transpose(np.clip(imgs_lr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
|
ax[1].cla() |
|
ax[1].imshow(np.transpose(np.clip(test_images.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
|
ax[2].cla() |
|
ax[2].imshow(np.transpose(np.clip(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
|
label = 'Epoch {0}'.format(num_epoch) |
|
fig.text(0.5, 0.04, label, ha='center') |
|
plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png") |
|
plt.close('all') |
|
|
|
|
|
|
|
|
|
|
|
def cvtColor(image): |
|
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: |
|
return image |
|
else: |
|
image = image.convert('RGB') |
|
return image |
|
|
|
def preprocess_input(image, mean, std): |
|
image = (image/255 - mean)/std |
|
return image |
|
|
|
def get_lr(optimizer): |
|
for param_group in optimizer.param_groups: |
|
return param_group['lr'] |
|
|
|
def print_arguments(args): |
|
print("----------- Configuration Arguments -----------") |
|
for arg, value in sorted(vars(args).items()): |
|
print("%s: %s" % (arg, value)) |
|
print("------------------------------------------------") |
|
|
|
|
|
def add_arguments(argname, type, default, help, argparser, **kwargs): |
|
type = distutils.util.strtobool if type == bool else type |
|
argparser.add_argument("--" + argname, |
|
default=default, |
|
type=type, |
|
help=help + ' 默认: %(default)s.', |
|
**kwargs) |
|
|
|
|