白鹭先生 commited on
Commit
73ca179
1 Parent(s): d69c2b3
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ train_lines.txt
3
+ __pycache__
4
+ .vscode
5
+ flagged
6
+ gradio_queue.db
7
+ runs
8
+ model_data/ResNeSt.pt
9
+ model_data/G_FFHQ.pth
10
+ dataset/FFHQ
11
+ dataset/MFFHQ
12
+ dataset/MFFHQ.7z
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Egrt
3
+ Date: 2022-01-13 13:34:10
4
+ LastEditors: [egrt]
5
+ LastEditTime: 2022-05-04 12:26:53
6
+ FilePath: \MaskGAN\app.py
7
+ '''
8
+
9
+ from PIL import Image
10
+ from maskgan import MASKGAN
11
+ import gradio as gr
12
+ import os
13
+ maskgan = MASKGAN()
14
+
15
+ # --------模型推理---------- #
16
+ def inference(img):
17
+ lr_shape = [112, 112]
18
+ img = img.resize(tuple(lr_shape), Image.BICUBIC)
19
+ r_image = maskgan.generate_1x1_image(img)
20
+ return r_image
21
+
22
+ # --------网页信息---------- #
23
+ title = "MaskGAN:融合无监督的口罩遮挡人脸修复"
24
+ description = "使用生成对抗网络对口罩遮挡人脸进行修复,能够有效的恢复被遮挡区域人脸。 @西南科技大学智能控制与图像处理研究室"
25
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.10257' target='_blank'>MaskGAN: Face Restoration Using Swin Transformer</a> | <a href='https://github.com/JingyunLiang/SwinIR' target='_blank'>Github Repo</a></p>"
26
+ example_img_dir = 'img'
27
+ example_img_name = os.listdir(example_img_dir)
28
+ examples=[[os.path.join(example_img_dir, image_path)] for image_path in example_img_name if image_path.endswith('.jpg')]
29
+ gr.Interface(
30
+ inference,
31
+ [gr.inputs.Image(type="pil", label="Input")],
32
+ gr.outputs.Image(type="pil", label="Output"),
33
+ title=title,
34
+ description=description,
35
+ article=article,
36
+ enable_queue=True,
37
+ examples=examples
38
+ ).launch(debug=True)
img/00131_1.jpg ADDED
img/00294_1.jpg ADDED
img/00365_1.jpg ADDED
img/00415_1.jpg ADDED
img/00713_1.jpg ADDED
img/01026_1.jpg ADDED
img/02832_1.jpg ADDED
img/03092_1.jpg ADDED
img/03232_1.jpg ADDED
img/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <!--
2
+ * @Author: Egrt
3
+ * @Date: 2022-04-09 11:02:07
4
+ * @LastEditors: Egrt
5
+ * @LastEditTime: 2022-04-09 11:02:07
6
+ * @FilePath: \MaskGAN\img\README.md
7
+ -->
8
+ 存放示例照片
maskgan.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Egrt
3
+ Date: 2022-04-07 14:00:52
4
+ LastEditors: [egrt]
5
+ LastEditTime: 2022-05-04 11:47:21
6
+ FilePath: \MaskGAN\maskgan.py
7
+ '''
8
+ import numpy as np
9
+ import torch
10
+ import torch.backends.cudnn as cudnn
11
+ from PIL import Image
12
+ from models.SwinIR import Generator
13
+ from utils.utils import cvtColor, preprocess_input
14
+
15
+
16
+ class MASKGAN(object):
17
+ #-----------------------------------------#
18
+ # 注意修改model_path
19
+ #-----------------------------------------#
20
+ _defaults = {
21
+ #-----------------------------------------------#
22
+ # model_path指向logs文件夹下的权值文件
23
+ #-----------------------------------------------#
24
+ "model_path" : 'model_data/G_FFHQ.pth',
25
+ #-----------------------------------------------#
26
+ # 上采样的倍数,和训练时一样
27
+ #-----------------------------------------------#
28
+ "scale_factor" : 1,
29
+ #-----------------------------------------------#
30
+ # hr_shape
31
+ #-----------------------------------------------#
32
+ "hr_shape" : [112, 112],
33
+ #-------------------------------#
34
+ # 是否使用Cuda
35
+ # 没有GPU可以设置成False
36
+ #-------------------------------#
37
+ "cuda" : False,
38
+ }
39
+
40
+ #---------------------------------------------------#
41
+ # 初始化MASKGAN
42
+ #---------------------------------------------------#
43
+ def __init__(self, **kwargs):
44
+ self.__dict__.update(self._defaults)
45
+ for name, value in kwargs.items():
46
+ setattr(self, name, value)
47
+ self.generate()
48
+
49
+ def generate(self):
50
+ self.net = Generator(upscale=self.scale_factor, img_size=tuple(self.hr_shape),
51
+ window_size=7, img_range=1., depths=[6, 6, 6, 6],
52
+ embed_dim=96, num_heads=[6, 6, 6, 6], mlp_ratio=4, upsampler='pixelshuffledirect')
53
+
54
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
55
+ self.net = torch.load(self.model_path, map_location=device)
56
+ self.net = self.net.eval()
57
+ print('{} model, and classes loaded.'.format(self.model_path))
58
+
59
+ if self.cuda:
60
+ self.net = torch.nn.DataParallel(self.net)
61
+ cudnn.benchmark = True
62
+ self.net = self.net.cuda()
63
+
64
+ def generate_1x1_image(self, image):
65
+ #---------------------------------------------------------#
66
+ # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
67
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
68
+ #---------------------------------------------------------#
69
+ image = cvtColor(image)
70
+ #---------------------------------------------------------#
71
+ # 添加上batch_size维度,并进行归一化
72
+ #---------------------------------------------------------#
73
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1]), 0)
74
+
75
+ with torch.no_grad():
76
+ image_data = torch.from_numpy(image_data).type(torch.FloatTensor)
77
+ if self.cuda:
78
+ image_data = image_data.cuda()
79
+
80
+ #---------------------------------------------------------#
81
+ # 将图像输入网络当中进行预测!
82
+ #---------------------------------------------------------#
83
+ hr_image = self.net(image_data)[0]
84
+ #---------------------------------------------------------#
85
+ # 将归一化的结果再转成rgb格式
86
+ #---------------------------------------------------------#
87
+ hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5)
88
+ hr_image = np.clip(hr_image * 255, 0, 255)
89
+
90
+ hr_image = Image.fromarray(np.uint8(hr_image))
91
+ return hr_image
model_data/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <!--
2
+ * @Author: Egrt
3
+ * @Date: 2022-04-11 20:53:42
4
+ * @LastEditors: Egrt
5
+ * @LastEditTime: 2022-04-11 20:53:42
6
+ * @FilePath: \MaskGAN\model_data\README.md
7
+ -->
8
+ 存放模型权重
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils import *
utils/dataloader.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import randint
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torch.utils.data.dataset import Dataset
7
+
8
+ from .utils import cvtColor, preprocess_input
9
+
10
+ def look_image(image_name, image):
11
+ image = np.array(image)
12
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
13
+ cv2.imshow(image_name, image)
14
+ cv2.waitKey(0)
15
+
16
+
17
+ def get_new_img_size(width, height, img_min_side=600):
18
+ if width <= height:
19
+ f = float(img_min_side) / width
20
+ resized_height = int(f * height)
21
+ resized_width = int(img_min_side)
22
+ else:
23
+ f = float(img_min_side) / height
24
+ resized_width = int(f * width)
25
+ resized_height = int(img_min_side)
26
+
27
+ return resized_width, resized_height
28
+
29
+ class MASKGANDataset(Dataset):
30
+ def __init__(self, train_lines, lr_shape, hr_shape):
31
+ super(MASKGANDataset, self).__init__()
32
+
33
+ self.train_lines = train_lines
34
+ self.train_batches = len(train_lines)
35
+
36
+ self.lr_shape = lr_shape
37
+ self.hr_shape = hr_shape
38
+
39
+ def __len__(self):
40
+ return self.train_batches
41
+
42
+ def __getitem__(self, index):
43
+ index = index % self.train_batches
44
+ image_list = self.train_lines[index].split(' ')
45
+ image_origin = Image.open(image_list[0])
46
+ image_masked = Image.open(image_list[1].split()[0])
47
+
48
+ image_origin, image_masked = self.get_random_data(image_origin, image_masked, self.hr_shape)
49
+
50
+ image_origin = image_origin.resize((self.hr_shape[1], self.hr_shape[0]), Image.BICUBIC)
51
+ image_masked = image_masked.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)
52
+ # look_image('origin', image_origin)
53
+ # look_image('masked', image_masked)
54
+ image_origin = np.transpose(preprocess_input(np.array(image_origin, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
55
+ image_masked = np.transpose(preprocess_input(np.array(image_masked, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
56
+
57
+ return np.array(image_masked), np.array(image_origin)
58
+
59
+ def rand(self, a=0, b=1):
60
+ return np.random.rand()*(b-a) + a
61
+
62
+ def get_random_data(self, image_origin, image_masked, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
63
+ #------------------------------#
64
+ # 读取图像并转换成RGB图像
65
+ #------------------------------#
66
+ image_origin = cvtColor(image_origin)
67
+ image_masked = cvtColor(image_masked)
68
+
69
+ #------------------------------------------#
70
+ # 色域扭曲
71
+ #------------------------------------------#
72
+ hue = self.rand(-hue, hue)
73
+ sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
74
+ val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
75
+
76
+ x = cv2.cvtColor(np.array(image_origin,np.float32)/255, cv2.COLOR_RGB2HSV)
77
+ x[..., 1] *= sat
78
+ x[..., 2] *= val
79
+ x[x[:,:, 0]>360, 0] = 360
80
+ x[:, :, 1:][x[:, :, 1:]>1] = 1
81
+ x[x<0] = 0
82
+ image_data_origin = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
83
+
84
+ x = cv2.cvtColor(np.array(image_masked,np.float32)/255, cv2.COLOR_RGB2HSV)
85
+ x[..., 1] *= sat
86
+ x[..., 2] *= val
87
+ x[x[:,:, 0]>360, 0] = 360
88
+ x[:, :, 1:][x[:, :, 1:]>1] = 1
89
+ x[x<0] = 0
90
+ image_data_masked = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
91
+
92
+ return Image.fromarray(np.uint8(image_data_origin)), Image.fromarray(np.uint8(image_data_masked))
93
+
94
+
95
+ def MASKGAN_dataset_collate(batch):
96
+ images_l = []
97
+ images_h = []
98
+ for img_l, img_h in batch:
99
+ images_l.append(img_l)
100
+ images_h.append(img_h)
101
+ return np.array(images_l), np.array(images_h)
utils/utils.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ from torch.nn import functional as F
6
+ import cv2
7
+ import distutils.util
8
+
9
+ def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
10
+ with torch.no_grad():
11
+ test_images = G_net(imgs_lr)
12
+
13
+ fig, ax = plt.subplots(1, 3)
14
+
15
+ for j in itertools.product(range(3)):
16
+ ax[j].get_xaxis().set_visible(False)
17
+ ax[j].get_yaxis().set_visible(False)
18
+ ax[0].cla()
19
+ ax[0].imshow(np.transpose(np.clip(imgs_lr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))
20
+
21
+ ax[1].cla()
22
+ ax[1].imshow(np.transpose(np.clip(test_images.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))
23
+
24
+ ax[2].cla()
25
+ ax[2].imshow(np.transpose(np.clip(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))
26
+
27
+ label = 'Epoch {0}'.format(num_epoch)
28
+ fig.text(0.5, 0.04, label, ha='center')
29
+ plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png")
30
+ plt.close('all') #避免内存泄漏
31
+
32
+ #---------------------------------------------------------#
33
+ # 将图像转换成RGB图像,防止灰度图在预测时报错。
34
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
35
+ #---------------------------------------------------------#
36
+ def cvtColor(image):
37
+ if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
38
+ return image
39
+ else:
40
+ image = image.convert('RGB')
41
+ return image
42
+
43
+ def preprocess_input(image, mean, std):
44
+ image = (image/255 - mean)/std
45
+ return image
46
+
47
+ def get_lr(optimizer):
48
+ for param_group in optimizer.param_groups:
49
+ return param_group['lr']
50
+
51
+ def print_arguments(args):
52
+ print("----------- Configuration Arguments -----------")
53
+ for arg, value in sorted(vars(args).items()):
54
+ print("%s: %s" % (arg, value))
55
+ print("------------------------------------------------")
56
+
57
+
58
+ def add_arguments(argname, type, default, help, argparser, **kwargs):
59
+ type = distutils.util.strtobool if type == bool else type
60
+ argparser.add_argument("--" + argname,
61
+ default=default,
62
+ type=type,
63
+ help=help + ' 默认: %(default)s.',
64
+ **kwargs)
65
+
66
+ def filter2D(img, kernel):
67
+ """PyTorch version of cv2.filter2D
68
+
69
+ Args:
70
+ img (Tensor): (b, c, h, w)
71
+ kernel (Tensor): (b, k, k)
72
+ """
73
+ k = kernel.size(-1)
74
+ b, c, h, w = img.size()
75
+ if k % 2 == 1:
76
+ img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
77
+ else:
78
+ raise ValueError('Wrong kernel size')
79
+
80
+ ph, pw = img.size()[-2:]
81
+
82
+ if kernel.size(0) == 1:
83
+ # apply the same kernel to all batch images
84
+ img = img.view(b * c, 1, ph, pw)
85
+ kernel = kernel.view(1, 1, k, k)
86
+ return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
87
+ else:
88
+ img = img.view(1, b * c, ph, pw)
89
+ kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
90
+ return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
91
+
92
+
93
+ def usm_sharp(img, weight=0.5, radius=50, threshold=10):
94
+ """USM sharpening.
95
+
96
+ Input image: I; Blurry image: B.
97
+ 1. sharp = I + weight * (I - B)
98
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
99
+ 3. Blur mask:
100
+ 4. Out = Mask * sharp + (1 - Mask) * I
101
+
102
+
103
+ Args:
104
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
105
+ weight (float): Sharp weight. Default: 1.
106
+ radius (float): Kernel size of Gaussian blur. Default: 50.
107
+ threshold (int):
108
+ """
109
+ if radius % 2 == 0:
110
+ radius += 1
111
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
112
+ residual = img - blur
113
+ mask = np.abs(residual) * 255 > threshold
114
+ mask = mask.astype('float32')
115
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
116
+
117
+ sharp = img + weight * residual
118
+ sharp = np.clip(sharp, 0, 1)
119
+ return soft_mask * sharp + (1 - soft_mask) * img
120
+
121
+
122
+ class USMSharp(torch.nn.Module):
123
+
124
+ def __init__(self, radius=50, sigma=0):
125
+ super(USMSharp, self).__init__()
126
+ if radius % 2 == 0:
127
+ radius += 1
128
+ self.radius = radius
129
+ kernel = cv2.getGaussianKernel(radius, sigma)
130
+ kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
131
+ self.register_buffer('kernel', kernel)
132
+
133
+ def forward(self, img, weight=0.5, threshold=10):
134
+ blur = filter2D(img, self.kernel)
135
+ residual = img - blur
136
+
137
+ mask = torch.abs(residual) * 255 > threshold
138
+ mask = mask.float()
139
+ soft_mask = filter2D(mask, self.kernel)
140
+ sharp = img + weight * residual
141
+ sharp = torch.clip(sharp, 0, 1)
142
+ return soft_mask * sharp + (1 - soft_mask) * img
143
+
144
+ class USMSharp_npy():
145
+
146
+ def __init__(self, radius=50, sigma=0):
147
+ super(USMSharp_npy, self).__init__()
148
+ if radius % 2 == 0:
149
+ radius += 1
150
+ self.radius = radius
151
+ kernel = cv2.getGaussianKernel(radius, sigma)
152
+ self.kernel = np.dot(kernel, kernel.transpose()).astype(np.float32)
153
+
154
+ def filt(self, img, weight=0.5, threshold=10):
155
+ blur = cv2.filter2D(img, -1, self.kernel)
156
+ residual = img - blur
157
+
158
+ mask = np.abs(residual) * 255 > threshold
159
+ mask = mask.astype(np.float32)
160
+ soft_mask = cv2.filter2D(mask, -1, self.kernel)
161
+ sharp = img + weight * residual
162
+ sharp = np.clip(sharp, 0, 1)
163
+ return soft_mask * sharp + (1 - soft_mask) * img
164
+
utils/utils_fit.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from models.SwinIR import compute_gradient_penalty
4
+ from tqdm import tqdm
5
+
6
+ from .utils import get_lr, show_result
7
+ from .utils_metrics import PSNR, SSIM
8
+
9
+
10
+
11
+ def fit_one_epoch(writer, G_model_train, D_model_train, G_model, D_model, VGG_feature_model, ResNeSt_model, G_optimizer, D_optimizer, BCEWithLogits_loss, L1_loss, Face_loss, epoch, epoch_size, gen, Epoch, cuda, batch_size, save_interval):
12
+ G_total_loss = 0
13
+ D_total_loss = 0
14
+ G_total_PSNR = 0
15
+ G_total_SSIM = 0
16
+
17
+ with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3, ncols=150) as pbar:
18
+ for iteration, batch in enumerate(gen):
19
+ if iteration >= epoch_size:
20
+ break
21
+
22
+ with torch.no_grad():
23
+ lr_images, hr_images = batch
24
+ lr_images, hr_images = torch.from_numpy(lr_images).type(torch.FloatTensor), torch.from_numpy(hr_images).type(torch.FloatTensor)
25
+ y_real, y_fake = torch.ones(batch_size), torch.zeros(batch_size)
26
+ if cuda:
27
+ lr_images, hr_images, y_real, y_fake = lr_images.cuda(), hr_images.cuda(), y_real.cuda(), y_fake.cuda()
28
+
29
+ #-------------------------------------------------#
30
+ # 训练判别器
31
+ #-------------------------------------------------#
32
+ D_optimizer.zero_grad()
33
+
34
+ D_result_r = D_model_train(hr_images)
35
+
36
+ G_result = G_model_train(lr_images)
37
+ D_result_f = D_model_train(G_result).squeeze()
38
+ D_result_rf = D_result_r - D_result_f.mean()
39
+ D_result_fr = D_result_f - D_result_r.mean()
40
+ D_train_loss_rf = BCEWithLogits_loss(D_result_rf, y_real)
41
+ D_train_loss_fr = BCEWithLogits_loss(D_result_fr, y_fake)
42
+ gradient_penalty = compute_gradient_penalty(D_model_train, hr_images, G_result)
43
+ D_train_loss = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2
44
+ D_train_loss.backward()
45
+
46
+ D_optimizer.step()
47
+
48
+ #-------------------------------------------------#
49
+ # 训练生成器
50
+ #-------------------------------------------------#
51
+ G_optimizer.zero_grad()
52
+
53
+ G_result = G_model_train(lr_images)
54
+ image_loss = L1_loss(G_result, hr_images)
55
+
56
+ D_result_r = D_model_train(hr_images)
57
+ D_result_f = D_model_train(G_result).squeeze()
58
+ D_result_rf = D_result_r - D_result_f.mean()
59
+ D_result_fr = D_result_f - D_result_r.mean()
60
+ D_train_loss_rf = BCEWithLogits_loss(D_result_rf, y_fake)
61
+ D_train_loss_fr = BCEWithLogits_loss(D_result_fr, y_real)
62
+ adversarial_loss = (D_train_loss_rf + D_train_loss_fr) / 2
63
+
64
+ perception_loss = L1_loss(VGG_feature_model(G_result), VGG_feature_model(hr_images))
65
+ # 进行下采样以适配人脸识别网络
66
+ G_result_face = F.interpolate(G_result, size=(112, 112), mode='bicubic', align_corners=True)
67
+ hr_images_face = F.interpolate(hr_images, size=(112, 112), mode='bicubic', align_corners=True)
68
+ face_loss = torch.mean(1. - Face_loss(ResNeSt_model(G_result_face), ResNeSt_model(hr_images_face)))
69
+ G_train_loss = 3.0 * image_loss + 1.0 * adversarial_loss + 0.9 * perception_loss + 2.5 * face_loss
70
+
71
+ G_train_loss.backward()
72
+ G_optimizer.step()
73
+
74
+ G_total_loss += G_train_loss.item()
75
+ D_total_loss += D_train_loss.item()
76
+
77
+ with torch.no_grad():
78
+ G_total_PSNR += PSNR(G_result, hr_images).item()
79
+ G_total_SSIM += SSIM(G_result, hr_images).item()
80
+
81
+ pbar.set_postfix(**{'G_loss' : G_total_loss / (iteration + 1),
82
+ 'D_loss' : D_total_loss / (iteration + 1),
83
+ 'G_PSNR' : G_total_PSNR / (iteration + 1),
84
+ 'G_SSIM' : G_total_SSIM / (iteration + 1),
85
+ 'lr' : get_lr(G_optimizer)})
86
+ pbar.update(1)
87
+
88
+ if iteration % save_interval == 0:
89
+ show_result(epoch + 1, G_model_train, lr_images, hr_images)
90
+ writer.add_scalar('G_loss', G_total_loss / (iteration + 1), epoch + 1)
91
+ writer.add_scalar('D_loss', D_total_loss / (iteration + 1), epoch + 1)
92
+ writer.add_scalar('G_PSNR', G_total_PSNR / (iteration + 1), epoch + 1)
93
+ writer.add_scalar('G_SSIM', G_total_SSIM / (iteration + 1), epoch + 1)
94
+ writer.add_scalar('lr', get_lr(G_optimizer), epoch + 1)
95
+ print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
96
+ print('G Loss: %.4f || D Loss: %.4f ' % (G_total_loss / epoch_size, D_total_loss / epoch_size))
97
+ print('Saving state, iter:', str(epoch+1))
98
+ # 保存模型权重
99
+ torch.save(G_model, 'logs/G_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))
100
+ torch.save(D_model, 'logs/D_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))
utils/utils_metrics.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+ import numpy as np
5
+
6
+ def gaussian(window_size, sigma):
7
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
8
+ return gauss/gauss.sum()
9
+
10
+ def create_window(window_size, channel=1):
11
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
12
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
13
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
14
+ return window
15
+
16
+ def SSIM(img1, img2, window_size=11, window=None, size_average=True, full=False):
17
+ img1 = (img1 * 0.5 + 0.5) * 255
18
+ img2 = (img2 * 0.5 + 0.5) * 255
19
+ min_val = 0
20
+ max_val = 255
21
+ L = max_val - min_val
22
+ img2 = torch.clamp(img2, 0.0, 255.0)
23
+
24
+ padd = 0
25
+ (_, channel, height, width) = img1.size()
26
+ if window is None:
27
+ real_size = min(window_size, height, width)
28
+ window = create_window(real_size, channel=channel).to(img1.device)
29
+
30
+ mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
31
+ mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
32
+
33
+ mu1_sq = mu1.pow(2)
34
+ mu2_sq = mu2.pow(2)
35
+ mu1_mu2 = mu1 * mu2
36
+
37
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
38
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
39
+ sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
40
+
41
+ C1 = (0.01 * L) ** 2
42
+ C2 = (0.03 * L) ** 2
43
+
44
+ v1 = 2.0 * sigma12 + C2
45
+ v2 = sigma1_sq + sigma2_sq + C2
46
+ cs = torch.mean(v1 / v2) # contrast sensitivity
47
+
48
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
49
+
50
+ if size_average:
51
+ ret = ssim_map.mean()
52
+ else:
53
+ ret = ssim_map.mean(1).mean(1).mean(1)
54
+
55
+ if full:
56
+ return ret, cs
57
+ return ret
58
+
59
+ def tf_log10(x):
60
+ numerator = torch.log(x)
61
+ denominator = torch.log(torch.tensor(10.0))
62
+ return numerator / denominator
63
+
64
+ def PSNR(img1, img2):
65
+ img1 = (img1 * 0.5 + 0.5) * 255
66
+ img2 = (img2 * 0.5 + 0.5) * 255
67
+ max_pixel = 255.0
68
+ img2 = torch.clamp(img2, 0.0, 255.0)
69
+ return 10.0 * tf_log10((max_pixel ** 2) / (torch.mean(torch.pow(img2 - img1, 2))))