File size: 2,274 Bytes
73ca179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.nn import functional as F
import distutils.util

def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
    with torch.no_grad():
        test_images = G_net(imgs_lr)

        fig, ax = plt.subplots(1, 3)

        for j in itertools.product(range(3)):
            ax[j].get_xaxis().set_visible(False)
            ax[j].get_yaxis().set_visible(False)
        ax[0].cla()
        ax[0].imshow(np.transpose(np.clip(imgs_lr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))

        ax[1].cla()
        ax[1].imshow(np.transpose(np.clip(test_images.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))

        ax[2].cla()
        ax[2].imshow(np.transpose(np.clip(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))

        label = 'Epoch {0}'.format(num_epoch)
        fig.text(0.5, 0.04, label, ha='center')
        plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png")
        plt.close('all')  #避免内存泄漏

#---------------------------------------------------------#
#   将图像转换成RGB图像,防止灰度图在预测时报错。
#   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def cvtColor(image):
    if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
        return image 
    else:
        image = image.convert('RGB')
        return image 

def preprocess_input(image, mean, std):
    image = (image/255 - mean)/std
    return image

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def print_arguments(args):
    print("-----------  Configuration Arguments -----------")
    for arg, value in sorted(vars(args).items()):
        print("%s: %s" % (arg, value))
    print("------------------------------------------------")


def add_arguments(argname, type, default, help, argparser, **kwargs):
    type = distutils.util.strtobool if type == bool else type
    argparser.add_argument("--" + argname,
                           default=default,
                           type=type,
                           help=help + ' 默认: %(default)s.',
                           **kwargs)