guetLzy commited on
Commit
bc5f715
1 Parent(s): 1ba4791

Upload 25 files

Browse files
realesrgan/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ from .archs import *
3
+ from .data import *
4
+ from .models import *
5
+ from .utils import *
6
+ from .version import *
realesrgan/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (236 Bytes). View file
 
realesrgan/__pycache__/utils.cpython-39.pyc ADDED
Binary file (9.11 kB). View file
 
realesrgan/__pycache__/version.cpython-39.pyc ADDED
Binary file (228 Bytes). View file
 
realesrgan/archs/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import arch modules for registry
6
+ # scan all the files that end with '_arch.py' under the archs folder
7
+ arch_folder = osp.dirname(osp.abspath(__file__))
8
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
9
+ # import all the arch modules
10
+ _arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]
realesrgan/archs/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (702 Bytes). View file
 
realesrgan/archs/__pycache__/discriminator_arch.cpython-39.pyc ADDED
Binary file (2.41 kB). View file
 
realesrgan/archs/__pycache__/srvgg_arch.cpython-39.pyc ADDED
Binary file (2.38 kB). View file
 
realesrgan/archs/discriminator_arch.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from basicsr.utils.registry import ARCH_REGISTRY
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.nn.utils import spectral_norm
5
+
6
+
7
+ @ARCH_REGISTRY.register()
8
+ class UNetDiscriminatorSN(nn.Module):
9
+ """Defines a U-Net discriminator with spectral normalization (SN)
10
+
11
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
12
+
13
+ Arg:
14
+ num_in_ch (int): Channel number of inputs. Default: 3.
15
+ num_feat (int): Channel number of base intermediate features. Default: 64.
16
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
17
+ """
18
+
19
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
20
+ super(UNetDiscriminatorSN, self).__init__()
21
+ self.skip_connection = skip_connection
22
+ norm = spectral_norm
23
+ # the first convolution
24
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
25
+ # downsample
26
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
27
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
28
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
29
+ # upsample
30
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
31
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
32
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
33
+ # extra convolutions
34
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
35
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
36
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
37
+
38
+ def forward(self, x):
39
+ # downsample
40
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
41
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
42
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
43
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
44
+
45
+ # upsample
46
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
47
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
48
+
49
+ if self.skip_connection:
50
+ x4 = x4 + x2
51
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
52
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
53
+
54
+ if self.skip_connection:
55
+ x5 = x5 + x1
56
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
57
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
58
+
59
+ if self.skip_connection:
60
+ x6 = x6 + x0
61
+
62
+ # extra convolutions
63
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
64
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
65
+ out = self.conv9(out)
66
+
67
+ return out
realesrgan/archs/srvgg_arch.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from basicsr.utils.registry import ARCH_REGISTRY
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ @ARCH_REGISTRY.register()
7
+ class SRVGGNetCompact(nn.Module):
8
+ """A compact VGG-style network structure for super-resolution.
9
+
10
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
11
+ conducted on the HR feature space.
12
+
13
+ Args:
14
+ num_in_ch (int): Channel number of inputs. Default: 3.
15
+ num_out_ch (int): Channel number of outputs. Default: 3.
16
+ num_feat (int): Channel number of intermediate features. Default: 64.
17
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
18
+ upscale (int): Upsampling factor. Default: 4.
19
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
20
+ """
21
+
22
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
23
+ super(SRVGGNetCompact, self).__init__()
24
+ self.num_in_ch = num_in_ch
25
+ self.num_out_ch = num_out_ch
26
+ self.num_feat = num_feat
27
+ self.num_conv = num_conv
28
+ self.upscale = upscale
29
+ self.act_type = act_type
30
+
31
+ self.body = nn.ModuleList()
32
+ # the first conv
33
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
34
+ # the first activation
35
+ if act_type == 'relu':
36
+ activation = nn.ReLU(inplace=True)
37
+ elif act_type == 'prelu':
38
+ activation = nn.PReLU(num_parameters=num_feat)
39
+ elif act_type == 'leakyrelu':
40
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
41
+ self.body.append(activation)
42
+
43
+ # the body structure
44
+ for _ in range(num_conv):
45
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
46
+ # activation
47
+ if act_type == 'relu':
48
+ activation = nn.ReLU(inplace=True)
49
+ elif act_type == 'prelu':
50
+ activation = nn.PReLU(num_parameters=num_feat)
51
+ elif act_type == 'leakyrelu':
52
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
53
+ self.body.append(activation)
54
+
55
+ # the last conv
56
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
57
+ # upsample
58
+ self.upsampler = nn.PixelShuffle(upscale)
59
+
60
+ def forward(self, x):
61
+ out = x
62
+ for i in range(0, len(self.body)):
63
+ out = self.body[i](out)
64
+
65
+ out = self.upsampler(out)
66
+ # add the nearest upsampled image, so that the network learns the residual
67
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
68
+ out += base
69
+ return out
realesrgan/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import dataset modules for registry
6
+ # scan all the files that end with '_dataset.py' under the data folder
7
+ data_folder = osp.dirname(osp.abspath(__file__))
8
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
9
+ # import all the dataset modules
10
+ _dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]
realesrgan/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (709 Bytes). View file
 
