白鹭先生
commited on
Commit
•
73ca179
1
Parent(s):
d69c2b3
init
Browse files- .gitignore +12 -0
- app.py +38 -0
- img/00131_1.jpg +0 -0
- img/00294_1.jpg +0 -0
- img/00365_1.jpg +0 -0
- img/00415_1.jpg +0 -0
- img/00713_1.jpg +0 -0
- img/01026_1.jpg +0 -0
- img/02832_1.jpg +0 -0
- img/03092_1.jpg +0 -0
- img/03232_1.jpg +0 -0
- img/README.md +8 -0
- maskgan.py +91 -0
- model_data/README.md +8 -0
- utils/__init__.py +1 -0
- utils/dataloader.py +101 -0
- utils/utils.py +164 -0
- utils/utils_fit.py +100 -0
- utils/utils_metrics.py +69 -0
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
train_lines.txt
|
3 |
+
__pycache__
|
4 |
+
.vscode
|
5 |
+
flagged
|
6 |
+
gradio_queue.db
|
7 |
+
runs
|
8 |
+
model_data/ResNeSt.pt
|
9 |
+
model_data/G_FFHQ.pth
|
10 |
+
dataset/FFHQ
|
11 |
+
dataset/MFFHQ
|
12 |
+
dataset/MFFHQ.7z
|
app.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Egrt
|
3 |
+
Date: 2022-01-13 13:34:10
|
4 |
+
LastEditors: [egrt]
|
5 |
+
LastEditTime: 2022-05-04 12:26:53
|
6 |
+
FilePath: \MaskGAN\app.py
|
7 |
+
'''
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
from maskgan import MASKGAN
|
11 |
+
import gradio as gr
|
12 |
+
import os
|
13 |
+
maskgan = MASKGAN()
|
14 |
+
|
15 |
+
# --------模型推理---------- #
|
16 |
+
def inference(img):
|
17 |
+
lr_shape = [112, 112]
|
18 |
+
img = img.resize(tuple(lr_shape), Image.BICUBIC)
|
19 |
+
r_image = maskgan.generate_1x1_image(img)
|
20 |
+
return r_image
|
21 |
+
|
22 |
+
# --------网页信息---------- #
|
23 |
+
title = "MaskGAN:融合无监督的口罩遮挡人脸修复"
|
24 |
+
description = "使用生成对抗网络对口罩遮挡人脸进行修复,能够有效的恢复被遮挡区域人脸。 @西南科技大学智能控制与图像处理研究室"
|
25 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.10257' target='_blank'>MaskGAN: Face Restoration Using Swin Transformer</a> | <a href='https://github.com/JingyunLiang/SwinIR' target='_blank'>Github Repo</a></p>"
|
26 |
+
example_img_dir = 'img'
|
27 |
+
example_img_name = os.listdir(example_img_dir)
|
28 |
+
examples=[[os.path.join(example_img_dir, image_path)] for image_path in example_img_name if image_path.endswith('.jpg')]
|
29 |
+
gr.Interface(
|
30 |
+
inference,
|
31 |
+
[gr.inputs.Image(type="pil", label="Input")],
|
32 |
+
gr.outputs.Image(type="pil", label="Output"),
|
33 |
+
title=title,
|
34 |
+
description=description,
|
35 |
+
article=article,
|
36 |
+
enable_queue=True,
|
37 |
+
examples=examples
|
38 |
+
).launch(debug=True)
|
img/00131_1.jpg
ADDED
img/00294_1.jpg
ADDED
img/00365_1.jpg
ADDED
img/00415_1.jpg
ADDED
img/00713_1.jpg
ADDED
img/01026_1.jpg
ADDED
img/02832_1.jpg
ADDED
img/03092_1.jpg
ADDED
img/03232_1.jpg
ADDED
img/README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!--
|
2 |
+
* @Author: Egrt
|
3 |
+
* @Date: 2022-04-09 11:02:07
|
4 |
+
* @LastEditors: Egrt
|
5 |
+
* @LastEditTime: 2022-04-09 11:02:07
|
6 |
+
* @FilePath: \MaskGAN\img\README.md
|
7 |
+
-->
|
8 |
+
存放示例照片
|
maskgan.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Egrt
|
3 |
+
Date: 2022-04-07 14:00:52
|
4 |
+
LastEditors: [egrt]
|
5 |
+
LastEditTime: 2022-05-04 11:47:21
|
6 |
+
FilePath: \MaskGAN\maskgan.py
|
7 |
+
'''
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.backends.cudnn as cudnn
|
11 |
+
from PIL import Image
|
12 |
+
from models.SwinIR import Generator
|
13 |
+
from utils.utils import cvtColor, preprocess_input
|
14 |
+
|
15 |
+
|
16 |
+
class MASKGAN(object):
|
17 |
+
#-----------------------------------------#
|
18 |
+
# 注意修改model_path
|
19 |
+
#-----------------------------------------#
|
20 |
+
_defaults = {
|
21 |
+
#-----------------------------------------------#
|
22 |
+
# model_path指向logs文件夹下的权值文件
|
23 |
+
#-----------------------------------------------#
|
24 |
+
"model_path" : 'model_data/G_FFHQ.pth',
|
25 |
+
#-----------------------------------------------#
|
26 |
+
# 上采样的倍数,和训练时一样
|
27 |
+
#-----------------------------------------------#
|
28 |
+
"scale_factor" : 1,
|
29 |
+
#-----------------------------------------------#
|
30 |
+
# hr_shape
|
31 |
+
#-----------------------------------------------#
|
32 |
+
"hr_shape" : [112, 112],
|
33 |
+
#-------------------------------#
|
34 |
+
# 是否使用Cuda
|
35 |
+
# 没有GPU可以设置成False
|
36 |
+
#-------------------------------#
|
37 |
+
"cuda" : False,
|
38 |
+
}
|
39 |
+
|
40 |
+
#---------------------------------------------------#
|
41 |
+
# 初始化MASKGAN
|
42 |
+
#---------------------------------------------------#
|
43 |
+
def __init__(self, **kwargs):
|
44 |
+
self.__dict__.update(self._defaults)
|
45 |
+
for name, value in kwargs.items():
|
46 |
+
setattr(self, name, value)
|
47 |
+
self.generate()
|
48 |
+
|
49 |
+
def generate(self):
|
50 |
+
self.net = Generator(upscale=self.scale_factor, img_size=tuple(self.hr_shape),
|
51 |
+
window_size=7, img_range=1., depths=[6, 6, 6, 6],
|
52 |
+
embed_dim=96, num_heads=[6, 6, 6, 6], mlp_ratio=4, upsampler='pixelshuffledirect')
|
53 |
+
|
54 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
55 |
+
self.net = torch.load(self.model_path, map_location=device)
|
56 |
+
self.net = self.net.eval()
|
57 |
+
print('{} model, and classes loaded.'.format(self.model_path))
|
58 |
+
|
59 |
+
if self.cuda:
|
60 |
+
self.net = torch.nn.DataParallel(self.net)
|
61 |
+
cudnn.benchmark = True
|
62 |
+
self.net = self.net.cuda()
|
63 |
+
|
64 |
+
def generate_1x1_image(self, image):
|
65 |
+
#---------------------------------------------------------#
|
66 |
+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
67 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
68 |
+
#---------------------------------------------------------#
|
69 |
+
image = cvtColor(image)
|
70 |
+
#---------------------------------------------------------#
|
71 |
+
# 添加上batch_size维度,并进行归一化
|
72 |
+
#---------------------------------------------------------#
|
73 |
+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1]), 0)
|
74 |
+
|
75 |
+
with torch.no_grad():
|
76 |
+
image_data = torch.from_numpy(image_data).type(torch.FloatTensor)
|
77 |
+
if self.cuda:
|
78 |
+
image_data = image_data.cuda()
|
79 |
+
|
80 |
+
#---------------------------------------------------------#
|
81 |
+
# 将图像输入网络当中进行预测!
|
82 |
+
#---------------------------------------------------------#
|
83 |
+
hr_image = self.net(image_data)[0]
|
84 |
+
#---------------------------------------------------------#
|
85 |
+
# 将归一化的结果再转成rgb格式
|
86 |
+
#---------------------------------------------------------#
|
87 |
+
hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5)
|
88 |
+
hr_image = np.clip(hr_image * 255, 0, 255)
|
89 |
+
|
90 |
+
hr_image = Image.fromarray(np.uint8(hr_image))
|
91 |
+
return hr_image
|
model_data/README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!--
|
2 |
+
* @Author: Egrt
|
3 |
+
* @Date: 2022-04-11 20:53:42
|
4 |
+
* @LastEditors: Egrt
|
5 |
+
* @LastEditTime: 2022-04-11 20:53:42
|
6 |
+
* @FilePath: \MaskGAN\model_data\README.md
|
7 |
+
-->
|
8 |
+
存放模型权重
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .utils import *
|
utils/dataloader.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from random import randint
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data.dataset import Dataset
|
7 |
+
|
8 |
+
from .utils import cvtColor, preprocess_input
|
9 |
+
|
10 |
+
def look_image(image_name, image):
|
11 |
+
image = np.array(image)
|
12 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
13 |
+
cv2.imshow(image_name, image)
|
14 |
+
cv2.waitKey(0)
|
15 |
+
|
16 |
+
|
17 |
+
def get_new_img_size(width, height, img_min_side=600):
|
18 |
+
if width <= height:
|
19 |
+
f = float(img_min_side) / width
|
20 |
+
resized_height = int(f * height)
|
21 |
+
resized_width = int(img_min_side)
|
22 |
+
else:
|
23 |
+
f = float(img_min_side) / height
|
24 |
+
resized_width = int(f * width)
|
25 |
+
resized_height = int(img_min_side)
|
26 |
+
|
27 |
+
return resized_width, resized_height
|
28 |
+
|
29 |
+
class MASKGANDataset(Dataset):
|
30 |
+
def __init__(self, train_lines, lr_shape, hr_shape):
|
31 |
+
super(MASKGANDataset, self).__init__()
|
32 |
+
|
33 |
+
self.train_lines = train_lines
|
34 |
+
self.train_batches = len(train_lines)
|
35 |
+
|
36 |
+
self.lr_shape = lr_shape
|
37 |
+
self.hr_shape = hr_shape
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
return self.train_batches
|
41 |
+
|
42 |
+
def __getitem__(self, index):
|
43 |
+
index = index % self.train_batches
|
44 |
+
image_list = self.train_lines[index].split(' ')
|
45 |
+
image_origin = Image.open(image_list[0])
|
46 |
+
image_masked = Image.open(image_list[1].split()[0])
|
47 |
+
|
48 |
+
image_origin, image_masked = self.get_random_data(image_origin, image_masked, self.hr_shape)
|
49 |
+
|
50 |
+
image_origin = image_origin.resize((self.hr_shape[1], self.hr_shape[0]), Image.BICUBIC)
|
51 |
+
image_masked = image_masked.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)
|
52 |
+
# look_image('origin', image_origin)
|
53 |
+
# look_image('masked', image_masked)
|
54 |
+
image_origin = np.transpose(preprocess_input(np.array(image_origin, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
|
55 |
+
image_masked = np.transpose(preprocess_input(np.array(image_masked, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
|
56 |
+
|
57 |
+
return np.array(image_masked), np.array(image_origin)
|
58 |
+
|
59 |
+
def rand(self, a=0, b=1):
|
60 |
+
return np.random.rand()*(b-a) + a
|
61 |
+
|
62 |
+
def get_random_data(self, image_origin, image_masked, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
|
63 |
+
#------------------------------#
|
64 |
+
# 读取图像并转换成RGB图像
|
65 |
+
#------------------------------#
|
66 |
+
image_origin = cvtColor(image_origin)
|
67 |
+
image_masked = cvtColor(image_masked)
|
68 |
+
|
69 |
+
#------------------------------------------#
|
70 |
+
# 色域扭曲
|
71 |
+
#------------------------------------------#
|
72 |
+
hue = self.rand(-hue, hue)
|
73 |
+
sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
|
74 |
+
val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
|
75 |
+
|
76 |
+
x = cv2.cvtColor(np.array(image_origin,np.float32)/255, cv2.COLOR_RGB2HSV)
|
77 |
+
x[..., 1] *= sat
|
78 |
+
x[..., 2] *= val
|
79 |
+
x[x[:,:, 0]>360, 0] = 360
|
80 |
+
x[:, :, 1:][x[:, :, 1:]>1] = 1
|
81 |
+
x[x<0] = 0
|
82 |
+
image_data_origin = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
|
83 |
+
|
84 |
+
x = cv2.cvtColor(np.array(image_masked,np.float32)/255, cv2.COLOR_RGB2HSV)
|
85 |
+
x[..., 1] *= sat
|
86 |
+
x[..., 2] *= val
|
87 |
+
x[x[:,:, 0]>360, 0] = 360
|
88 |
+
x[:, :, 1:][x[:, :, 1:]>1] = 1
|
89 |
+
x[x<0] = 0
|
90 |
+
image_data_masked = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
|
91 |
+
|
92 |
+
return Image.fromarray(np.uint8(image_data_origin)), Image.fromarray(np.uint8(image_data_masked))
|
93 |
+
|
94 |
+
|
95 |
+
def MASKGAN_dataset_collate(batch):
|
96 |
+
images_l = []
|
97 |
+
images_h = []
|
98 |
+
for img_l, img_h in batch:
|
99 |
+
images_l.append(img_l)
|
100 |
+
images_h.append(img_h)
|
101 |
+
return np.array(images_l), np.array(images_h)
|
utils/utils.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import cv2
|
7 |
+
import distutils.util
|
8 |
+
|
9 |
+
def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
|
10 |
+
with torch.no_grad():
|
11 |
+
test_images = G_net(imgs_lr)
|
12 |
+
|
13 |
+
fig, ax = plt.subplots(1, 3)
|
14 |
+
|
15 |
+
for j in itertools.product(range(3)):
|
16 |
+
ax[j].get_xaxis().set_visible(False)
|
17 |
+
ax[j].get_yaxis().set_visible(False)
|
18 |
+
ax[0].cla()
|
19 |
+
ax[0].imshow(np.transpose(np.clip(imgs_lr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))
|
20 |
+
|
21 |
+
ax[1].cla()
|
22 |
+
ax[1].imshow(np.transpose(np.clip(test_images.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))
|
23 |
+
|
24 |
+
ax[2].cla()
|
25 |
+
ax[2].imshow(np.transpose(np.clip(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0]))
|
26 |
+
|
27 |
+
label = 'Epoch {0}'.format(num_epoch)
|
28 |
+
fig.text(0.5, 0.04, label, ha='center')
|
29 |
+
plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png")
|
30 |
+
plt.close('all') #避免内存泄漏
|
31 |
+
|
32 |
+
#---------------------------------------------------------#
|
33 |
+
# 将图像转换成RGB图像,防止灰度图在预测时报错。
|
34 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
35 |
+
#---------------------------------------------------------#
|
36 |
+
def cvtColor(image):
|
37 |
+
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
38 |
+
return image
|
39 |
+
else:
|
40 |
+
image = image.convert('RGB')
|
41 |
+
return image
|
42 |
+
|
43 |
+
def preprocess_input(image, mean, std):
|
44 |
+
image = (image/255 - mean)/std
|
45 |
+
return image
|
46 |
+
|
47 |
+
def get_lr(optimizer):
|
48 |
+
for param_group in optimizer.param_groups:
|
49 |
+
return param_group['lr']
|
50 |
+
|
51 |
+
def print_arguments(args):
|
52 |
+
print("----------- Configuration Arguments -----------")
|
53 |
+
for arg, value in sorted(vars(args).items()):
|
54 |
+
print("%s: %s" % (arg, value))
|
55 |
+
print("------------------------------------------------")
|
56 |
+
|
57 |
+
|
58 |
+
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
59 |
+
type = distutils.util.strtobool if type == bool else type
|
60 |
+
argparser.add_argument("--" + argname,
|
61 |
+
default=default,
|
62 |
+
type=type,
|
63 |
+
help=help + ' 默认: %(default)s.',
|
64 |
+
**kwargs)
|
65 |
+
|
66 |
+
def filter2D(img, kernel):
|
67 |
+
"""PyTorch version of cv2.filter2D
|
68 |
+
|
69 |
+
Args:
|
70 |
+
img (Tensor): (b, c, h, w)
|
71 |
+
kernel (Tensor): (b, k, k)
|
72 |
+
"""
|
73 |
+
k = kernel.size(-1)
|
74 |
+
b, c, h, w = img.size()
|
75 |
+
if k % 2 == 1:
|
76 |
+
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
|
77 |
+
else:
|
78 |
+
raise ValueError('Wrong kernel size')
|
79 |
+
|
80 |
+
ph, pw = img.size()[-2:]
|
81 |
+
|
82 |
+
if kernel.size(0) == 1:
|
83 |
+
# apply the same kernel to all batch images
|
84 |
+
img = img.view(b * c, 1, ph, pw)
|
85 |
+
kernel = kernel.view(1, 1, k, k)
|
86 |
+
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
|
87 |
+
else:
|
88 |
+
img = img.view(1, b * c, ph, pw)
|
89 |
+
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
|
90 |
+
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
|
91 |
+
|
92 |
+
|
93 |
+
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
|
94 |
+
"""USM sharpening.
|
95 |
+
|
96 |
+
Input image: I; Blurry image: B.
|
97 |
+
1. sharp = I + weight * (I - B)
|
98 |
+
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
99 |
+
3. Blur mask:
|
100 |
+
4. Out = Mask * sharp + (1 - Mask) * I
|
101 |
+
|
102 |
+
|
103 |
+
Args:
|
104 |
+
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
105 |
+
weight (float): Sharp weight. Default: 1.
|
106 |
+
radius (float): Kernel size of Gaussian blur. Default: 50.
|
107 |
+
threshold (int):
|
108 |
+
"""
|
109 |
+
if radius % 2 == 0:
|
110 |
+
radius += 1
|
111 |
+
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
112 |
+
residual = img - blur
|
113 |
+
mask = np.abs(residual) * 255 > threshold
|
114 |
+
mask = mask.astype('float32')
|
115 |
+
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
116 |
+
|
117 |
+
sharp = img + weight * residual
|
118 |
+
sharp = np.clip(sharp, 0, 1)
|
119 |
+
return soft_mask * sharp + (1 - soft_mask) * img
|
120 |
+
|
121 |
+
|
122 |
+
class USMSharp(torch.nn.Module):
|
123 |
+
|
124 |
+
def __init__(self, radius=50, sigma=0):
|
125 |
+
super(USMSharp, self).__init__()
|
126 |
+
if radius % 2 == 0:
|
127 |
+
radius += 1
|
128 |
+
self.radius = radius
|
129 |
+
kernel = cv2.getGaussianKernel(radius, sigma)
|
130 |
+
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
|
131 |
+
self.register_buffer('kernel', kernel)
|
132 |
+
|
133 |
+
def forward(self, img, weight=0.5, threshold=10):
|
134 |
+
blur = filter2D(img, self.kernel)
|
135 |
+
residual = img - blur
|
136 |
+
|
137 |
+
mask = torch.abs(residual) * 255 > threshold
|
138 |
+
mask = mask.float()
|
139 |
+
soft_mask = filter2D(mask, self.kernel)
|
140 |
+
sharp = img + weight * residual
|
141 |
+
sharp = torch.clip(sharp, 0, 1)
|
142 |
+
return soft_mask * sharp + (1 - soft_mask) * img
|
143 |
+
|
144 |
+
class USMSharp_npy():
|
145 |
+
|
146 |
+
def __init__(self, radius=50, sigma=0):
|
147 |
+
super(USMSharp_npy, self).__init__()
|
148 |
+
if radius % 2 == 0:
|
149 |
+
radius += 1
|
150 |
+
self.radius = radius
|
151 |
+
kernel = cv2.getGaussianKernel(radius, sigma)
|
152 |
+
self.kernel = np.dot(kernel, kernel.transpose()).astype(np.float32)
|
153 |
+
|
154 |
+
def filt(self, img, weight=0.5, threshold=10):
|
155 |
+
blur = cv2.filter2D(img, -1, self.kernel)
|
156 |
+
residual = img - blur
|
157 |
+
|
158 |
+
mask = np.abs(residual) * 255 > threshold
|
159 |
+
mask = mask.astype(np.float32)
|
160 |
+
soft_mask = cv2.filter2D(mask, -1, self.kernel)
|
161 |
+
sharp = img + weight * residual
|
162 |
+
sharp = np.clip(sharp, 0, 1)
|
163 |
+
return soft_mask * sharp + (1 - soft_mask) * img
|
164 |
+
|
utils/utils_fit.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from models.SwinIR import compute_gradient_penalty
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from .utils import get_lr, show_result
|
7 |
+
from .utils_metrics import PSNR, SSIM
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def fit_one_epoch(writer, G_model_train, D_model_train, G_model, D_model, VGG_feature_model, ResNeSt_model, G_optimizer, D_optimizer, BCEWithLogits_loss, L1_loss, Face_loss, epoch, epoch_size, gen, Epoch, cuda, batch_size, save_interval):
|
12 |
+
G_total_loss = 0
|
13 |
+
D_total_loss = 0
|
14 |
+
G_total_PSNR = 0
|
15 |
+
G_total_SSIM = 0
|
16 |
+
|
17 |
+
with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3, ncols=150) as pbar:
|
18 |
+
for iteration, batch in enumerate(gen):
|
19 |
+
if iteration >= epoch_size:
|
20 |
+
break
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
lr_images, hr_images = batch
|
24 |
+
lr_images, hr_images = torch.from_numpy(lr_images).type(torch.FloatTensor), torch.from_numpy(hr_images).type(torch.FloatTensor)
|
25 |
+
y_real, y_fake = torch.ones(batch_size), torch.zeros(batch_size)
|
26 |
+
if cuda:
|
27 |
+
lr_images, hr_images, y_real, y_fake = lr_images.cuda(), hr_images.cuda(), y_real.cuda(), y_fake.cuda()
|
28 |
+
|
29 |
+
#-------------------------------------------------#
|
30 |
+
# 训练判别器
|
31 |
+
#-------------------------------------------------#
|
32 |
+
D_optimizer.zero_grad()
|
33 |
+
|
34 |
+
D_result_r = D_model_train(hr_images)
|
35 |
+
|
36 |
+
G_result = G_model_train(lr_images)
|
37 |
+
D_result_f = D_model_train(G_result).squeeze()
|
38 |
+
D_result_rf = D_result_r - D_result_f.mean()
|
39 |
+
D_result_fr = D_result_f - D_result_r.mean()
|
40 |
+
D_train_loss_rf = BCEWithLogits_loss(D_result_rf, y_real)
|
41 |
+
D_train_loss_fr = BCEWithLogits_loss(D_result_fr, y_fake)
|
42 |
+
gradient_penalty = compute_gradient_penalty(D_model_train, hr_images, G_result)
|
43 |
+
D_train_loss = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2
|
44 |
+
D_train_loss.backward()
|
45 |
+
|
46 |
+
D_optimizer.step()
|
47 |
+
|
48 |
+
#-------------------------------------------------#
|
49 |
+
# 训练生成器
|
50 |
+
#-------------------------------------------------#
|
51 |
+
G_optimizer.zero_grad()
|
52 |
+
|
53 |
+
G_result = G_model_train(lr_images)
|
54 |
+
image_loss = L1_loss(G_result, hr_images)
|
55 |
+
|
56 |
+
D_result_r = D_model_train(hr_images)
|
57 |
+
D_result_f = D_model_train(G_result).squeeze()
|
58 |
+
D_result_rf = D_result_r - D_result_f.mean()
|
59 |
+
D_result_fr = D_result_f - D_result_r.mean()
|
60 |
+
D_train_loss_rf = BCEWithLogits_loss(D_result_rf, y_fake)
|
61 |
+
D_train_loss_fr = BCEWithLogits_loss(D_result_fr, y_real)
|
62 |
+
adversarial_loss = (D_train_loss_rf + D_train_loss_fr) / 2
|
63 |
+
|
64 |
+
perception_loss = L1_loss(VGG_feature_model(G_result), VGG_feature_model(hr_images))
|
65 |
+
# 进行下采样以适配人脸识别网络
|
66 |
+
G_result_face = F.interpolate(G_result, size=(112, 112), mode='bicubic', align_corners=True)
|
67 |
+
hr_images_face = F.interpolate(hr_images, size=(112, 112), mode='bicubic', align_corners=True)
|
68 |
+
face_loss = torch.mean(1. - Face_loss(ResNeSt_model(G_result_face), ResNeSt_model(hr_images_face)))
|
69 |
+
G_train_loss = 3.0 * image_loss + 1.0 * adversarial_loss + 0.9 * perception_loss + 2.5 * face_loss
|
70 |
+
|
71 |
+
G_train_loss.backward()
|
72 |
+
G_optimizer.step()
|
73 |
+
|
74 |
+
G_total_loss += G_train_loss.item()
|
75 |
+
D_total_loss += D_train_loss.item()
|
76 |
+
|
77 |
+
with torch.no_grad():
|
78 |
+
G_total_PSNR += PSNR(G_result, hr_images).item()
|
79 |
+
G_total_SSIM += SSIM(G_result, hr_images).item()
|
80 |
+
|
81 |
+
pbar.set_postfix(**{'G_loss' : G_total_loss / (iteration + 1),
|
82 |
+
'D_loss' : D_total_loss / (iteration + 1),
|
83 |
+
'G_PSNR' : G_total_PSNR / (iteration + 1),
|
84 |
+
'G_SSIM' : G_total_SSIM / (iteration + 1),
|
85 |
+
'lr' : get_lr(G_optimizer)})
|
86 |
+
pbar.update(1)
|
87 |
+
|
88 |
+
if iteration % save_interval == 0:
|
89 |
+
show_result(epoch + 1, G_model_train, lr_images, hr_images)
|
90 |
+
writer.add_scalar('G_loss', G_total_loss / (iteration + 1), epoch + 1)
|
91 |
+
writer.add_scalar('D_loss', D_total_loss / (iteration + 1), epoch + 1)
|
92 |
+
writer.add_scalar('G_PSNR', G_total_PSNR / (iteration + 1), epoch + 1)
|
93 |
+
writer.add_scalar('G_SSIM', G_total_SSIM / (iteration + 1), epoch + 1)
|
94 |
+
writer.add_scalar('lr', get_lr(G_optimizer), epoch + 1)
|
95 |
+
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
|
96 |
+
print('G Loss: %.4f || D Loss: %.4f ' % (G_total_loss / epoch_size, D_total_loss / epoch_size))
|
97 |
+
print('Saving state, iter:', str(epoch+1))
|
98 |
+
# 保存模型权重
|
99 |
+
torch.save(G_model, 'logs/G_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))
|
100 |
+
torch.save(D_model, 'logs/D_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))
|
utils/utils_metrics.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from math import exp
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def gaussian(window_size, sigma):
|
7 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
8 |
+
return gauss/gauss.sum()
|
9 |
+
|
10 |
+
def create_window(window_size, channel=1):
|
11 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
12 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
13 |
+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
14 |
+
return window
|
15 |
+
|
16 |
+
def SSIM(img1, img2, window_size=11, window=None, size_average=True, full=False):
|
17 |
+
img1 = (img1 * 0.5 + 0.5) * 255
|
18 |
+
img2 = (img2 * 0.5 + 0.5) * 255
|
19 |
+
min_val = 0
|
20 |
+
max_val = 255
|
21 |
+
L = max_val - min_val
|
22 |
+
img2 = torch.clamp(img2, 0.0, 255.0)
|
23 |
+
|
24 |
+
padd = 0
|
25 |
+
(_, channel, height, width) = img1.size()
|
26 |
+
if window is None:
|
27 |
+
real_size = min(window_size, height, width)
|
28 |
+
window = create_window(real_size, channel=channel).to(img1.device)
|
29 |
+
|
30 |
+
mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
31 |
+
mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
32 |
+
|
33 |
+
mu1_sq = mu1.pow(2)
|
34 |
+
mu2_sq = mu2.pow(2)
|
35 |
+
mu1_mu2 = mu1 * mu2
|
36 |
+
|
37 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
|
38 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
|
39 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
|
40 |
+
|
41 |
+
C1 = (0.01 * L) ** 2
|
42 |
+
C2 = (0.03 * L) ** 2
|
43 |
+
|
44 |
+
v1 = 2.0 * sigma12 + C2
|
45 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
46 |
+
cs = torch.mean(v1 / v2) # contrast sensitivity
|
47 |
+
|
48 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
49 |
+
|
50 |
+
if size_average:
|
51 |
+
ret = ssim_map.mean()
|
52 |
+
else:
|
53 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
54 |
+
|
55 |
+
if full:
|
56 |
+
return ret, cs
|
57 |
+
return ret
|
58 |
+
|
59 |
+
def tf_log10(x):
|
60 |
+
numerator = torch.log(x)
|
61 |
+
denominator = torch.log(torch.tensor(10.0))
|
62 |
+
return numerator / denominator
|
63 |
+
|
64 |
+
def PSNR(img1, img2):
|
65 |
+
img1 = (img1 * 0.5 + 0.5) * 255
|
66 |
+
img2 = (img2 * 0.5 + 0.5) * 255
|
67 |
+
max_pixel = 255.0
|
68 |
+
img2 = torch.clamp(img2, 0.0, 255.0)
|
69 |
+
return 10.0 * tf_log10((max_pixel ** 2) / (torch.mean(torch.pow(img2 - img1, 2))))
|