realesrgan/data/__pycache__/realesrgan_dataset.cpython-39.pyc ADDED
Binary file (5.68 kB). View file
 
realesrgan/data/__pycache__/realesrgan_paired_dataset.cpython-39.pyc ADDED
Binary file (4.08 kB). View file
 
realesrgan/data/realesrgan_dataset.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import time
8
+ import torch
9
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
10
+ from basicsr.data.transforms import augment
11
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
12
+ from basicsr.utils.registry import DATASET_REGISTRY
13
+ from torch.utils import data as data
14
+
15
+
16
+ @DATASET_REGISTRY.register()
17
+ class RealESRGANDataset(data.Dataset):
18
+ """Dataset used for Real-ESRGAN model:
19
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
20
+
21
+ It loads gt (Ground-Truth) images, and augments them.
22
+ It also generates blur kernels and sinc kernels for generating low-quality images.
23
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
24
+
25
+ Args:
26
+ opt (dict): Config for train datasets. It contains the following keys:
27
+ dataroot_gt (str): Data root path for gt.
28
+ meta_info (str): Path for meta information file.
29
+ io_backend (dict): IO backend type and other kwarg.
30
+ use_hflip (bool): Use horizontal flips.
31
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
32
+ Please see more options in the codes.
33
+ """
34
+
35
+ def __init__(self, opt):
36
+ super(RealESRGANDataset, self).__init__()
37
+ self.opt = opt
38
+ self.file_client = None
39
+ self.io_backend_opt = opt['io_backend']
40
+ self.gt_folder = opt['dataroot_gt']
41
+
42
+ # file client (lmdb io backend)
43
+ if self.io_backend_opt['type'] == 'lmdb':
44
+ self.io_backend_opt['db_paths'] = [self.gt_folder]
45
+ self.io_backend_opt['client_keys'] = ['gt']
46
+ if not self.gt_folder.endswith('.lmdb'):
47
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
48
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
49
+ self.paths = [line.split('.')[0] for line in fin]
50
+ else:
51
+ # disk backend with meta_info
52
+ # Each line in the meta_info describes the relative path to an image
53
+ with open(self.opt['meta_info']) as fin:
54
+ paths = [line.strip().split(' ')[0] for line in fin]
55
+ self.paths = [os.path.join(self.gt_folder, v) for v in paths]
56
+
57
+ # blur settings for the first degradation
58
+ self.blur_kernel_size = opt['blur_kernel_size']
59
+ self.kernel_list = opt['kernel_list']
60
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
61
+ self.blur_sigma = opt['blur_sigma']
62
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
63
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
64
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
65
+
66
+ # blur settings for the second degradation
67
+ self.blur_kernel_size2 = opt['blur_kernel_size2']
68
+ self.kernel_list2 = opt['kernel_list2']
69
+ self.kernel_prob2 = opt['kernel_prob2']
70
+ self.blur_sigma2 = opt['blur_sigma2']
71
+ self.betag_range2 = opt['betag_range2']
72
+ self.betap_range2 = opt['betap_range2']
73
+ self.sinc_prob2 = opt['sinc_prob2']
74
+
75
+ # a final sinc filter
76
+ self.final_sinc_prob = opt['final_sinc_prob']
77
+
78
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
79
+ # TODO: kernel range is now hard-coded, should be in the configure file
80
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
81
+ self.pulse_tensor[10, 10] = 1
82
+
83
+ def __getitem__(self, index):
84
+ if self.file_client is None:
85
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
86
+
87
+ # -------------------------------- Load gt images -------------------------------- #
88
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
89
+ gt_path = self.paths[index]
90
+ # avoid errors caused by high latency in reading files
91
+ retry = 3
92
+ while retry > 0:
93
+ try:
94
+ img_bytes = self.file_client.get(gt_path, 'gt')
95
+ except (IOError, OSError) as e:
96
+ logger = get_root_logger()
97
+ logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
98
+ # change another file to read
99
+ index = random.randint(0, self.__len__())
100
+ gt_path = self.paths[index]
101
+ time.sleep(1) # sleep 1s for occasional server congestion
102
+ else:
103
+ break
104
+ finally:
105
+ retry -= 1
106
+ img_gt = imfrombytes(img_bytes, float32=True)
107
+
108
+ # -------------------- Do augmentation for training: flip, rotation -------------------- #
109
+ img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
110
+
111
+ # crop or pad to 400
112
+ # TODO: 400 is hard-coded. You may change it accordingly
113
+ h, w = img_gt.shape[0:2]
114
+ crop_pad_size = 400
115
+ # pad
116
+ if h < crop_pad_size or w < crop_pad_size:
117
+ pad_h = max(0, crop_pad_size - h)
118
+ pad_w = max(0, crop_pad_size - w)
119
+ img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
120
+ # crop
121
+ if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
122
+ h, w = img_gt.shape[0:2]
123
+ # randomly choose top and left coordinates
124
+ top = random.randint(0, h - crop_pad_size)
125
+ left = random.randint(0, w - crop_pad_size)
126
+ img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
127
+
128
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
129
+ kernel_size = random.choice(self.kernel_range)
130
+ if np.random.uniform() < self.opt['sinc_prob']:
131
+ # this sinc filter setting is for kernels ranging from [7, 21]
132
+ if kernel_size < 13:
133
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
134
+ else:
135
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
136
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
137
+ else:
138
+ kernel = random_mixed_kernels(
139
+ self.kernel_list,
140
+ self.kernel_prob,
141
+ kernel_size,
142
+ self.blur_sigma,
143
+ self.blur_sigma, [-math.pi, math.pi],
144
+ self.betag_range,
145
+ self.betap_range,
146
+ noise_range=None)
147
+ # pad kernel
148
+ pad_size = (21 - kernel_size) // 2
149
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
150
+
151
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
152
+ kernel_size = random.choice(self.kernel_range)
153
+ if np.random.uniform() < self.opt['sinc_prob2']:
154
+ if kernel_size < 13:
155
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
156
+ else:
157
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
158
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
159
+ else:
160
+ kernel2 = random_mixed_kernels(
161
+ self.kernel_list2,
162
+ self.kernel_prob2,
163
+ kernel_size,
164
+ self.blur_sigma2,
165
+ self.blur_sigma2, [-math.pi, math.pi],
166
+ self.betag_range2,
167
+ self.betap_range2,
168
+ noise_range=None)
169
+
170
+ # pad kernel
171
+ pad_size = (21 - kernel_size) // 2
172
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
173
+
174
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
175
+ if np.random.uniform() < self.opt['final_sinc_prob']:
176
+ kernel_size = random.choice(self.kernel_range)
177
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
178
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
179
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
180
+ else:
181
+ sinc_kernel = self.pulse_tensor
182
+
183
+ # BGR to RGB, HWC to CHW, numpy to tensor
184
+ img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
185
+ kernel = torch.FloatTensor(kernel)
186
+ kernel2 = torch.FloatTensor(kernel2)
187
+
188
+ return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
189
+ return return_d
190
+
191
+ def __len__(self):
192
+ return len(self.paths)
realesrgan/data/realesrgan_paired_dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
3
+ from basicsr.data.transforms import augment, paired_random_crop
4
+ from basicsr.utils import FileClient, imfrombytes, img2tensor
5
+ from basicsr.utils.registry import DATASET_REGISTRY
6
+ from torch.utils import data as data
7
+ from torchvision.transforms.functional import normalize
8
+
9
+
10
+ @DATASET_REGISTRY.register()
11
+ class RealESRGANPairedDataset(data.Dataset):
12
+ """Paired image dataset for image restoration.
13
+
14
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
15
+
16
+ There are three modes:
17
+ 1. 'lmdb': Use lmdb files.
18
+ If opt['io_backend'] == lmdb.
19
+ 2. 'meta_info': Use meta information file to generate paths.
20
+ If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
21
+ 3. 'folder': Scan folders to generate paths.
22
+ The rest.
23
+
24
+ Args:
25
+ opt (dict): Config for train datasets. It contains the following keys:
26
+ dataroot_gt (str): Data root path for gt.
27
+ dataroot_lq (str): Data root path for lq.
28
+ meta_info (str): Path for meta information file.
29
+ io_backend (dict): IO backend type and other kwarg.
30
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
31
+ Default: '{}'.
32
+ gt_size (int): Cropped patched size for gt patches.
33
+ use_hflip (bool): Use horizontal flips.
34
+ use_rot (bool): Use rotation (use vertical flip and transposing h
35
+ and w for implementation).
36
+
37
+ scale (bool): Scale, which will be added automatically.
38
+ phase (str): 'train' or 'val'.
39
+ """
40
+
41
+ def __init__(self, opt):
42
+ super(RealESRGANPairedDataset, self).__init__()
43
+ self.opt = opt
44
+ self.file_client = None
45
+ self.io_backend_opt = opt['io_backend']
46
+ # mean and std for normalizing the input images
47
+ self.mean = opt['mean'] if 'mean' in opt else None
48
+ self.std = opt['std'] if 'std' in opt else None
49
+
50
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
51
+ self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
52
+
53
+ # file client (lmdb io backend)
54
+ if self.io_backend_opt['type'] == 'lmdb':
55
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
56
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
57
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
58
+ elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
59
+ # disk backend with meta_info
60
+ # Each line in the meta_info describes the relative path to an image
61
+ with open(self.opt['meta_info']) as fin:
62
+ paths = [line.strip() for line in fin]
63
+ self.paths = []
64
+ for path in paths:
65
+ gt_path, lq_path = path.split(', ')
66
+ gt_path = os.path.join(self.gt_folder, gt_path)
67
+ lq_path = os.path.join(self.lq_folder, lq_path)
68
+ self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
69
+ else:
70
+ # disk backend
71
+ # it will scan the whole folder to get meta info
72
+ # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
73
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
74
+
75
+ def __getitem__(self, index):
76
+ if self.file_client is None:
77
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
78
+
79
+ scale = self.opt['scale']
80
+
81
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
82
+ # image range: [0, 1], float32.
83
+ gt_path = self.paths[index]['gt_path']
84
+ img_bytes = self.file_client.get(gt_path, 'gt')
85
+ img_gt = imfrombytes(img_bytes, float32=True)
86
+ lq_path = self.paths[index]['lq_path']
87
+ img_bytes = self.file_client.get(lq_path, 'lq')
88
+ img_lq = imfrombytes(img_bytes, float32=True)
89
+
90
+ # augmentation for training
91
+ if self.opt['phase'] == 'train':
92
+ gt_size = self.opt['gt_size']
93
+ # random crop
94
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
95
+ # flip, rotation
96
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
97
+
98
+ # BGR to RGB, HWC to CHW, numpy to tensor
99
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
100
+ # normalize
101
+ if self.mean is not None or self.std is not None:
102
+ normalize(img_lq, self.mean, self.std, inplace=True)
103
+ normalize(img_gt, self.mean, self.std, inplace=True)
104
+
105
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
106
+
107
+ def __len__(self):
108
+ return len(self.paths)
realesrgan/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import model modules for registry
6
+ # scan all the files that end with '_model.py' under the model folder
7
+ model_folder = osp.dirname(osp.abspath(__file__))
8
+ model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
9
+ # import all the model modules
10
+ _model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]
realesrgan/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (708 Bytes). View file
 
realesrgan/models/__pycache__/realesrgan_model.cpython-39.pyc ADDED
Binary file (6.67 kB). View file
 
realesrgan/models/__pycache__/realesrnet_model.cpython-39.pyc ADDED
Binary file (5.31 kB). View file
 
realesrgan/models/realesrgan_model.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
5
+ from basicsr.data.transforms import paired_random_crop
6
+ from basicsr.models.srgan_model import SRGANModel
7
+ from basicsr.utils import DiffJPEG, USMSharp
8
+ from basicsr.utils.img_process_util import filter2D
9
+ from basicsr.utils.registry import MODEL_REGISTRY
10
+ from collections import OrderedDict
11
+ from torch.nn import functional as F
12
+
13
+
14
+ @MODEL_REGISTRY.register()
15
+ class RealESRGANModel(SRGANModel):
16
+ """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
17
+
18
+ It mainly performs:
19
+ 1. randomly synthesize LQ images in GPU tensors
20
+ 2. optimize the networks with GAN training.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ super(RealESRGANModel, self).__init__(opt)
25
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
26
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
27
+ self.queue_size = opt.get('queue_size', 180)
28
+
29
+ @torch.no_grad()
30
+ def _dequeue_and_enqueue(self):
31
+ """It is the training pair pool for increasing the diversity in a batch.
32
+
33
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
34
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
35
+ to increase the degradation diversity in a batch.
36
+ """
37
+ # initialize
38
+ b, c, h, w = self.lq.size()
39
+ if not hasattr(self, 'queue_lr'):
40
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
41
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
42
+ _, c, h, w = self.gt.size()
43
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
44
+ self.queue_ptr = 0
45
+ if self.queue_ptr == self.queue_size: # the pool is full
46
+ # do dequeue and enqueue
47
+ # shuffle
48
+ idx = torch.randperm(self.queue_size)
49
+ self.queue_lr = self.queue_lr[idx]
50
+ self.queue_gt = self.queue_gt[idx]
51
+ # get first b samples
52
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
53
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
54
+ # update the queue
55
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
56
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
57
+
58
+ self.lq = lq_dequeue
59
+ self.gt = gt_dequeue
60
+ else:
61
+ # only do enqueue
62
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
63
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
64
+ self.queue_ptr = self.queue_ptr + b
65
+
66
+ @torch.no_grad()
67
+ def feed_data(self, data):
68
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
69
+ """
70
+ if self.is_train and self.opt.get('high_order_degradation', True):
71
+ # training data synthesis
72
+ self.gt = data['gt'].to(self.device)
73
+ self.gt_usm = self.usm_sharpener(self.gt)
74
+
75
+ self.kernel1 = data['kernel1'].to(self.device)
76
+ self.kernel2 = data['kernel2'].to(self.device)
77
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
78
+
79
+ ori_h, ori_w = self.gt.size()[2:4]
80
+
81
+ # ----------------------- The first degradation process ----------------------- #
82
+ # blur
83
+ out = filter2D(self.gt_usm, self.kernel1)
84
+ # random resize
85
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
86
+ if updown_type == 'up':
87
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
88
+ elif updown_type == 'down':
89
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
90
+ else:
91
+ scale = 1
92
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
93
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
94
+ # add noise
95
+ gray_noise_prob = self.opt['gray_noise_prob']
96
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
97
+ out = random_add_gaussian_noise_pt(
98
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
99
+ else:
100
+ out = random_add_poisson_noise_pt(
101
+ out,
102
+ scale_range=self.opt['poisson_scale_range'],
103
+ gray_prob=gray_noise_prob,
104
+ clip=True,
105
+ rounds=False)
106
+ # JPEG compression
107
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
108
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
109
+ out = self.jpeger(out, quality=jpeg_p)
110
+
111
+ # ----------------------- The second degradation process ----------------------- #
112
+ # blur
113
+ if np.random.uniform() < self.opt['second_blur_prob']:
114
+ out = filter2D(out, self.kernel2)
115
+ # random resize
116
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
117
+ if updown_type == 'up':
118
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
119
+ elif updown_type == 'down':
120
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
121
+ else:
122
+ scale = 1
123
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
124
+ out = F.interpolate(
125
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
126
+ # add noise
127
+ gray_noise_prob = self.opt['gray_noise_prob2']
128
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
129
+ out = random_add_gaussian_noise_pt(
130
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
131
+ else:
132
+ out = random_add_poisson_noise_pt(
133
+ out,
134
+ scale_range=self.opt['poisson_scale_range2'],
135
+ gray_prob=gray_noise_prob,
136
+ clip=True,
137
+ rounds=False)
138
+
139
+ # JPEG compression + the final sinc filter
140
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
141
+ # as one operation.
142
+ # We consider two orders:
143
+ # 1. [resize back + sinc filter] + JPEG compression
144
+ # 2. JPEG compression + [resize back + sinc filter]
145
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
146
+ if np.random.uniform() < 0.5:
147
+ # resize back + the final sinc filter
148
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
149
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
150
+ out = filter2D(out, self.sinc_kernel)
151
+ # JPEG compression
152
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
153
+ out = torch.clamp(out, 0, 1)
154
+ out = self.jpeger(out, quality=jpeg_p)
155
+ else:
156
+ # JPEG compression
157
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
158
+ out = torch.clamp(out, 0, 1)
159
+ out = self.jpeger(out, quality=jpeg_p)
160
+ # resize back + the final sinc filter
161
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
162
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
163
+ out = filter2D(out, self.sinc_kernel)
164
+
165
+ # clamp and round
166
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
167
+
168
+ # random crop
169
+ gt_size = self.opt['gt_size']
170
+ (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
171
+ self.opt['scale'])
172
+
173
+ # training pair pool
174
+ self._dequeue_and_enqueue()
175
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
176
+ self.gt_usm = self.usm_sharpener(self.gt)
177
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
178
+ else:
179
+ # for paired training or validation
180
+ self.lq = data['lq'].to(self.device)
181
+ if 'gt' in data:
182
+ self.gt = data['gt'].to(self.device)
183
+ self.gt_usm = self.usm_sharpener(self.gt)
184
+
185
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
186
+ # do not use the synthetic process during validation
187
+ self.is_train = False
188
+ super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
189
+ self.is_train = True
190
+
191
+ def optimize_parameters(self, current_iter):
192
+ # usm sharpening
193
+ l1_gt = self.gt_usm
194
+ percep_gt = self.gt_usm
195
+ gan_gt = self.gt_usm
196
+ if self.opt['l1_gt_usm'] is False:
197
+ l1_gt = self.gt
198
+ if self.opt['percep_gt_usm'] is False:
199
+ percep_gt = self.gt
200
+ if self.opt['gan_gt_usm'] is False:
201
+ gan_gt = self.gt
202
+
203
+ # optimize net_g
204
+ for p in self.net_d.parameters():
205
+ p.requires_grad = False
206
+
207
+ self.optimizer_g.zero_grad()
208
+ self.output = self.net_g(self.lq)
209
+
210
+ l_g_total = 0
211
+ loss_dict = OrderedDict()
212
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
213
+ # pixel loss
214
+ if self.cri_pix:
215
+ l_g_pix = self.cri_pix(self.output, l1_gt)
216
+ l_g_total += l_g_pix
217
+ loss_dict['l_g_pix'] = l_g_pix
218
+ # perceptual loss
219
+ if self.cri_perceptual:
220
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
221
+ if l_g_percep is not None:
222
+ l_g_total += l_g_percep
223
+ loss_dict['l_g_percep'] = l_g_percep
224
+ if l_g_style is not None:
225
+ l_g_total += l_g_style
226
+ loss_dict['l_g_style'] = l_g_style
227
+ # gan loss
228
+ fake_g_pred = self.net_d(self.output)
229
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
230
+ l_g_total += l_g_gan
231
+ loss_dict['l_g_gan'] = l_g_gan
232
+
233
+ l_g_total.backward()
234
+ self.optimizer_g.step()
235
+
236
+ # optimize net_d
237
+ for p in self.net_d.parameters():
238
+ p.requires_grad = True
239
+
240
+ self.optimizer_d.zero_grad()
241
+ # real
242
+ real_d_pred = self.net_d(gan_gt)
243
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
244
+ loss_dict['l_d_real'] = l_d_real
245
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
246
+ l_d_real.backward()
247
+ # fake
248
+ fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
249
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
250
+ loss_dict['l_d_fake'] = l_d_fake
251
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
252
+ l_d_fake.backward()
253
+ self.optimizer_d.step()
254
+
255
+ if self.ema_decay > 0:
256
+ self.model_ema(decay=self.ema_decay)
257
+
258
+ self.log_dict = self.reduce_loss_dict(loss_dict)
realesrgan/models/realesrnet_model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
5
+ from basicsr.data.transforms import paired_random_crop
6
+ from basicsr.models.sr_model import SRModel
7
+ from basicsr.utils import DiffJPEG, USMSharp
8
+ from basicsr.utils.img_process_util import filter2D
9
+ from basicsr.utils.registry import MODEL_REGISTRY
10
+ from torch.nn import functional as F
11
+
12
+
13
+ @MODEL_REGISTRY.register()
14
+ class RealESRNetModel(SRModel):
15
+ """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
16
+
17
+ It is trained without GAN losses.
18
+ It mainly performs:
19
+ 1. randomly synthesize LQ images in GPU tensors
20
+ 2. optimize the networks with GAN training.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ super(RealESRNetModel, self).__init__(opt)
25
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
26
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
27
+ self.queue_size = opt.get('queue_size', 180)
28
+
29
+ @torch.no_grad()
30
+ def _dequeue_and_enqueue(self):
31
+ """It is the training pair pool for increasing the diversity in a batch.
32
+
33
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
34
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
35
+ to increase the degradation diversity in a batch.
36
+ """
37
+ # initialize
38
+ b, c, h, w = self.lq.size()
39
+ if not hasattr(self, 'queue_lr'):
40
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
41
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
42
+ _, c, h, w = self.gt.size()
43
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
44
+ self.queue_ptr = 0
45
+ if self.queue_ptr == self.queue_size: # the pool is full
46
+ # do dequeue and enqueue
47
+ # shuffle
48
+ idx = torch.randperm(self.queue_size)
49
+ self.queue_lr = self.queue_lr[idx]
50
+ self.queue_gt = self.queue_gt[idx]
51
+ # get first b samples
52
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
53
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
54
+ # update the queue
55
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
56
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
57
+
58
+ self.lq = lq_dequeue
59
+ self.gt = gt_dequeue
60
+ else:
61
+ # only do enqueue
62
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
63
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
64
+ self.queue_ptr = self.queue_ptr + b
65
+
66
+ @torch.no_grad()
67
+ def feed_data(self, data):
68
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
69
+ """
70
+ if self.is_train and self.opt.get('high_order_degradation', True):
71
+ # training data synthesis
72
+ self.gt = data['gt'].to(self.device)
73
+ # USM sharpen the GT images
74
+ if self.opt['gt_usm'] is True:
75
+ self.gt = self.usm_sharpener(self.gt)
76
+
77
+ self.kernel1 = data['kernel1'].to(self.device)
78
+ self.kernel2 = data['kernel2'].to(self.device)
79
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
80
+
81
+ ori_h, ori_w = self.gt.size()[2:4]
82
+
83
+ # ----------------------- The first degradation process ----------------------- #
84
+ # blur
85
+ out = filter2D(self.gt, self.kernel1)
86
+ # random resize
87
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
88
+ if updown_type == 'up':
89
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
90
+ elif updown_type == 'down':
91
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
92
+ else:
93
+ scale = 1
94
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
95
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
96
+ # add noise
97
+ gray_noise_prob = self.opt['gray_noise_prob']
98
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
99
+ out = random_add_gaussian_noise_pt(
100
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
101
+ else:
102
+ out = random_add_poisson_noise_pt(
103
+ out,
104
+ scale_range=self.opt['poisson_scale_range'],
105
+ gray_prob=gray_noise_prob,
106
+ clip=True,
107
+ rounds=False)
108
+ # JPEG compression
109
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
110
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
111
+ out = self.jpeger(out, quality=jpeg_p)
112
+
113
+ # ----------------------- The second degradation process ----------------------- #
114
+ # blur
115
+ if np.random.uniform() < self.opt['second_blur_prob']:
116
+ out = filter2D(out, self.kernel2)
117
+ # random resize
118
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
119
+ if updown_type == 'up':
120
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
121
+ elif updown_type == 'down':
122
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
123
+ else:
124
+ scale = 1
125
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
126
+ out = F.interpolate(
127
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
128
+ # add noise
129
+ gray_noise_prob = self.opt['gray_noise_prob2']
130
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
131
+ out = random_add_gaussian_noise_pt(
132
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
133
+ else:
134
+ out = random_add_poisson_noise_pt(
135
+ out,
136
+ scale_range=self.opt['poisson_scale_range2'],
137
+ gray_prob=gray_noise_prob,
138
+ clip=True,
139
+ rounds=False)
140
+
141
+ # JPEG compression + the final sinc filter
142
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
143
+ # as one operation.
144
+ # We consider two orders:
145
+ # 1. [resize back + sinc filter] + JPEG compression
146
+ # 2. JPEG compression + [resize back + sinc filter]
147
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
148
+ if np.random.uniform() < 0.5:
149
+ # resize back + the final sinc filter
150
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
151
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
152
+ out = filter2D(out, self.sinc_kernel)
153
+ # JPEG compression
154
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
155
+ out = torch.clamp(out, 0, 1)
156
+ out = self.jpeger(out, quality=jpeg_p)
157
+ else:
158
+ # JPEG compression
159
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
160
+ out = torch.clamp(out, 0, 1)
161
+ out = self.jpeger(out, quality=jpeg_p)
162
+ # resize back + the final sinc filter
163
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
164
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
165
+ out = filter2D(out, self.sinc_kernel)
166
+
167
+ # clamp and round
168
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
169
+
170
+ # random crop
171
+ gt_size = self.opt['gt_size']
172
+ self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
173
+
174
+ # training pair pool
175
+ self._dequeue_and_enqueue()
176
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
177
+ else:
178
+ # for paired training or validation
179
+ self.lq = data['lq'].to(self.device)
180
+ if 'gt' in data:
181
+ self.gt = data['gt'].to(self.device)
182
+ self.gt_usm = self.usm_sharpener(self.gt)
183
+
184
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
185
+ # do not use the synthetic process during validation
186
+ self.is_train = False
187
+ super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
188
+ self.is_train = True
realesrgan/train.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ import os.path as osp
3
+ from basicsr.train import train_pipeline
4
+
5
+ import realesrgan.archs
6
+ import realesrgan.data
7
+ import realesrgan.models
8
+
9
+ if __name__ == '__main__':
10
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
11
+ train_pipeline(root_path)
realesrgan/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import queue
6
+ import threading
7
+ import torch
8
+ from basicsr.utils.download_util import load_file_from_url
9
+ from torch.nn import functional as F
10
+
11
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
+
13
+
14
+ class RealESRGANer():
15
+ """A helper class for upsampling images with RealESRGAN.
16
+
17
+ Args:
18
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
19
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
20
+ model (nn.Module): The defined network. Default: None.
21
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
22
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
23
+ 0 denotes for do not use tile. Default: 0.
24
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
25
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
26
+ half (float): Whether to use half precision during inference. Default: False.
27
+ """
28
+
29
+ def __init__(self,
30
+ scale,
31
+ model_path,
32
+ dni_weight=None,
33
+ model=None,
34
+ tile=0,
35
+ tile_pad=10,
36
+ pre_pad=10,
37
+ half=False,
38
+ device=None,
39
+ gpu_id=None):
40
+ self.scale = scale
41
+ self.tile_size = tile
42
+ self.tile_pad = tile_pad
43
+ self.pre_pad = pre_pad
44
+ self.mod_scale = None
45
+ self.half = half
46
+
47
+ # initialize model
48
+ if gpu_id:
49
+ self.device = torch.device(
50
+ f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
51
+ else:
52
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
53
+
54
+ if isinstance(model_path, list):
55
+ # dni
56
+ assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
57
+ loadnet = self.dni(model_path[0], model_path[1], dni_weight)
58
+ else:
59
+ # if the model_path starts with https, it will first download models to the folder: weights
60
+ if model_path.startswith('https://'):
61
+ model_path = load_file_from_url(
62
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
63
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'))
64
+
65
+ # prefer to use params_ema
66
+ if 'params_ema' in loadnet:
67
+ keyname = 'params_ema'
68
+ else:
69
+ keyname = 'params'
70
+ model.load_state_dict(loadnet[keyname], strict=True)
71
+
72
+ model.eval()
73
+ self.model = model.to(self.device)
74
+ if self.half:
75
+ self.model = self.model.half()
76
+
77
+ def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
78
+ """Deep network interpolation.
79
+
80
+ ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
81
+ """
82
+ net_a = torch.load(net_a, map_location=torch.device(loc))
83
+ net_b = torch.load(net_b, map_location=torch.device(loc))
84
+ for k, v_a in net_a[key].items():
85
+ net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
86
+ return net_a
87
+
88
+ def pre_process(self, img):
89
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
90
+ """
91
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
92
+ self.img = img.unsqueeze(0).to(self.device)
93
+ if self.half:
94
+ self.img = self.img.half()
95
+
96
+ # pre_pad
97
+ if self.pre_pad != 0:
98
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
99
+ # mod pad for divisible borders
100
+ if self.scale == 2:
101
+ self.mod_scale = 2
102
+ elif self.scale == 1:
103
+ self.mod_scale = 4
104
+ if self.mod_scale is not None:
105
+ self.mod_pad_h, self.mod_pad_w = 0, 0
106
+ _, _, h, w = self.img.size()
107
+ if (h % self.mod_scale != 0):
108
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
109
+ if (w % self.mod_scale != 0):
110
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
111
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
112
+
113
+ def process(self):
114
+ # model inference
115
+ self.output = self.model(self.img)
116
+
117
+ def tile_process(self):
118
+ """It will first crop input images to tiles, and then process each tile.
119
+ Finally, all the processed tiles are merged into one images.
120
+
121
+ Modified from: https://github.com/ata4/esrgan-launcher
122
+ """
123
+ batch, channel, height, width = self.img.shape
124
+ output_height = height * self.scale
125
+ output_width = width * self.scale
126
+ output_shape = (batch, channel, output_height, output_width)
127
+
128
+ # start with black image
129
+ self.output = self.img.new_zeros(output_shape)
130
+ tiles_x = math.ceil(width / self.tile_size)
131
+ tiles_y = math.ceil(height / self.tile_size)
132
+
133
+ # loop over all tiles
134
+ for y in range(tiles_y):
135
+ for x in range(tiles_x):
136
+ # extract tile from input image
137
+ ofs_x = x * self.tile_size
138
+ ofs_y = y * self.tile_size
139
+ # input tile area on total image
140
+ input_start_x = ofs_x
141
+ input_end_x = min(ofs_x + self.tile_size, width)
142
+ input_start_y = ofs_y
143
+ input_end_y = min(ofs_y + self.tile_size, height)
144
+
145
+ # input tile area on total image with padding
146
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
147
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
148
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
149
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
150
+
151
+ # input tile dimensions
152
+ input_tile_width = input_end_x - input_start_x
153
+ input_tile_height = input_end_y - input_start_y
154
+ tile_idx = y * tiles_x + x + 1
155
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
156
+
157
+ # upscale tile
158
+ try:
159
+ with torch.no_grad():
160
+ output_tile = self.model(input_tile)
161
+ except RuntimeError as error:
162
+ print('Error', error)
163
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
164
+
165
+ # output tile area on total image
166
+ output_start_x = input_start_x * self.scale
167
+ output_end_x = input_end_x * self.scale
168
+ output_start_y = input_start_y * self.scale
169
+ output_end_y = input_end_y * self.scale
170
+
171
+ # output tile area without padding
172
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
173
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
174
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
175
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
176
+
177
+ # put tile into output image
178
+ self.output[:, :, output_start_y:output_end_y,
179
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
180
+ output_start_x_tile:output_end_x_tile]
181
+
182
+ def post_process(self):
183
+ # remove extra pad
184
+ if self.mod_scale is not None:
185
+ _, _, h, w = self.output.size()
186
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
187
+ # remove prepad
188
+ if self.pre_pad != 0:
189
+ _, _, h, w = self.output.size()
190
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
191
+ return self.output
192
+
193
+ @torch.no_grad()
194
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
195
+ h_input, w_input = img.shape[0:2]
196
+ # img: numpy
197
+ img = img.astype(np.float32)
198
+ if np.max(img) > 256: # 16-bit image
199
+ max_range = 65535
200
+ print('\tInput is a 16-bit image')
201
+ else:
202
+ max_range = 255
203
+ img = img / max_range
204
+ if len(img.shape) == 2: # gray image
205
+ img_mode = 'L'
206
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
207
+ elif img.shape[2] == 4: # RGBA image with alpha channel
208
+ img_mode = 'RGBA'
209
+ alpha = img[:, :, 3]
210
+ img = img[:, :, 0:3]
211
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
212
+ if alpha_upsampler == 'realesrgan':
213
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
214
+ else:
215
+ img_mode = 'RGB'
216
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
217
+
218
+ # ------------------- process image (without the alpha channel) ------------------- #
219
+ self.pre_process(img)
220
+ if self.tile_size > 0:
221
+ self.tile_process()
222
+ else:
223
+ self.process()
224
+ output_img = self.post_process()
225
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
226
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
227
+ if img_mode == 'L':
228
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
229
+
230
+ # ------------------- process the alpha channel if necessary ------------------- #
231
+ if img_mode == 'RGBA':
232
+ if alpha_upsampler == 'realesrgan':
233
+ self.pre_process(alpha)
234
+ if self.tile_size > 0:
235
+ self.tile_process()
236
+ else:
237
+ self.process()
238
+ output_alpha = self.post_process()
239
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
240
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
241
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
242
+ else: # use the cv2 resize for alpha channel
243
+ h, w = alpha.shape[0:2]
244
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
245
+
246
+ # merge the alpha channel
247
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
248
+ output_img[:, :, 3] = output_alpha
249
+
250
+ # ------------------------------ return ------------------------------ #
251
+ if max_range == 65535: # 16-bit image
252
+ output = (output_img * 65535.0).round().astype(np.uint16)
253
+ else:
254
+ output = (output_img * 255.0).round().astype(np.uint8)
255
+
256
+ if outscale is not None and outscale != float(self.scale):
257
+ output = cv2.resize(
258
+ output, (
259
+ int(w_input * outscale),
260
+ int(h_input * outscale),
261
+ ), interpolation=cv2.INTER_LANCZOS4)
262
+
263
+ return output, img_mode
264
+
265
+
266
+ class PrefetchReader(threading.Thread):
267
+ """Prefetch images.
268
+
269
+ Args:
270
+ img_list (list[str]): A image list of image paths to be read.
271
+ num_prefetch_queue (int): Number of prefetch queue.
272
+ """
273
+
274
+ def __init__(self, img_list, num_prefetch_queue):
275
+ super().__init__()
276
+ self.que = queue.Queue(num_prefetch_queue)
277
+ self.img_list = img_list
278
+
279
+ def run(self):
280
+ for img_path in self.img_list:
281
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
282
+ self.que.put(img)
283
+
284
+ self.que.put(None)
285
+
286
+ def __next__(self):
287
+ next_item = self.que.get()
288
+ if next_item is None:
289
+ raise StopIteration
290
+ return next_item
291
+
292
+ def __iter__(self):
293
+ return self
294
+
295
+
296
+ class IOConsumer(threading.Thread):
297
+
298
+ def __init__(self, opt, que, qid):
299
+ super().__init__()
300
+ self._queue = que
301
+ self.qid = qid
302
+ self.opt = opt
303
+
304
+ def run(self):
305
+ while True:
306
+ msg = self._queue.get()
307
+ if isinstance(msg, str) and msg == 'quit':
308
+ break
309
+
310
+ output = msg['output']
311
+ save_path = msg['save_path']
312
+ cv2.imwrite(save_path, output)
313
+ print(f'IO worker {self.qid} is done.')
realesrgan/version.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # GENERATED VERSION FILE
2
+ # TIME: Thu Sep 14 18:29:16 2023
3
+ __version__ = '0.3.0'
4
+ __gitsha__ = '5ca1078'
5
+ version_info = (0, 3, 0)