lixiang46 commited on
Commit
66fd925
1 Parent(s): 7132521

init faceid

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. README.md +3 -3
  3. annotator/canny/__init__.py +0 -6
  4. annotator/midas/LICENSE +0 -21
  5. annotator/midas/__init__.py +0 -35
  6. annotator/midas/api.py +0 -169
  7. annotator/midas/midas/__init__.py +0 -0
  8. annotator/midas/midas/base_model.py +0 -16
  9. annotator/midas/midas/blocks.py +0 -342
  10. annotator/midas/midas/dpt_depth.py +0 -109
  11. annotator/midas/midas/midas_net.py +0 -76
  12. annotator/midas/midas/midas_net_custom.py +0 -128
  13. annotator/midas/midas/transforms.py +0 -234
  14. annotator/midas/midas/vit.py +0 -491
  15. annotator/midas/utils.py +0 -189
  16. annotator/util.py +0 -129
  17. app.py +84 -155
  18. assets/title.md +3 -3
  19. basicsr/__init__.py +0 -11
  20. basicsr/archs/__init__.py +0 -24
  21. basicsr/archs/arch_util.py +0 -313
  22. basicsr/archs/basicvsr_arch.py +0 -336
  23. basicsr/archs/basicvsrpp_arch.py +0 -417
  24. basicsr/archs/dfdnet_arch.py +0 -169
  25. basicsr/archs/dfdnet_util.py +0 -162
  26. basicsr/archs/discriminator_arch.py +0 -150
  27. basicsr/archs/duf_arch.py +0 -276
  28. basicsr/archs/ecbsr_arch.py +0 -275
  29. basicsr/archs/edsr_arch.py +0 -61
  30. basicsr/archs/edvr_arch.py +0 -382
  31. basicsr/archs/hifacegan_arch.py +0 -260
  32. basicsr/archs/hifacegan_util.py +0 -255
  33. basicsr/archs/inception.py +0 -307
  34. basicsr/archs/rcan_arch.py +0 -135
  35. basicsr/archs/ridnet_arch.py +0 -180
  36. basicsr/archs/rrdbnet_arch.py +0 -119
  37. basicsr/archs/spynet_arch.py +0 -96
  38. basicsr/archs/srresnet_arch.py +0 -65
  39. basicsr/archs/srvgg_arch.py +0 -70
  40. basicsr/archs/stylegan2_arch.py +0 -799
  41. basicsr/archs/stylegan2_bilinear_arch.py +0 -614
  42. basicsr/archs/swinir_arch.py +0 -956
  43. basicsr/archs/tof_arch.py +0 -172
  44. basicsr/archs/vgg_arch.py +0 -161
  45. basicsr/data/__init__.py +0 -101
  46. basicsr/data/data_sampler.py +0 -48
  47. basicsr/data/data_util.py +0 -315
  48. basicsr/data/degradations.py +0 -764
  49. basicsr/data/ffhq_dataset.py +0 -80
  50. basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt +0 -0
.gitattributes CHANGED
@@ -37,3 +37,5 @@ image/bird.png filter=lfs diff=lfs merge=lfs -text
37
  image/dog.png filter=lfs diff=lfs merge=lfs -text
38
  image/woman_1.png filter=lfs diff=lfs merge=lfs -text
39
  image/woman_2.png filter=lfs diff=lfs merge=lfs -text
 
 
 
37
  image/dog.png filter=lfs diff=lfs merge=lfs -text
38
  image/woman_1.png filter=lfs diff=lfs merge=lfs -text
39
  image/woman_2.png filter=lfs diff=lfs merge=lfs -text
40
+ image/image1.png filter=lfs diff=lfs merge=lfs -text
41
+ image/image2.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Kolors-Controlnet
3
- emoji: 🏞️
4
  colorFrom: purple
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.38.1
8
  app_file: app.py
 
1
  ---
2
+ title: Kolors-FaceID
3
+ emoji: 🥸
4
  colorFrom: purple
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.38.1
8
  app_file: app.py
annotator/canny/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- import cv2
2
-
3
-
4
- class CannyDetector:
5
- def __call__(self, img, low_threshold, high_threshold):
6
- return cv2.Canny(img, low_threshold, high_threshold)
 
 
 
 
 
 
 
annotator/midas/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/midas/__init__.py DELETED
@@ -1,35 +0,0 @@
1
- # Midas Depth Estimation
2
- # From https://github.com/isl-org/MiDaS
3
- # MIT LICENSE
4
-
5
- import cv2
6
- import numpy as np
7
- import torch
8
-
9
- from einops import rearrange
10
- from .api import MiDaSInference
11
-
12
-
13
- class MidasDetector:
14
- def __init__(self):
15
- self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
16
- self.rng = np.random.RandomState(0)
17
-
18
- def __call__(self, input_image):
19
- assert input_image.ndim == 3
20
- image_depth = input_image
21
- with torch.no_grad():
22
- image_depth = torch.from_numpy(image_depth).float().cuda()
23
- image_depth = image_depth / 127.5 - 1.0
24
- image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
25
- depth = self.model(image_depth)[0]
26
-
27
- depth -= torch.min(depth)
28
- depth /= torch.max(depth)
29
- depth = depth.cpu().numpy()
30
- depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
31
-
32
- return depth_image
33
-
34
-
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/midas/api.py DELETED
@@ -1,169 +0,0 @@
1
- # based on https://github.com/isl-org/MiDaS
2
-
3
- import cv2
4
- import os
5
- import torch
6
- import torch.nn as nn
7
- from torchvision.transforms import Compose
8
-
9
- from .midas.dpt_depth import DPTDepthModel
10
- from .midas.midas_net import MidasNet
11
- from .midas.midas_net_custom import MidasNet_small
12
- from .midas.transforms import Resize, NormalizeImage, PrepareForNet
13
- from annotator.util import annotator_ckpts_path
14
-
15
-
16
- ISL_PATHS = {
17
- "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
18
- "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
19
- "midas_v21": "",
20
- "midas_v21_small": "",
21
- }
22
-
23
- remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/dpt_hybrid-midas-501f0c75.pt"
24
-
25
-
26
- def disabled_train(self, mode=True):
27
- """Overwrite model.train with this function to make sure train/eval mode
28
- does not change anymore."""
29
- return self
30
-
31
-
32
- def load_midas_transform(model_type):
33
- # https://github.com/isl-org/MiDaS/blob/master/run.py
34
- # load transform only
35
- if model_type == "dpt_large": # DPT-Large
36
- net_w, net_h = 384, 384
37
- resize_mode = "minimal"
38
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
39
-
40
- elif model_type == "dpt_hybrid": # DPT-Hybrid
41
- net_w, net_h = 384, 384
42
- resize_mode = "minimal"
43
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
44
-
45
- elif model_type == "midas_v21":
46
- net_w, net_h = 384, 384
47
- resize_mode = "upper_bound"
48
- normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
49
-
50
- elif model_type == "midas_v21_small":
51
- net_w, net_h = 256, 256
52
- resize_mode = "upper_bound"
53
- normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54
-
55
- else:
56
- assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
57
-
58
- transform = Compose(
59
- [
60
- Resize(
61
- net_w,
62
- net_h,
63
- resize_target=None,
64
- keep_aspect_ratio=True,
65
- ensure_multiple_of=32,
66
- resize_method=resize_mode,
67
- image_interpolation_method=cv2.INTER_CUBIC,
68
- ),
69
- normalization,
70
- PrepareForNet(),
71
- ]
72
- )
73
-
74
- return transform
75
-
76
-
77
- def load_model(model_type):
78
- # https://github.com/isl-org/MiDaS/blob/master/run.py
79
- # load network
80
- model_path = ISL_PATHS[model_type]
81
- if model_type == "dpt_large": # DPT-Large
82
- model = DPTDepthModel(
83
- path=model_path,
84
- backbone="vitl16_384",
85
- non_negative=True,
86
- )
87
- net_w, net_h = 384, 384
88
- resize_mode = "minimal"
89
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
90
-
91
- elif model_type == "dpt_hybrid": # DPT-Hybrid
92
- if not os.path.exists(model_path):
93
- from basicsr.utils.download_util import load_file_from_url
94
- load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
95
-
96
- model = DPTDepthModel(
97
- path=model_path,
98
- backbone="vitb_rn50_384",
99
- non_negative=True,
100
- )
101
- net_w, net_h = 384, 384
102
- resize_mode = "minimal"
103
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
104
-
105
- elif model_type == "midas_v21":
106
- model = MidasNet(model_path, non_negative=True)
107
- net_w, net_h = 384, 384
108
- resize_mode = "upper_bound"
109
- normalization = NormalizeImage(
110
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
111
- )
112
-
113
- elif model_type == "midas_v21_small":
114
- model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
115
- non_negative=True, blocks={'expand': True})
116
- net_w, net_h = 256, 256
117
- resize_mode = "upper_bound"
118
- normalization = NormalizeImage(
119
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
120
- )
121
-
122
- else:
123
- print(f"model_type '{model_type}' not implemented, use: --model_type large")
124
- assert False
125
-
126
- transform = Compose(
127
- [
128
- Resize(
129
- net_w,
130
- net_h,
131
- resize_target=None,
132
- keep_aspect_ratio=True,
133
- ensure_multiple_of=32,
134
- resize_method=resize_mode,
135
- image_interpolation_method=cv2.INTER_CUBIC,
136
- ),
137
- normalization,
138
- PrepareForNet(),
139
- ]
140
- )
141
-
142
- return model.eval(), transform
143
-
144
-
145
- class MiDaSInference(nn.Module):
146
- MODEL_TYPES_TORCH_HUB = [
147
- "DPT_Large",
148
- "DPT_Hybrid",
149
- "MiDaS_small"
150
- ]
151
- MODEL_TYPES_ISL = [
152
- "dpt_large",
153
- "dpt_hybrid",
154
- "midas_v21",
155
- "midas_v21_small",
156
- ]
157
-
158
- def __init__(self, model_type):
159
- super().__init__()
160
- assert (model_type in self.MODEL_TYPES_ISL)
161
- model, _ = load_model(model_type)
162
- self.model = model
163
- self.model.train = disabled_train
164
-
165
- def forward(self, x):
166
- with torch.no_grad():
167
- prediction = self.model(x)
168
- return prediction
169
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/midas/midas/__init__.py DELETED
File without changes
annotator/midas/midas/base_model.py DELETED
@@ -1,16 +0,0 @@
1
- import torch
2
-
3
-
4
- class BaseModel(torch.nn.Module):
5
- def load(self, path):
6
- """Load model from file.
7
-
8
- Args:
9
- path (str): file path
10
- """
11
- parameters = torch.load(path, map_location=torch.device('cpu'))
12
-
13
- if "optimizer" in parameters:
14
- parameters = parameters["model"]
15
-
16
- self.load_state_dict(parameters)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/midas/midas/blocks.py DELETED
@@ -1,342 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from .vit import (
5
- _make_pretrained_vitb_rn50_384,
6
- _make_pretrained_vitl16_384,
7
- _make_pretrained_vitb16_384,
8
- forward_vit,
9
- )
10
-
11
- def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
- if backbone == "vitl16_384":
13
- pretrained = _make_pretrained_vitl16_384(
14
- use_pretrained, hooks=hooks, use_readout=use_readout
15
- )
16
- scratch = _make_scratch(
17
- [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
- ) # ViT-L/16 - 85.0% Top1 (backbone)
19
- elif backbone == "vitb_rn50_384":
20
- pretrained = _make_pretrained_vitb_rn50_384(
21
- use_pretrained,
22
- hooks=hooks,
23
- use_vit_only=use_vit_only,
24
- use_readout=use_readout,
25
- )
26
- scratch = _make_scratch(
27
- [256, 512, 768, 768], features, groups=groups, expand=expand
28
- ) # ViT-H/16 - 85.0% Top1 (backbone)
29
- elif backbone == "vitb16_384":
30
- pretrained = _make_pretrained_vitb16_384(
31
- use_pretrained, hooks=hooks, use_readout=use_readout
32
- )
33
- scratch = _make_scratch(
34
- [96, 192, 384, 768], features, groups=groups, expand=expand
35
- ) # ViT-B/16 - 84.6% Top1 (backbone)
36
- elif backbone == "resnext101_wsl":
37
- pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
- scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
- elif backbone == "efficientnet_lite3":
40
- pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
- scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
- else:
43
- print(f"Backbone '{backbone}' not implemented")
44
- assert False
45
-
46
- return pretrained, scratch
47
-
48
-
49
- def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
- scratch = nn.Module()
51
-
52
- out_shape1 = out_shape
53
- out_shape2 = out_shape
54
- out_shape3 = out_shape
55
- out_shape4 = out_shape
56
- if expand==True:
57
- out_shape1 = out_shape
58
- out_shape2 = out_shape*2
59
- out_shape3 = out_shape*4
60
- out_shape4 = out_shape*8
61
-
62
- scratch.layer1_rn = nn.Conv2d(
63
- in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
- )
65
- scratch.layer2_rn = nn.Conv2d(
66
- in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
- )
68
- scratch.layer3_rn = nn.Conv2d(
69
- in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
- )
71
- scratch.layer4_rn = nn.Conv2d(
72
- in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
- )
74
-
75
- return scratch
76
-
77
-
78
- def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
- efficientnet = torch.hub.load(
80
- "rwightman/gen-efficientnet-pytorch",
81
- "tf_efficientnet_lite3",
82
- pretrained=use_pretrained,
83
- exportable=exportable
84
- )
85
- return _make_efficientnet_backbone(efficientnet)
86
-
87
-
88
- def _make_efficientnet_backbone(effnet):
89
- pretrained = nn.Module()
90
-
91
- pretrained.layer1 = nn.Sequential(
92
- effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
- )
94
- pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
- pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
- pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
-
98
- return pretrained
99
-
100
-
101
- def _make_resnet_backbone(resnet):
102
- pretrained = nn.Module()
103
- pretrained.layer1 = nn.Sequential(
104
- resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
- )
106
-
107
- pretrained.layer2 = resnet.layer2
108
- pretrained.layer3 = resnet.layer3
109
- pretrained.layer4 = resnet.layer4
110
-
111
- return pretrained
112
-
113
-
114
- def _make_pretrained_resnext101_wsl(use_pretrained):
115
- resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
- return _make_resnet_backbone(resnet)
117
-
118
-
119
-
120
- class Interpolate(nn.Module):
121
- """Interpolation module.
122
- """
123
-
124
- def __init__(self, scale_factor, mode, align_corners=False):
125
- """Init.
126
-
127
- Args:
128
- scale_factor (float): scaling
129
- mode (str): interpolation mode
130
- """
131
- super(Interpolate, self).__init__()
132
-
133
- self.interp = nn.functional.interpolate
134
- self.scale_factor = scale_factor
135
- self.mode = mode
136
- self.align_corners = align_corners
137
-
138
- def forward(self, x):
139
- """Forward pass.
140
-
141
- Args:
142
- x (tensor): input
143
-
144
- Returns:
145
- tensor: interpolated data
146
- """
147
-
148
- x = self.interp(
149
- x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
- )
151
-
152
- return x
153
-
154
-
155
- class ResidualConvUnit(nn.Module):
156
- """Residual convolution module.
157
- """
158
-
159
- def __init__(self, features):
160
- """Init.
161
-
162
- Args:
163
- features (int): number of features
164
- """
165
- super().__init__()
166
-
167
- self.conv1 = nn.Conv2d(
168
- features, features, kernel_size=3, stride=1, padding=1, bias=True
169
- )
170
-
171
- self.conv2 = nn.Conv2d(
172
- features, features, kernel_size=3, stride=1, padding=1, bias=True
173
- )
174
-
175
- self.relu = nn.ReLU(inplace=True)
176
-
177
- def forward(self, x):
178
- """Forward pass.
179
-
180
- Args:
181
- x (tensor): input
182
-
183
- Returns:
184
- tensor: output
185
- """
186
- out = self.relu(x)
187
- out = self.conv1(out)
188
- out = self.relu(out)
189
- out = self.conv2(out)
190
-
191
- return out + x
192
-
193
-
194
- class FeatureFusionBlock(nn.Module):
195
- """Feature fusion block.
196
- """
197
-
198
- def __init__(self, features):
199
- """Init.
200
-
201
- Args:
202
- features (int): number of features
203
- """
204
- super(FeatureFusionBlock, self).__init__()
205
-
206
- self.resConfUnit1 = ResidualConvUnit(features)
207
- self.resConfUnit2 = ResidualConvUnit(features)
208
-
209
- def forward(self, *xs):
210
- """Forward pass.
211
-
212
- Returns:
213
- tensor: output
214
- """
215
- output = xs[0]
216
-
217
- if len(xs) == 2:
218
- output += self.resConfUnit1(xs[1])
219
-
220
- output = self.resConfUnit2(output)
221
-
222
- output = nn.functional.interpolate(
223
- output, scale_factor=2, mode="bilinear", align_corners=True
224
- )
225
-
226
- return output
227
-
228
-
229
-
230
-
231
- class ResidualConvUnit_custom(nn.Module):
232
- """Residual convolution module.
233
- """
234
-
235
- def __init__(self, features, activation, bn):
236
- """Init.
237
-
238
- Args:
239
- features (int): number of features
240
- """
241
- super().__init__()
242
-
243
- self.bn = bn
244
-
245
- self.groups=1
246
-
247
- self.conv1 = nn.Conv2d(
248
- features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
- )
250
-
251
- self.conv2 = nn.Conv2d(
252
- features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
- )
254
-
255
- if self.bn==True:
256
- self.bn1 = nn.BatchNorm2d(features)
257
- self.bn2 = nn.BatchNorm2d(features)
258
-
259
- self.activation = activation
260
-
261
- self.skip_add = nn.quantized.FloatFunctional()
262
-
263
- def forward(self, x):
264
- """Forward pass.
265
-
266
- Args:
267
- x (tensor): input
268
-
269
- Returns:
270
- tensor: output
271
- """
272
-
273
- out = self.activation(x)
274
- out = self.conv1(out)
275
- if self.bn==True:
276
- out = self.bn1(out)
277
-
278
- out = self.activation(out)
279
- out = self.conv2(out)
280
- if self.bn==True:
281
- out = self.bn2(out)
282
-
283
- if self.groups > 1:
284
- out = self.conv_merge(out)
285
-
286
- return self.skip_add.add(out, x)
287
-
288
- # return out + x
289
-
290
-
291
- class FeatureFusionBlock_custom(nn.Module):
292
- """Feature fusion block.
293
- """
294
-
295
- def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
- """Init.
297
-
298
- Args:
299
- features (int): number of features
300
- """
301
- super(FeatureFusionBlock_custom, self).__init__()
302
-
303
- self.deconv = deconv
304
- self.align_corners = align_corners
305
-
306
- self.groups=1
307
-
308
- self.expand = expand
309
- out_features = features
310
- if self.expand==True:
311
- out_features = features//2
312
-
313
- self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
-
315
- self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
- self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
-
318
- self.skip_add = nn.quantized.FloatFunctional()
319
-
320
- def forward(self, *xs):
321
- """Forward pass.
322
-
323
- Returns:
324
- tensor: output
325
- """
326
- output = xs[0]
327
-
328
- if len(xs) == 2:
329
- res = self.resConfUnit1(xs[1])
330
- output = self.skip_add.add(output, res)
331
- # output += res
332
-
333
- output = self.resConfUnit2(output)
334
-
335
- output = nn.functional.interpolate(
336
- output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
- )
338
-
339
- output = self.out_conv(output)
340
-
341
- return output
342
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/midas/midas/dpt_depth.py DELETED
@@ -1,109 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from .base_model import BaseModel
6
- from .blocks import (
7
- FeatureFusionBlock,
8
- FeatureFusionBlock_custom,
9
- Interpolate,
10
- _make_encoder,
11
- forward_vit,
12
- )
13
-
14
-
15
- def _make_fusion_block(features, use_bn):
16
- return FeatureFusionBlock_custom(
17
- features,
18
- nn.ReLU(False),
19
- deconv=False,
20
- bn=use_bn,
21
- expand=False,
22
- align_corners=True,
23
- )
24
-
25
-
26
- class DPT(BaseModel):
27
- def __init__(
28
- self,
29
- head,
30
- features=256,
31
- backbone="vitb_rn50_384",
32
- readout="project",
33
- channels_last=False,
34
- use_bn=False,
35
- ):
36
-
37
- super(DPT, self).__init__()
38
-
39
- self.channels_last = channels_last
40
-
41
- hooks = {
42
- "vitb_rn50_384": [0, 1, 8, 11],
43
- "vitb16_384": [2, 5, 8, 11],
44
- "vitl16_384": [5, 11, 17, 23],
45
- }
46
-
47
- # Instantiate backbone and reassemble blocks
48
- self.pretrained, self.scratch = _make_encoder(
49
- backbone,
50
- features,
51
- False, # Set to true of you want to train from scratch, uses ImageNet weights
52
- groups=1,
53
- expand=False,
54
- exportable=False,
55
- hooks=hooks[backbone],
56
- use_readout=readout,
57
- )
58
-
59
- self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
- self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
- self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
- self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
-
64
- self.scratch.output_conv = head
65
-
66
-
67
- def forward(self, x):
68
- if self.channels_last == True:
69
- x.contiguous(memory_format=torch.channels_last)
70
-
71
- layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
-
73
- layer_1_rn = self.scratch.layer1_rn(layer_1)
74
- layer_2_rn = self.scratch.layer2_rn(layer_2)
75
- layer_3_rn = self.scratch.layer3_rn(layer_3)
76
- layer_4_rn = self.scratch.layer4_rn(layer_4)
77
-
78
- path_4 = self.scratch.refinenet4(layer_4_rn)
79
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
-
83
- out = self.scratch.output_conv(path_1)
84
-
85
- return out
86
-
87
-
88
- class DPTDepthModel(DPT):
89
- def __init__(self, path=None, non_negative=True, **kwargs):
90
- features = kwargs["features"] if "features" in kwargs else 256
91
-
92
- head = nn.Sequential(
93
- nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
- Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
- nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
- nn.ReLU(True),
97
- nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
- nn.ReLU(True) if non_negative else nn.Identity(),
99
- nn.Identity(),
100
- )
101
-
102
- super().__init__(head, **kwargs)
103
-
104
- if path is not None:
105
- self.load(path)
106
-
107
- def forward(self, x):
108
- return super().forward(x).squeeze(dim=1)
109
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/midas/midas/midas_net.py DELETED
@@ -1,76 +0,0 @@
1
- """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
- This file contains code that is adapted from
3
- https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
- """
5
- import torch
6
- import torch.nn as nn
7
-
8
- from .base_model import BaseModel
9
- from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
-
11
-
12
- class MidasNet(BaseModel):
13
- """Network for monocular depth estimation.
14
- """
15
-
16
- def __init__(self, path=None, features=256, non_negative=True):
17
- """Init.
18
-
19
- Args:
20
- path (str, optional): Path to saved model. Defaults to None.
21
- features (int, optional): Number of features. Defaults to 256.
22
- backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
- """
24
- print("Loading weights: ", path)
25
-
26
- super(MidasNet, self).__init__()
27
-
28
- use_pretrained = False if path is None else True
29
-
30
- self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
-
32
- self.scratch.refinenet4 = FeatureFusionBlock(features)
33
- self.scratch.refinenet3 = FeatureFusionBlock(features)
34
- self.scratch.refinenet2 = FeatureFusionBlock(features)
35
- self.scratch.refinenet1 = FeatureFusionBlock(features)
36
-
37
- self.scratch.output_conv = nn.Sequential(
38
- nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
- Interpolate(scale_factor=2, mode="bilinear"),
40
- nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
- nn.ReLU(True),
42
- nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
- nn.ReLU(True) if non_negative else nn.Identity(),
44
- )
45
-
46
- if path:
47
- self.load(path)
48
-
49
- def forward(self, x):
50
- """Forward pass.
51
-
52
- Args:
53
- x (tensor): input data (image)
54
-
55
- Returns:
56
- tensor: depth
57
- """
58
-
59
- layer_1 = self.pretrained.layer1(x)
60
- layer_2 = self.pretrained.layer2(layer_1)
61
- layer_3 = self.pretrained.layer3(layer_2)
62
- layer_4 = self.pretrained.layer4(layer_3)
63
-
64
- layer_1_rn = self.scratch.layer1_rn(layer_1)
65
- layer_2_rn = self.scratch.layer2_rn(layer_2)
66
- layer_3_rn = self.scratch.layer3_rn(layer_3)
67
- layer_4_rn = self.scratch.layer4_rn(layer_4)
68
-
69
- path_4 = self.scratch.refinenet4(layer_4_rn)
70
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
-
74
- out = self.scratch.output_conv(path_1)
75
-
76
- return torch.squeeze(out, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/midas/midas/midas_net_custom.py DELETED
@@ -1,128 +0,0 @@
1
- """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
- This file contains code that is adapted from
3
- https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
- """
5
- import torch
6
- import torch.nn as nn
7
-
8
- from .base_model import BaseModel
9
- from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
-
11
-
12
- class MidasNet_small(BaseModel):
13
- """Network for monocular depth estimation.
14
- """
15
-
16
- def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
- blocks={'expand': True}):
18
- """Init.
19
-
20
- Args:
21
- path (str, optional): Path to saved model. Defaults to None.
22
- features (int, optional): Number of features. Defaults to 256.
23
- backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
- """
25
- print("Loading weights: ", path)
26
-
27
- super(MidasNet_small, self).__init__()
28
-
29
- use_pretrained = False if path else True
30
-
31
- self.channels_last = channels_last
32
- self.blocks = blocks
33
- self.backbone = backbone
34
-
35
- self.groups = 1
36
-
37
- features1=features
38
- features2=features
39
- features3=features
40
- features4=features
41
- self.expand = False
42
- if "expand" in self.blocks and self.blocks['expand'] == True:
43
- self.expand = True
44
- features1=features
45
- features2=features*2
46
- features3=features*4
47
- features4=features*8
48
-
49
- self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
-
51
- self.scratch.activation = nn.ReLU(False)
52
-
53
- self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
- self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
- self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
- self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
-
58
-
59
- self.scratch.output_conv = nn.Sequential(
60
- nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
- Interpolate(scale_factor=2, mode="bilinear"),
62
- nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
- self.scratch.activation,
64
- nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
- nn.ReLU(True) if non_negative else nn.Identity(),
66
- nn.Identity(),
67
- )
68
-
69
- if path:
70
- self.load(path)
71
-
72
-
73
- def forward(self, x):
74
- """Forward pass.
75
-
76
- Args:
77
- x (tensor): input data (image)
78
-
79
- Returns:
80
- tensor: depth
81
- """
82
- if self.channels_last==True:
83
- print("self.channels_last = ", self.channels_last)
84
- x.contiguous(memory_format=torch.channels_last)
85
-
86
-
87
- layer_1 = self.pretrained.layer1(x)
88
- layer_2 = self.pretrained.layer2(layer_1)
89
- layer_3 = self.pretrained.layer3(layer_2)
90
- layer_4 = self.pretrained.layer4(layer_3)
91
-
92
- layer_1_rn = self.scratch.layer1_rn(layer_1)
93
- layer_2_rn = self.scratch.layer2_rn(layer_2)
94
- layer_3_rn = self.scratch.layer3_rn(layer_3)
95
- layer_4_rn = self.scratch.layer4_rn(layer_4)
96
-
97
-
98
- path_4 = self.scratch.refinenet4(layer_4_rn)
99
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
-
103
- out = self.scratch.output_conv(path_1)
104
-
105
- return torch.squeeze(out, dim=1)
106
-
107
-
108
-
109
- def fuse_model(m):
110
- prev_previous_type = nn.Identity()
111
- prev_previous_name = ''
112
- previous_type = nn.Identity()
113
- previous_name = ''
114
- for name, module in m.named_modules():
115
- if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
- # print("FUSED ", prev_previous_name, previous_name, name)
117
- torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
- elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
- # print("FUSED ", prev_previous_name, previous_name)
120
- torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
- # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
- # print("FUSED ", previous_name, name)
123
- # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
-
125
- prev_previous_type = previous_type
126
- prev_previous_name = previous_name
127
- previous_type = type(module)
128
- previous_name = name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/midas/midas/transforms.py DELETED
@@ -1,234 +0,0 @@
1
- import numpy as np
2
- import cv2
3
- import math
4
-
5
-
6
- def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
- """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
-
9
- Args:
10
- sample (dict): sample
11
- size (tuple): image size
12
-
13
- Returns:
14
- tuple: new size
15
- """
16
- shape = list(sample["disparity"].shape)
17
-
18
- if shape[0] >= size[0] and shape[1] >= size[1]:
19
- return sample
20
-
21
- scale = [0, 0]
22
- scale[0] = size[0] / shape[0]
23
- scale[1] = size[1] / shape[1]
24
-
25
- scale = max(scale)
26
-
27
- shape[0] = math.ceil(scale * shape[0])
28
- shape[1] = math.ceil(scale * shape[1])
29
-
30
- # resize
31
- sample["image"] = cv2.resize(
32
- sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
- )
34
-
35
- sample["disparity"] = cv2.resize(
36
- sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
- )
38
- sample["mask"] = cv2.resize(
39
- sample["mask"].astype(np.float32),
40
- tuple(shape[::-1]),
41
- interpolation=cv2.INTER_NEAREST,
42
- )
43
- sample["mask"] = sample["mask"].astype(bool)
44
-
45
- return tuple(shape)
46
-
47
-
48
- class Resize(object):
49
- """Resize sample to given size (width, height).
50
- """
51
-
52
- def __init__(
53
- self,
54
- width,
55
- height,
56
- resize_target=True,
57
- keep_aspect_ratio=False,
58
- ensure_multiple_of=1,
59
- resize_method="lower_bound",
60
- image_interpolation_method=cv2.INTER_AREA,
61
- ):
62
- """Init.
63
-
64
- Args:
65
- width (int): desired output width
66
- height (int): desired output height
67
- resize_target (bool, optional):
68
- True: Resize the full sample (image, mask, target).
69
- False: Resize image only.
70
- Defaults to True.
71
- keep_aspect_ratio (bool, optional):
72
- True: Keep the aspect ratio of the input sample.
73
- Output sample might not have the given width and height, and
74
- resize behaviour depends on the parameter 'resize_method'.
75
- Defaults to False.
76
- ensure_multiple_of (int, optional):
77
- Output width and height is constrained to be multiple of this parameter.
78
- Defaults to 1.
79
- resize_method (str, optional):
80
- "lower_bound": Output will be at least as large as the given size.
81
- "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
- "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
- Defaults to "lower_bound".
84
- """
85
- self.__width = width
86
- self.__height = height
87
-
88
- self.__resize_target = resize_target
89
- self.__keep_aspect_ratio = keep_aspect_ratio
90
- self.__multiple_of = ensure_multiple_of
91
- self.__resize_method = resize_method
92
- self.__image_interpolation_method = image_interpolation_method
93
-
94
- def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
- y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
-
97
- if max_val is not None and y > max_val:
98
- y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
-
100
- if y < min_val:
101
- y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
-
103
- return y
104
-
105
- def get_size(self, width, height):
106
- # determine new height and width
107
- scale_height = self.__height / height
108
- scale_width = self.__width / width
109
-
110
- if self.__keep_aspect_ratio:
111
- if self.__resize_method == "lower_bound":
112
- # scale such that output size is lower bound
113
- if scale_width > scale_height:
114
- # fit width
115
- scale_height = scale_width
116
- else:
117
- # fit height
118
- scale_width = scale_height
119
- elif self.__resize_method == "upper_bound":
120
- # scale such that output size is upper bound
121
- if scale_width < scale_height:
122
- # fit width
123
- scale_height = scale_width
124
- else:
125
- # fit height
126
- scale_width = scale_height
127
- elif self.__resize_method == "minimal":
128
- # scale as least as possbile
129
- if abs(1 - scale_width) < abs(1 - scale_height):
130
- # fit width
131
- scale_height = scale_width
132
- else:
133
- # fit height
134
- scale_width = scale_height
135
- else:
136
- raise ValueError(
137
- f"resize_method {self.__resize_method} not implemented"
138
- )
139
-
140
- if self.__resize_method == "lower_bound":
141
- new_height = self.constrain_to_multiple_of(
142
- scale_height * height, min_val=self.__height
143
- )
144
- new_width = self.constrain_to_multiple_of(
145
- scale_width * width, min_val=self.__width
146
- )
147
- elif self.__resize_method == "upper_bound":
148
- new_height = self.constrain_to_multiple_of(
149
- scale_height * height, max_val=self.__height
150
- )
151
- new_width = self.constrain_to_multiple_of(
152
- scale_width * width, max_val=self.__width
153
- )
154
- elif self.__resize_method == "minimal":
155
- new_height = self.constrain_to_multiple_of(scale_height * height)
156
- new_width = self.constrain_to_multiple_of(scale_width * width)
157
- else:
158
- raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
-
160
- return (new_width, new_height)
161
-
162
- def __call__(self, sample):
163
- width, height = self.get_size(
164
- sample["image"].shape[1], sample["image"].shape[0]
165
- )
166
-
167
- # resize sample
168
- sample["image"] = cv2.resize(
169
- sample["image"],
170
- (width, height),
171
- interpolation=self.__image_interpolation_method,
172
- )
173
-
174
- if self.__resize_target:
175
- if "disparity" in sample:
176
- sample["disparity"] = cv2.resize(
177
- sample["disparity"],
178
- (width, height),
179
- interpolation=cv2.INTER_NEAREST,
180
- )
181
-
182
- if "depth" in sample:
183
- sample["depth"] = cv2.resize(
184
- sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
- )
186
-
187
- sample["mask"] = cv2.resize(
188
- sample["mask"].astype(np.float32),
189
- (width, height),
190
- interpolation=cv2.INTER_NEAREST,
191
- )
192
- sample["mask"] = sample["mask"].astype(bool)
193
-
194
- return sample
195
-
196
-
197
- class NormalizeImage(object):
198
- """Normlize image by given mean and std.
199
- """
200
-
201
- def __init__(self, mean, std):
202
- self.__mean = mean
203
- self.__std = std
204
-
205
- def __call__(self, sample):
206
- sample["image"] = (sample["image"] - self.__mean) / self.__std
207
-
208
- return sample
209
-
210
-
211
- class PrepareForNet(object):
212
- """Prepare sample for usage as network input.
213
- """
214
-
215
- def __init__(self):
216
- pass
217
-
218
- def __call__(self, sample):
219
- image = np.transpose(sample["image"], (2, 0, 1))
220
- sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
-
222
- if "mask" in sample:
223
- sample["mask"] = sample["mask"].astype(np.float32)
224
- sample["mask"] = np.ascontiguousarray(sample["mask"])
225
-
226
- if "disparity" in sample:
227
- disparity = sample["disparity"].astype(np.float32)
228
- sample["disparity"] = np.ascontiguousarray(disparity)
229
-
230
- if "depth" in sample:
231
- depth = sample["depth"].astype(np.float32)
232
- sample["depth"] = np.ascontiguousarray(depth)
233
-
234
- return sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/midas/midas/vit.py DELETED
@@ -1,491 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import timm
4
- import types
5
- import math
6
- import torch.nn.functional as F
7
-
8
-
9
- class Slice(nn.Module):
10
- def __init__(self, start_index=1):
11
- super(Slice, self).__init__()
12
- self.start_index = start_index
13
-
14
- def forward(self, x):
15
- return x[:, self.start_index :]
16
-
17
-
18
- class AddReadout(nn.Module):
19
- def __init__(self, start_index=1):
20
- super(AddReadout, self).__init__()
21
- self.start_index = start_index
22
-
23
- def forward(self, x):
24
- if self.start_index == 2:
25
- readout = (x[:, 0] + x[:, 1]) / 2
26
- else:
27
- readout = x[:, 0]
28
- return x[:, self.start_index :] + readout.unsqueeze(1)
29
-
30
-
31
- class ProjectReadout(nn.Module):
32
- def __init__(self, in_features, start_index=1):
33
- super(ProjectReadout, self).__init__()
34
- self.start_index = start_index
35
-
36
- self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
-
38
- def forward(self, x):
39
- readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
- features = torch.cat((x[:, self.start_index :], readout), -1)
41
-
42
- return self.project(features)
43
-
44
-
45
- class Transpose(nn.Module):
46
- def __init__(self, dim0, dim1):
47
- super(Transpose, self).__init__()
48
- self.dim0 = dim0
49
- self.dim1 = dim1
50
-
51
- def forward(self, x):
52
- x = x.transpose(self.dim0, self.dim1)
53
- return x
54
-
55
-
56
- def forward_vit(pretrained, x):
57
- b, c, h, w = x.shape
58
-
59
- glob = pretrained.model.forward_flex(x)
60
-
61
- layer_1 = pretrained.activations["1"]
62
- layer_2 = pretrained.activations["2"]
63
- layer_3 = pretrained.activations["3"]
64
- layer_4 = pretrained.activations["4"]
65
-
66
- layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67
- layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68
- layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69
- layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70
-
71
- unflatten = nn.Sequential(
72
- nn.Unflatten(
73
- 2,
74
- torch.Size(
75
- [
76
- h // pretrained.model.patch_size[1],
77
- w // pretrained.model.patch_size[0],
78
- ]
79
- ),
80
- )
81
- )
82
-
83
- if layer_1.ndim == 3:
84
- layer_1 = unflatten(layer_1)
85
- if layer_2.ndim == 3:
86
- layer_2 = unflatten(layer_2)
87
- if layer_3.ndim == 3:
88
- layer_3 = unflatten(layer_3)
89
- if layer_4.ndim == 3:
90
- layer_4 = unflatten(layer_4)
91
-
92
- layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93
- layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94
- layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95
- layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96
-
97
- return layer_1, layer_2, layer_3, layer_4
98
-
99
-
100
- def _resize_pos_embed(self, posemb, gs_h, gs_w):
101
- posemb_tok, posemb_grid = (
102
- posemb[:, : self.start_index],
103
- posemb[0, self.start_index :],
104
- )
105
-
106
- gs_old = int(math.sqrt(len(posemb_grid)))
107
-
108
- posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109
- posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110
- posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111
-
112
- posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113
-
114
- return posemb
115
-
116
-
117
- def forward_flex(self, x):
118
- b, c, h, w = x.shape
119
-
120
- pos_embed = self._resize_pos_embed(
121
- self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122
- )
123
-
124
- B = x.shape[0]
125
-
126
- if hasattr(self.patch_embed, "backbone"):
127
- x = self.patch_embed.backbone(x)
128
- if isinstance(x, (list, tuple)):
129
- x = x[-1] # last feature if backbone outputs list/tuple of features
130
-
131
- x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132
-
133
- if getattr(self, "dist_token", None) is not None:
134
- cls_tokens = self.cls_token.expand(
135
- B, -1, -1
136
- ) # stole cls_tokens impl from Phil Wang, thanks
137
- dist_token = self.dist_token.expand(B, -1, -1)
138
- x = torch.cat((cls_tokens, dist_token, x), dim=1)
139
- else:
140
- cls_tokens = self.cls_token.expand(
141
- B, -1, -1
142
- ) # stole cls_tokens impl from Phil Wang, thanks
143
- x = torch.cat((cls_tokens, x), dim=1)
144
-
145
- x = x + pos_embed
146
- x = self.pos_drop(x)
147
-
148
- for blk in self.blocks:
149
- x = blk(x)
150
-
151
- x = self.norm(x)
152
-
153
- return x
154
-
155
-
156
- activations = {}
157
-
158
-
159
- def get_activation(name):
160
- def hook(model, input, output):
161
- activations[name] = output
162
-
163
- return hook
164
-
165
-
166
- def get_readout_oper(vit_features, features, use_readout, start_index=1):
167
- if use_readout == "ignore":
168
- readout_oper = [Slice(start_index)] * len(features)
169
- elif use_readout == "add":
170
- readout_oper = [AddReadout(start_index)] * len(features)
171
- elif use_readout == "project":
172
- readout_oper = [
173
- ProjectReadout(vit_features, start_index) for out_feat in features
174
- ]
175
- else:
176
- assert (
177
- False
178
- ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179
-
180
- return readout_oper
181
-
182
-
183
- def _make_vit_b16_backbone(
184
- model,
185
- features=[96, 192, 384, 768],
186
- size=[384, 384],
187
- hooks=[2, 5, 8, 11],
188
- vit_features=768,
189
- use_readout="ignore",
190
- start_index=1,
191
- ):
192
- pretrained = nn.Module()
193
-
194
- pretrained.model = model
195
- pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196
- pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197
- pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198
- pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199
-
200
- pretrained.activations = activations
201
-
202
- readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203
-
204
- # 32, 48, 136, 384
205
- pretrained.act_postprocess1 = nn.Sequential(
206
- readout_oper[0],
207
- Transpose(1, 2),
208
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209
- nn.Conv2d(
210
- in_channels=vit_features,
211
- out_channels=features[0],
212
- kernel_size=1,
213
- stride=1,
214
- padding=0,
215
- ),
216
- nn.ConvTranspose2d(
217
- in_channels=features[0],
218
- out_channels=features[0],
219
- kernel_size=4,
220
- stride=4,
221
- padding=0,
222
- bias=True,
223
- dilation=1,
224
- groups=1,
225
- ),
226
- )
227
-
228
- pretrained.act_postprocess2 = nn.Sequential(
229
- readout_oper[1],
230
- Transpose(1, 2),
231
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232
- nn.Conv2d(
233
- in_channels=vit_features,
234
- out_channels=features[1],
235
- kernel_size=1,
236
- stride=1,
237
- padding=0,
238
- ),
239
- nn.ConvTranspose2d(
240
- in_channels=features[1],
241
- out_channels=features[1],
242
- kernel_size=2,
243
- stride=2,
244
- padding=0,
245
- bias=True,
246
- dilation=1,
247
- groups=1,
248
- ),
249
- )
250
-
251
- pretrained.act_postprocess3 = nn.Sequential(
252
- readout_oper[2],
253
- Transpose(1, 2),
254
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255
- nn.Conv2d(
256
- in_channels=vit_features,
257
- out_channels=features[2],
258
- kernel_size=1,
259
- stride=1,
260
- padding=0,
261
- ),
262
- )
263
-
264
- pretrained.act_postprocess4 = nn.Sequential(
265
- readout_oper[3],
266
- Transpose(1, 2),
267
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268
- nn.Conv2d(
269
- in_channels=vit_features,
270
- out_channels=features[3],
271
- kernel_size=1,
272
- stride=1,
273
- padding=0,
274
- ),
275
- nn.Conv2d(
276
- in_channels=features[3],
277
- out_channels=features[3],
278
- kernel_size=3,
279
- stride=2,
280
- padding=1,
281
- ),
282
- )
283
-
284
- pretrained.model.start_index = start_index
285
- pretrained.model.patch_size = [16, 16]
286
-
287
- # We inject this function into the VisionTransformer instances so that
288
- # we can use it with interpolated position embeddings without modifying the library source.
289
- pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290
- pretrained.model._resize_pos_embed = types.MethodType(
291
- _resize_pos_embed, pretrained.model
292
- )
293
-
294
- return pretrained
295
-
296
-
297
- def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298
- model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299
-
300
- hooks = [5, 11, 17, 23] if hooks == None else hooks
301
- return _make_vit_b16_backbone(
302
- model,
303
- features=[256, 512, 1024, 1024],
304
- hooks=hooks,
305
- vit_features=1024,
306
- use_readout=use_readout,
307
- )
308
-
309
-
310
- def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311
- model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312
-
313
- hooks = [2, 5, 8, 11] if hooks == None else hooks
314
- return _make_vit_b16_backbone(
315
- model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316
- )
317
-
318
-
319
- def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320
- model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321
-
322
- hooks = [2, 5, 8, 11] if hooks == None else hooks
323
- return _make_vit_b16_backbone(
324
- model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325
- )
326
-
327
-
328
- def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329
- model = timm.create_model(
330
- "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331
- )
332
-
333
- hooks = [2, 5, 8, 11] if hooks == None else hooks
334
- return _make_vit_b16_backbone(
335
- model,
336
- features=[96, 192, 384, 768],
337
- hooks=hooks,
338
- use_readout=use_readout,
339
- start_index=2,
340
- )
341
-
342
-
343
- def _make_vit_b_rn50_backbone(
344
- model,
345
- features=[256, 512, 768, 768],
346
- size=[384, 384],
347
- hooks=[0, 1, 8, 11],
348
- vit_features=768,
349
- use_vit_only=False,
350
- use_readout="ignore",
351
- start_index=1,
352
- ):
353
- pretrained = nn.Module()
354
-
355
- pretrained.model = model
356
-
357
- if use_vit_only == True:
358
- pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359
- pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360
- else:
361
- pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362
- get_activation("1")
363
- )
364
- pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365
- get_activation("2")
366
- )
367
-
368
- pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369
- pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370
-
371
- pretrained.activations = activations
372
-
373
- readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374
-
375
- if use_vit_only == True:
376
- pretrained.act_postprocess1 = nn.Sequential(
377
- readout_oper[0],
378
- Transpose(1, 2),
379
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380
- nn.Conv2d(
381
- in_channels=vit_features,
382
- out_channels=features[0],
383
- kernel_size=1,
384
- stride=1,
385
- padding=0,
386
- ),
387
- nn.ConvTranspose2d(
388
- in_channels=features[0],
389
- out_channels=features[0],
390
- kernel_size=4,
391
- stride=4,
392
- padding=0,
393
- bias=True,
394
- dilation=1,
395
- groups=1,
396
- ),
397
- )
398
-
399
- pretrained.act_postprocess2 = nn.Sequential(
400
- readout_oper[1],
401
- Transpose(1, 2),
402
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403
- nn.Conv2d(
404
- in_channels=vit_features,
405
- out_channels=features[1],
406
- kernel_size=1,
407
- stride=1,
408
- padding=0,
409
- ),
410
- nn.ConvTranspose2d(
411
- in_channels=features[1],
412
- out_channels=features[1],
413
- kernel_size=2,
414
- stride=2,
415
- padding=0,
416
- bias=True,
417
- dilation=1,
418
- groups=1,
419
- ),
420
- )
421
- else:
422
- pretrained.act_postprocess1 = nn.Sequential(
423
- nn.Identity(), nn.Identity(), nn.Identity()
424
- )
425
- pretrained.act_postprocess2 = nn.Sequential(
426
- nn.Identity(), nn.Identity(), nn.Identity()
427
- )
428
-
429
- pretrained.act_postprocess3 = nn.Sequential(
430
- readout_oper[2],
431
- Transpose(1, 2),
432
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433
- nn.Conv2d(
434
- in_channels=vit_features,
435
- out_channels=features[2],
436
- kernel_size=1,
437
- stride=1,
438
- padding=0,
439
- ),
440
- )
441
-
442
- pretrained.act_postprocess4 = nn.Sequential(
443
- readout_oper[3],
444
- Transpose(1, 2),
445
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446
- nn.Conv2d(
447
- in_channels=vit_features,
448
- out_channels=features[3],
449
- kernel_size=1,
450
- stride=1,
451
- padding=0,
452
- ),
453
- nn.Conv2d(
454
- in_channels=features[3],
455
- out_channels=features[3],
456
- kernel_size=3,
457
- stride=2,
458
- padding=1,
459
- ),
460
- )
461
-
462
- pretrained.model.start_index = start_index
463
- pretrained.model.patch_size = [16, 16]
464
-
465
- # We inject this function into the VisionTransformer instances so that
466
- # we can use it with interpolated position embeddings without modifying the library source.
467
- pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468
-
469
- # We inject this function into the VisionTransformer instances so that
470
- # we can use it with interpolated position embeddings without modifying the library source.
471
- pretrained.model._resize_pos_embed = types.MethodType(
472
- _resize_pos_embed, pretrained.model
473
- )
474
-
475
- return pretrained
476
-
477
-
478
- def _make_pretrained_vitb_rn50_384(
479
- pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480
- ):
481
- model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482
-
483
- hooks = [0, 1, 8, 11] if hooks == None else hooks
484
- return _make_vit_b_rn50_backbone(
485
- model,
486
- features=[256, 512, 768, 768],
487
- size=[384, 384],
488
- hooks=hooks,
489
- use_vit_only=use_vit_only,
490
- use_readout=use_readout,
491
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/midas/utils.py DELETED
@@ -1,189 +0,0 @@
1
- """Utils for monoDepth."""
2
- import sys
3
- import re
4
- import numpy as np
5
- import cv2
6
- import torch
7
-
8
-
9
- def read_pfm(path):
10
- """Read pfm file.
11
-
12
- Args:
13
- path (str): path to file
14
-
15
- Returns:
16
- tuple: (data, scale)
17
- """
18
- with open(path, "rb") as file:
19
-
20
- color = None
21
- width = None
22
- height = None
23
- scale = None
24
- endian = None
25
-
26
- header = file.readline().rstrip()
27
- if header.decode("ascii") == "PF":
28
- color = True
29
- elif header.decode("ascii") == "Pf":
30
- color = False
31
- else:
32
- raise Exception("Not a PFM file: " + path)
33
-
34
- dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35
- if dim_match:
36
- width, height = list(map(int, dim_match.groups()))
37
- else:
38
- raise Exception("Malformed PFM header.")
39
-
40
- scale = float(file.readline().decode("ascii").rstrip())
41
- if scale < 0:
42
- # little-endian
43
- endian = "<"
44
- scale = -scale
45
- else:
46
- # big-endian
47
- endian = ">"
48
-
49
- data = np.fromfile(file, endian + "f")
50
- shape = (height, width, 3) if color else (height, width)
51
-
52
- data = np.reshape(data, shape)
53
- data = np.flipud(data)
54
-
55
- return data, scale
56
-
57
-
58
- def write_pfm(path, image, scale=1):
59
- """Write pfm file.
60
-
61
- Args:
62
- path (str): pathto file
63
- image (array): data
64
- scale (int, optional): Scale. Defaults to 1.
65
- """
66
-
67
- with open(path, "wb") as file:
68
- color = None
69
-
70
- if image.dtype.name != "float32":
71
- raise Exception("Image dtype must be float32.")
72
-
73
- image = np.flipud(image)
74
-
75
- if len(image.shape) == 3 and image.shape[2] == 3: # color image
76
- color = True
77
- elif (
78
- len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79
- ): # greyscale
80
- color = False
81
- else:
82
- raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83
-
84
- file.write("PF\n" if color else "Pf\n".encode())
85
- file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86
-
87
- endian = image.dtype.byteorder
88
-
89
- if endian == "<" or endian == "=" and sys.byteorder == "little":
90
- scale = -scale
91
-
92
- file.write("%f\n".encode() % scale)
93
-
94
- image.tofile(file)
95
-
96
-
97
- def read_image(path):
98
- """Read image and output RGB image (0-1).
99
-
100
- Args:
101
- path (str): path to file
102
-
103
- Returns:
104
- array: RGB image (0-1)
105
- """
106
- img = cv2.imread(path)
107
-
108
- if img.ndim == 2:
109
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110
-
111
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112
-
113
- return img
114
-
115
-
116
- def resize_image(img):
117
- """Resize image and make it fit for network.
118
-
119
- Args:
120
- img (array): image
121
-
122
- Returns:
123
- tensor: data ready for network
124
- """
125
- height_orig = img.shape[0]
126
- width_orig = img.shape[1]
127
-
128
- if width_orig > height_orig:
129
- scale = width_orig / 384
130
- else:
131
- scale = height_orig / 384
132
-
133
- height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134
- width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135
-
136
- img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137
-
138
- img_resized = (
139
- torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140
- )
141
- img_resized = img_resized.unsqueeze(0)
142
-
143
- return img_resized
144
-
145
-
146
- def resize_depth(depth, width, height):
147
- """Resize depth map and bring to CPU (numpy).
148
-
149
- Args:
150
- depth (tensor): depth
151
- width (int): image width
152
- height (int): image height
153
-
154
- Returns:
155
- array: processed depth
156
- """
157
- depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158
-
159
- depth_resized = cv2.resize(
160
- depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161
- )
162
-
163
- return depth_resized
164
-
165
- def write_depth(path, depth, bits=1):
166
- """Write depth map to pfm and png file.
167
-
168
- Args:
169
- path (str): filepath without extension
170
- depth (array): depth
171
- """
172
- write_pfm(path + ".pfm", depth.astype(np.float32))
173
-
174
- depth_min = depth.min()
175
- depth_max = depth.max()
176
-
177
- max_val = (2**(8*bits))-1
178
-
179
- if depth_max - depth_min > np.finfo("float").eps:
180
- out = max_val * (depth - depth_min) / (depth_max - depth_min)
181
- else:
182
- out = np.zeros(depth.shape, dtype=depth.type)
183
-
184
- if bits == 1:
185
- cv2.imwrite(path + ".png", out.astype("uint8"))
186
- elif bits == 2:
187
- cv2.imwrite(path + ".png", out.astype("uint16"))
188
-
189
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/util.py DELETED
@@ -1,129 +0,0 @@
1
- import random
2
-
3
- import numpy as np
4
- import cv2
5
- import os
6
- import PIL
7
-
8
- annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
9
-
10
- def HWC3(x):
11
- assert x.dtype == np.uint8
12
- if x.ndim == 2:
13
- x = x[:, :, None]
14
- assert x.ndim == 3
15
- H, W, C = x.shape
16
- assert C == 1 or C == 3 or C == 4
17
- if C == 3:
18
- return x
19
- if C == 1:
20
- return np.concatenate([x, x, x], axis=2)
21
- if C == 4:
22
- color = x[:, :, 0:3].astype(np.float32)
23
- alpha = x[:, :, 3:4].astype(np.float32) / 255.0
24
- y = color * alpha + 255.0 * (1.0 - alpha)
25
- y = y.clip(0, 255).astype(np.uint8)
26
- return y
27
-
28
-
29
- def resize_image(input_image, resolution, short = False, interpolation=None):
30
- if isinstance(input_image,PIL.Image.Image):
31
- mode = 'pil'
32
- W,H = input_image.size
33
-
34
- elif isinstance(input_image,np.ndarray):
35
- mode = 'cv2'
36
- H, W, _ = input_image.shape
37
-
38
- H = float(H)
39
- W = float(W)
40
- if short:
41
- k = float(resolution) / min(H, W) # k>1 放大, k<1 缩小
42
- else:
43
- k = float(resolution) / max(H, W) # k>1 放大, k<1 缩小
44
- H *= k
45
- W *= k
46
- H = int(np.round(H / 64.0)) * 64
47
- W = int(np.round(W / 64.0)) * 64
48
-
49
- if mode == 'cv2':
50
- if interpolation is None:
51
- interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
52
- img = cv2.resize(input_image, (W, H), interpolation=interpolation)
53
-
54
- elif mode == 'pil':
55
- if interpolation is None:
56
- interpolation = PIL.Image.LANCZOS if k > 1 else PIL.Image.BILINEAR
57
- img = input_image.resize((W, H), resample=interpolation)
58
-
59
- return img
60
-
61
- # def resize_image(input_image, resolution):
62
- # H, W, C = input_image.shape
63
- # H = float(H)
64
- # W = float(W)
65
- # k = float(resolution) / min(H, W)
66
- # H *= k
67
- # W *= k
68
- # H = int(np.round(H / 64.0)) * 64
69
- # W = int(np.round(W / 64.0)) * 64
70
- # img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
71
- # return img
72
-
73
-
74
- def nms(x, t, s):
75
- x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
76
-
77
- f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
78
- f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
79
- f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
80
- f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
81
-
82
- y = np.zeros_like(x)
83
-
84
- for f in [f1, f2, f3, f4]:
85
- np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
86
-
87
- z = np.zeros_like(y, dtype=np.uint8)
88
- z[y > t] = 255
89
- return z
90
-
91
-
92
- def make_noise_disk(H, W, C, F):
93
- noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
94
- noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
95
- noise = noise[F: F + H, F: F + W]
96
- noise -= np.min(noise)
97
- noise /= np.max(noise)
98
- if C == 1:
99
- noise = noise[:, :, None]
100
- return noise
101
-
102
-
103
- def min_max_norm(x):
104
- x -= np.min(x)
105
- x /= np.maximum(np.max(x), 1e-5)
106
- return x
107
-
108
-
109
- def safe_step(x, step=2):
110
- y = x.astype(np.float32) * float(step + 1)
111
- y = y.astype(np.int32).astype(np.float32) / float(step)
112
- return y
113
-
114
-
115
- def img2mask(img, H, W, low=10, high=90):
116
- assert img.ndim == 3 or img.ndim == 2
117
- assert img.dtype == np.uint8
118
-
119
- if img.ndim == 3:
120
- y = img[:, :, random.randrange(0, img.shape[2])]
121
- else:
122
- y = img
123
-
124
- y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
125
-
126
- if random.uniform(0, 1) < 0.5:
127
- y = 255 - y
128
-
129
- return y < np.percentile(y, random.randrange(low, high))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -2,158 +2,126 @@ import spaces
2
  import random
3
  import torch
4
  import cv2
 
5
  import gradio as gr
6
  import numpy as np
7
  from huggingface_hub import snapshot_download
8
  from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
9
- from diffusers.utils import load_image
10
- from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
11
  from kolors.models.modeling_chatglm import ChatGLMModel
12
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
13
- from kolors.models.controlnet import ControlNetModel
14
- from diffusers import AutoencoderKL
15
  from kolors.models.unet_2d_condition import UNet2DConditionModel
16
  from diffusers import EulerDiscreteScheduler
17
  from PIL import Image
18
- from annotator.midas import MidasDetector
19
- from annotator.util import resize_image, HWC3
20
 
21
 
22
  device = "cuda"
23
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
24
- ckpt_dir_depth = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Depth")
25
- ckpt_dir_canny = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Canny")
26
 
27
  text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
28
  tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
29
  vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
30
  scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
31
  unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
32
- controlnet_depth = ControlNetModel.from_pretrained(f"{ckpt_dir_depth}", revision=None).half().to(device)
33
- controlnet_canny = ControlNetModel.from_pretrained(f"{ckpt_dir_canny}", revision=None).half().to(device)
34
-
35
- pipe_depth = StableDiffusionXLControlNetImg2ImgPipeline(
36
- vae=vae,
37
- controlnet = controlnet_depth,
38
- text_encoder=text_encoder,
39
- tokenizer=tokenizer,
40
- unet=unet,
41
- scheduler=scheduler,
42
- force_zeros_for_empty_prompt=False
 
 
43
  )
44
 
45
- pipe_canny = StableDiffusionXLControlNetImg2ImgPipeline(
46
- vae=vae,
47
- controlnet = controlnet_canny,
48
- text_encoder=text_encoder,
49
- tokenizer=tokenizer,
50
- unet=unet,
51
- scheduler=scheduler,
52
- force_zeros_for_empty_prompt=False
53
- )
54
 
55
- @spaces.GPU
56
- def process_canny_condition(image, canny_threods=[100,200]):
57
- np_image = image.copy()
58
- np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1])
59
- np_image = np_image[:, :, None]
60
- np_image = np.concatenate([np_image, np_image, np_image], axis=2)
61
- np_image = HWC3(np_image)
62
- return Image.fromarray(np_image)
63
 
64
- model_midas = MidasDetector()
 
 
 
 
65
 
66
- @spaces.GPU
67
- def process_depth_condition_midas(img, res = 1024):
68
- h,w,_ = img.shape
69
- img = resize_image(HWC3(img), res)
70
- result = HWC3(model_midas(img))
71
- result = cv2.resize(result, (w,h))
72
- return Image.fromarray(result)
 
 
 
 
 
 
 
73
 
74
  MAX_SEED = np.iinfo(np.int32).max
75
  MAX_IMAGE_SIZE = 1024
76
 
77
  @spaces.GPU
78
- def infer_depth(prompt,
79
  image = None,
80
  negative_prompt = "nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯",
81
- seed = 397886929,
82
  randomize_seed = False,
83
  guidance_scale = 6.0,
84
- num_inference_steps = 50,
85
- controlnet_conditioning_scale = 0.7,
86
- control_guidance_end = 0.9,
87
- strength = 1.0
88
  ):
89
  if randomize_seed:
90
  seed = random.randint(0, MAX_SEED)
91
  generator = torch.Generator().manual_seed(seed)
92
- init_image = resize_image(image, MAX_IMAGE_SIZE)
93
- pipe = pipe_depth.to("cuda")
94
- condi_img = process_depth_condition_midas( np.array(init_image), MAX_IMAGE_SIZE)
95
- image = pipe(
96
- prompt= prompt ,
97
- image = init_image,
98
- controlnet_conditioning_scale = controlnet_conditioning_scale,
99
- control_guidance_end = control_guidance_end,
100
- strength= strength ,
101
- control_image = condi_img,
102
- negative_prompt= negative_prompt ,
103
- num_inference_steps= num_inference_steps,
104
- guidance_scale= guidance_scale,
105
- num_images_per_prompt=1,
106
- generator=generator,
107
- ).images[0]
108
- return [condi_img, image], seed
109
 
110
- @spaces.GPU
111
- def infer_canny(prompt,
112
- image = None,
113
- negative_prompt = "nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯",
114
- seed = 397886929,
115
- randomize_seed = False,
116
- guidance_scale = 6.0,
117
- num_inference_steps = 50,
118
- controlnet_conditioning_scale = 0.7,
119
- control_guidance_end = 0.9,
120
- strength = 1.0
121
- ):
122
- if randomize_seed:
123
- seed = random.randint(0, MAX_SEED)
124
- generator = torch.Generator().manual_seed(seed)
125
- init_image = resize_image(image, MAX_IMAGE_SIZE)
126
- pipe = pipe_canny.to("cuda")
127
- condi_img = process_canny_condition(np.array(init_image))
128
  image = pipe(
129
- prompt= prompt ,
130
- image = init_image,
131
- controlnet_conditioning_scale = controlnet_conditioning_scale,
132
- control_guidance_end = control_guidance_end,
133
- strength= strength ,
134
- control_image = condi_img,
135
- negative_prompt= negative_prompt ,
136
  num_inference_steps= num_inference_steps,
137
- guidance_scale= guidance_scale,
138
- num_images_per_prompt=1,
139
- generator=generator,
 
 
140
  ).images[0]
141
- return [condi_img, image], seed
142
 
143
- canny_examples = [
144
- ["一个漂亮的女孩,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K",
145
- "image/woman_1.png"],
146
- ["全景,一只可爱的白色小狗坐在杯子里,看向镜头,动漫风格,3d渲染,辛烷值渲染",
147
- "image/dog.png"]
148
- ]
149
 
150
- depth_examples = [
151
- ["新海诚风格,丰富的色彩,穿着绿色衬衫的女人站在田野里,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质",
152
- "image/woman_2.png"],
153
- ["一只颜色鲜艳的小鸟,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K",
154
- "image/bird.png"]
155
  ]
156
 
 
157
  css="""
158
  #col-left {
159
  margin: 0 auto;
@@ -190,7 +158,6 @@ with gr.Blocks(css=css) as Kolors:
190
  label="Negative prompt",
191
  placeholder="Enter a negative prompt",
192
  visible=True,
193
- value="nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯"
194
  )
195
  seed = gr.Slider(
196
  label="Seed",
@@ -206,73 +173,35 @@ with gr.Blocks(css=css) as Kolors:
206
  minimum=0.0,
207
  maximum=10.0,
208
  step=0.1,
209
- value=6.0,
210
  )
211
  num_inference_steps = gr.Slider(
212
  label="Number of inference steps",
213
  minimum=10,
214
  maximum=50,
215
  step=1,
216
- value=30,
217
- )
218
- with gr.Row():
219
- controlnet_conditioning_scale = gr.Slider(
220
- label="Controlnet Conditioning Scale",
221
- minimum=0.0,
222
- maximum=1.0,
223
- step=0.1,
224
- value=0.7,
225
- )
226
- control_guidance_end = gr.Slider(
227
- label="Control Guidance End",
228
- minimum=0.0,
229
- maximum=1.0,
230
- step=0.1,
231
- value=0.9,
232
- )
233
- with gr.Row():
234
- strength = gr.Slider(
235
- label="Strength",
236
- minimum=0.0,
237
- maximum=1.0,
238
- step=0.1,
239
- value=1.0,
240
  )
241
  with gr.Row():
242
- canny_button = gr.Button("Canny", elem_id="button")
243
- depth_button = gr.Button("Depth", elem_id="button")
244
 
245
  with gr.Column(elem_id="col-right"):
246
- result = gr.Gallery(label="Result", show_label=False, columns=2)
247
  seed_used = gr.Number(label="Seed Used")
248
 
249
  with gr.Row():
250
  gr.Examples(
251
- fn = infer_canny,
252
- examples = canny_examples,
253
  inputs = [prompt, image],
254
  outputs = [result, seed_used],
255
- label = "Canny"
256
- )
257
- with gr.Row():
258
- gr.Examples(
259
- fn = infer_depth,
260
- examples = depth_examples,
261
- inputs = [prompt, image],
262
- outputs = [result, seed_used],
263
- label = "Depth"
264
  )
265
 
266
- canny_button.click(
267
- fn = infer_canny,
268
- inputs = [prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength],
269
  outputs = [result, seed_used]
270
  )
271
 
272
- depth_button.click(
273
- fn = infer_depth,
274
- inputs = [prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength],
275
- outputs = [result, seed_used]
276
- )
277
 
278
  Kolors.queue().launch(debug=True)
 
2
  import random
3
  import torch
4
  import cv2
5
+ import insightface
6
  import gradio as gr
7
  import numpy as np
8
  from huggingface_hub import snapshot_download
9
  from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
10
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline
 
11
  from kolors.models.modeling_chatglm import ChatGLMModel
12
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
13
+ from diffusers import AutoencoderKL
 
14
  from kolors.models.unet_2d_condition import UNet2DConditionModel
15
  from diffusers import EulerDiscreteScheduler
16
  from PIL import Image
17
+ from insightface.app import FaceAnalysis
18
+ from insightface.data import get_image as ins_get_image
19
 
20
 
21
  device = "cuda"
22
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
23
+ ckpt_dir_faceid = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus")
 
24
 
25
  text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
26
  tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
27
  vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
28
  scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
29
  unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
30
+ clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_dir_faceid}/clip-vit-large-patch14-336', ignore_mismatched_sizes=True)
31
+ clip_image_encoder.to(device)
32
+ clip_image_processor = CLIPImageProcessor(size = 336, crop_size = 336)
33
+
34
+ pipe = StableDiffusionXLPipeline(
35
+ vae = vae,
36
+ text_encoder = text_encoder,
37
+ tokenizer = tokenizer,
38
+ unet = unet,
39
+ scheduler = scheduler,
40
+ face_clip_encoder = clip_image_encoder,
41
+ face_clip_processor = clip_image_processor,
42
+ force_zeros_for_empty_prompt = False,
43
  )
44
 
45
+ class FaceInfoGenerator():
46
+ def __init__(self, root_dir = "./"):
47
+ self.app = FaceAnalysis(name = 'antelopev2', root = root_dir, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
48
+ self.app.prepare(ctx_id = 0, det_size = (640, 640))
 
 
 
 
 
49
 
50
+ def get_faceinfo_one_img(self, face_image):
51
+ face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
 
 
 
 
 
 
52
 
53
+ if len(face_info) == 0:
54
+ face_info = None
55
+ else:
56
+ face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
57
+ return face_info
58
 
59
+ def face_bbox_to_square(bbox):
60
+ ## l, t, r, b to square l, t, r, b
61
+ l,t,r,b = bbox
62
+ cent_x = (l + r) / 2
63
+ cent_y = (t + b) / 2
64
+ w, h = r - l, b - t
65
+ r = max(w, h) / 2
66
+
67
+ l0 = cent_x - r
68
+ r0 = cent_x + r
69
+ t0 = cent_y - r
70
+ b0 = cent_y + r
71
+
72
+ return [l0, t0, r0, b0]
73
 
74
  MAX_SEED = np.iinfo(np.int32).max
75
  MAX_IMAGE_SIZE = 1024
76
 
77
  @spaces.GPU
78
+ def infer(prompt,
79
  image = None,
80
  negative_prompt = "nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯",
81
+ seed = 66,
82
  randomize_seed = False,
83
  guidance_scale = 6.0,
84
+ num_inference_steps = 50
 
 
 
85
  ):
86
  if randomize_seed:
87
  seed = random.randint(0, MAX_SEED)
88
  generator = torch.Generator().manual_seed(seed)
89
+ pipe = pipe.to(device)
90
+ pipe.load_ip_adapter_faceid_plus(f'{ckpt_dir_faceid}/ipa-faceid-plus.bin', device = device)
91
+ scale = 0.8
92
+ pipe.set_face_fidelity_scale(scale)
93
+
94
+ face_info_generator = FaceInfoGenerator(root_dir = "./")
95
+ face_info = face_info_generator.get_faceinfo_one_img(image)
96
+ face_bbox_square = face_bbox_to_square(face_info["bbox"])
97
+ crop_image = image.crop(face_bbox_square)
98
+ crop_image = crop_image.resize((336, 336))
99
+ crop_image = [crop_image]
100
+ face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
101
+ face_embeds = face_embeds.to(device, dtype = torch.float16)
 
 
 
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  image = pipe(
104
+ prompt = prompt,
105
+ negative_prompt = negative_prompt,
106
+ height = 1024,
107
+ width = 1024,
 
 
 
108
  num_inference_steps= num_inference_steps,
109
+ guidance_scale = guidance_scale,
110
+ num_images_per_prompt = 1,
111
+ generator = generator,
112
+ face_crop_image = crop_image,
113
+ face_insightface_embeds = face_embeds
114
  ).images[0]
 
115
 
116
+ return image, seed
117
+
 
 
 
 
118
 
119
+ examples = [
120
+ ["穿着晚礼服,在星光下的晚宴场景中,烛光闪闪,整个场景洋溢着浪漫而奢华的氛围", "image/image1.png"],
121
+ ["西部牛仔,牛仔帽,荒野大镖客,背景是西部小镇,仙人掌,,日落余晖, 暖色调, 使用XT4胶片拍摄, 噪点, 晕影, 柯达胶卷,复古", "image/image2.png"]
 
 
122
  ]
123
 
124
+
125
  css="""
126
  #col-left {
127
  margin: 0 auto;
 
158
  label="Negative prompt",
159
  placeholder="Enter a negative prompt",
160
  visible=True,
 
161
  )
162
  seed = gr.Slider(
163
  label="Seed",
 
173
  minimum=0.0,
174
  maximum=10.0,
175
  step=0.1,
176
+ value=5.0,
177
  )
178
  num_inference_steps = gr.Slider(
179
  label="Number of inference steps",
180
  minimum=10,
181
  maximum=50,
182
  step=1,
183
+ value=25,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  )
185
  with gr.Row():
186
+ button = gr.Button("Run", elem_id="button")
 
187
 
188
  with gr.Column(elem_id="col-right"):
189
+ result = gr.Image(label="Result", show_label=False)
190
  seed_used = gr.Number(label="Seed Used")
191
 
192
  with gr.Row():
193
  gr.Examples(
194
+ fn = infer,
195
+ examples = examples,
196
  inputs = [prompt, image],
197
  outputs = [result, seed_used],
 
 
 
 
 
 
 
 
 
198
  )
199
 
200
+ button.click(
201
+ fn = infer,
202
+ inputs = [prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps],
203
  outputs = [result, seed_used]
204
  )
205
 
 
 
 
 
 
206
 
207
  Kolors.queue().launch(debug=True)
assets/title.md CHANGED
@@ -1,10 +1,10 @@
1
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
2
  <div>
3
- <h1>Kolors-Controlnet</h1>
4
- <span>Two ControlNet based on Kolors-Basemodel: Canny and Depth</span>
5
  <br>
6
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
7
- <a href="https://github.com/Kwai-Kolors/Kolors/tree/master/controlnet"><img src="https://img.shields.io/static/v1?label=Kolors Code&message=Github&color=blue&logo=github-pages"></a> &ensp;
8
  <a href="https://kwai-kolors.github.io/"><img src="https://img.shields.io/static/v1?label=Team%20Page&message=Page&color=green"></a> &ensp;
9
  <a href="https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Kolors&color=red"></a> &ensp;
10
  <a href="https://klingai.kuaishou.com/"><img src="https://img.shields.io/static/v1?label=Official Website&message=Page&color=green"></a>
 
1
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
2
  <div>
3
+ <h1>Kolors-FaceID</h1>
4
+ <span>Kolors-IP-Adapter-FaceID-Plus based on Kolors-Basemodel.</span>
5
  <br>
6
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
7
+ <a href="https://github.com/Kwai-Kolors/Kolors/tree/master/ipadapter_FaceID"><img src="https://img.shields.io/static/v1?label=Kolors Code&message=Github&color=blue&logo=github-pages"></a> &ensp;
8
  <a href="https://kwai-kolors.github.io/"><img src="https://img.shields.io/static/v1?label=Team%20Page&message=Page&color=green"></a> &ensp;
9
  <a href="https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Kolors&color=red"></a> &ensp;
10
  <a href="https://klingai.kuaishou.com/"><img src="https://img.shields.io/static/v1?label=Official Website&message=Page&color=green"></a>
basicsr/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- # https://github.com/xinntao/BasicSR
2
- # flake8: noqa
3
- from .archs import *
4
- from .data import *
5
- from .losses import *
6
- from .metrics import *
7
- from .models import *
8
- from .ops import *
9
- from .test import *
10
- from .train import *
11
- from .utils import *
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/__init__.py DELETED
@@ -1,24 +0,0 @@
1
- import importlib
2
- from copy import deepcopy
3
- from os import path as osp
4
-
5
- from basicsr.utils import get_root_logger, scandir
6
- from basicsr.utils.registry import ARCH_REGISTRY
7
-
8
- __all__ = ['build_network']
9
-
10
- # automatically scan and import arch modules for registry
11
- # scan all the files under the 'archs' folder and collect files ending with '_arch.py'
12
- arch_folder = osp.dirname(osp.abspath(__file__))
13
- arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
14
- # import all the arch modules
15
- _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
16
-
17
-
18
- def build_network(opt):
19
- opt = deepcopy(opt)
20
- network_type = opt.pop('type')
21
- net = ARCH_REGISTRY.get(network_type)(**opt)
22
- logger = get_root_logger()
23
- logger.info(f'Network [{net.__class__.__name__}] is created.')
24
- return net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/arch_util.py DELETED
@@ -1,313 +0,0 @@
1
- import collections.abc
2
- import math
3
- import torch
4
- import torchvision
5
- import warnings
6
- from distutils.version import LooseVersion
7
- from itertools import repeat
8
- from torch import nn as nn
9
- from torch.nn import functional as F
10
- from torch.nn import init as init
11
- from torch.nn.modules.batchnorm import _BatchNorm
12
-
13
- from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
14
- from basicsr.utils import get_root_logger
15
-
16
-
17
- @torch.no_grad()
18
- def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
19
- """Initialize network weights.
20
-
21
- Args:
22
- module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23
- scale (float): Scale initialized weights, especially for residual
24
- blocks. Default: 1.
25
- bias_fill (float): The value to fill bias. Default: 0
26
- kwargs (dict): Other arguments for initialization function.
27
- """
28
- if not isinstance(module_list, list):
29
- module_list = [module_list]
30
- for module in module_list:
31
- for m in module.modules():
32
- if isinstance(m, nn.Conv2d):
33
- init.kaiming_normal_(m.weight, **kwargs)
34
- m.weight.data *= scale
35
- if m.bias is not None:
36
- m.bias.data.fill_(bias_fill)
37
- elif isinstance(m, nn.Linear):
38
- init.kaiming_normal_(m.weight, **kwargs)
39
- m.weight.data *= scale
40
- if m.bias is not None:
41
- m.bias.data.fill_(bias_fill)
42
- elif isinstance(m, _BatchNorm):
43
- init.constant_(m.weight, 1)
44
- if m.bias is not None:
45
- m.bias.data.fill_(bias_fill)
46
-
47
-
48
- def make_layer(basic_block, num_basic_block, **kwarg):
49
- """Make layers by stacking the same blocks.
50
-
51
- Args:
52
- basic_block (nn.module): nn.module class for basic block.
53
- num_basic_block (int): number of blocks.
54
-
55
- Returns:
56
- nn.Sequential: Stacked blocks in nn.Sequential.
57
- """
58
- layers = []
59
- for _ in range(num_basic_block):
60
- layers.append(basic_block(**kwarg))
61
- return nn.Sequential(*layers)
62
-
63
-
64
- class ResidualBlockNoBN(nn.Module):
65
- """Residual block without BN.
66
-
67
- Args:
68
- num_feat (int): Channel number of intermediate features.
69
- Default: 64.
70
- res_scale (float): Residual scale. Default: 1.
71
- pytorch_init (bool): If set to True, use pytorch default init,
72
- otherwise, use default_init_weights. Default: False.
73
- """
74
-
75
- def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
76
- super(ResidualBlockNoBN, self).__init__()
77
- self.res_scale = res_scale
78
- self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
79
- self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
80
- self.relu = nn.ReLU(inplace=True)
81
-
82
- if not pytorch_init:
83
- default_init_weights([self.conv1, self.conv2], 0.1)
84
-
85
- def forward(self, x):
86
- identity = x
87
- out = self.conv2(self.relu(self.conv1(x)))
88
- return identity + out * self.res_scale
89
-
90
-
91
- class Upsample(nn.Sequential):
92
- """Upsample module.
93
-
94
- Args:
95
- scale (int): Scale factor. Supported scales: 2^n and 3.
96
- num_feat (int): Channel number of intermediate features.
97
- """
98
-
99
- def __init__(self, scale, num_feat):
100
- m = []
101
- if (scale & (scale - 1)) == 0: # scale = 2^n
102
- for _ in range(int(math.log(scale, 2))):
103
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
104
- m.append(nn.PixelShuffle(2))
105
- elif scale == 3:
106
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
107
- m.append(nn.PixelShuffle(3))
108
- else:
109
- raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
110
- super(Upsample, self).__init__(*m)
111
-
112
-
113
- def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
114
- """Warp an image or feature map with optical flow.
115
-
116
- Args:
117
- x (Tensor): Tensor with size (n, c, h, w).
118
- flow (Tensor): Tensor with size (n, h, w, 2), normal value.
119
- interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
120
- padding_mode (str): 'zeros' or 'border' or 'reflection'.
121
- Default: 'zeros'.
122
- align_corners (bool): Before pytorch 1.3, the default value is
123
- align_corners=True. After pytorch 1.3, the default value is
124
- align_corners=False. Here, we use the True as default.
125
-
126
- Returns:
127
- Tensor: Warped image or feature map.
128
- """
129
- assert x.size()[-2:] == flow.size()[1:3]
130
- _, _, h, w = x.size()
131
- # create mesh grid
132
- grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
133
- grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
134
- grid.requires_grad = False
135
-
136
- vgrid = grid + flow
137
- # scale grid to [-1,1]
138
- vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
139
- vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
140
- vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
141
- output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
142
-
143
- # TODO, what if align_corners=False
144
- return output
145
-
146
-
147
- def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
148
- """Resize a flow according to ratio or shape.
149
-
150
- Args:
151
- flow (Tensor): Precomputed flow. shape [N, 2, H, W].
152
- size_type (str): 'ratio' or 'shape'.
153
- sizes (list[int | float]): the ratio for resizing or the final output
154
- shape.
155
- 1) The order of ratio should be [ratio_h, ratio_w]. For
156
- downsampling, the ratio should be smaller than 1.0 (i.e., ratio
157
- < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
158
- ratio > 1.0).
159
- 2) The order of output_size should be [out_h, out_w].
160
- interp_mode (str): The mode of interpolation for resizing.
161
- Default: 'bilinear'.
162
- align_corners (bool): Whether align corners. Default: False.
163
-
164
- Returns:
165
- Tensor: Resized flow.
166
- """
167
- _, _, flow_h, flow_w = flow.size()
168
- if size_type == 'ratio':
169
- output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
170
- elif size_type == 'shape':
171
- output_h, output_w = sizes[0], sizes[1]
172
- else:
173
- raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
174
-
175
- input_flow = flow.clone()
176
- ratio_h = output_h / flow_h
177
- ratio_w = output_w / flow_w
178
- input_flow[:, 0, :, :] *= ratio_w
179
- input_flow[:, 1, :, :] *= ratio_h
180
- resized_flow = F.interpolate(
181
- input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
182
- return resized_flow
183
-
184
-
185
- # TODO: may write a cpp file
186
- def pixel_unshuffle(x, scale):
187
- """ Pixel unshuffle.
188
-
189
- Args:
190
- x (Tensor): Input feature with shape (b, c, hh, hw).
191
- scale (int): Downsample ratio.
192
-
193
- Returns:
194
- Tensor: the pixel unshuffled feature.
195
- """
196
- b, c, hh, hw = x.size()
197
- out_channel = c * (scale**2)
198
- assert hh % scale == 0 and hw % scale == 0
199
- h = hh // scale
200
- w = hw // scale
201
- x_view = x.view(b, c, h, scale, w, scale)
202
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
203
-
204
-
205
- class DCNv2Pack(ModulatedDeformConvPack):
206
- """Modulated deformable conv for deformable alignment.
207
-
208
- Different from the official DCNv2Pack, which generates offsets and masks
209
- from the preceding features, this DCNv2Pack takes another different
210
- features to generate offsets and masks.
211
-
212
- ``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution``
213
- """
214
-
215
- def forward(self, x, feat):
216
- out = self.conv_offset(feat)
217
- o1, o2, mask = torch.chunk(out, 3, dim=1)
218
- offset = torch.cat((o1, o2), dim=1)
219
- mask = torch.sigmoid(mask)
220
-
221
- offset_absmean = torch.mean(torch.abs(offset))
222
- if offset_absmean > 50:
223
- logger = get_root_logger()
224
- logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
225
-
226
- if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
227
- return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
228
- self.dilation, mask)
229
- else:
230
- return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
231
- self.dilation, self.groups, self.deformable_groups)
232
-
233
-
234
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
235
- # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
236
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
237
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
238
- def norm_cdf(x):
239
- # Computes standard normal cumulative distribution function
240
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
241
-
242
- if (mean < a - 2 * std) or (mean > b + 2 * std):
243
- warnings.warn(
244
- 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
245
- 'The distribution of values may be incorrect.',
246
- stacklevel=2)
247
-
248
- with torch.no_grad():
249
- # Values are generated by using a truncated uniform distribution and
250
- # then using the inverse CDF for the normal distribution.
251
- # Get upper and lower cdf values
252
- low = norm_cdf((a - mean) / std)
253
- up = norm_cdf((b - mean) / std)
254
-
255
- # Uniformly fill tensor with values from [low, up], then translate to
256
- # [2l-1, 2u-1].
257
- tensor.uniform_(2 * low - 1, 2 * up - 1)
258
-
259
- # Use inverse cdf transform for normal distribution to get truncated
260
- # standard normal
261
- tensor.erfinv_()
262
-
263
- # Transform to proper mean, std
264
- tensor.mul_(std * math.sqrt(2.))
265
- tensor.add_(mean)
266
-
267
- # Clamp to ensure it's in the proper range
268
- tensor.clamp_(min=a, max=b)
269
- return tensor
270
-
271
-
272
- def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
273
- r"""Fills the input Tensor with values drawn from a truncated
274
- normal distribution.
275
-
276
- From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
277
-
278
- The values are effectively drawn from the
279
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
280
- with values outside :math:`[a, b]` redrawn until they are within
281
- the bounds. The method used for generating the random values works
282
- best when :math:`a \leq \text{mean} \leq b`.
283
-
284
- Args:
285
- tensor: an n-dimensional `torch.Tensor`
286
- mean: the mean of the normal distribution
287
- std: the standard deviation of the normal distribution
288
- a: the minimum cutoff value
289
- b: the maximum cutoff value
290
-
291
- Examples:
292
- >>> w = torch.empty(3, 5)
293
- >>> nn.init.trunc_normal_(w)
294
- """
295
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
296
-
297
-
298
- # From PyTorch
299
- def _ntuple(n):
300
-
301
- def parse(x):
302
- if isinstance(x, collections.abc.Iterable):
303
- return x
304
- return tuple(repeat(x, n))
305
-
306
- return parse
307
-
308
-
309
- to_1tuple = _ntuple(1)
310
- to_2tuple = _ntuple(2)
311
- to_3tuple = _ntuple(3)
312
- to_4tuple = _ntuple(4)
313
- to_ntuple = _ntuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/basicvsr_arch.py DELETED
@@ -1,336 +0,0 @@
1
- import torch
2
- from torch import nn as nn
3
- from torch.nn import functional as F
4
-
5
- from basicsr.utils.registry import ARCH_REGISTRY
6
- from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
7
- from .edvr_arch import PCDAlignment, TSAFusion
8
- from .spynet_arch import SpyNet
9
-
10
-
11
- @ARCH_REGISTRY.register()
12
- class BasicVSR(nn.Module):
13
- """A recurrent network for video SR. Now only x4 is supported.
14
-
15
- Args:
16
- num_feat (int): Number of channels. Default: 64.
17
- num_block (int): Number of residual blocks for each branch. Default: 15
18
- spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
19
- """
20
-
21
- def __init__(self, num_feat=64, num_block=15, spynet_path=None):
22
- super().__init__()
23
- self.num_feat = num_feat
24
-
25
- # alignment
26
- self.spynet = SpyNet(spynet_path)
27
-
28
- # propagation
29
- self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
30
- self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
31
-
32
- # reconstruction
33
- self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
34
- self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
35
- self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
36
- self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
37
- self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
38
-
39
- self.pixel_shuffle = nn.PixelShuffle(2)
40
-
41
- # activation functions
42
- self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
43
-
44
- def get_flow(self, x):
45
- b, n, c, h, w = x.size()
46
-
47
- x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
48
- x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
49
-
50
- flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
51
- flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
52
-
53
- return flows_forward, flows_backward
54
-
55
- def forward(self, x):
56
- """Forward function of BasicVSR.
57
-
58
- Args:
59
- x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
60
- """
61
- flows_forward, flows_backward = self.get_flow(x)
62
- b, n, _, h, w = x.size()
63
-
64
- # backward branch
65
- out_l = []
66
- feat_prop = x.new_zeros(b, self.num_feat, h, w)
67
- for i in range(n - 1, -1, -1):
68
- x_i = x[:, i, :, :, :]
69
- if i < n - 1:
70
- flow = flows_backward[:, i, :, :, :]
71
- feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
72
- feat_prop = torch.cat([x_i, feat_prop], dim=1)
73
- feat_prop = self.backward_trunk(feat_prop)
74
- out_l.insert(0, feat_prop)
75
-
76
- # forward branch
77
- feat_prop = torch.zeros_like(feat_prop)
78
- for i in range(0, n):
79
- x_i = x[:, i, :, :, :]
80
- if i > 0:
81
- flow = flows_forward[:, i - 1, :, :, :]
82
- feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
83
-
84
- feat_prop = torch.cat([x_i, feat_prop], dim=1)
85
- feat_prop = self.forward_trunk(feat_prop)
86
-
87
- # upsample
88
- out = torch.cat([out_l[i], feat_prop], dim=1)
89
- out = self.lrelu(self.fusion(out))
90
- out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
91
- out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
92
- out = self.lrelu(self.conv_hr(out))
93
- out = self.conv_last(out)
94
- base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
95
- out += base
96
- out_l[i] = out
97
-
98
- return torch.stack(out_l, dim=1)
99
-
100
-
101
- class ConvResidualBlocks(nn.Module):
102
- """Conv and residual block used in BasicVSR.
103
-
104
- Args:
105
- num_in_ch (int): Number of input channels. Default: 3.
106
- num_out_ch (int): Number of output channels. Default: 64.
107
- num_block (int): Number of residual blocks. Default: 15.
108
- """
109
-
110
- def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
111
- super().__init__()
112
- self.main = nn.Sequential(
113
- nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
114
- make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
115
-
116
- def forward(self, fea):
117
- return self.main(fea)
118
-
119
-
120
- @ARCH_REGISTRY.register()
121
- class IconVSR(nn.Module):
122
- """IconVSR, proposed also in the BasicVSR paper.
123
-
124
- Args:
125
- num_feat (int): Number of channels. Default: 64.
126
- num_block (int): Number of residual blocks for each branch. Default: 15.
127
- keyframe_stride (int): Keyframe stride. Default: 5.
128
- temporal_padding (int): Temporal padding. Default: 2.
129
- spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
130
- edvr_path (str): Path to the pretrained EDVR model. Default: None.
131
- """
132
-
133
- def __init__(self,
134
- num_feat=64,
135
- num_block=15,
136
- keyframe_stride=5,
137
- temporal_padding=2,
138
- spynet_path=None,
139
- edvr_path=None):
140
- super().__init__()
141
-
142
- self.num_feat = num_feat
143
- self.temporal_padding = temporal_padding
144
- self.keyframe_stride = keyframe_stride
145
-
146
- # keyframe_branch
147
- self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
148
- # alignment
149
- self.spynet = SpyNet(spynet_path)
150
-
151
- # propagation
152
- self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
153
- self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
154
-
155
- self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
156
- self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
157
-
158
- # reconstruction
159
- self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
160
- self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
161
- self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
162
- self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
163
-
164
- self.pixel_shuffle = nn.PixelShuffle(2)
165
-
166
- # activation functions
167
- self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
168
-
169
- def pad_spatial(self, x):
170
- """Apply padding spatially.
171
-
172
- Since the PCD module in EDVR requires that the resolution is a multiple
173
- of 4, we apply padding to the input LR images if their resolution is
174
- not divisible by 4.
175
-
176
- Args:
177
- x (Tensor): Input LR sequence with shape (n, t, c, h, w).
178
- Returns:
179
- Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
180
- """
181
- n, t, c, h, w = x.size()
182
-
183
- pad_h = (4 - h % 4) % 4
184
- pad_w = (4 - w % 4) % 4
185
-
186
- # padding
187
- x = x.view(-1, c, h, w)
188
- x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
189
-
190
- return x.view(n, t, c, h + pad_h, w + pad_w)
191
-
192
- def get_flow(self, x):
193
- b, n, c, h, w = x.size()
194
-
195
- x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
196
- x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
197
-
198
- flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
199
- flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
200
-
201
- return flows_forward, flows_backward
202
-
203
- def get_keyframe_feature(self, x, keyframe_idx):
204
- if self.temporal_padding == 2:
205
- x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
206
- elif self.temporal_padding == 3:
207
- x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
208
- x = torch.cat(x, dim=1)
209
-
210
- num_frames = 2 * self.temporal_padding + 1
211
- feats_keyframe = {}
212
- for i in keyframe_idx:
213
- feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
214
- return feats_keyframe
215
-
216
- def forward(self, x):
217
- b, n, _, h_input, w_input = x.size()
218
-
219
- x = self.pad_spatial(x)
220
- h, w = x.shape[3:]
221
-
222
- keyframe_idx = list(range(0, n, self.keyframe_stride))
223
- if keyframe_idx[-1] != n - 1:
224
- keyframe_idx.append(n - 1) # last frame is a keyframe
225
-
226
- # compute flow and keyframe features
227
- flows_forward, flows_backward = self.get_flow(x)
228
- feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
229
-
230
- # backward branch
231
- out_l = []
232
- feat_prop = x.new_zeros(b, self.num_feat, h, w)
233
- for i in range(n - 1, -1, -1):
234
- x_i = x[:, i, :, :, :]
235
- if i < n - 1:
236
- flow = flows_backward[:, i, :, :, :]
237
- feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
238
- if i in keyframe_idx:
239
- feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
240
- feat_prop = self.backward_fusion(feat_prop)
241
- feat_prop = torch.cat([x_i, feat_prop], dim=1)
242
- feat_prop = self.backward_trunk(feat_prop)
243
- out_l.insert(0, feat_prop)
244
-
245
- # forward branch
246
- feat_prop = torch.zeros_like(feat_prop)
247
- for i in range(0, n):
248
- x_i = x[:, i, :, :, :]
249
- if i > 0:
250
- flow = flows_forward[:, i - 1, :, :, :]
251
- feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
252
- if i in keyframe_idx:
253
- feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
254
- feat_prop = self.forward_fusion(feat_prop)
255
-
256
- feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
257
- feat_prop = self.forward_trunk(feat_prop)
258
-
259
- # upsample
260
- out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
261
- out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
262
- out = self.lrelu(self.conv_hr(out))
263
- out = self.conv_last(out)
264
- base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
265
- out += base
266
- out_l[i] = out
267
-
268
- return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
269
-
270
-
271
- class EDVRFeatureExtractor(nn.Module):
272
- """EDVR feature extractor used in IconVSR.
273
-
274
- Args:
275
- num_input_frame (int): Number of input frames.
276
- num_feat (int): Number of feature channels
277
- load_path (str): Path to the pretrained weights of EDVR. Default: None.
278
- """
279
-
280
- def __init__(self, num_input_frame, num_feat, load_path):
281
-
282
- super(EDVRFeatureExtractor, self).__init__()
283
-
284
- self.center_frame_idx = num_input_frame // 2
285
-
286
- # extract pyramid features
287
- self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
288
- self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
289
- self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
290
- self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
291
- self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
292
- self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
293
-
294
- # pcd and tsa module
295
- self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
296
- self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
297
-
298
- # activation function
299
- self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
300
-
301
- if load_path:
302
- self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
303
-
304
- def forward(self, x):
305
- b, n, c, h, w = x.size()
306
-
307
- # extract features for each frame
308
- # L1
309
- feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
310
- feat_l1 = self.feature_extraction(feat_l1)
311
- # L2
312
- feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
313
- feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
314
- # L3
315
- feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
316
- feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
317
-
318
- feat_l1 = feat_l1.view(b, n, -1, h, w)
319
- feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
320
- feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
321
-
322
- # PCD alignment
323
- ref_feat_l = [ # reference feature list
324
- feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
325
- feat_l3[:, self.center_frame_idx, :, :, :].clone()
326
- ]
327
- aligned_feat = []
328
- for i in range(n):
329
- nbr_feat_l = [ # neighboring feature list
330
- feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
331
- ]
332
- aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
333
- aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
334
-
335
- # TSA fusion
336
- return self.fusion(aligned_feat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/basicvsrpp_arch.py DELETED
@@ -1,417 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchvision
5
- import warnings
6
-
7
- from basicsr.archs.arch_util import flow_warp
8
- from basicsr.archs.basicvsr_arch import ConvResidualBlocks
9
- from basicsr.archs.spynet_arch import SpyNet
10
- from basicsr.ops.dcn import ModulatedDeformConvPack
11
- from basicsr.utils.registry import ARCH_REGISTRY
12
-
13
-
14
- @ARCH_REGISTRY.register()
15
- class BasicVSRPlusPlus(nn.Module):
16
- """BasicVSR++ network structure.
17
-
18
- Support either x4 upsampling or same size output. Since DCN is used in this
19
- model, it can only be used with CUDA enabled. If CUDA is not enabled,
20
- feature alignment will be skipped. Besides, we adopt the official DCN
21
- implementation and the version of torch need to be higher than 1.9.
22
-
23
- ``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment``
24
-
25
- Args:
26
- mid_channels (int, optional): Channel number of the intermediate
27
- features. Default: 64.
28
- num_blocks (int, optional): The number of residual blocks in each
29
- propagation branch. Default: 7.
30
- max_residue_magnitude (int): The maximum magnitude of the offset
31
- residue (Eq. 6 in paper). Default: 10.
32
- is_low_res_input (bool, optional): Whether the input is low-resolution
33
- or not. If False, the output resolution is equal to the input
34
- resolution. Default: True.
35
- spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
36
- cpu_cache_length (int, optional): When the length of sequence is larger
37
- than this value, the intermediate features are sent to CPU. This
38
- saves GPU memory, but slows down the inference speed. You can
39
- increase this number if you have a GPU with large memory.
40
- Default: 100.
41
- """
42
-
43
- def __init__(self,
44
- mid_channels=64,
45
- num_blocks=7,
46
- max_residue_magnitude=10,
47
- is_low_res_input=True,
48
- spynet_path=None,
49
- cpu_cache_length=100):
50
-
51
- super().__init__()
52
- self.mid_channels = mid_channels
53
- self.is_low_res_input = is_low_res_input
54
- self.cpu_cache_length = cpu_cache_length
55
-
56
- # optical flow
57
- self.spynet = SpyNet(spynet_path)
58
-
59
- # feature extraction module
60
- if is_low_res_input:
61
- self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
62
- else:
63
- self.feat_extract = nn.Sequential(
64
- nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
65
- nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
66
- ConvResidualBlocks(mid_channels, mid_channels, 5))
67
-
68
- # propagation branches
69
- self.deform_align = nn.ModuleDict()
70
- self.backbone = nn.ModuleDict()
71
- modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
72
- for i, module in enumerate(modules):
73
- if torch.cuda.is_available():
74
- self.deform_align[module] = SecondOrderDeformableAlignment(
75
- 2 * mid_channels,
76
- mid_channels,
77
- 3,
78
- padding=1,
79
- deformable_groups=16,
80
- max_residue_magnitude=max_residue_magnitude)
81
- self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
82
-
83
- # upsampling module
84
- self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
85
-
86
- self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
87
- self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
88
-
89
- self.pixel_shuffle = nn.PixelShuffle(2)
90
-
91
- self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
92
- self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
93
- self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
94
-
95
- # activation function
96
- self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
97
-
98
- # check if the sequence is augmented by flipping
99
- self.is_mirror_extended = False
100
-
101
- if len(self.deform_align) > 0:
102
- self.is_with_alignment = True
103
- else:
104
- self.is_with_alignment = False
105
- warnings.warn('Deformable alignment module is not added. '
106
- 'Probably your CUDA is not configured correctly. DCN can only '
107
- 'be used with CUDA enabled. Alignment is skipped now.')
108
-
109
- def check_if_mirror_extended(self, lqs):
110
- """Check whether the input is a mirror-extended sequence.
111
-
112
- If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame.
113
-
114
- Args:
115
- lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w).
116
- """
117
-
118
- if lqs.size(1) % 2 == 0:
119
- lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
120
- if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
121
- self.is_mirror_extended = True
122
-
123
- def compute_flow(self, lqs):
124
- """Compute optical flow using SPyNet for feature alignment.
125
-
126
- Note that if the input is an mirror-extended sequence, 'flows_forward'
127
- is not needed, since it is equal to 'flows_backward.flip(1)'.
128
-
129
- Args:
130
- lqs (tensor): Input low quality (LQ) sequence with
131
- shape (n, t, c, h, w).
132
-
133
- Return:
134
- tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation \
135
- (current to previous). 'flows_backward' corresponds to the flows used for backward-time \
136
- propagation (current to next).
137
- """
138
-
139
- n, t, c, h, w = lqs.size()
140
- lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
141
- lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
142
-
143
- flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
144
-
145
- if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
146
- flows_forward = flows_backward.flip(1)
147
- else:
148
- flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
149
-
150
- if self.cpu_cache:
151
- flows_backward = flows_backward.cpu()
152
- flows_forward = flows_forward.cpu()
153
-
154
- return flows_forward, flows_backward
155
-
156
- def propagate(self, feats, flows, module_name):
157
- """Propagate the latent features throughout the sequence.
158
-
159
- Args:
160
- feats dict(list[tensor]): Features from previous branches. Each
161
- component is a list of tensors with shape (n, c, h, w).
162
- flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
163
- module_name (str): The name of the propgation branches. Can either
164
- be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
165
-
166
- Return:
167
- dict(list[tensor]): A dictionary containing all the propagated \
168
- features. Each key in the dictionary corresponds to a \
169
- propagation branch, which is represented by a list of tensors.
170
- """
171
-
172
- n, t, _, h, w = flows.size()
173
-
174
- frame_idx = range(0, t + 1)
175
- flow_idx = range(-1, t)
176
- mapping_idx = list(range(0, len(feats['spatial'])))
177
- mapping_idx += mapping_idx[::-1]
178
-
179
- if 'backward' in module_name:
180
- frame_idx = frame_idx[::-1]
181
- flow_idx = frame_idx
182
-
183
- feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
184
- for i, idx in enumerate(frame_idx):
185
- feat_current = feats['spatial'][mapping_idx[idx]]
186
- if self.cpu_cache:
187
- feat_current = feat_current.cuda()
188
- feat_prop = feat_prop.cuda()
189
- # second-order deformable alignment
190
- if i > 0 and self.is_with_alignment:
191
- flow_n1 = flows[:, flow_idx[i], :, :, :]
192
- if self.cpu_cache:
193
- flow_n1 = flow_n1.cuda()
194
-
195
- cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
196
-
197
- # initialize second-order features
198
- feat_n2 = torch.zeros_like(feat_prop)
199
- flow_n2 = torch.zeros_like(flow_n1)
200
- cond_n2 = torch.zeros_like(cond_n1)
201
-
202
- if i > 1: # second-order features
203
- feat_n2 = feats[module_name][-2]
204
- if self.cpu_cache:
205
- feat_n2 = feat_n2.cuda()
206
-
207
- flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
208
- if self.cpu_cache:
209
- flow_n2 = flow_n2.cuda()
210
-
211
- flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
212
- cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
213
-
214
- # flow-guided deformable convolution
215
- cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
216
- feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
217
- feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
218
-
219
- # concatenate and residual blocks
220
- feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
221
- if self.cpu_cache:
222
- feat = [f.cuda() for f in feat]
223
-
224
- feat = torch.cat(feat, dim=1)
225
- feat_prop = feat_prop + self.backbone[module_name](feat)
226
- feats[module_name].append(feat_prop)
227
-
228
- if self.cpu_cache:
229
- feats[module_name][-1] = feats[module_name][-1].cpu()
230
- torch.cuda.empty_cache()
231
-
232
- if 'backward' in module_name:
233
- feats[module_name] = feats[module_name][::-1]
234
-
235
- return feats
236
-
237
- def upsample(self, lqs, feats):
238
- """Compute the output image given the features.
239
-
240
- Args:
241
- lqs (tensor): Input low quality (LQ) sequence with
242
- shape (n, t, c, h, w).
243
- feats (dict): The features from the propagation branches.
244
-
245
- Returns:
246
- Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
247
- """
248
-
249
- outputs = []
250
- num_outputs = len(feats['spatial'])
251
-
252
- mapping_idx = list(range(0, num_outputs))
253
- mapping_idx += mapping_idx[::-1]
254
-
255
- for i in range(0, lqs.size(1)):
256
- hr = [feats[k].pop(0) for k in feats if k != 'spatial']
257
- hr.insert(0, feats['spatial'][mapping_idx[i]])
258
- hr = torch.cat(hr, dim=1)
259
- if self.cpu_cache:
260
- hr = hr.cuda()
261
-
262
- hr = self.reconstruction(hr)
263
- hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
264
- hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
265
- hr = self.lrelu(self.conv_hr(hr))
266
- hr = self.conv_last(hr)
267
- if self.is_low_res_input:
268
- hr += self.img_upsample(lqs[:, i, :, :, :])
269
- else:
270
- hr += lqs[:, i, :, :, :]
271
-
272
- if self.cpu_cache:
273
- hr = hr.cpu()
274
- torch.cuda.empty_cache()
275
-
276
- outputs.append(hr)
277
-
278
- return torch.stack(outputs, dim=1)
279
-
280
- def forward(self, lqs):
281
- """Forward function for BasicVSR++.
282
-
283
- Args:
284
- lqs (tensor): Input low quality (LQ) sequence with
285
- shape (n, t, c, h, w).
286
-
287
- Returns:
288
- Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
289
- """
290
-
291
- n, t, c, h, w = lqs.size()
292
-
293
- # whether to cache the features in CPU
294
- self.cpu_cache = True if t > self.cpu_cache_length else False
295
-
296
- if self.is_low_res_input:
297
- lqs_downsample = lqs.clone()
298
- else:
299
- lqs_downsample = F.interpolate(
300
- lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
301
-
302
- # check whether the input is an extended sequence
303
- self.check_if_mirror_extended(lqs)
304
-
305
- feats = {}
306
- # compute spatial features
307
- if self.cpu_cache:
308
- feats['spatial'] = []
309
- for i in range(0, t):
310
- feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
311
- feats['spatial'].append(feat)
312
- torch.cuda.empty_cache()
313
- else:
314
- feats_ = self.feat_extract(lqs.view(-1, c, h, w))
315
- h, w = feats_.shape[2:]
316
- feats_ = feats_.view(n, t, -1, h, w)
317
- feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
318
-
319
- # compute optical flow using the low-res inputs
320
- assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
321
- 'The height and width of low-res inputs must be at least 64, '
322
- f'but got {h} and {w}.')
323
- flows_forward, flows_backward = self.compute_flow(lqs_downsample)
324
-
325
- # feature propgation
326
- for iter_ in [1, 2]:
327
- for direction in ['backward', 'forward']:
328
- module = f'{direction}_{iter_}'
329
-
330
- feats[module] = []
331
-
332
- if direction == 'backward':
333
- flows = flows_backward
334
- elif flows_forward is not None:
335
- flows = flows_forward
336
- else:
337
- flows = flows_backward.flip(1)
338
-
339
- feats = self.propagate(feats, flows, module)
340
- if self.cpu_cache:
341
- del flows
342
- torch.cuda.empty_cache()
343
-
344
- return self.upsample(lqs, feats)
345
-
346
-
347
- class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
348
- """Second-order deformable alignment module.
349
-
350
- Args:
351
- in_channels (int): Same as nn.Conv2d.
352
- out_channels (int): Same as nn.Conv2d.
353
- kernel_size (int or tuple[int]): Same as nn.Conv2d.
354
- stride (int or tuple[int]): Same as nn.Conv2d.
355
- padding (int or tuple[int]): Same as nn.Conv2d.
356
- dilation (int or tuple[int]): Same as nn.Conv2d.
357
- groups (int): Same as nn.Conv2d.
358
- bias (bool or str): If specified as `auto`, it will be decided by the
359
- norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
360
- False.
361
- max_residue_magnitude (int): The maximum magnitude of the offset
362
- residue (Eq. 6 in paper). Default: 10.
363
- """
364
-
365
- def __init__(self, *args, **kwargs):
366
- self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
367
-
368
- super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
369
-
370
- self.conv_offset = nn.Sequential(
371
- nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
372
- nn.LeakyReLU(negative_slope=0.1, inplace=True),
373
- nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
374
- nn.LeakyReLU(negative_slope=0.1, inplace=True),
375
- nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
376
- nn.LeakyReLU(negative_slope=0.1, inplace=True),
377
- nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
378
- )
379
-
380
- self.init_offset()
381
-
382
- def init_offset(self):
383
-
384
- def _constant_init(module, val, bias=0):
385
- if hasattr(module, 'weight') and module.weight is not None:
386
- nn.init.constant_(module.weight, val)
387
- if hasattr(module, 'bias') and module.bias is not None:
388
- nn.init.constant_(module.bias, bias)
389
-
390
- _constant_init(self.conv_offset[-1], val=0, bias=0)
391
-
392
- def forward(self, x, extra_feat, flow_1, flow_2):
393
- extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
394
- out = self.conv_offset(extra_feat)
395
- o1, o2, mask = torch.chunk(out, 3, dim=1)
396
-
397
- # offset
398
- offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
399
- offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
400
- offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
401
- offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
402
- offset = torch.cat([offset_1, offset_2], dim=1)
403
-
404
- # mask
405
- mask = torch.sigmoid(mask)
406
-
407
- return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
408
- self.dilation, mask)
409
-
410
-
411
- # if __name__ == '__main__':
412
- # spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
413
- # model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
414
- # input = torch.rand(1, 2, 3, 64, 64).cuda()
415
- # output = model(input)
416
- # print('===================')
417
- # print(output.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/dfdnet_arch.py DELETED
@@ -1,169 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from torch.nn.utils.spectral_norm import spectral_norm
6
-
7
- from basicsr.utils.registry import ARCH_REGISTRY
8
- from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
9
- from .vgg_arch import VGGFeatureExtractor
10
-
11
-
12
- class SFTUpBlock(nn.Module):
13
- """Spatial feature transform (SFT) with upsampling block.
14
-
15
- Args:
16
- in_channel (int): Number of input channels.
17
- out_channel (int): Number of output channels.
18
- kernel_size (int): Kernel size in convolutions. Default: 3.
19
- padding (int): Padding in convolutions. Default: 1.
20
- """
21
-
22
- def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
23
- super(SFTUpBlock, self).__init__()
24
- self.conv1 = nn.Sequential(
25
- Blur(in_channel),
26
- spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
27
- nn.LeakyReLU(0.04, True),
28
- # The official codes use two LeakyReLU here, so 0.04 for equivalent
29
- )
30
- self.convup = nn.Sequential(
31
- nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
32
- spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
33
- nn.LeakyReLU(0.2, True),
34
- )
35
-
36
- # for SFT scale and shift
37
- self.scale_block = nn.Sequential(
38
- spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
39
- spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
40
- self.shift_block = nn.Sequential(
41
- spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
42
- spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
43
- # The official codes use sigmoid for shift block, do not know why
44
-
45
- def forward(self, x, updated_feat):
46
- out = self.conv1(x)
47
- # SFT
48
- scale = self.scale_block(updated_feat)
49
- shift = self.shift_block(updated_feat)
50
- out = out * scale + shift
51
- # upsample
52
- out = self.convup(out)
53
- return out
54
-
55
-
56
- @ARCH_REGISTRY.register()
57
- class DFDNet(nn.Module):
58
- """DFDNet: Deep Face Dictionary Network.
59
-
60
- It only processes faces with 512x512 size.
61
-
62
- Args:
63
- num_feat (int): Number of feature channels.
64
- dict_path (str): Path to the facial component dictionary.
65
- """
66
-
67
- def __init__(self, num_feat, dict_path):
68
- super().__init__()
69
- self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
70
- # part_sizes: [80, 80, 50, 110]
71
- channel_sizes = [128, 256, 512, 512]
72
- self.feature_sizes = np.array([256, 128, 64, 32])
73
- self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
74
- self.flag_dict_device = False
75
-
76
- # dict
77
- self.dict = torch.load(dict_path)
78
-
79
- # vgg face extractor
80
- self.vgg_extractor = VGGFeatureExtractor(
81
- layer_name_list=self.vgg_layers,
82
- vgg_type='vgg19',
83
- use_input_norm=True,
84
- range_norm=True,
85
- requires_grad=False)
86
-
87
- # attention block for fusing dictionary features and input features
88
- self.attn_blocks = nn.ModuleDict()
89
- for idx, feat_size in enumerate(self.feature_sizes):
90
- for name in self.parts:
91
- self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
92
-
93
- # multi scale dilation block
94
- self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
95
-
96
- # upsampling and reconstruction
97
- self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
98
- self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
99
- self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
100
- self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
101
- self.upsample4 = nn.Sequential(
102
- spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
103
- UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
104
-
105
- def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
106
- """swap the features from the dictionary."""
107
- # get the original vgg features
108
- part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
109
- # resize original vgg features
110
- part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
111
- # use adaptive instance normalization to adjust color and illuminations
112
- dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
113
- # get similarity scores
114
- similarity_score = F.conv2d(part_resize_feat, dict_feat)
115
- similarity_score = F.softmax(similarity_score.view(-1), dim=0)
116
- # select the most similar features in the dict (after norm)
117
- select_idx = torch.argmax(similarity_score)
118
- swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
119
- # attention
120
- attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
121
- attn_feat = attn * swap_feat
122
- # update features
123
- updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
124
- return updated_feat
125
-
126
- def put_dict_to_device(self, x):
127
- if self.flag_dict_device is False:
128
- for k, v in self.dict.items():
129
- for kk, vv in v.items():
130
- self.dict[k][kk] = vv.to(x)
131
- self.flag_dict_device = True
132
-
133
- def forward(self, x, part_locations):
134
- """
135
- Now only support testing with batch size = 0.
136
-
137
- Args:
138
- x (Tensor): Input faces with shape (b, c, 512, 512).
139
- part_locations (list[Tensor]): Part locations.
140
- """
141
- self.put_dict_to_device(x)
142
- # extract vggface features
143
- vgg_features = self.vgg_extractor(x)
144
- # update vggface features using the dictionary for each part
145
- updated_vgg_features = []
146
- batch = 0 # only supports testing with batch size = 0
147
- for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
148
- dict_features = self.dict[f'{f_size}']
149
- vgg_feat = vgg_features[vgg_layer]
150
- updated_feat = vgg_feat.clone()
151
-
152
- # swap features from dictionary
153
- for part_idx, part_name in enumerate(self.parts):
154
- location = (part_locations[part_idx][batch] // (512 / f_size)).int()
155
- updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
156
- f_size)
157
-
158
- updated_vgg_features.append(updated_feat)
159
-
160
- vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
161
- # use updated vgg features to modulate the upsampled features with
162
- # SFT (Spatial Feature Transform) scaling and shifting manner.
163
- upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
164
- upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
165
- upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
166
- upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
167
- out = self.upsample4(upsampled_feat)
168
-
169
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/dfdnet_util.py DELETED
@@ -1,162 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch.autograd import Function
5
- from torch.nn.utils.spectral_norm import spectral_norm
6
-
7
-
8
- class BlurFunctionBackward(Function):
9
-
10
- @staticmethod
11
- def forward(ctx, grad_output, kernel, kernel_flip):
12
- ctx.save_for_backward(kernel, kernel_flip)
13
- grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
14
- return grad_input
15
-
16
- @staticmethod
17
- def backward(ctx, gradgrad_output):
18
- kernel, _ = ctx.saved_tensors
19
- grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
20
- return grad_input, None, None
21
-
22
-
23
- class BlurFunction(Function):
24
-
25
- @staticmethod
26
- def forward(ctx, x, kernel, kernel_flip):
27
- ctx.save_for_backward(kernel, kernel_flip)
28
- output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
29
- return output
30
-
31
- @staticmethod
32
- def backward(ctx, grad_output):
33
- kernel, kernel_flip = ctx.saved_tensors
34
- grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
35
- return grad_input, None, None
36
-
37
-
38
- blur = BlurFunction.apply
39
-
40
-
41
- class Blur(nn.Module):
42
-
43
- def __init__(self, channel):
44
- super().__init__()
45
- kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
46
- kernel = kernel.view(1, 1, 3, 3)
47
- kernel = kernel / kernel.sum()
48
- kernel_flip = torch.flip(kernel, [2, 3])
49
-
50
- self.kernel = kernel.repeat(channel, 1, 1, 1)
51
- self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
52
-
53
- def forward(self, x):
54
- return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
55
-
56
-
57
- def calc_mean_std(feat, eps=1e-5):
58
- """Calculate mean and std for adaptive_instance_normalization.
59
-
60
- Args:
61
- feat (Tensor): 4D tensor.
62
- eps (float): A small value added to the variance to avoid
63
- divide-by-zero. Default: 1e-5.
64
- """
65
- size = feat.size()
66
- assert len(size) == 4, 'The input feature should be 4D tensor.'
67
- n, c = size[:2]
68
- feat_var = feat.view(n, c, -1).var(dim=2) + eps
69
- feat_std = feat_var.sqrt().view(n, c, 1, 1)
70
- feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
71
- return feat_mean, feat_std
72
-
73
-
74
- def adaptive_instance_normalization(content_feat, style_feat):
75
- """Adaptive instance normalization.
76
-
77
- Adjust the reference features to have the similar color and illuminations
78
- as those in the degradate features.
79
-
80
- Args:
81
- content_feat (Tensor): The reference feature.
82
- style_feat (Tensor): The degradate features.
83
- """
84
- size = content_feat.size()
85
- style_mean, style_std = calc_mean_std(style_feat)
86
- content_mean, content_std = calc_mean_std(content_feat)
87
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
88
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
89
-
90
-
91
- def AttentionBlock(in_channel):
92
- return nn.Sequential(
93
- spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
94
- spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
95
-
96
-
97
- def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
98
- """Conv block used in MSDilationBlock."""
99
-
100
- return nn.Sequential(
101
- spectral_norm(
102
- nn.Conv2d(
103
- in_channels,
104
- out_channels,
105
- kernel_size=kernel_size,
106
- stride=stride,
107
- dilation=dilation,
108
- padding=((kernel_size - 1) // 2) * dilation,
109
- bias=bias)),
110
- nn.LeakyReLU(0.2),
111
- spectral_norm(
112
- nn.Conv2d(
113
- out_channels,
114
- out_channels,
115
- kernel_size=kernel_size,
116
- stride=stride,
117
- dilation=dilation,
118
- padding=((kernel_size - 1) // 2) * dilation,
119
- bias=bias)),
120
- )
121
-
122
-
123
- class MSDilationBlock(nn.Module):
124
- """Multi-scale dilation block."""
125
-
126
- def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
127
- super(MSDilationBlock, self).__init__()
128
-
129
- self.conv_blocks = nn.ModuleList()
130
- for i in range(4):
131
- self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
132
- self.conv_fusion = spectral_norm(
133
- nn.Conv2d(
134
- in_channels * 4,
135
- in_channels,
136
- kernel_size=kernel_size,
137
- stride=1,
138
- padding=(kernel_size - 1) // 2,
139
- bias=bias))
140
-
141
- def forward(self, x):
142
- out = []
143
- for i in range(4):
144
- out.append(self.conv_blocks[i](x))
145
- out = torch.cat(out, 1)
146
- out = self.conv_fusion(out) + x
147
- return out
148
-
149
-
150
- class UpResBlock(nn.Module):
151
-
152
- def __init__(self, in_channel):
153
- super(UpResBlock, self).__init__()
154
- self.body = nn.Sequential(
155
- nn.Conv2d(in_channel, in_channel, 3, 1, 1),
156
- nn.LeakyReLU(0.2, True),
157
- nn.Conv2d(in_channel, in_channel, 3, 1, 1),
158
- )
159
-
160
- def forward(self, x):
161
- out = x + self.body(x)
162
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/discriminator_arch.py DELETED
@@ -1,150 +0,0 @@
1
- from torch import nn as nn
2
- from torch.nn import functional as F
3
- from torch.nn.utils import spectral_norm
4
-
5
- from basicsr.utils.registry import ARCH_REGISTRY
6
-
7
-
8
- @ARCH_REGISTRY.register()
9
- class VGGStyleDiscriminator(nn.Module):
10
- """VGG style discriminator with input size 128 x 128 or 256 x 256.
11
-
12
- It is used to train SRGAN, ESRGAN, and VideoGAN.
13
-
14
- Args:
15
- num_in_ch (int): Channel number of inputs. Default: 3.
16
- num_feat (int): Channel number of base intermediate features.Default: 64.
17
- """
18
-
19
- def __init__(self, num_in_ch, num_feat, input_size=128):
20
- super(VGGStyleDiscriminator, self).__init__()
21
- self.input_size = input_size
22
- assert self.input_size == 128 or self.input_size == 256, (
23
- f'input size must be 128 or 256, but received {input_size}')
24
-
25
- self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
26
- self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
27
- self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
28
-
29
- self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
30
- self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
31
- self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
32
- self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
33
-
34
- self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
35
- self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
36
- self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
37
- self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
38
-
39
- self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
40
- self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
41
- self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
42
- self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
43
-
44
- self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
45
- self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
46
- self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
47
- self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
48
-
49
- if self.input_size == 256:
50
- self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
51
- self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
52
- self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
53
- self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
54
-
55
- self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
56
- self.linear2 = nn.Linear(100, 1)
57
-
58
- # activation function
59
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
60
-
61
- def forward(self, x):
62
- assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
63
-
64
- feat = self.lrelu(self.conv0_0(x))
65
- feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
66
-
67
- feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
68
- feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4
69
-
70
- feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
71
- feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8
72
-
73
- feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
74
- feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16
75
-
76
- feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
77
- feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32
78
-
79
- if self.input_size == 256:
80
- feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
81
- feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64
82
-
83
- # spatial size: (4, 4)
84
- feat = feat.view(feat.size(0), -1)
85
- feat = self.lrelu(self.linear1(feat))
86
- out = self.linear2(feat)
87
- return out
88
-
89
-
90
- @ARCH_REGISTRY.register(suffix='basicsr')
91
- class UNetDiscriminatorSN(nn.Module):
92
- """Defines a U-Net discriminator with spectral normalization (SN)
93
-
94
- It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
95
-
96
- Arg:
97
- num_in_ch (int): Channel number of inputs. Default: 3.
98
- num_feat (int): Channel number of base intermediate features. Default: 64.
99
- skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
100
- """
101
-
102
- def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
103
- super(UNetDiscriminatorSN, self).__init__()
104
- self.skip_connection = skip_connection
105
- norm = spectral_norm
106
- # the first convolution
107
- self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
108
- # downsample
109
- self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
110
- self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
111
- self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
112
- # upsample
113
- self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
114
- self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
115
- self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
116
- # extra convolutions
117
- self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
118
- self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
119
- self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
120
-
121
- def forward(self, x):
122
- # downsample
123
- x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
124
- x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
125
- x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
126
- x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
127
-
128
- # upsample
129
- x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
130
- x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
131
-
132
- if self.skip_connection:
133
- x4 = x4 + x2
134
- x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
135
- x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
136
-
137
- if self.skip_connection:
138
- x5 = x5 + x1
139
- x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
140
- x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
141
-
142
- if self.skip_connection:
143
- x6 = x6 + x0
144
-
145
- # extra convolutions
146
- out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
147
- out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
148
- out = self.conv9(out)
149
-
150
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/duf_arch.py DELETED
@@ -1,276 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from torch import nn as nn
4
- from torch.nn import functional as F
5
-
6
- from basicsr.utils.registry import ARCH_REGISTRY
7
-
8
-
9
- class DenseBlocksTemporalReduce(nn.Module):
10
- """A concatenation of 3 dense blocks with reduction in temporal dimension.
11
-
12
- Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
13
-
14
- Args:
15
- num_feat (int): Number of channels in the blocks. Default: 64.
16
- num_grow_ch (int): Growing factor of the dense blocks. Default: 32
17
- adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
18
- Set to false if you want to train from scratch. Default: False.
19
- """
20
-
21
- def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
22
- super(DenseBlocksTemporalReduce, self).__init__()
23
- if adapt_official_weights:
24
- eps = 1e-3
25
- momentum = 1e-3
26
- else: # pytorch default values
27
- eps = 1e-05
28
- momentum = 0.1
29
-
30
- self.temporal_reduce1 = nn.Sequential(
31
- nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
32
- nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
33
- nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
34
- nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
35
-
36
- self.temporal_reduce2 = nn.Sequential(
37
- nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
38
- nn.Conv3d(
39
- num_feat + num_grow_ch,
40
- num_feat + num_grow_ch, (1, 1, 1),
41
- stride=(1, 1, 1),
42
- padding=(0, 0, 0),
43
- bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
44
- nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
45
-
46
- self.temporal_reduce3 = nn.Sequential(
47
- nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
48
- nn.Conv3d(
49
- num_feat + 2 * num_grow_ch,
50
- num_feat + 2 * num_grow_ch, (1, 1, 1),
51
- stride=(1, 1, 1),
52
- padding=(0, 0, 0),
53
- bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
54
- nn.ReLU(inplace=True),
55
- nn.Conv3d(
56
- num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
57
-
58
- def forward(self, x):
59
- """
60
- Args:
61
- x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
62
-
63
- Returns:
64
- Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
65
- """
66
- x1 = self.temporal_reduce1(x)
67
- x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
68
-
69
- x2 = self.temporal_reduce2(x1)
70
- x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
71
-
72
- x3 = self.temporal_reduce3(x2)
73
- x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
74
-
75
- return x3
76
-
77
-
78
- class DenseBlocks(nn.Module):
79
- """ A concatenation of N dense blocks.
80
-
81
- Args:
82
- num_feat (int): Number of channels in the blocks. Default: 64.
83
- num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
84
- num_block (int): Number of dense blocks. The values are:
85
- DUF-S (16 layers): 3
86
- DUF-M (18 layers): 9
87
- DUF-L (52 layers): 21
88
- adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
89
- Set to false if you want to train from scratch. Default: False.
90
- """
91
-
92
- def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
93
- super(DenseBlocks, self).__init__()
94
- if adapt_official_weights:
95
- eps = 1e-3
96
- momentum = 1e-3
97
- else: # pytorch default values
98
- eps = 1e-05
99
- momentum = 0.1
100
-
101
- self.dense_blocks = nn.ModuleList()
102
- for i in range(0, num_block):
103
- self.dense_blocks.append(
104
- nn.Sequential(
105
- nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
106
- nn.Conv3d(
107
- num_feat + i * num_grow_ch,
108
- num_feat + i * num_grow_ch, (1, 1, 1),
109
- stride=(1, 1, 1),
110
- padding=(0, 0, 0),
111
- bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
112
- nn.ReLU(inplace=True),
113
- nn.Conv3d(
114
- num_feat + i * num_grow_ch,
115
- num_grow_ch, (3, 3, 3),
116
- stride=(1, 1, 1),
117
- padding=(1, 1, 1),
118
- bias=True)))
119
-
120
- def forward(self, x):
121
- """
122
- Args:
123
- x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
124
-
125
- Returns:
126
- Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
127
- """
128
- for i in range(0, len(self.dense_blocks)):
129
- y = self.dense_blocks[i](x)
130
- x = torch.cat((x, y), 1)
131
- return x
132
-
133
-
134
- class DynamicUpsamplingFilter(nn.Module):
135
- """Dynamic upsampling filter used in DUF.
136
-
137
- Reference: https://github.com/yhjo09/VSR-DUF
138
-
139
- It only supports input with 3 channels. And it applies the same filters to 3 channels.
140
-
141
- Args:
142
- filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
143
- """
144
-
145
- def __init__(self, filter_size=(5, 5)):
146
- super(DynamicUpsamplingFilter, self).__init__()
147
- if not isinstance(filter_size, tuple):
148
- raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
149
- if len(filter_size) != 2:
150
- raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
151
- # generate a local expansion filter, similar to im2col
152
- self.filter_size = filter_size
153
- filter_prod = np.prod(filter_size)
154
- expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw)
155
- self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels
156
-
157
- def forward(self, x, filters):
158
- """Forward function for DynamicUpsamplingFilter.
159
-
160
- Args:
161
- x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
162
- filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w).
163
- filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
164
- upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling.
165
- e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
166
-
167
- Returns:
168
- Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
169
- """
170
- n, filter_prod, upsampling_square, h, w = filters.size()
171
- kh, kw = self.filter_size
172
- expanded_input = F.conv2d(
173
- x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
174
- expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
175
- 2) # (n, h, w, 3, filter_prod)
176
- filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
177
- out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
178
- return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
179
-
180
-
181
- @ARCH_REGISTRY.register()
182
- class DUF(nn.Module):
183
- """Network architecture for DUF
184
-
185
- ``Paper: Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation``
186
-
187
- Reference: https://github.com/yhjo09/VSR-DUF
188
-
189
- For all the models below, 'adapt_official_weights' is only necessary when
190
- loading the weights converted from the official TensorFlow weights.
191
- Please set it to False if you are training the model from scratch.
192
-
193
- There are three models with different model size: DUF16Layers, DUF28Layers,
194
- and DUF52Layers. This class is the base class for these models.
195
-
196
- Args:
197
- scale (int): The upsampling factor. Default: 4.
198
- num_layer (int): The number of layers. Default: 52.
199
- adapt_official_weights_weights (bool): Whether to adapt the weights
200
- translated from the official implementation. Set to false if you
201
- want to train from scratch. Default: False.
202
- """
203
-
204
- def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
205
- super(DUF, self).__init__()
206
- self.scale = scale
207
- if adapt_official_weights:
208
- eps = 1e-3
209
- momentum = 1e-3
210
- else: # pytorch default values
211
- eps = 1e-05
212
- momentum = 0.1
213
-
214
- self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
215
- self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
216
-
217
- if num_layer == 16:
218
- num_block = 3
219
- num_grow_ch = 32
220
- elif num_layer == 28:
221
- num_block = 9
222
- num_grow_ch = 16
223
- elif num_layer == 52:
224
- num_block = 21
225
- num_grow_ch = 16
226
- else:
227
- raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
228
-
229
- self.dense_block1 = DenseBlocks(
230
- num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
231
- adapt_official_weights=adapt_official_weights) # T = 7
232
- self.dense_block2 = DenseBlocksTemporalReduce(
233
- 64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
234
- channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
235
- self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
236
- self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
237
-
238
- self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
239
- self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
240
-
241
- self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
242
- self.conv3d_f2 = nn.Conv3d(
243
- 512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
244
-
245
- def forward(self, x):
246
- """
247
- Args:
248
- x (Tensor): Input with shape (b, 7, c, h, w)
249
-
250
- Returns:
251
- Tensor: Output with shape (b, c, h * scale, w * scale)
252
- """
253
- num_batches, num_imgs, _, h, w = x.size()
254
-
255
- x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D
256
- x_center = x[:, :, num_imgs // 2, :, :]
257
-
258
- x = self.conv3d1(x)
259
- x = self.dense_block1(x)
260
- x = self.dense_block2(x)
261
- x = F.relu(self.bn3d2(x), inplace=True)
262
- x = F.relu(self.conv3d2(x), inplace=True)
263
-
264
- # residual image
265
- res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
266
-
267
- # filter
268
- filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
269
- filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
270
-
271
- # dynamic filter
272
- out = self.dynamic_filter(x_center, filter_)
273
- out += res.squeeze_(2)
274
- out = F.pixel_shuffle(out, self.scale)
275
-
276
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/ecbsr_arch.py DELETED
@@ -1,275 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from basicsr.utils.registry import ARCH_REGISTRY
6
-
7
-
8
- class SeqConv3x3(nn.Module):
9
- """The re-parameterizable block used in the ECBSR architecture.
10
-
11
- ``Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices``
12
-
13
- Reference: https://github.com/xindongzhang/ECBSR
14
-
15
- Args:
16
- seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
17
- in_channels (int): Channel number of input.
18
- out_channels (int): Channel number of output.
19
- depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
20
- """
21
-
22
- def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
23
- super(SeqConv3x3, self).__init__()
24
- self.seq_type = seq_type
25
- self.in_channels = in_channels
26
- self.out_channels = out_channels
27
-
28
- if self.seq_type == 'conv1x1-conv3x3':
29
- self.mid_planes = int(out_channels * depth_multiplier)
30
- conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
31
- self.k0 = conv0.weight
32
- self.b0 = conv0.bias
33
-
34
- conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
35
- self.k1 = conv1.weight
36
- self.b1 = conv1.bias
37
-
38
- elif self.seq_type == 'conv1x1-sobelx':
39
- conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
40
- self.k0 = conv0.weight
41
- self.b0 = conv0.bias
42
-
43
- # init scale and bias
44
- scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
45
- self.scale = nn.Parameter(scale)
46
- bias = torch.randn(self.out_channels) * 1e-3
47
- bias = torch.reshape(bias, (self.out_channels, ))
48
- self.bias = nn.Parameter(bias)
49
- # init mask
50
- self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
51
- for i in range(self.out_channels):
52
- self.mask[i, 0, 0, 0] = 1.0
53
- self.mask[i, 0, 1, 0] = 2.0
54
- self.mask[i, 0, 2, 0] = 1.0
55
- self.mask[i, 0, 0, 2] = -1.0
56
- self.mask[i, 0, 1, 2] = -2.0
57
- self.mask[i, 0, 2, 2] = -1.0
58
- self.mask = nn.Parameter(data=self.mask, requires_grad=False)
59
-
60
- elif self.seq_type == 'conv1x1-sobely':
61
- conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
62
- self.k0 = conv0.weight
63
- self.b0 = conv0.bias
64
-
65
- # init scale and bias
66
- scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
67
- self.scale = nn.Parameter(torch.FloatTensor(scale))
68
- bias = torch.randn(self.out_channels) * 1e-3
69
- bias = torch.reshape(bias, (self.out_channels, ))
70
- self.bias = nn.Parameter(torch.FloatTensor(bias))
71
- # init mask
72
- self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
73
- for i in range(self.out_channels):
74
- self.mask[i, 0, 0, 0] = 1.0
75
- self.mask[i, 0, 0, 1] = 2.0
76
- self.mask[i, 0, 0, 2] = 1.0
77
- self.mask[i, 0, 2, 0] = -1.0
78
- self.mask[i, 0, 2, 1] = -2.0
79
- self.mask[i, 0, 2, 2] = -1.0
80
- self.mask = nn.Parameter(data=self.mask, requires_grad=False)
81
-
82
- elif self.seq_type == 'conv1x1-laplacian':
83
- conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
84
- self.k0 = conv0.weight
85
- self.b0 = conv0.bias
86
-
87
- # init scale and bias
88
- scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
89
- self.scale = nn.Parameter(torch.FloatTensor(scale))
90
- bias = torch.randn(self.out_channels) * 1e-3
91
- bias = torch.reshape(bias, (self.out_channels, ))
92
- self.bias = nn.Parameter(torch.FloatTensor(bias))
93
- # init mask
94
- self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
95
- for i in range(self.out_channels):
96
- self.mask[i, 0, 0, 1] = 1.0
97
- self.mask[i, 0, 1, 0] = 1.0
98
- self.mask[i, 0, 1, 2] = 1.0
99
- self.mask[i, 0, 2, 1] = 1.0
100
- self.mask[i, 0, 1, 1] = -4.0
101
- self.mask = nn.Parameter(data=self.mask, requires_grad=False)
102
- else:
103
- raise ValueError('The type of seqconv is not supported!')
104
-
105
- def forward(self, x):
106
- if self.seq_type == 'conv1x1-conv3x3':
107
- # conv-1x1
108
- y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
109
- # explicitly padding with bias
110
- y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
111
- b0_pad = self.b0.view(1, -1, 1, 1)
112
- y0[:, :, 0:1, :] = b0_pad
113
- y0[:, :, -1:, :] = b0_pad
114
- y0[:, :, :, 0:1] = b0_pad
115
- y0[:, :, :, -1:] = b0_pad
116
- # conv-3x3
117
- y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
118
- else:
119
- y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
120
- # explicitly padding with bias
121
- y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
122
- b0_pad = self.b0.view(1, -1, 1, 1)
123
- y0[:, :, 0:1, :] = b0_pad
124
- y0[:, :, -1:, :] = b0_pad
125
- y0[:, :, :, 0:1] = b0_pad
126
- y0[:, :, :, -1:] = b0_pad
127
- # conv-3x3
128
- y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
129
- return y1
130
-
131
- def rep_params(self):
132
- device = self.k0.get_device()
133
- if device < 0:
134
- device = None
135
-
136
- if self.seq_type == 'conv1x1-conv3x3':
137
- # re-param conv kernel
138
- rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
139
- # re-param conv bias
140
- rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
141
- rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
142
- else:
143
- tmp = self.scale * self.mask
144
- k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
145
- for i in range(self.out_channels):
146
- k1[i, i, :, :] = tmp[i, 0, :, :]
147
- b1 = self.bias
148
- # re-param conv kernel
149
- rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
150
- # re-param conv bias
151
- rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
152
- rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
153
- return rep_weight, rep_bias
154
-
155
-
156
- class ECB(nn.Module):
157
- """The ECB block used in the ECBSR architecture.
158
-
159
- Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
160
- Ref git repo: https://github.com/xindongzhang/ECBSR
161
-
162
- Args:
163
- in_channels (int): Channel number of input.
164
- out_channels (int): Channel number of output.
165
- depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
166
- act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
167
- with_idt (bool): Whether to use identity connection. Default: False.
168
- """
169
-
170
- def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
171
- super(ECB, self).__init__()
172
-
173
- self.depth_multiplier = depth_multiplier
174
- self.in_channels = in_channels
175
- self.out_channels = out_channels
176
- self.act_type = act_type
177
-
178
- if with_idt and (self.in_channels == self.out_channels):
179
- self.with_idt = True
180
- else:
181
- self.with_idt = False
182
-
183
- self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
184
- self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
185
- self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
186
- self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
187
- self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
188
-
189
- if self.act_type == 'prelu':
190
- self.act = nn.PReLU(num_parameters=self.out_channels)
191
- elif self.act_type == 'relu':
192
- self.act = nn.ReLU(inplace=True)
193
- elif self.act_type == 'rrelu':
194
- self.act = nn.RReLU(lower=-0.05, upper=0.05)
195
- elif self.act_type == 'softplus':
196
- self.act = nn.Softplus()
197
- elif self.act_type == 'linear':
198
- pass
199
- else:
200
- raise ValueError('The type of activation if not support!')
201
-
202
- def forward(self, x):
203
- if self.training:
204
- y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
205
- if self.with_idt:
206
- y += x
207
- else:
208
- rep_weight, rep_bias = self.rep_params()
209
- y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
210
- if self.act_type != 'linear':
211
- y = self.act(y)
212
- return y
213
-
214
- def rep_params(self):
215
- weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
216
- weight1, bias1 = self.conv1x1_3x3.rep_params()
217
- weight2, bias2 = self.conv1x1_sbx.rep_params()
218
- weight3, bias3 = self.conv1x1_sby.rep_params()
219
- weight4, bias4 = self.conv1x1_lpl.rep_params()
220
- rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
221
- bias0 + bias1 + bias2 + bias3 + bias4)
222
-
223
- if self.with_idt:
224
- device = rep_weight.get_device()
225
- if device < 0:
226
- device = None
227
- weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
228
- for i in range(self.out_channels):
229
- weight_idt[i, i, 1, 1] = 1.0
230
- bias_idt = 0.0
231
- rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
232
- return rep_weight, rep_bias
233
-
234
-
235
- @ARCH_REGISTRY.register()
236
- class ECBSR(nn.Module):
237
- """ECBSR architecture.
238
-
239
- Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
240
- Ref git repo: https://github.com/xindongzhang/ECBSR
241
-
242
- Args:
243
- num_in_ch (int): Channel number of inputs.
244
- num_out_ch (int): Channel number of outputs.
245
- num_block (int): Block number in the trunk network.
246
- num_channel (int): Channel number.
247
- with_idt (bool): Whether use identity in convolution layers.
248
- act_type (str): Activation type.
249
- scale (int): Upsampling factor.
250
- """
251
-
252
- def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
253
- super(ECBSR, self).__init__()
254
- self.num_in_ch = num_in_ch
255
- self.scale = scale
256
-
257
- backbone = []
258
- backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
259
- for _ in range(num_block):
260
- backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
261
- backbone += [
262
- ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
263
- ]
264
-
265
- self.backbone = nn.Sequential(*backbone)
266
- self.upsampler = nn.PixelShuffle(scale)
267
-
268
- def forward(self, x):
269
- if self.num_in_ch > 1:
270
- shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
271
- else:
272
- shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
273
- y = self.backbone(x) + shortcut
274
- y = self.upsampler(y)
275
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/edsr_arch.py DELETED
@@ -1,61 +0,0 @@
1
- import torch
2
- from torch import nn as nn
3
-
4
- from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
5
- from basicsr.utils.registry import ARCH_REGISTRY
6
-
7
-
8
- @ARCH_REGISTRY.register()
9
- class EDSR(nn.Module):
10
- """EDSR network structure.
11
-
12
- Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
13
- Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
14
-
15
- Args:
16
- num_in_ch (int): Channel number of inputs.
17
- num_out_ch (int): Channel number of outputs.
18
- num_feat (int): Channel number of intermediate features.
19
- Default: 64.
20
- num_block (int): Block number in the trunk network. Default: 16.
21
- upscale (int): Upsampling factor. Support 2^n and 3.
22
- Default: 4.
23
- res_scale (float): Used to scale the residual in residual block.
24
- Default: 1.
25
- img_range (float): Image range. Default: 255.
26
- rgb_mean (tuple[float]): Image mean in RGB orders.
27
- Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
28
- """
29
-
30
- def __init__(self,
31
- num_in_ch,
32
- num_out_ch,
33
- num_feat=64,
34
- num_block=16,
35
- upscale=4,
36
- res_scale=1,
37
- img_range=255.,
38
- rgb_mean=(0.4488, 0.4371, 0.4040)):
39
- super(EDSR, self).__init__()
40
-
41
- self.img_range = img_range
42
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
43
-
44
- self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
45
- self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
46
- self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
47
- self.upsample = Upsample(upscale, num_feat)
48
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
49
-
50
- def forward(self, x):
51
- self.mean = self.mean.type_as(x)
52
-
53
- x = (x - self.mean) * self.img_range
54
- x = self.conv_first(x)
55
- res = self.conv_after_body(self.body(x))
56
- res += x
57
-
58
- x = self.conv_last(self.upsample(res))
59
- x = x / self.img_range + self.mean
60
-
61
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/edvr_arch.py DELETED
@@ -1,382 +0,0 @@
1
- import torch
2
- from torch import nn as nn
3
- from torch.nn import functional as F
4
-
5
- from basicsr.utils.registry import ARCH_REGISTRY
6
- from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
7
-
8
-
9
- class PCDAlignment(nn.Module):
10
- """Alignment module using Pyramid, Cascading and Deformable convolution
11
- (PCD). It is used in EDVR.
12
-
13
- ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
14
-
15
- Args:
16
- num_feat (int): Channel number of middle features. Default: 64.
17
- deformable_groups (int): Deformable groups. Defaults: 8.
18
- """
19
-
20
- def __init__(self, num_feat=64, deformable_groups=8):
21
- super(PCDAlignment, self).__init__()
22
-
23
- # Pyramid has three levels:
24
- # L3: level 3, 1/4 spatial size
25
- # L2: level 2, 1/2 spatial size
26
- # L1: level 1, original spatial size
27
- self.offset_conv1 = nn.ModuleDict()
28
- self.offset_conv2 = nn.ModuleDict()
29
- self.offset_conv3 = nn.ModuleDict()
30
- self.dcn_pack = nn.ModuleDict()
31
- self.feat_conv = nn.ModuleDict()
32
-
33
- # Pyramids
34
- for i in range(3, 0, -1):
35
- level = f'l{i}'
36
- self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
37
- if i == 3:
38
- self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
39
- else:
40
- self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
41
- self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
42
- self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
43
-
44
- if i < 3:
45
- self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
46
-
47
- # Cascading dcn
48
- self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
49
- self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
50
- self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
51
-
52
- self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
53
- self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
54
-
55
- def forward(self, nbr_feat_l, ref_feat_l):
56
- """Align neighboring frame features to the reference frame features.
57
-
58
- Args:
59
- nbr_feat_l (list[Tensor]): Neighboring feature list. It
60
- contains three pyramid levels (L1, L2, L3),
61
- each with shape (b, c, h, w).
62
- ref_feat_l (list[Tensor]): Reference feature list. It
63
- contains three pyramid levels (L1, L2, L3),
64
- each with shape (b, c, h, w).
65
-
66
- Returns:
67
- Tensor: Aligned features.
68
- """
69
- # Pyramids
70
- upsampled_offset, upsampled_feat = None, None
71
- for i in range(3, 0, -1):
72
- level = f'l{i}'
73
- offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
74
- offset = self.lrelu(self.offset_conv1[level](offset))
75
- if i == 3:
76
- offset = self.lrelu(self.offset_conv2[level](offset))
77
- else:
78
- offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
79
- offset = self.lrelu(self.offset_conv3[level](offset))
80
-
81
- feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
82
- if i < 3:
83
- feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
84
- if i > 1:
85
- feat = self.lrelu(feat)
86
-
87
- if i > 1: # upsample offset and features
88
- # x2: when we upsample the offset, we should also enlarge
89
- # the magnitude.
90
- upsampled_offset = self.upsample(offset) * 2
91
- upsampled_feat = self.upsample(feat)
92
-
93
- # Cascading
94
- offset = torch.cat([feat, ref_feat_l[0]], dim=1)
95
- offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
96
- feat = self.lrelu(self.cas_dcnpack(feat, offset))
97
- return feat
98
-
99
-
100
- class TSAFusion(nn.Module):
101
- """Temporal Spatial Attention (TSA) fusion module.
102
-
103
- Temporal: Calculate the correlation between center frame and
104
- neighboring frames;
105
- Spatial: It has 3 pyramid levels, the attention is similar to SFT.
106
- (SFT: Recovering realistic texture in image super-resolution by deep
107
- spatial feature transform.)
108
-
109
- Args:
110
- num_feat (int): Channel number of middle features. Default: 64.
111
- num_frame (int): Number of frames. Default: 5.
112
- center_frame_idx (int): The index of center frame. Default: 2.
113
- """
114
-
115
- def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
116
- super(TSAFusion, self).__init__()
117
- self.center_frame_idx = center_frame_idx
118
- # temporal attention (before fusion conv)
119
- self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
120
- self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
121
- self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
122
-
123
- # spatial attention (after fusion conv)
124
- self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
125
- self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
126
- self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
127
- self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
128
- self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
129
- self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
130
- self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
131
- self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
132
- self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
133
- self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
134
- self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
135
- self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
136
-
137
- self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
138
- self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
139
-
140
- def forward(self, aligned_feat):
141
- """
142
- Args:
143
- aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
144
-
145
- Returns:
146
- Tensor: Features after TSA with the shape (b, c, h, w).
147
- """
148
- b, t, c, h, w = aligned_feat.size()
149
- # temporal attention
150
- embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
151
- embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
152
- embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
153
-
154
- corr_l = [] # correlation list
155
- for i in range(t):
156
- emb_neighbor = embedding[:, i, :, :, :]
157
- corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
158
- corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
159
- corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
160
- corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
161
- corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
162
- aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
163
-
164
- # fusion
165
- feat = self.lrelu(self.feat_fusion(aligned_feat))
166
-
167
- # spatial attention
168
- attn = self.lrelu(self.spatial_attn1(aligned_feat))
169
- attn_max = self.max_pool(attn)
170
- attn_avg = self.avg_pool(attn)
171
- attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
172
- # pyramid levels
173
- attn_level = self.lrelu(self.spatial_attn_l1(attn))
174
- attn_max = self.max_pool(attn_level)
175
- attn_avg = self.avg_pool(attn_level)
176
- attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
177
- attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
178
- attn_level = self.upsample(attn_level)
179
-
180
- attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
181
- attn = self.lrelu(self.spatial_attn4(attn))
182
- attn = self.upsample(attn)
183
- attn = self.spatial_attn5(attn)
184
- attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
185
- attn = torch.sigmoid(attn)
186
-
187
- # after initialization, * 2 makes (attn * 2) to be close to 1.
188
- feat = feat * attn * 2 + attn_add
189
- return feat
190
-
191
-
192
- class PredeblurModule(nn.Module):
193
- """Pre-dublur module.
194
-
195
- Args:
196
- num_in_ch (int): Channel number of input image. Default: 3.
197
- num_feat (int): Channel number of intermediate features. Default: 64.
198
- hr_in (bool): Whether the input has high resolution. Default: False.
199
- """
200
-
201
- def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
202
- super(PredeblurModule, self).__init__()
203
- self.hr_in = hr_in
204
-
205
- self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
206
- if self.hr_in:
207
- # downsample x4 by stride conv
208
- self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
209
- self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
210
-
211
- # generate feature pyramid
212
- self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
213
- self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
214
-
215
- self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
216
- self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
217
- self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
218
- self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
219
-
220
- self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
221
- self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
222
-
223
- def forward(self, x):
224
- feat_l1 = self.lrelu(self.conv_first(x))
225
- if self.hr_in:
226
- feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
227
- feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
228
-
229
- # generate feature pyramid
230
- feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
231
- feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
232
-
233
- feat_l3 = self.upsample(self.resblock_l3(feat_l3))
234
- feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
235
- feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
236
-
237
- for i in range(2):
238
- feat_l1 = self.resblock_l1[i](feat_l1)
239
- feat_l1 = feat_l1 + feat_l2
240
- for i in range(2, 5):
241
- feat_l1 = self.resblock_l1[i](feat_l1)
242
- return feat_l1
243
-
244
-
245
- @ARCH_REGISTRY.register()
246
- class EDVR(nn.Module):
247
- """EDVR network structure for video super-resolution.
248
-
249
- Now only support X4 upsampling factor.
250
-
251
- ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
252
-
253
- Args:
254
- num_in_ch (int): Channel number of input image. Default: 3.
255
- num_out_ch (int): Channel number of output image. Default: 3.
256
- num_feat (int): Channel number of intermediate features. Default: 64.
257
- num_frame (int): Number of input frames. Default: 5.
258
- deformable_groups (int): Deformable groups. Defaults: 8.
259
- num_extract_block (int): Number of blocks for feature extraction.
260
- Default: 5.
261
- num_reconstruct_block (int): Number of blocks for reconstruction.
262
- Default: 10.
263
- center_frame_idx (int): The index of center frame. Frame counting from
264
- 0. Default: Middle of input frames.
265
- hr_in (bool): Whether the input has high resolution. Default: False.
266
- with_predeblur (bool): Whether has predeblur module.
267
- Default: False.
268
- with_tsa (bool): Whether has TSA module. Default: True.
269
- """
270
-
271
- def __init__(self,
272
- num_in_ch=3,
273
- num_out_ch=3,
274
- num_feat=64,
275
- num_frame=5,
276
- deformable_groups=8,
277
- num_extract_block=5,
278
- num_reconstruct_block=10,
279
- center_frame_idx=None,
280
- hr_in=False,
281
- with_predeblur=False,
282
- with_tsa=True):
283
- super(EDVR, self).__init__()
284
- if center_frame_idx is None:
285
- self.center_frame_idx = num_frame // 2
286
- else:
287
- self.center_frame_idx = center_frame_idx
288
- self.hr_in = hr_in
289
- self.with_predeblur = with_predeblur
290
- self.with_tsa = with_tsa
291
-
292
- # extract features for each frame
293
- if self.with_predeblur:
294
- self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
295
- self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
296
- else:
297
- self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
298
-
299
- # extract pyramid features
300
- self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
301
- self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
302
- self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
303
- self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
304
- self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
305
-
306
- # pcd and tsa module
307
- self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
308
- if self.with_tsa:
309
- self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
310
- else:
311
- self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
312
-
313
- # reconstruction
314
- self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
315
- # upsample
316
- self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
317
- self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
318
- self.pixel_shuffle = nn.PixelShuffle(2)
319
- self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
320
- self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
321
-
322
- # activation function
323
- self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
324
-
325
- def forward(self, x):
326
- b, t, c, h, w = x.size()
327
- if self.hr_in:
328
- assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
329
- else:
330
- assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
331
-
332
- x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
333
-
334
- # extract features for each frame
335
- # L1
336
- if self.with_predeblur:
337
- feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
338
- if self.hr_in:
339
- h, w = h // 4, w // 4
340
- else:
341
- feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
342
-
343
- feat_l1 = self.feature_extraction(feat_l1)
344
- # L2
345
- feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
346
- feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
347
- # L3
348
- feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
349
- feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
350
-
351
- feat_l1 = feat_l1.view(b, t, -1, h, w)
352
- feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
353
- feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
354
-
355
- # PCD alignment
356
- ref_feat_l = [ # reference feature list
357
- feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
358
- feat_l3[:, self.center_frame_idx, :, :, :].clone()
359
- ]
360
- aligned_feat = []
361
- for i in range(t):
362
- nbr_feat_l = [ # neighboring feature list
363
- feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
364
- ]
365
- aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
366
- aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
367
-
368
- if not self.with_tsa:
369
- aligned_feat = aligned_feat.view(b, -1, h, w)
370
- feat = self.fusion(aligned_feat)
371
-
372
- out = self.reconstruction(feat)
373
- out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
374
- out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
375
- out = self.lrelu(self.conv_hr(out))
376
- out = self.conv_last(out)
377
- if self.hr_in:
378
- base = x_center
379
- else:
380
- base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
381
- out += base
382
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/hifacegan_arch.py DELETED
@@ -1,260 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- from basicsr.utils.registry import ARCH_REGISTRY
7
- from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
8
-
9
-
10
- class SPADEGenerator(BaseNetwork):
11
- """Generator with SPADEResBlock"""
12
-
13
- def __init__(self,
14
- num_in_ch=3,
15
- num_feat=64,
16
- use_vae=False,
17
- z_dim=256,
18
- crop_size=512,
19
- norm_g='spectralspadesyncbatch3x3',
20
- is_train=True,
21
- init_train_phase=3): # progressive training disabled
22
- super().__init__()
23
- self.nf = num_feat
24
- self.input_nc = num_in_ch
25
- self.is_train = is_train
26
- self.train_phase = init_train_phase
27
-
28
- self.scale_ratio = 5 # hardcoded now
29
- self.sw = crop_size // (2**self.scale_ratio)
30
- self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0
31
-
32
- if use_vae:
33
- # In case of VAE, we will sample from random z vector
34
- self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
35
- else:
36
- # Otherwise, we make the network deterministic by starting with
37
- # downsampled segmentation map instead of random z
38
- self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
39
-
40
- self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
41
-
42
- self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
43
- self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
44
-
45
- self.ups = nn.ModuleList([
46
- SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
47
- SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
48
- SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
49
- SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
50
- ])
51
-
52
- self.to_rgbs = nn.ModuleList([
53
- nn.Conv2d(8 * self.nf, 3, 3, padding=1),
54
- nn.Conv2d(4 * self.nf, 3, 3, padding=1),
55
- nn.Conv2d(2 * self.nf, 3, 3, padding=1),
56
- nn.Conv2d(1 * self.nf, 3, 3, padding=1)
57
- ])
58
-
59
- self.up = nn.Upsample(scale_factor=2)
60
-
61
- def encode(self, input_tensor):
62
- """
63
- Encode input_tensor into feature maps, can be overridden in derived classes
64
- Default: nearest downsampling of 2**5 = 32 times
65
- """
66
- h, w = input_tensor.size()[-2:]
67
- sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
68
- x = F.interpolate(input_tensor, size=(sh, sw))
69
- return self.fc(x)
70
-
71
- def forward(self, x):
72
- # In oroginal SPADE, seg means a segmentation map, but here we use x instead.
73
- seg = x
74
-
75
- x = self.encode(x)
76
- x = self.head_0(x, seg)
77
-
78
- x = self.up(x)
79
- x = self.g_middle_0(x, seg)
80
- x = self.g_middle_1(x, seg)
81
-
82
- if self.is_train:
83
- phase = self.train_phase + 1
84
- else:
85
- phase = len(self.to_rgbs)
86
-
87
- for i in range(phase):
88
- x = self.up(x)
89
- x = self.ups[i](x, seg)
90
-
91
- x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
92
- x = torch.tanh(x)
93
-
94
- return x
95
-
96
- def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
97
- """
98
- A helper class for subspace visualization. Input and seg are different images.
99
- For the first n levels (including encoder) we use input, for the rest we use seg.
100
-
101
- If mode = 'progressive', the output's like: AAABBB
102
- If mode = 'one_plug', the output's like: AAABAA
103
- If mode = 'one_ablate', the output's like: BBBABB
104
- """
105
-
106
- if seg is None:
107
- return self.forward(input_x)
108
-
109
- if self.is_train:
110
- phase = self.train_phase + 1
111
- else:
112
- phase = len(self.to_rgbs)
113
-
114
- if mode == 'progressive':
115
- n = max(min(n, 4 + phase), 0)
116
- guide_list = [input_x] * n + [seg] * (4 + phase - n)
117
- elif mode == 'one_plug':
118
- n = max(min(n, 4 + phase - 1), 0)
119
- guide_list = [seg] * (4 + phase)
120
- guide_list[n] = input_x
121
- elif mode == 'one_ablate':
122
- if n > 3 + phase:
123
- return self.forward(input_x)
124
- guide_list = [input_x] * (4 + phase)
125
- guide_list[n] = seg
126
-
127
- x = self.encode(guide_list[0])
128
- x = self.head_0(x, guide_list[1])
129
-
130
- x = self.up(x)
131
- x = self.g_middle_0(x, guide_list[2])
132
- x = self.g_middle_1(x, guide_list[3])
133
-
134
- for i in range(phase):
135
- x = self.up(x)
136
- x = self.ups[i](x, guide_list[4 + i])
137
-
138
- x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
139
- x = torch.tanh(x)
140
-
141
- return x
142
-
143
-
144
- @ARCH_REGISTRY.register()
145
- class HiFaceGAN(SPADEGenerator):
146
- """
147
- HiFaceGAN: SPADEGenerator with a learnable feature encoder
148
- Current encoder design: LIPEncoder
149
- """
150
-
151
- def __init__(self,
152
- num_in_ch=3,
153
- num_feat=64,
154
- use_vae=False,
155
- z_dim=256,
156
- crop_size=512,
157
- norm_g='spectralspadesyncbatch3x3',
158
- is_train=True,
159
- init_train_phase=3):
160
- super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
161
- self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
162
-
163
- def encode(self, input_tensor):
164
- return self.lip_encoder(input_tensor)
165
-
166
-
167
- @ARCH_REGISTRY.register()
168
- class HiFaceGANDiscriminator(BaseNetwork):
169
- """
170
- Inspired by pix2pixHD multiscale discriminator.
171
-
172
- Args:
173
- num_in_ch (int): Channel number of inputs. Default: 3.
174
- num_out_ch (int): Channel number of outputs. Default: 3.
175
- conditional_d (bool): Whether use conditional discriminator.
176
- Default: True.
177
- num_d (int): Number of Multiscale discriminators. Default: 3.
178
- n_layers_d (int): Number of downsample layers in each D. Default: 4.
179
- num_feat (int): Channel number of base intermediate features.
180
- Default: 64.
181
- norm_d (str): String to determine normalization layers in D.
182
- Choices: [spectral][instance/batch/syncbatch]
183
- Default: 'spectralinstance'.
184
- keep_features (bool): Keep intermediate features for matching loss, etc.
185
- Default: True.
186
- """
187
-
188
- def __init__(self,
189
- num_in_ch=3,
190
- num_out_ch=3,
191
- conditional_d=True,
192
- num_d=2,
193
- n_layers_d=4,
194
- num_feat=64,
195
- norm_d='spectralinstance',
196
- keep_features=True):
197
- super().__init__()
198
- self.num_d = num_d
199
-
200
- input_nc = num_in_ch
201
- if conditional_d:
202
- input_nc += num_out_ch
203
-
204
- for i in range(num_d):
205
- subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
206
- self.add_module(f'discriminator_{i}', subnet_d)
207
-
208
- def downsample(self, x):
209
- return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
210
-
211
- # Returns list of lists of discriminator outputs.
212
- # The final result is of size opt.num_d x opt.n_layers_D
213
- def forward(self, x):
214
- result = []
215
- for _, _net_d in self.named_children():
216
- out = _net_d(x)
217
- result.append(out)
218
- x = self.downsample(x)
219
-
220
- return result
221
-
222
-
223
- class NLayerDiscriminator(BaseNetwork):
224
- """Defines the PatchGAN discriminator with the specified arguments."""
225
-
226
- def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
227
- super().__init__()
228
- kw = 4
229
- padw = int(np.ceil((kw - 1.0) / 2))
230
- nf = num_feat
231
- self.keep_features = keep_features
232
-
233
- norm_layer = get_nonspade_norm_layer(norm_d)
234
- sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
235
-
236
- for n in range(1, n_layers_d):
237
- nf_prev = nf
238
- nf = min(nf * 2, 512)
239
- stride = 1 if n == n_layers_d - 1 else 2
240
- sequence += [[
241
- norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
242
- nn.LeakyReLU(0.2, False)
243
- ]]
244
-
245
- sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
246
-
247
- # We divide the layers into groups to extract intermediate layer outputs
248
- for n in range(len(sequence)):
249
- self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
250
-
251
- def forward(self, x):
252
- results = [x]
253
- for submodel in self.children():
254
- intermediate_output = submodel(results[-1])
255
- results.append(intermediate_output)
256
-
257
- if self.keep_features:
258
- return results[1:]
259
- else:
260
- return results[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/hifacegan_util.py DELETED
@@ -1,255 +0,0 @@
1
- import re
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from torch.nn import init
6
- # Warning: spectral norm could be buggy
7
- # under eval mode and multi-GPU inference
8
- # A workaround is sticking to single-GPU inference and train mode
9
- from torch.nn.utils import spectral_norm
10
-
11
-
12
- class SPADE(nn.Module):
13
-
14
- def __init__(self, config_text, norm_nc, label_nc):
15
- super().__init__()
16
-
17
- assert config_text.startswith('spade')
18
- parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
19
- param_free_norm_type = str(parsed.group(1))
20
- ks = int(parsed.group(2))
21
-
22
- if param_free_norm_type == 'instance':
23
- self.param_free_norm = nn.InstanceNorm2d(norm_nc)
24
- elif param_free_norm_type == 'syncbatch':
25
- print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
26
- self.param_free_norm = nn.InstanceNorm2d(norm_nc)
27
- elif param_free_norm_type == 'batch':
28
- self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
29
- else:
30
- raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
31
-
32
- # The dimension of the intermediate embedding space. Yes, hardcoded.
33
- nhidden = 128 if norm_nc > 128 else norm_nc
34
-
35
- pw = ks // 2
36
- self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
37
- self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
38
- self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
39
-
40
- def forward(self, x, segmap):
41
-
42
- # Part 1. generate parameter-free normalized activations
43
- normalized = self.param_free_norm(x)
44
-
45
- # Part 2. produce scaling and bias conditioned on semantic map
46
- segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
47
- actv = self.mlp_shared(segmap)
48
- gamma = self.mlp_gamma(actv)
49
- beta = self.mlp_beta(actv)
50
-
51
- # apply scale and bias
52
- out = normalized * gamma + beta
53
-
54
- return out
55
-
56
-
57
- class SPADEResnetBlock(nn.Module):
58
- """
59
- ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
60
- it takes in the segmentation map as input, learns the skip connection if necessary,
61
- and applies normalization first and then convolution.
62
- This architecture seemed like a standard architecture for unconditional or
63
- class-conditional GAN architecture using residual block.
64
- The code was inspired from https://github.com/LMescheder/GAN_stability.
65
- """
66
-
67
- def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
68
- super().__init__()
69
- # Attributes
70
- self.learned_shortcut = (fin != fout)
71
- fmiddle = min(fin, fout)
72
-
73
- # create conv layers
74
- self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
75
- self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
76
- if self.learned_shortcut:
77
- self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
78
-
79
- # apply spectral norm if specified
80
- if 'spectral' in norm_g:
81
- self.conv_0 = spectral_norm(self.conv_0)
82
- self.conv_1 = spectral_norm(self.conv_1)
83
- if self.learned_shortcut:
84
- self.conv_s = spectral_norm(self.conv_s)
85
-
86
- # define normalization layers
87
- spade_config_str = norm_g.replace('spectral', '')
88
- self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
89
- self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
90
- if self.learned_shortcut:
91
- self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
92
-
93
- # note the resnet block with SPADE also takes in |seg|,
94
- # the semantic segmentation map as input
95
- def forward(self, x, seg):
96
- x_s = self.shortcut(x, seg)
97
- dx = self.conv_0(self.act(self.norm_0(x, seg)))
98
- dx = self.conv_1(self.act(self.norm_1(dx, seg)))
99
- out = x_s + dx
100
- return out
101
-
102
- def shortcut(self, x, seg):
103
- if self.learned_shortcut:
104
- x_s = self.conv_s(self.norm_s(x, seg))
105
- else:
106
- x_s = x
107
- return x_s
108
-
109
- def act(self, x):
110
- return F.leaky_relu(x, 2e-1)
111
-
112
-
113
- class BaseNetwork(nn.Module):
114
- """ A basis for hifacegan archs with custom initialization """
115
-
116
- def init_weights(self, init_type='normal', gain=0.02):
117
-
118
- def init_func(m):
119
- classname = m.__class__.__name__
120
- if classname.find('BatchNorm2d') != -1:
121
- if hasattr(m, 'weight') and m.weight is not None:
122
- init.normal_(m.weight.data, 1.0, gain)
123
- if hasattr(m, 'bias') and m.bias is not None:
124
- init.constant_(m.bias.data, 0.0)
125
- elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
126
- if init_type == 'normal':
127
- init.normal_(m.weight.data, 0.0, gain)
128
- elif init_type == 'xavier':
129
- init.xavier_normal_(m.weight.data, gain=gain)
130
- elif init_type == 'xavier_uniform':
131
- init.xavier_uniform_(m.weight.data, gain=1.0)
132
- elif init_type == 'kaiming':
133
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
134
- elif init_type == 'orthogonal':
135
- init.orthogonal_(m.weight.data, gain=gain)
136
- elif init_type == 'none': # uses pytorch's default init method
137
- m.reset_parameters()
138
- else:
139
- raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
140
- if hasattr(m, 'bias') and m.bias is not None:
141
- init.constant_(m.bias.data, 0.0)
142
-
143
- self.apply(init_func)
144
-
145
- # propagate to children
146
- for m in self.children():
147
- if hasattr(m, 'init_weights'):
148
- m.init_weights(init_type, gain)
149
-
150
- def forward(self, x):
151
- pass
152
-
153
-
154
- def lip2d(x, logit, kernel=3, stride=2, padding=1):
155
- weight = logit.exp()
156
- return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
157
-
158
-
159
- class SoftGate(nn.Module):
160
- COEFF = 12.0
161
-
162
- def forward(self, x):
163
- return torch.sigmoid(x).mul(self.COEFF)
164
-
165
-
166
- class SimplifiedLIP(nn.Module):
167
-
168
- def __init__(self, channels):
169
- super(SimplifiedLIP, self).__init__()
170
- self.logit = nn.Sequential(
171
- nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
172
- SoftGate())
173
-
174
- def init_layer(self):
175
- self.logit[0].weight.data.fill_(0.0)
176
-
177
- def forward(self, x):
178
- frac = lip2d(x, self.logit(x))
179
- return frac
180
-
181
-
182
- class LIPEncoder(BaseNetwork):
183
- """Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
184
-
185
- def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
186
- super().__init__()
187
- self.sw = sw
188
- self.sh = sh
189
- self.max_ratio = 16
190
- # 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
191
- kw = 3
192
- pw = (kw - 1) // 2
193
-
194
- model = [
195
- nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
196
- norm_layer(ngf),
197
- nn.ReLU(),
198
- ]
199
- cur_ratio = 1
200
- for i in range(n_2xdown):
201
- next_ratio = min(cur_ratio * 2, self.max_ratio)
202
- model += [
203
- SimplifiedLIP(ngf * cur_ratio),
204
- nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
205
- norm_layer(ngf * next_ratio),
206
- ]
207
- cur_ratio = next_ratio
208
- if i < n_2xdown - 1:
209
- model += [nn.ReLU(inplace=True)]
210
-
211
- self.model = nn.Sequential(*model)
212
-
213
- def forward(self, x):
214
- return self.model(x)
215
-
216
-
217
- def get_nonspade_norm_layer(norm_type='instance'):
218
- # helper function to get # output channels of the previous layer
219
- def get_out_channel(layer):
220
- if hasattr(layer, 'out_channels'):
221
- return getattr(layer, 'out_channels')
222
- return layer.weight.size(0)
223
-
224
- # this function will be returned
225
- def add_norm_layer(layer):
226
- nonlocal norm_type
227
- if norm_type.startswith('spectral'):
228
- layer = spectral_norm(layer)
229
- subnorm_type = norm_type[len('spectral'):]
230
-
231
- if subnorm_type == 'none' or len(subnorm_type) == 0:
232
- return layer
233
-
234
- # remove bias in the previous layer, which is meaningless
235
- # since it has no effect after normalization
236
- if getattr(layer, 'bias', None) is not None:
237
- delattr(layer, 'bias')
238
- layer.register_parameter('bias', None)
239
-
240
- if subnorm_type == 'batch':
241
- norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
242
- elif subnorm_type == 'sync_batch':
243
- print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
244
- # norm_layer = SynchronizedBatchNorm2d(
245
- # get_out_channel(layer), affine=True)
246
- norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
247
- elif subnorm_type == 'instance':
248
- norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
249
- else:
250
- raise ValueError(f'normalization layer {subnorm_type} is not recognized')
251
-
252
- return nn.Sequential(layer, norm_layer)
253
-
254
- print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
255
- return add_norm_layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/inception.py DELETED
@@ -1,307 +0,0 @@
1
- # Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
2
- # For FID metric
3
-
4
- import os
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from torch.utils.model_zoo import load_url
9
- from torchvision import models
10
-
11
- # Inception weights ported to Pytorch from
12
- # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
- FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14
- LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
15
-
16
-
17
- class InceptionV3(nn.Module):
18
- """Pretrained InceptionV3 network returning feature maps"""
19
-
20
- # Index of default block of inception to return,
21
- # corresponds to output of final average pooling
22
- DEFAULT_BLOCK_INDEX = 3
23
-
24
- # Maps feature dimensionality to their output blocks indices
25
- BLOCK_INDEX_BY_DIM = {
26
- 64: 0, # First max pooling features
27
- 192: 1, # Second max pooling features
28
- 768: 2, # Pre-aux classifier features
29
- 2048: 3 # Final average pooling features
30
- }
31
-
32
- def __init__(self,
33
- output_blocks=(DEFAULT_BLOCK_INDEX),
34
- resize_input=True,
35
- normalize_input=True,
36
- requires_grad=False,
37
- use_fid_inception=True):
38
- """Build pretrained InceptionV3.
39
-
40
- Args:
41
- output_blocks (list[int]): Indices of blocks to return features of.
42
- Possible values are:
43
- - 0: corresponds to output of first max pooling
44
- - 1: corresponds to output of second max pooling
45
- - 2: corresponds to output which is fed to aux classifier
46
- - 3: corresponds to output of final average pooling
47
- resize_input (bool): If true, bilinearly resizes input to width and
48
- height 299 before feeding input to model. As the network
49
- without fully connected layers is fully convolutional, it
50
- should be able to handle inputs of arbitrary size, so resizing
51
- might not be strictly needed. Default: True.
52
- normalize_input (bool): If true, scales the input from range (0, 1)
53
- to the range the pretrained Inception network expects,
54
- namely (-1, 1). Default: True.
55
- requires_grad (bool): If true, parameters of the model require
56
- gradients. Possibly useful for finetuning the network.
57
- Default: False.
58
- use_fid_inception (bool): If true, uses the pretrained Inception
59
- model used in Tensorflow's FID implementation.
60
- If false, uses the pretrained Inception model available in
61
- torchvision. The FID Inception model has different weights
62
- and a slightly different structure from torchvision's
63
- Inception model. If you want to compute FID scores, you are
64
- strongly advised to set this parameter to true to get
65
- comparable results. Default: True.
66
- """
67
- super(InceptionV3, self).__init__()
68
-
69
- self.resize_input = resize_input
70
- self.normalize_input = normalize_input
71
- self.output_blocks = sorted(output_blocks)
72
- self.last_needed_block = max(output_blocks)
73
-
74
- assert self.last_needed_block <= 3, ('Last possible output block index is 3')
75
-
76
- self.blocks = nn.ModuleList()
77
-
78
- if use_fid_inception:
79
- inception = fid_inception_v3()
80
- else:
81
- try:
82
- inception = models.inception_v3(pretrained=True, init_weights=False)
83
- except TypeError:
84
- # pytorch < 1.5 does not have init_weights for inception_v3
85
- inception = models.inception_v3(pretrained=True)
86
-
87
- # Block 0: input to maxpool1
88
- block0 = [
89
- inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
90
- nn.MaxPool2d(kernel_size=3, stride=2)
91
- ]
92
- self.blocks.append(nn.Sequential(*block0))
93
-
94
- # Block 1: maxpool1 to maxpool2
95
- if self.last_needed_block >= 1:
96
- block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
97
- self.blocks.append(nn.Sequential(*block1))
98
-
99
- # Block 2: maxpool2 to aux classifier
100
- if self.last_needed_block >= 2:
101
- block2 = [
102
- inception.Mixed_5b,
103
- inception.Mixed_5c,
104
- inception.Mixed_5d,
105
- inception.Mixed_6a,
106
- inception.Mixed_6b,
107
- inception.Mixed_6c,
108
- inception.Mixed_6d,
109
- inception.Mixed_6e,
110
- ]
111
- self.blocks.append(nn.Sequential(*block2))
112
-
113
- # Block 3: aux classifier to final avgpool
114
- if self.last_needed_block >= 3:
115
- block3 = [
116
- inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
117
- nn.AdaptiveAvgPool2d(output_size=(1, 1))
118
- ]
119
- self.blocks.append(nn.Sequential(*block3))
120
-
121
- for param in self.parameters():
122
- param.requires_grad = requires_grad
123
-
124
- def forward(self, x):
125
- """Get Inception feature maps.
126
-
127
- Args:
128
- x (Tensor): Input tensor of shape (b, 3, h, w).
129
- Values are expected to be in range (-1, 1). You can also input
130
- (0, 1) with setting normalize_input = True.
131
-
132
- Returns:
133
- list[Tensor]: Corresponding to the selected output block, sorted
134
- ascending by index.
135
- """
136
- output = []
137
-
138
- if self.resize_input:
139
- x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
140
-
141
- if self.normalize_input:
142
- x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
143
-
144
- for idx, block in enumerate(self.blocks):
145
- x = block(x)
146
- if idx in self.output_blocks:
147
- output.append(x)
148
-
149
- if idx == self.last_needed_block:
150
- break
151
-
152
- return output
153
-
154
-
155
- def fid_inception_v3():
156
- """Build pretrained Inception model for FID computation.
157
-
158
- The Inception model for FID computation uses a different set of weights
159
- and has a slightly different structure than torchvision's Inception.
160
-
161
- This method first constructs torchvision's Inception and then patches the
162
- necessary parts that are different in the FID Inception model.
163
- """
164
- try:
165
- inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False)
166
- except TypeError:
167
- # pytorch < 1.5 does not have init_weights for inception_v3
168
- inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
169
-
170
- inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
171
- inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
172
- inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
173
- inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
174
- inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
175
- inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
176
- inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
177
- inception.Mixed_7b = FIDInceptionE_1(1280)
178
- inception.Mixed_7c = FIDInceptionE_2(2048)
179
-
180
- if os.path.exists(LOCAL_FID_WEIGHTS):
181
- state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage)
182
- else:
183
- state_dict = load_url(FID_WEIGHTS_URL, progress=True)
184
-
185
- inception.load_state_dict(state_dict)
186
- return inception
187
-
188
-
189
- class FIDInceptionA(models.inception.InceptionA):
190
- """InceptionA block patched for FID computation"""
191
-
192
- def __init__(self, in_channels, pool_features):
193
- super(FIDInceptionA, self).__init__(in_channels, pool_features)
194
-
195
- def forward(self, x):
196
- branch1x1 = self.branch1x1(x)
197
-
198
- branch5x5 = self.branch5x5_1(x)
199
- branch5x5 = self.branch5x5_2(branch5x5)
200
-
201
- branch3x3dbl = self.branch3x3dbl_1(x)
202
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
203
- branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
204
-
205
- # Patch: Tensorflow's average pool does not use the padded zero's in
206
- # its average calculation
207
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
208
- branch_pool = self.branch_pool(branch_pool)
209
-
210
- outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
211
- return torch.cat(outputs, 1)
212
-
213
-
214
- class FIDInceptionC(models.inception.InceptionC):
215
- """InceptionC block patched for FID computation"""
216
-
217
- def __init__(self, in_channels, channels_7x7):
218
- super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
219
-
220
- def forward(self, x):
221
- branch1x1 = self.branch1x1(x)
222
-
223
- branch7x7 = self.branch7x7_1(x)
224
- branch7x7 = self.branch7x7_2(branch7x7)
225
- branch7x7 = self.branch7x7_3(branch7x7)
226
-
227
- branch7x7dbl = self.branch7x7dbl_1(x)
228
- branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
229
- branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
230
- branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
231
- branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
232
-
233
- # Patch: Tensorflow's average pool does not use the padded zero's in
234
- # its average calculation
235
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
236
- branch_pool = self.branch_pool(branch_pool)
237
-
238
- outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
239
- return torch.cat(outputs, 1)
240
-
241
-
242
- class FIDInceptionE_1(models.inception.InceptionE):
243
- """First InceptionE block patched for FID computation"""
244
-
245
- def __init__(self, in_channels):
246
- super(FIDInceptionE_1, self).__init__(in_channels)
247
-
248
- def forward(self, x):
249
- branch1x1 = self.branch1x1(x)
250
-
251
- branch3x3 = self.branch3x3_1(x)
252
- branch3x3 = [
253
- self.branch3x3_2a(branch3x3),
254
- self.branch3x3_2b(branch3x3),
255
- ]
256
- branch3x3 = torch.cat(branch3x3, 1)
257
-
258
- branch3x3dbl = self.branch3x3dbl_1(x)
259
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
260
- branch3x3dbl = [
261
- self.branch3x3dbl_3a(branch3x3dbl),
262
- self.branch3x3dbl_3b(branch3x3dbl),
263
- ]
264
- branch3x3dbl = torch.cat(branch3x3dbl, 1)
265
-
266
- # Patch: Tensorflow's average pool does not use the padded zero's in
267
- # its average calculation
268
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
269
- branch_pool = self.branch_pool(branch_pool)
270
-
271
- outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
272
- return torch.cat(outputs, 1)
273
-
274
-
275
- class FIDInceptionE_2(models.inception.InceptionE):
276
- """Second InceptionE block patched for FID computation"""
277
-
278
- def __init__(self, in_channels):
279
- super(FIDInceptionE_2, self).__init__(in_channels)
280
-
281
- def forward(self, x):
282
- branch1x1 = self.branch1x1(x)
283
-
284
- branch3x3 = self.branch3x3_1(x)
285
- branch3x3 = [
286
- self.branch3x3_2a(branch3x3),
287
- self.branch3x3_2b(branch3x3),
288
- ]
289
- branch3x3 = torch.cat(branch3x3, 1)
290
-
291
- branch3x3dbl = self.branch3x3dbl_1(x)
292
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
293
- branch3x3dbl = [
294
- self.branch3x3dbl_3a(branch3x3dbl),
295
- self.branch3x3dbl_3b(branch3x3dbl),
296
- ]
297
- branch3x3dbl = torch.cat(branch3x3dbl, 1)
298
-
299
- # Patch: The FID Inception model uses max pooling instead of average
300
- # pooling. This is likely an error in this specific Inception
301
- # implementation, as other Inception models use average pooling here
302
- # (which matches the description in the paper).
303
- branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
304
- branch_pool = self.branch_pool(branch_pool)
305
-
306
- outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
307
- return torch.cat(outputs, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/rcan_arch.py DELETED
@@ -1,135 +0,0 @@
1
- import torch
2
- from torch import nn as nn
3
-
4
- from basicsr.utils.registry import ARCH_REGISTRY
5
- from .arch_util import Upsample, make_layer
6
-
7
-
8
- class ChannelAttention(nn.Module):
9
- """Channel attention used in RCAN.
10
-
11
- Args:
12
- num_feat (int): Channel number of intermediate features.
13
- squeeze_factor (int): Channel squeeze factor. Default: 16.
14
- """
15
-
16
- def __init__(self, num_feat, squeeze_factor=16):
17
- super(ChannelAttention, self).__init__()
18
- self.attention = nn.Sequential(
19
- nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
20
- nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
21
-
22
- def forward(self, x):
23
- y = self.attention(x)
24
- return x * y
25
-
26
-
27
- class RCAB(nn.Module):
28
- """Residual Channel Attention Block (RCAB) used in RCAN.
29
-
30
- Args:
31
- num_feat (int): Channel number of intermediate features.
32
- squeeze_factor (int): Channel squeeze factor. Default: 16.
33
- res_scale (float): Scale the residual. Default: 1.
34
- """
35
-
36
- def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
37
- super(RCAB, self).__init__()
38
- self.res_scale = res_scale
39
-
40
- self.rcab = nn.Sequential(
41
- nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
42
- ChannelAttention(num_feat, squeeze_factor))
43
-
44
- def forward(self, x):
45
- res = self.rcab(x) * self.res_scale
46
- return res + x
47
-
48
-
49
- class ResidualGroup(nn.Module):
50
- """Residual Group of RCAB.
51
-
52
- Args:
53
- num_feat (int): Channel number of intermediate features.
54
- num_block (int): Block number in the body network.
55
- squeeze_factor (int): Channel squeeze factor. Default: 16.
56
- res_scale (float): Scale the residual. Default: 1.
57
- """
58
-
59
- def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
60
- super(ResidualGroup, self).__init__()
61
-
62
- self.residual_group = make_layer(
63
- RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
64
- self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
65
-
66
- def forward(self, x):
67
- res = self.conv(self.residual_group(x))
68
- return res + x
69
-
70
-
71
- @ARCH_REGISTRY.register()
72
- class RCAN(nn.Module):
73
- """Residual Channel Attention Networks.
74
-
75
- ``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks``
76
-
77
- Reference: https://github.com/yulunzhang/RCAN
78
-
79
- Args:
80
- num_in_ch (int): Channel number of inputs.
81
- num_out_ch (int): Channel number of outputs.
82
- num_feat (int): Channel number of intermediate features.
83
- Default: 64.
84
- num_group (int): Number of ResidualGroup. Default: 10.
85
- num_block (int): Number of RCAB in ResidualGroup. Default: 16.
86
- squeeze_factor (int): Channel squeeze factor. Default: 16.
87
- upscale (int): Upsampling factor. Support 2^n and 3.
88
- Default: 4.
89
- res_scale (float): Used to scale the residual in residual block.
90
- Default: 1.
91
- img_range (float): Image range. Default: 255.
92
- rgb_mean (tuple[float]): Image mean in RGB orders.
93
- Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
94
- """
95
-
96
- def __init__(self,
97
- num_in_ch,
98
- num_out_ch,
99
- num_feat=64,
100
- num_group=10,
101
- num_block=16,
102
- squeeze_factor=16,
103
- upscale=4,
104
- res_scale=1,
105
- img_range=255.,
106
- rgb_mean=(0.4488, 0.4371, 0.4040)):
107
- super(RCAN, self).__init__()
108
-
109
- self.img_range = img_range
110
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
111
-
112
- self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
113
- self.body = make_layer(
114
- ResidualGroup,
115
- num_group,
116
- num_feat=num_feat,
117
- num_block=num_block,
118
- squeeze_factor=squeeze_factor,
119
- res_scale=res_scale)
120
- self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
121
- self.upsample = Upsample(upscale, num_feat)
122
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
123
-
124
- def forward(self, x):
125
- self.mean = self.mean.type_as(x)
126
-
127
- x = (x - self.mean) * self.img_range
128
- x = self.conv_first(x)
129
- res = self.conv_after_body(self.body(x))
130
- res += x
131
-
132
- x = self.conv_last(self.upsample(res))
133
- x = x / self.img_range + self.mean
134
-
135
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/ridnet_arch.py DELETED
@@ -1,180 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from basicsr.utils.registry import ARCH_REGISTRY
5
- from .arch_util import ResidualBlockNoBN, make_layer
6
-
7
-
8
- class MeanShift(nn.Conv2d):
9
- """ Data normalization with mean and std.
10
-
11
- Args:
12
- rgb_range (int): Maximum value of RGB.
13
- rgb_mean (list[float]): Mean for RGB channels.
14
- rgb_std (list[float]): Std for RGB channels.
15
- sign (int): For subtraction, sign is -1, for addition, sign is 1.
16
- Default: -1.
17
- requires_grad (bool): Whether to update the self.weight and self.bias.
18
- Default: True.
19
- """
20
-
21
- def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True):
22
- super(MeanShift, self).__init__(3, 3, kernel_size=1)
23
- std = torch.Tensor(rgb_std)
24
- self.weight.data = torch.eye(3).view(3, 3, 1, 1)
25
- self.weight.data.div_(std.view(3, 1, 1, 1))
26
- self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
27
- self.bias.data.div_(std)
28
- self.requires_grad = requires_grad
29
-
30
-
31
- class EResidualBlockNoBN(nn.Module):
32
- """Enhanced Residual block without BN.
33
-
34
- There are three convolution layers in residual branch.
35
- """
36
-
37
- def __init__(self, in_channels, out_channels):
38
- super(EResidualBlockNoBN, self).__init__()
39
-
40
- self.body = nn.Sequential(
41
- nn.Conv2d(in_channels, out_channels, 3, 1, 1),
42
- nn.ReLU(inplace=True),
43
- nn.Conv2d(out_channels, out_channels, 3, 1, 1),
44
- nn.ReLU(inplace=True),
45
- nn.Conv2d(out_channels, out_channels, 1, 1, 0),
46
- )
47
- self.relu = nn.ReLU(inplace=True)
48
-
49
- def forward(self, x):
50
- out = self.body(x)
51
- out = self.relu(out + x)
52
- return out
53
-
54
-
55
- class MergeRun(nn.Module):
56
- """ Merge-and-run unit.
57
-
58
- This unit contains two branches with different dilated convolutions,
59
- followed by a convolution to process the concatenated features.
60
-
61
- Paper: Real Image Denoising with Feature Attention
62
- Ref git repo: https://github.com/saeed-anwar/RIDNet
63
- """
64
-
65
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
66
- super(MergeRun, self).__init__()
67
-
68
- self.dilation1 = nn.Sequential(
69
- nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
70
- nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
71
- self.dilation2 = nn.Sequential(
72
- nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
73
- nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
74
-
75
- self.aggregation = nn.Sequential(
76
- nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
77
-
78
- def forward(self, x):
79
- dilation1 = self.dilation1(x)
80
- dilation2 = self.dilation2(x)
81
- out = torch.cat([dilation1, dilation2], dim=1)
82
- out = self.aggregation(out)
83
- out = out + x
84
- return out
85
-
86
-
87
- class ChannelAttention(nn.Module):
88
- """Channel attention.
89
-
90
- Args:
91
- num_feat (int): Channel number of intermediate features.
92
- squeeze_factor (int): Channel squeeze factor. Default:
93
- """
94
-
95
- def __init__(self, mid_channels, squeeze_factor=16):
96
- super(ChannelAttention, self).__init__()
97
- self.attention = nn.Sequential(
98
- nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
99
- nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
100
-
101
- def forward(self, x):
102
- y = self.attention(x)
103
- return x * y
104
-
105
-
106
- class EAM(nn.Module):
107
- """Enhancement attention modules (EAM) in RIDNet.
108
-
109
- This module contains a merge-and-run unit, a residual block,
110
- an enhanced residual block and a feature attention unit.
111
-
112
- Attributes:
113
- merge: The merge-and-run unit.
114
- block1: The residual block.
115
- block2: The enhanced residual block.
116
- ca: The feature/channel attention unit.
117
- """
118
-
119
- def __init__(self, in_channels, mid_channels, out_channels):
120
- super(EAM, self).__init__()
121
-
122
- self.merge = MergeRun(in_channels, mid_channels)
123
- self.block1 = ResidualBlockNoBN(mid_channels)
124
- self.block2 = EResidualBlockNoBN(mid_channels, out_channels)
125
- self.ca = ChannelAttention(out_channels)
126
- # The residual block in the paper contains a relu after addition.
127
- self.relu = nn.ReLU(inplace=True)
128
-
129
- def forward(self, x):
130
- out = self.merge(x)
131
- out = self.relu(self.block1(out))
132
- out = self.block2(out)
133
- out = self.ca(out)
134
- return out
135
-
136
-
137
- @ARCH_REGISTRY.register()
138
- class RIDNet(nn.Module):
139
- """RIDNet: Real Image Denoising with Feature Attention.
140
-
141
- Ref git repo: https://github.com/saeed-anwar/RIDNet
142
-
143
- Args:
144
- in_channels (int): Channel number of inputs.
145
- mid_channels (int): Channel number of EAM modules.
146
- Default: 64.
147
- out_channels (int): Channel number of outputs.
148
- num_block (int): Number of EAM. Default: 4.
149
- img_range (float): Image range. Default: 255.
150
- rgb_mean (tuple[float]): Image mean in RGB orders.
151
- Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
152
- """
153
-
154
- def __init__(self,
155
- in_channels,
156
- mid_channels,
157
- out_channels,
158
- num_block=4,
159
- img_range=255.,
160
- rgb_mean=(0.4488, 0.4371, 0.4040),
161
- rgb_std=(1.0, 1.0, 1.0)):
162
- super(RIDNet, self).__init__()
163
-
164
- self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
165
- self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)
166
-
167
- self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
168
- self.body = make_layer(
169
- EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
170
- self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
171
-
172
- self.relu = nn.ReLU(inplace=True)
173
-
174
- def forward(self, x):
175
- res = self.sub_mean(x)
176
- res = self.tail(self.body(self.relu(self.head(res))))
177
- res = self.add_mean(res)
178
-
179
- out = x + res
180
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/rrdbnet_arch.py DELETED
@@ -1,119 +0,0 @@
1
- import torch
2
- from torch import nn as nn
3
- from torch.nn import functional as F
4
-
5
- from basicsr.utils.registry import ARCH_REGISTRY
6
- from .arch_util import default_init_weights, make_layer, pixel_unshuffle
7
-
8
-
9
- class ResidualDenseBlock(nn.Module):
10
- """Residual Dense Block.
11
-
12
- Used in RRDB block in ESRGAN.
13
-
14
- Args:
15
- num_feat (int): Channel number of intermediate features.
16
- num_grow_ch (int): Channels for each growth.
17
- """
18
-
19
- def __init__(self, num_feat=64, num_grow_ch=32):
20
- super(ResidualDenseBlock, self).__init__()
21
- self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
22
- self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
23
- self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
- self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
25
- self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
26
-
27
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
28
-
29
- # initialization
30
- default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
31
-
32
- def forward(self, x):
33
- x1 = self.lrelu(self.conv1(x))
34
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
35
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
36
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
37
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
38
- # Empirically, we use 0.2 to scale the residual for better performance
39
- return x5 * 0.2 + x
40
-
41
-
42
- class RRDB(nn.Module):
43
- """Residual in Residual Dense Block.
44
-
45
- Used in RRDB-Net in ESRGAN.
46
-
47
- Args:
48
- num_feat (int): Channel number of intermediate features.
49
- num_grow_ch (int): Channels for each growth.
50
- """
51
-
52
- def __init__(self, num_feat, num_grow_ch=32):
53
- super(RRDB, self).__init__()
54
- self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
55
- self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
56
- self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
57
-
58
- def forward(self, x):
59
- out = self.rdb1(x)
60
- out = self.rdb2(out)
61
- out = self.rdb3(out)
62
- # Empirically, we use 0.2 to scale the residual for better performance
63
- return out * 0.2 + x
64
-
65
-
66
- @ARCH_REGISTRY.register()
67
- class RRDBNet(nn.Module):
68
- """Networks consisting of Residual in Residual Dense Block, which is used
69
- in ESRGAN.
70
-
71
- ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
72
-
73
- We extend ESRGAN for scale x2 and scale x1.
74
- Note: This is one option for scale 1, scale 2 in RRDBNet.
75
- We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
76
- and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
77
-
78
- Args:
79
- num_in_ch (int): Channel number of inputs.
80
- num_out_ch (int): Channel number of outputs.
81
- num_feat (int): Channel number of intermediate features.
82
- Default: 64
83
- num_block (int): Block number in the trunk network. Defaults: 23
84
- num_grow_ch (int): Channels for each growth. Default: 32.
85
- """
86
-
87
- def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
88
- super(RRDBNet, self).__init__()
89
- self.scale = scale
90
- if scale == 2:
91
- num_in_ch = num_in_ch * 4
92
- elif scale == 1:
93
- num_in_ch = num_in_ch * 16
94
- self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
95
- self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
96
- self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
- # upsample
98
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
99
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
-
103
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
-
105
- def forward(self, x):
106
- if self.scale == 2:
107
- feat = pixel_unshuffle(x, scale=2)
108
- elif self.scale == 1:
109
- feat = pixel_unshuffle(x, scale=4)
110
- else:
111
- feat = x
112
- feat = self.conv_first(feat)
113
- body_feat = self.conv_body(self.body(feat))
114
- feat = feat + body_feat
115
- # upsample
116
- feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
- feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
- out = self.conv_last(self.lrelu(self.conv_hr(feat)))
119
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/spynet_arch.py DELETED
@@ -1,96 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn as nn
4
- from torch.nn import functional as F
5
-
6
- from basicsr.utils.registry import ARCH_REGISTRY
7
- from .arch_util import flow_warp
8
-
9
-
10
- class BasicModule(nn.Module):
11
- """Basic Module for SpyNet.
12
- """
13
-
14
- def __init__(self):
15
- super(BasicModule, self).__init__()
16
-
17
- self.basic_module = nn.Sequential(
18
- nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
19
- nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
20
- nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
21
- nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
22
- nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
23
-
24
- def forward(self, tensor_input):
25
- return self.basic_module(tensor_input)
26
-
27
-
28
- @ARCH_REGISTRY.register()
29
- class SpyNet(nn.Module):
30
- """SpyNet architecture.
31
-
32
- Args:
33
- load_path (str): path for pretrained SpyNet. Default: None.
34
- """
35
-
36
- def __init__(self, load_path=None):
37
- super(SpyNet, self).__init__()
38
- self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
39
- if load_path:
40
- self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
41
-
42
- self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
43
- self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
44
-
45
- def preprocess(self, tensor_input):
46
- tensor_output = (tensor_input - self.mean) / self.std
47
- return tensor_output
48
-
49
- def process(self, ref, supp):
50
- flow = []
51
-
52
- ref = [self.preprocess(ref)]
53
- supp = [self.preprocess(supp)]
54
-
55
- for level in range(5):
56
- ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
57
- supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
58
-
59
- flow = ref[0].new_zeros(
60
- [ref[0].size(0), 2,
61
- int(math.floor(ref[0].size(2) / 2.0)),
62
- int(math.floor(ref[0].size(3) / 2.0))])
63
-
64
- for level in range(len(ref)):
65
- upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
66
-
67
- if upsampled_flow.size(2) != ref[level].size(2):
68
- upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
69
- if upsampled_flow.size(3) != ref[level].size(3):
70
- upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
71
-
72
- flow = self.basic_module[level](torch.cat([
73
- ref[level],
74
- flow_warp(
75
- supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
76
- upsampled_flow
77
- ], 1)) + upsampled_flow
78
-
79
- return flow
80
-
81
- def forward(self, ref, supp):
82
- assert ref.size() == supp.size()
83
-
84
- h, w = ref.size(2), ref.size(3)
85
- w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
86
- h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
87
-
88
- ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
89
- supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
90
-
91
- flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
92
-
93
- flow[:, 0, :, :] *= float(w) / float(w_floor)
94
- flow[:, 1, :, :] *= float(h) / float(h_floor)
95
-
96
- return flow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/srresnet_arch.py DELETED
@@ -1,65 +0,0 @@
1
- from torch import nn as nn
2
- from torch.nn import functional as F
3
-
4
- from basicsr.utils.registry import ARCH_REGISTRY
5
- from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
6
-
7
-
8
- @ARCH_REGISTRY.register()
9
- class MSRResNet(nn.Module):
10
- """Modified SRResNet.
11
-
12
- A compacted version modified from SRResNet in
13
- "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
14
- It uses residual blocks without BN, similar to EDSR.
15
- Currently, it supports x2, x3 and x4 upsampling scale factor.
16
-
17
- Args:
18
- num_in_ch (int): Channel number of inputs. Default: 3.
19
- num_out_ch (int): Channel number of outputs. Default: 3.
20
- num_feat (int): Channel number of intermediate features. Default: 64.
21
- num_block (int): Block number in the body network. Default: 16.
22
- upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
23
- """
24
-
25
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
26
- super(MSRResNet, self).__init__()
27
- self.upscale = upscale
28
-
29
- self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
30
- self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
31
-
32
- # upsampling
33
- if self.upscale in [2, 3]:
34
- self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
35
- self.pixel_shuffle = nn.PixelShuffle(self.upscale)
36
- elif self.upscale == 4:
37
- self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
38
- self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
39
- self.pixel_shuffle = nn.PixelShuffle(2)
40
-
41
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
42
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
43
-
44
- # activation function
45
- self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
46
-
47
- # initialization
48
- default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
49
- if self.upscale == 4:
50
- default_init_weights(self.upconv2, 0.1)
51
-
52
- def forward(self, x):
53
- feat = self.lrelu(self.conv_first(x))
54
- out = self.body(feat)
55
-
56
- if self.upscale == 4:
57
- out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
58
- out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
59
- elif self.upscale in [2, 3]:
60
- out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
61
-
62
- out = self.conv_last(self.lrelu(self.conv_hr(out)))
63
- base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
64
- out += base
65
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/srvgg_arch.py DELETED
@@ -1,70 +0,0 @@
1
- from torch import nn as nn
2
- from torch.nn import functional as F
3
-
4
- from basicsr.utils.registry import ARCH_REGISTRY
5
-
6
-
7
- @ARCH_REGISTRY.register(suffix='basicsr')
8
- class SRVGGNetCompact(nn.Module):
9
- """A compact VGG-style network structure for super-resolution.
10
-
11
- It is a compact network structure, which performs upsampling in the last layer and no convolution is
12
- conducted on the HR feature space.
13
-
14
- Args:
15
- num_in_ch (int): Channel number of inputs. Default: 3.
16
- num_out_ch (int): Channel number of outputs. Default: 3.
17
- num_feat (int): Channel number of intermediate features. Default: 64.
18
- num_conv (int): Number of convolution layers in the body network. Default: 16.
19
- upscale (int): Upsampling factor. Default: 4.
20
- act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
21
- """
22
-
23
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
24
- super(SRVGGNetCompact, self).__init__()
25
- self.num_in_ch = num_in_ch
26
- self.num_out_ch = num_out_ch
27
- self.num_feat = num_feat
28
- self.num_conv = num_conv
29
- self.upscale = upscale
30
- self.act_type = act_type
31
-
32
- self.body = nn.ModuleList()
33
- # the first conv
34
- self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
35
- # the first activation
36
- if act_type == 'relu':
37
- activation = nn.ReLU(inplace=True)
38
- elif act_type == 'prelu':
39
- activation = nn.PReLU(num_parameters=num_feat)
40
- elif act_type == 'leakyrelu':
41
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
42
- self.body.append(activation)
43
-
44
- # the body structure
45
- for _ in range(num_conv):
46
- self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
47
- # activation
48
- if act_type == 'relu':
49
- activation = nn.ReLU(inplace=True)
50
- elif act_type == 'prelu':
51
- activation = nn.PReLU(num_parameters=num_feat)
52
- elif act_type == 'leakyrelu':
53
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
54
- self.body.append(activation)
55
-
56
- # the last conv
57
- self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
58
- # upsample
59
- self.upsampler = nn.PixelShuffle(upscale)
60
-
61
- def forward(self, x):
62
- out = x
63
- for i in range(0, len(self.body)):
64
- out = self.body[i](out)
65
-
66
- out = self.upsampler(out)
67
- # add the nearest upsampled image, so that the network learns the residual
68
- base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
69
- out += base
70
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/stylegan2_arch.py DELETED
@@ -1,799 +0,0 @@
1
- import math
2
- import random
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
- from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
8
- from basicsr.ops.upfirdn2d import upfirdn2d
9
- from basicsr.utils.registry import ARCH_REGISTRY
10
-
11
-
12
- class NormStyleCode(nn.Module):
13
-
14
- def forward(self, x):
15
- """Normalize the style codes.
16
-
17
- Args:
18
- x (Tensor): Style codes with shape (b, c).
19
-
20
- Returns:
21
- Tensor: Normalized tensor.
22
- """
23
- return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
24
-
25
-
26
- def make_resample_kernel(k):
27
- """Make resampling kernel for UpFirDn.
28
-
29
- Args:
30
- k (list[int]): A list indicating the 1D resample kernel magnitude.
31
-
32
- Returns:
33
- Tensor: 2D resampled kernel.
34
- """
35
- k = torch.tensor(k, dtype=torch.float32)
36
- if k.ndim == 1:
37
- k = k[None, :] * k[:, None] # to 2D kernel, outer product
38
- # normalize
39
- k /= k.sum()
40
- return k
41
-
42
-
43
- class UpFirDnUpsample(nn.Module):
44
- """Upsample, FIR filter, and downsample (upsampole version).
45
-
46
- References:
47
- 1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
48
- 2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
49
-
50
- Args:
51
- resample_kernel (list[int]): A list indicating the 1D resample kernel
52
- magnitude.
53
- factor (int): Upsampling scale factor. Default: 2.
54
- """
55
-
56
- def __init__(self, resample_kernel, factor=2):
57
- super(UpFirDnUpsample, self).__init__()
58
- self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
59
- self.factor = factor
60
-
61
- pad = self.kernel.shape[0] - factor
62
- self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
63
-
64
- def forward(self, x):
65
- out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
66
- return out
67
-
68
- def __repr__(self):
69
- return (f'{self.__class__.__name__}(factor={self.factor})')
70
-
71
-
72
- class UpFirDnDownsample(nn.Module):
73
- """Upsample, FIR filter, and downsample (downsampole version).
74
-
75
- Args:
76
- resample_kernel (list[int]): A list indicating the 1D resample kernel
77
- magnitude.
78
- factor (int): Downsampling scale factor. Default: 2.
79
- """
80
-
81
- def __init__(self, resample_kernel, factor=2):
82
- super(UpFirDnDownsample, self).__init__()
83
- self.kernel = make_resample_kernel(resample_kernel)
84
- self.factor = factor
85
-
86
- pad = self.kernel.shape[0] - factor
87
- self.pad = ((pad + 1) // 2, pad // 2)
88
-
89
- def forward(self, x):
90
- out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
91
- return out
92
-
93
- def __repr__(self):
94
- return (f'{self.__class__.__name__}(factor={self.factor})')
95
-
96
-
97
- class UpFirDnSmooth(nn.Module):
98
- """Upsample, FIR filter, and downsample (smooth version).
99
-
100
- Args:
101
- resample_kernel (list[int]): A list indicating the 1D resample kernel
102
- magnitude.
103
- upsample_factor (int): Upsampling scale factor. Default: 1.
104
- downsample_factor (int): Downsampling scale factor. Default: 1.
105
- kernel_size (int): Kernel size: Default: 1.
106
- """
107
-
108
- def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1):
109
- super(UpFirDnSmooth, self).__init__()
110
- self.upsample_factor = upsample_factor
111
- self.downsample_factor = downsample_factor
112
- self.kernel = make_resample_kernel(resample_kernel)
113
- if upsample_factor > 1:
114
- self.kernel = self.kernel * (upsample_factor**2)
115
-
116
- if upsample_factor > 1:
117
- pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
118
- self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
119
- elif downsample_factor > 1:
120
- pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
121
- self.pad = ((pad + 1) // 2, pad // 2)
122
- else:
123
- raise NotImplementedError
124
-
125
- def forward(self, x):
126
- out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
127
- return out
128
-
129
- def __repr__(self):
130
- return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}'
131
- f', downsample_factor={self.downsample_factor})')
132
-
133
-
134
- class EqualLinear(nn.Module):
135
- """Equalized Linear as StyleGAN2.
136
-
137
- Args:
138
- in_channels (int): Size of each sample.
139
- out_channels (int): Size of each output sample.
140
- bias (bool): If set to ``False``, the layer will not learn an additive
141
- bias. Default: ``True``.
142
- bias_init_val (float): Bias initialized value. Default: 0.
143
- lr_mul (float): Learning rate multiplier. Default: 1.
144
- activation (None | str): The activation after ``linear`` operation.
145
- Supported: 'fused_lrelu', None. Default: None.
146
- """
147
-
148
- def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
149
- super(EqualLinear, self).__init__()
150
- self.in_channels = in_channels
151
- self.out_channels = out_channels
152
- self.lr_mul = lr_mul
153
- self.activation = activation
154
- if self.activation not in ['fused_lrelu', None]:
155
- raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
156
- "Supported ones are: ['fused_lrelu', None].")
157
- self.scale = (1 / math.sqrt(in_channels)) * lr_mul
158
-
159
- self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
160
- if bias:
161
- self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
162
- else:
163
- self.register_parameter('bias', None)
164
-
165
- def forward(self, x):
166
- if self.bias is None:
167
- bias = None
168
- else:
169
- bias = self.bias * self.lr_mul
170
- if self.activation == 'fused_lrelu':
171
- out = F.linear(x, self.weight * self.scale)
172
- out = fused_leaky_relu(out, bias)
173
- else:
174
- out = F.linear(x, self.weight * self.scale, bias=bias)
175
- return out
176
-
177
- def __repr__(self):
178
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
179
- f'out_channels={self.out_channels}, bias={self.bias is not None})')
180
-
181
-
182
- class ModulatedConv2d(nn.Module):
183
- """Modulated Conv2d used in StyleGAN2.
184
-
185
- There is no bias in ModulatedConv2d.
186
-
187
- Args:
188
- in_channels (int): Channel number of the input.
189
- out_channels (int): Channel number of the output.
190
- kernel_size (int): Size of the convolving kernel.
191
- num_style_feat (int): Channel number of style features.
192
- demodulate (bool): Whether to demodulate in the conv layer.
193
- Default: True.
194
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
195
- Default: None.
196
- resample_kernel (list[int]): A list indicating the 1D resample kernel
197
- magnitude. Default: (1, 3, 3, 1).
198
- eps (float): A value added to the denominator for numerical stability.
199
- Default: 1e-8.
200
- """
201
-
202
- def __init__(self,
203
- in_channels,
204
- out_channels,
205
- kernel_size,
206
- num_style_feat,
207
- demodulate=True,
208
- sample_mode=None,
209
- resample_kernel=(1, 3, 3, 1),
210
- eps=1e-8):
211
- super(ModulatedConv2d, self).__init__()
212
- self.in_channels = in_channels
213
- self.out_channels = out_channels
214
- self.kernel_size = kernel_size
215
- self.demodulate = demodulate
216
- self.sample_mode = sample_mode
217
- self.eps = eps
218
-
219
- if self.sample_mode == 'upsample':
220
- self.smooth = UpFirDnSmooth(
221
- resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size)
222
- elif self.sample_mode == 'downsample':
223
- self.smooth = UpFirDnSmooth(
224
- resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)
225
- elif self.sample_mode is None:
226
- pass
227
- else:
228
- raise ValueError(f'Wrong sample mode {self.sample_mode}, '
229
- "supported ones are ['upsample', 'downsample', None].")
230
-
231
- self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
232
- # modulation inside each modulated conv
233
- self.modulation = EqualLinear(
234
- num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
235
-
236
- self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
237
- self.padding = kernel_size // 2
238
-
239
- def forward(self, x, style):
240
- """Forward function.
241
-
242
- Args:
243
- x (Tensor): Tensor with shape (b, c, h, w).
244
- style (Tensor): Tensor with shape (b, num_style_feat).
245
-
246
- Returns:
247
- Tensor: Modulated tensor after convolution.
248
- """
249
- b, c, h, w = x.shape # c = c_in
250
- # weight modulation
251
- style = self.modulation(style).view(b, 1, c, 1, 1)
252
- # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
253
- weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
254
-
255
- if self.demodulate:
256
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
257
- weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
258
-
259
- weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
260
-
261
- if self.sample_mode == 'upsample':
262
- x = x.view(1, b * c, h, w)
263
- weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size)
264
- weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size)
265
- out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
266
- out = out.view(b, self.out_channels, *out.shape[2:4])
267
- out = self.smooth(out)
268
- elif self.sample_mode == 'downsample':
269
- x = self.smooth(x)
270
- x = x.view(1, b * c, *x.shape[2:4])
271
- out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
272
- out = out.view(b, self.out_channels, *out.shape[2:4])
273
- else:
274
- x = x.view(1, b * c, h, w)
275
- # weight: (b*c_out, c_in, k, k), groups=b
276
- out = F.conv2d(x, weight, padding=self.padding, groups=b)
277
- out = out.view(b, self.out_channels, *out.shape[2:4])
278
-
279
- return out
280
-
281
- def __repr__(self):
282
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
283
- f'out_channels={self.out_channels}, '
284
- f'kernel_size={self.kernel_size}, '
285
- f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
286
-
287
-
288
- class StyleConv(nn.Module):
289
- """Style conv.
290
-
291
- Args:
292
- in_channels (int): Channel number of the input.
293
- out_channels (int): Channel number of the output.
294
- kernel_size (int): Size of the convolving kernel.
295
- num_style_feat (int): Channel number of style features.
296
- demodulate (bool): Whether demodulate in the conv layer. Default: True.
297
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
298
- Default: None.
299
- resample_kernel (list[int]): A list indicating the 1D resample kernel
300
- magnitude. Default: (1, 3, 3, 1).
301
- """
302
-
303
- def __init__(self,
304
- in_channels,
305
- out_channels,
306
- kernel_size,
307
- num_style_feat,
308
- demodulate=True,
309
- sample_mode=None,
310
- resample_kernel=(1, 3, 3, 1)):
311
- super(StyleConv, self).__init__()
312
- self.modulated_conv = ModulatedConv2d(
313
- in_channels,
314
- out_channels,
315
- kernel_size,
316
- num_style_feat,
317
- demodulate=demodulate,
318
- sample_mode=sample_mode,
319
- resample_kernel=resample_kernel)
320
- self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
321
- self.activate = FusedLeakyReLU(out_channels)
322
-
323
- def forward(self, x, style, noise=None):
324
- # modulate
325
- out = self.modulated_conv(x, style)
326
- # noise injection
327
- if noise is None:
328
- b, _, h, w = out.shape
329
- noise = out.new_empty(b, 1, h, w).normal_()
330
- out = out + self.weight * noise
331
- # activation (with bias)
332
- out = self.activate(out)
333
- return out
334
-
335
-
336
- class ToRGB(nn.Module):
337
- """To RGB from features.
338
-
339
- Args:
340
- in_channels (int): Channel number of input.
341
- num_style_feat (int): Channel number of style features.
342
- upsample (bool): Whether to upsample. Default: True.
343
- resample_kernel (list[int]): A list indicating the 1D resample kernel
344
- magnitude. Default: (1, 3, 3, 1).
345
- """
346
-
347
- def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)):
348
- super(ToRGB, self).__init__()
349
- if upsample:
350
- self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
351
- else:
352
- self.upsample = None
353
- self.modulated_conv = ModulatedConv2d(
354
- in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
355
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
356
-
357
- def forward(self, x, style, skip=None):
358
- """Forward function.
359
-
360
- Args:
361
- x (Tensor): Feature tensor with shape (b, c, h, w).
362
- style (Tensor): Tensor with shape (b, num_style_feat).
363
- skip (Tensor): Base/skip tensor. Default: None.
364
-
365
- Returns:
366
- Tensor: RGB images.
367
- """
368
- out = self.modulated_conv(x, style)
369
- out = out + self.bias
370
- if skip is not None:
371
- if self.upsample:
372
- skip = self.upsample(skip)
373
- out = out + skip
374
- return out
375
-
376
-
377
- class ConstantInput(nn.Module):
378
- """Constant input.
379
-
380
- Args:
381
- num_channel (int): Channel number of constant input.
382
- size (int): Spatial size of constant input.
383
- """
384
-
385
- def __init__(self, num_channel, size):
386
- super(ConstantInput, self).__init__()
387
- self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
388
-
389
- def forward(self, batch):
390
- out = self.weight.repeat(batch, 1, 1, 1)
391
- return out
392
-
393
-
394
- @ARCH_REGISTRY.register()
395
- class StyleGAN2Generator(nn.Module):
396
- """StyleGAN2 Generator.
397
-
398
- Args:
399
- out_size (int): The spatial size of outputs.
400
- num_style_feat (int): Channel number of style features. Default: 512.
401
- num_mlp (int): Layer number of MLP style layers. Default: 8.
402
- channel_multiplier (int): Channel multiplier for large networks of
403
- StyleGAN2. Default: 2.
404
- resample_kernel (list[int]): A list indicating the 1D resample kernel
405
- magnitude. A cross production will be applied to extent 1D resample
406
- kernel to 2D resample kernel. Default: (1, 3, 3, 1).
407
- lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
408
- narrow (float): Narrow ratio for channels. Default: 1.0.
409
- """
410
-
411
- def __init__(self,
412
- out_size,
413
- num_style_feat=512,
414
- num_mlp=8,
415
- channel_multiplier=2,
416
- resample_kernel=(1, 3, 3, 1),
417
- lr_mlp=0.01,
418
- narrow=1):
419
- super(StyleGAN2Generator, self).__init__()
420
- # Style MLP layers
421
- self.num_style_feat = num_style_feat
422
- style_mlp_layers = [NormStyleCode()]
423
- for i in range(num_mlp):
424
- style_mlp_layers.append(
425
- EqualLinear(
426
- num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
427
- activation='fused_lrelu'))
428
- self.style_mlp = nn.Sequential(*style_mlp_layers)
429
-
430
- channels = {
431
- '4': int(512 * narrow),
432
- '8': int(512 * narrow),
433
- '16': int(512 * narrow),
434
- '32': int(512 * narrow),
435
- '64': int(256 * channel_multiplier * narrow),
436
- '128': int(128 * channel_multiplier * narrow),
437
- '256': int(64 * channel_multiplier * narrow),
438
- '512': int(32 * channel_multiplier * narrow),
439
- '1024': int(16 * channel_multiplier * narrow)
440
- }
441
- self.channels = channels
442
-
443
- self.constant_input = ConstantInput(channels['4'], size=4)
444
- self.style_conv1 = StyleConv(
445
- channels['4'],
446
- channels['4'],
447
- kernel_size=3,
448
- num_style_feat=num_style_feat,
449
- demodulate=True,
450
- sample_mode=None,
451
- resample_kernel=resample_kernel)
452
- self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel)
453
-
454
- self.log_size = int(math.log(out_size, 2))
455
- self.num_layers = (self.log_size - 2) * 2 + 1
456
- self.num_latent = self.log_size * 2 - 2
457
-
458
- self.style_convs = nn.ModuleList()
459
- self.to_rgbs = nn.ModuleList()
460
- self.noises = nn.Module()
461
-
462
- in_channels = channels['4']
463
- # noise
464
- for layer_idx in range(self.num_layers):
465
- resolution = 2**((layer_idx + 5) // 2)
466
- shape = [1, 1, resolution, resolution]
467
- self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
468
- # style convs and to_rgbs
469
- for i in range(3, self.log_size + 1):
470
- out_channels = channels[f'{2**i}']
471
- self.style_convs.append(
472
- StyleConv(
473
- in_channels,
474
- out_channels,
475
- kernel_size=3,
476
- num_style_feat=num_style_feat,
477
- demodulate=True,
478
- sample_mode='upsample',
479
- resample_kernel=resample_kernel,
480
- ))
481
- self.style_convs.append(
482
- StyleConv(
483
- out_channels,
484
- out_channels,
485
- kernel_size=3,
486
- num_style_feat=num_style_feat,
487
- demodulate=True,
488
- sample_mode=None,
489
- resample_kernel=resample_kernel))
490
- self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel))
491
- in_channels = out_channels
492
-
493
- def make_noise(self):
494
- """Make noise for noise injection."""
495
- device = self.constant_input.weight.device
496
- noises = [torch.randn(1, 1, 4, 4, device=device)]
497
-
498
- for i in range(3, self.log_size + 1):
499
- for _ in range(2):
500
- noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
501
-
502
- return noises
503
-
504
- def get_latent(self, x):
505
- return self.style_mlp(x)
506
-
507
- def mean_latent(self, num_latent):
508
- latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
509
- latent = self.style_mlp(latent_in).mean(0, keepdim=True)
510
- return latent
511
-
512
- def forward(self,
513
- styles,
514
- input_is_latent=False,
515
- noise=None,
516
- randomize_noise=True,
517
- truncation=1,
518
- truncation_latent=None,
519
- inject_index=None,
520
- return_latents=False):
521
- """Forward function for StyleGAN2Generator.
522
-
523
- Args:
524
- styles (list[Tensor]): Sample codes of styles.
525
- input_is_latent (bool): Whether input is latent style.
526
- Default: False.
527
- noise (Tensor | None): Input noise or None. Default: None.
528
- randomize_noise (bool): Randomize noise, used when 'noise' is
529
- False. Default: True.
530
- truncation (float): TODO. Default: 1.
531
- truncation_latent (Tensor | None): TODO. Default: None.
532
- inject_index (int | None): The injection index for mixing noise.
533
- Default: None.
534
- return_latents (bool): Whether to return style latents.
535
- Default: False.
536
- """
537
- # style codes -> latents with Style MLP layer
538
- if not input_is_latent:
539
- styles = [self.style_mlp(s) for s in styles]
540
- # noises
541
- if noise is None:
542
- if randomize_noise:
543
- noise = [None] * self.num_layers # for each style conv layer
544
- else: # use the stored noise
545
- noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
546
- # style truncation
547
- if truncation < 1:
548
- style_truncation = []
549
- for style in styles:
550
- style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
551
- styles = style_truncation
552
- # get style latent with injection
553
- if len(styles) == 1:
554
- inject_index = self.num_latent
555
-
556
- if styles[0].ndim < 3:
557
- # repeat latent code for all the layers
558
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
559
- else: # used for encoder with different latent code for each layer
560
- latent = styles[0]
561
- elif len(styles) == 2: # mixing noises
562
- if inject_index is None:
563
- inject_index = random.randint(1, self.num_latent - 1)
564
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
565
- latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
566
- latent = torch.cat([latent1, latent2], 1)
567
-
568
- # main generation
569
- out = self.constant_input(latent.shape[0])
570
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
571
- skip = self.to_rgb1(out, latent[:, 1])
572
-
573
- i = 1
574
- for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
575
- noise[2::2], self.to_rgbs):
576
- out = conv1(out, latent[:, i], noise=noise1)
577
- out = conv2(out, latent[:, i + 1], noise=noise2)
578
- skip = to_rgb(out, latent[:, i + 2], skip)
579
- i += 2
580
-
581
- image = skip
582
-
583
- if return_latents:
584
- return image, latent
585
- else:
586
- return image, None
587
-
588
-
589
- class ScaledLeakyReLU(nn.Module):
590
- """Scaled LeakyReLU.
591
-
592
- Args:
593
- negative_slope (float): Negative slope. Default: 0.2.
594
- """
595
-
596
- def __init__(self, negative_slope=0.2):
597
- super(ScaledLeakyReLU, self).__init__()
598
- self.negative_slope = negative_slope
599
-
600
- def forward(self, x):
601
- out = F.leaky_relu(x, negative_slope=self.negative_slope)
602
- return out * math.sqrt(2)
603
-
604
-
605
- class EqualConv2d(nn.Module):
606
- """Equalized Linear as StyleGAN2.
607
-
608
- Args:
609
- in_channels (int): Channel number of the input.
610
- out_channels (int): Channel number of the output.
611
- kernel_size (int): Size of the convolving kernel.
612
- stride (int): Stride of the convolution. Default: 1
613
- padding (int): Zero-padding added to both sides of the input.
614
- Default: 0.
615
- bias (bool): If ``True``, adds a learnable bias to the output.
616
- Default: ``True``.
617
- bias_init_val (float): Bias initialized value. Default: 0.
618
- """
619
-
620
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
621
- super(EqualConv2d, self).__init__()
622
- self.in_channels = in_channels
623
- self.out_channels = out_channels
624
- self.kernel_size = kernel_size
625
- self.stride = stride
626
- self.padding = padding
627
- self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
628
-
629
- self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
630
- if bias:
631
- self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
632
- else:
633
- self.register_parameter('bias', None)
634
-
635
- def forward(self, x):
636
- out = F.conv2d(
637
- x,
638
- self.weight * self.scale,
639
- bias=self.bias,
640
- stride=self.stride,
641
- padding=self.padding,
642
- )
643
-
644
- return out
645
-
646
- def __repr__(self):
647
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
648
- f'out_channels={self.out_channels}, '
649
- f'kernel_size={self.kernel_size},'
650
- f' stride={self.stride}, padding={self.padding}, '
651
- f'bias={self.bias is not None})')
652
-
653
-
654
- class ConvLayer(nn.Sequential):
655
- """Conv Layer used in StyleGAN2 Discriminator.
656
-
657
- Args:
658
- in_channels (int): Channel number of the input.
659
- out_channels (int): Channel number of the output.
660
- kernel_size (int): Kernel size.
661
- downsample (bool): Whether downsample by a factor of 2.
662
- Default: False.
663
- resample_kernel (list[int]): A list indicating the 1D resample
664
- kernel magnitude. A cross production will be applied to
665
- extent 1D resample kernel to 2D resample kernel.
666
- Default: (1, 3, 3, 1).
667
- bias (bool): Whether with bias. Default: True.
668
- activate (bool): Whether use activateion. Default: True.
669
- """
670
-
671
- def __init__(self,
672
- in_channels,
673
- out_channels,
674
- kernel_size,
675
- downsample=False,
676
- resample_kernel=(1, 3, 3, 1),
677
- bias=True,
678
- activate=True):
679
- layers = []
680
- # downsample
681
- if downsample:
682
- layers.append(
683
- UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size))
684
- stride = 2
685
- self.padding = 0
686
- else:
687
- stride = 1
688
- self.padding = kernel_size // 2
689
- # conv
690
- layers.append(
691
- EqualConv2d(
692
- in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
693
- and not activate))
694
- # activation
695
- if activate:
696
- if bias:
697
- layers.append(FusedLeakyReLU(out_channels))
698
- else:
699
- layers.append(ScaledLeakyReLU(0.2))
700
-
701
- super(ConvLayer, self).__init__(*layers)
702
-
703
-
704
- class ResBlock(nn.Module):
705
- """Residual block used in StyleGAN2 Discriminator.
706
-
707
- Args:
708
- in_channels (int): Channel number of the input.
709
- out_channels (int): Channel number of the output.
710
- resample_kernel (list[int]): A list indicating the 1D resample
711
- kernel magnitude. A cross production will be applied to
712
- extent 1D resample kernel to 2D resample kernel.
713
- Default: (1, 3, 3, 1).
714
- """
715
-
716
- def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
717
- super(ResBlock, self).__init__()
718
-
719
- self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
720
- self.conv2 = ConvLayer(
721
- in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True)
722
- self.skip = ConvLayer(
723
- in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False)
724
-
725
- def forward(self, x):
726
- out = self.conv1(x)
727
- out = self.conv2(out)
728
- skip = self.skip(x)
729
- out = (out + skip) / math.sqrt(2)
730
- return out
731
-
732
-
733
- @ARCH_REGISTRY.register()
734
- class StyleGAN2Discriminator(nn.Module):
735
- """StyleGAN2 Discriminator.
736
-
737
- Args:
738
- out_size (int): The spatial size of outputs.
739
- channel_multiplier (int): Channel multiplier for large networks of
740
- StyleGAN2. Default: 2.
741
- resample_kernel (list[int]): A list indicating the 1D resample kernel
742
- magnitude. A cross production will be applied to extent 1D resample
743
- kernel to 2D resample kernel. Default: (1, 3, 3, 1).
744
- stddev_group (int): For group stddev statistics. Default: 4.
745
- narrow (float): Narrow ratio for channels. Default: 1.0.
746
- """
747
-
748
- def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), stddev_group=4, narrow=1):
749
- super(StyleGAN2Discriminator, self).__init__()
750
-
751
- channels = {
752
- '4': int(512 * narrow),
753
- '8': int(512 * narrow),
754
- '16': int(512 * narrow),
755
- '32': int(512 * narrow),
756
- '64': int(256 * channel_multiplier * narrow),
757
- '128': int(128 * channel_multiplier * narrow),
758
- '256': int(64 * channel_multiplier * narrow),
759
- '512': int(32 * channel_multiplier * narrow),
760
- '1024': int(16 * channel_multiplier * narrow)
761
- }
762
-
763
- log_size = int(math.log(out_size, 2))
764
-
765
- conv_body = [ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True)]
766
-
767
- in_channels = channels[f'{out_size}']
768
- for i in range(log_size, 2, -1):
769
- out_channels = channels[f'{2**(i - 1)}']
770
- conv_body.append(ResBlock(in_channels, out_channels, resample_kernel))
771
- in_channels = out_channels
772
- self.conv_body = nn.Sequential(*conv_body)
773
-
774
- self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True)
775
- self.final_linear = nn.Sequential(
776
- EqualLinear(
777
- channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'),
778
- EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None),
779
- )
780
- self.stddev_group = stddev_group
781
- self.stddev_feat = 1
782
-
783
- def forward(self, x):
784
- out = self.conv_body(x)
785
-
786
- b, c, h, w = out.shape
787
- # concatenate a group stddev statistics to out
788
- group = min(b, self.stddev_group) # Minibatch must be divisible by (or smaller than) group_size
789
- stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w)
790
- stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
791
- stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
792
- stddev = stddev.repeat(group, 1, h, w)
793
- out = torch.cat([out, stddev], 1)
794
-
795
- out = self.final_conv(out)
796
- out = out.view(b, -1)
797
- out = self.final_linear(out)
798
-
799
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/stylegan2_bilinear_arch.py DELETED
@@ -1,614 +0,0 @@
1
- import math
2
- import random
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
- from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
8
- from basicsr.utils.registry import ARCH_REGISTRY
9
-
10
-
11
- class NormStyleCode(nn.Module):
12
-
13
- def forward(self, x):
14
- """Normalize the style codes.
15
-
16
- Args:
17
- x (Tensor): Style codes with shape (b, c).
18
-
19
- Returns:
20
- Tensor: Normalized tensor.
21
- """
22
- return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
23
-
24
-
25
- class EqualLinear(nn.Module):
26
- """Equalized Linear as StyleGAN2.
27
-
28
- Args:
29
- in_channels (int): Size of each sample.
30
- out_channels (int): Size of each output sample.
31
- bias (bool): If set to ``False``, the layer will not learn an additive
32
- bias. Default: ``True``.
33
- bias_init_val (float): Bias initialized value. Default: 0.
34
- lr_mul (float): Learning rate multiplier. Default: 1.
35
- activation (None | str): The activation after ``linear`` operation.
36
- Supported: 'fused_lrelu', None. Default: None.
37
- """
38
-
39
- def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
40
- super(EqualLinear, self).__init__()
41
- self.in_channels = in_channels
42
- self.out_channels = out_channels
43
- self.lr_mul = lr_mul
44
- self.activation = activation
45
- if self.activation not in ['fused_lrelu', None]:
46
- raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
47
- "Supported ones are: ['fused_lrelu', None].")
48
- self.scale = (1 / math.sqrt(in_channels)) * lr_mul
49
-
50
- self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
51
- if bias:
52
- self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
53
- else:
54
- self.register_parameter('bias', None)
55
-
56
- def forward(self, x):
57
- if self.bias is None:
58
- bias = None
59
- else:
60
- bias = self.bias * self.lr_mul
61
- if self.activation == 'fused_lrelu':
62
- out = F.linear(x, self.weight * self.scale)
63
- out = fused_leaky_relu(out, bias)
64
- else:
65
- out = F.linear(x, self.weight * self.scale, bias=bias)
66
- return out
67
-
68
- def __repr__(self):
69
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
70
- f'out_channels={self.out_channels}, bias={self.bias is not None})')
71
-
72
-
73
- class ModulatedConv2d(nn.Module):
74
- """Modulated Conv2d used in StyleGAN2.
75
-
76
- There is no bias in ModulatedConv2d.
77
-
78
- Args:
79
- in_channels (int): Channel number of the input.
80
- out_channels (int): Channel number of the output.
81
- kernel_size (int): Size of the convolving kernel.
82
- num_style_feat (int): Channel number of style features.
83
- demodulate (bool): Whether to demodulate in the conv layer.
84
- Default: True.
85
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
86
- Default: None.
87
- eps (float): A value added to the denominator for numerical stability.
88
- Default: 1e-8.
89
- """
90
-
91
- def __init__(self,
92
- in_channels,
93
- out_channels,
94
- kernel_size,
95
- num_style_feat,
96
- demodulate=True,
97
- sample_mode=None,
98
- eps=1e-8,
99
- interpolation_mode='bilinear'):
100
- super(ModulatedConv2d, self).__init__()
101
- self.in_channels = in_channels
102
- self.out_channels = out_channels
103
- self.kernel_size = kernel_size
104
- self.demodulate = demodulate
105
- self.sample_mode = sample_mode
106
- self.eps = eps
107
- self.interpolation_mode = interpolation_mode
108
- if self.interpolation_mode == 'nearest':
109
- self.align_corners = None
110
- else:
111
- self.align_corners = False
112
-
113
- self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
114
- # modulation inside each modulated conv
115
- self.modulation = EqualLinear(
116
- num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
117
-
118
- self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
119
- self.padding = kernel_size // 2
120
-
121
- def forward(self, x, style):
122
- """Forward function.
123
-
124
- Args:
125
- x (Tensor): Tensor with shape (b, c, h, w).
126
- style (Tensor): Tensor with shape (b, num_style_feat).
127
-
128
- Returns:
129
- Tensor: Modulated tensor after convolution.
130
- """
131
- b, c, h, w = x.shape # c = c_in
132
- # weight modulation
133
- style = self.modulation(style).view(b, 1, c, 1, 1)
134
- # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
135
- weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
136
-
137
- if self.demodulate:
138
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
139
- weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
140
-
141
- weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
142
-
143
- if self.sample_mode == 'upsample':
144
- x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
145
- elif self.sample_mode == 'downsample':
146
- x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
147
-
148
- b, c, h, w = x.shape
149
- x = x.view(1, b * c, h, w)
150
- # weight: (b*c_out, c_in, k, k), groups=b
151
- out = F.conv2d(x, weight, padding=self.padding, groups=b)
152
- out = out.view(b, self.out_channels, *out.shape[2:4])
153
-
154
- return out
155
-
156
- def __repr__(self):
157
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
158
- f'out_channels={self.out_channels}, '
159
- f'kernel_size={self.kernel_size}, '
160
- f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
161
-
162
-
163
- class StyleConv(nn.Module):
164
- """Style conv.
165
-
166
- Args:
167
- in_channels (int): Channel number of the input.
168
- out_channels (int): Channel number of the output.
169
- kernel_size (int): Size of the convolving kernel.
170
- num_style_feat (int): Channel number of style features.
171
- demodulate (bool): Whether demodulate in the conv layer. Default: True.
172
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
173
- Default: None.
174
- """
175
-
176
- def __init__(self,
177
- in_channels,
178
- out_channels,
179
- kernel_size,
180
- num_style_feat,
181
- demodulate=True,
182
- sample_mode=None,
183
- interpolation_mode='bilinear'):
184
- super(StyleConv, self).__init__()
185
- self.modulated_conv = ModulatedConv2d(
186
- in_channels,
187
- out_channels,
188
- kernel_size,
189
- num_style_feat,
190
- demodulate=demodulate,
191
- sample_mode=sample_mode,
192
- interpolation_mode=interpolation_mode)
193
- self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
194
- self.activate = FusedLeakyReLU(out_channels)
195
-
196
- def forward(self, x, style, noise=None):
197
- # modulate
198
- out = self.modulated_conv(x, style)
199
- # noise injection
200
- if noise is None:
201
- b, _, h, w = out.shape
202
- noise = out.new_empty(b, 1, h, w).normal_()
203
- out = out + self.weight * noise
204
- # activation (with bias)
205
- out = self.activate(out)
206
- return out
207
-
208
-
209
- class ToRGB(nn.Module):
210
- """To RGB from features.
211
-
212
- Args:
213
- in_channels (int): Channel number of input.
214
- num_style_feat (int): Channel number of style features.
215
- upsample (bool): Whether to upsample. Default: True.
216
- """
217
-
218
- def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
219
- super(ToRGB, self).__init__()
220
- self.upsample = upsample
221
- self.interpolation_mode = interpolation_mode
222
- if self.interpolation_mode == 'nearest':
223
- self.align_corners = None
224
- else:
225
- self.align_corners = False
226
- self.modulated_conv = ModulatedConv2d(
227
- in_channels,
228
- 3,
229
- kernel_size=1,
230
- num_style_feat=num_style_feat,
231
- demodulate=False,
232
- sample_mode=None,
233
- interpolation_mode=interpolation_mode)
234
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
235
-
236
- def forward(self, x, style, skip=None):
237
- """Forward function.
238
-
239
- Args:
240
- x (Tensor): Feature tensor with shape (b, c, h, w).
241
- style (Tensor): Tensor with shape (b, num_style_feat).
242
- skip (Tensor): Base/skip tensor. Default: None.
243
-
244
- Returns:
245
- Tensor: RGB images.
246
- """
247
- out = self.modulated_conv(x, style)
248
- out = out + self.bias
249
- if skip is not None:
250
- if self.upsample:
251
- skip = F.interpolate(
252
- skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
253
- out = out + skip
254
- return out
255
-
256
-
257
- class ConstantInput(nn.Module):
258
- """Constant input.
259
-
260
- Args:
261
- num_channel (int): Channel number of constant input.
262
- size (int): Spatial size of constant input.
263
- """
264
-
265
- def __init__(self, num_channel, size):
266
- super(ConstantInput, self).__init__()
267
- self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
268
-
269
- def forward(self, batch):
270
- out = self.weight.repeat(batch, 1, 1, 1)
271
- return out
272
-
273
-
274
- @ARCH_REGISTRY.register(suffix='basicsr')
275
- class StyleGAN2GeneratorBilinear(nn.Module):
276
- """StyleGAN2 Generator.
277
-
278
- Args:
279
- out_size (int): The spatial size of outputs.
280
- num_style_feat (int): Channel number of style features. Default: 512.
281
- num_mlp (int): Layer number of MLP style layers. Default: 8.
282
- channel_multiplier (int): Channel multiplier for large networks of
283
- StyleGAN2. Default: 2.
284
- lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
285
- narrow (float): Narrow ratio for channels. Default: 1.0.
286
- """
287
-
288
- def __init__(self,
289
- out_size,
290
- num_style_feat=512,
291
- num_mlp=8,
292
- channel_multiplier=2,
293
- lr_mlp=0.01,
294
- narrow=1,
295
- interpolation_mode='bilinear'):
296
- super(StyleGAN2GeneratorBilinear, self).__init__()
297
- # Style MLP layers
298
- self.num_style_feat = num_style_feat
299
- style_mlp_layers = [NormStyleCode()]
300
- for i in range(num_mlp):
301
- style_mlp_layers.append(
302
- EqualLinear(
303
- num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
304
- activation='fused_lrelu'))
305
- self.style_mlp = nn.Sequential(*style_mlp_layers)
306
-
307
- channels = {
308
- '4': int(512 * narrow),
309
- '8': int(512 * narrow),
310
- '16': int(512 * narrow),
311
- '32': int(512 * narrow),
312
- '64': int(256 * channel_multiplier * narrow),
313
- '128': int(128 * channel_multiplier * narrow),
314
- '256': int(64 * channel_multiplier * narrow),
315
- '512': int(32 * channel_multiplier * narrow),
316
- '1024': int(16 * channel_multiplier * narrow)
317
- }
318
- self.channels = channels
319
-
320
- self.constant_input = ConstantInput(channels['4'], size=4)
321
- self.style_conv1 = StyleConv(
322
- channels['4'],
323
- channels['4'],
324
- kernel_size=3,
325
- num_style_feat=num_style_feat,
326
- demodulate=True,
327
- sample_mode=None,
328
- interpolation_mode=interpolation_mode)
329
- self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
330
-
331
- self.log_size = int(math.log(out_size, 2))
332
- self.num_layers = (self.log_size - 2) * 2 + 1
333
- self.num_latent = self.log_size * 2 - 2
334
-
335
- self.style_convs = nn.ModuleList()
336
- self.to_rgbs = nn.ModuleList()
337
- self.noises = nn.Module()
338
-
339
- in_channels = channels['4']
340
- # noise
341
- for layer_idx in range(self.num_layers):
342
- resolution = 2**((layer_idx + 5) // 2)
343
- shape = [1, 1, resolution, resolution]
344
- self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
345
- # style convs and to_rgbs
346
- for i in range(3, self.log_size + 1):
347
- out_channels = channels[f'{2**i}']
348
- self.style_convs.append(
349
- StyleConv(
350
- in_channels,
351
- out_channels,
352
- kernel_size=3,
353
- num_style_feat=num_style_feat,
354
- demodulate=True,
355
- sample_mode='upsample',
356
- interpolation_mode=interpolation_mode))
357
- self.style_convs.append(
358
- StyleConv(
359
- out_channels,
360
- out_channels,
361
- kernel_size=3,
362
- num_style_feat=num_style_feat,
363
- demodulate=True,
364
- sample_mode=None,
365
- interpolation_mode=interpolation_mode))
366
- self.to_rgbs.append(
367
- ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
368
- in_channels = out_channels
369
-
370
- def make_noise(self):
371
- """Make noise for noise injection."""
372
- device = self.constant_input.weight.device
373
- noises = [torch.randn(1, 1, 4, 4, device=device)]
374
-
375
- for i in range(3, self.log_size + 1):
376
- for _ in range(2):
377
- noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
378
-
379
- return noises
380
-
381
- def get_latent(self, x):
382
- return self.style_mlp(x)
383
-
384
- def mean_latent(self, num_latent):
385
- latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
386
- latent = self.style_mlp(latent_in).mean(0, keepdim=True)
387
- return latent
388
-
389
- def forward(self,
390
- styles,
391
- input_is_latent=False,
392
- noise=None,
393
- randomize_noise=True,
394
- truncation=1,
395
- truncation_latent=None,
396
- inject_index=None,
397
- return_latents=False):
398
- """Forward function for StyleGAN2Generator.
399
-
400
- Args:
401
- styles (list[Tensor]): Sample codes of styles.
402
- input_is_latent (bool): Whether input is latent style.
403
- Default: False.
404
- noise (Tensor | None): Input noise or None. Default: None.
405
- randomize_noise (bool): Randomize noise, used when 'noise' is
406
- False. Default: True.
407
- truncation (float): TODO. Default: 1.
408
- truncation_latent (Tensor | None): TODO. Default: None.
409
- inject_index (int | None): The injection index for mixing noise.
410
- Default: None.
411
- return_latents (bool): Whether to return style latents.
412
- Default: False.
413
- """
414
- # style codes -> latents with Style MLP layer
415
- if not input_is_latent:
416
- styles = [self.style_mlp(s) for s in styles]
417
- # noises
418
- if noise is None:
419
- if randomize_noise:
420
- noise = [None] * self.num_layers # for each style conv layer
421
- else: # use the stored noise
422
- noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
423
- # style truncation
424
- if truncation < 1:
425
- style_truncation = []
426
- for style in styles:
427
- style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
428
- styles = style_truncation
429
- # get style latent with injection
430
- if len(styles) == 1:
431
- inject_index = self.num_latent
432
-
433
- if styles[0].ndim < 3:
434
- # repeat latent code for all the layers
435
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
436
- else: # used for encoder with different latent code for each layer
437
- latent = styles[0]
438
- elif len(styles) == 2: # mixing noises
439
- if inject_index is None:
440
- inject_index = random.randint(1, self.num_latent - 1)
441
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
442
- latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
443
- latent = torch.cat([latent1, latent2], 1)
444
-
445
- # main generation
446
- out = self.constant_input(latent.shape[0])
447
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
448
- skip = self.to_rgb1(out, latent[:, 1])
449
-
450
- i = 1
451
- for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
452
- noise[2::2], self.to_rgbs):
453
- out = conv1(out, latent[:, i], noise=noise1)
454
- out = conv2(out, latent[:, i + 1], noise=noise2)
455
- skip = to_rgb(out, latent[:, i + 2], skip)
456
- i += 2
457
-
458
- image = skip
459
-
460
- if return_latents:
461
- return image, latent
462
- else:
463
- return image, None
464
-
465
-
466
- class ScaledLeakyReLU(nn.Module):
467
- """Scaled LeakyReLU.
468
-
469
- Args:
470
- negative_slope (float): Negative slope. Default: 0.2.
471
- """
472
-
473
- def __init__(self, negative_slope=0.2):
474
- super(ScaledLeakyReLU, self).__init__()
475
- self.negative_slope = negative_slope
476
-
477
- def forward(self, x):
478
- out = F.leaky_relu(x, negative_slope=self.negative_slope)
479
- return out * math.sqrt(2)
480
-
481
-
482
- class EqualConv2d(nn.Module):
483
- """Equalized Linear as StyleGAN2.
484
-
485
- Args:
486
- in_channels (int): Channel number of the input.
487
- out_channels (int): Channel number of the output.
488
- kernel_size (int): Size of the convolving kernel.
489
- stride (int): Stride of the convolution. Default: 1
490
- padding (int): Zero-padding added to both sides of the input.
491
- Default: 0.
492
- bias (bool): If ``True``, adds a learnable bias to the output.
493
- Default: ``True``.
494
- bias_init_val (float): Bias initialized value. Default: 0.
495
- """
496
-
497
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
498
- super(EqualConv2d, self).__init__()
499
- self.in_channels = in_channels
500
- self.out_channels = out_channels
501
- self.kernel_size = kernel_size
502
- self.stride = stride
503
- self.padding = padding
504
- self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
505
-
506
- self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
507
- if bias:
508
- self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
509
- else:
510
- self.register_parameter('bias', None)
511
-
512
- def forward(self, x):
513
- out = F.conv2d(
514
- x,
515
- self.weight * self.scale,
516
- bias=self.bias,
517
- stride=self.stride,
518
- padding=self.padding,
519
- )
520
-
521
- return out
522
-
523
- def __repr__(self):
524
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
525
- f'out_channels={self.out_channels}, '
526
- f'kernel_size={self.kernel_size},'
527
- f' stride={self.stride}, padding={self.padding}, '
528
- f'bias={self.bias is not None})')
529
-
530
-
531
- class ConvLayer(nn.Sequential):
532
- """Conv Layer used in StyleGAN2 Discriminator.
533
-
534
- Args:
535
- in_channels (int): Channel number of the input.
536
- out_channels (int): Channel number of the output.
537
- kernel_size (int): Kernel size.
538
- downsample (bool): Whether downsample by a factor of 2.
539
- Default: False.
540
- bias (bool): Whether with bias. Default: True.
541
- activate (bool): Whether use activateion. Default: True.
542
- """
543
-
544
- def __init__(self,
545
- in_channels,
546
- out_channels,
547
- kernel_size,
548
- downsample=False,
549
- bias=True,
550
- activate=True,
551
- interpolation_mode='bilinear'):
552
- layers = []
553
- self.interpolation_mode = interpolation_mode
554
- # downsample
555
- if downsample:
556
- if self.interpolation_mode == 'nearest':
557
- self.align_corners = None
558
- else:
559
- self.align_corners = False
560
-
561
- layers.append(
562
- torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
563
- stride = 1
564
- self.padding = kernel_size // 2
565
- # conv
566
- layers.append(
567
- EqualConv2d(
568
- in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
569
- and not activate))
570
- # activation
571
- if activate:
572
- if bias:
573
- layers.append(FusedLeakyReLU(out_channels))
574
- else:
575
- layers.append(ScaledLeakyReLU(0.2))
576
-
577
- super(ConvLayer, self).__init__(*layers)
578
-
579
-
580
- class ResBlock(nn.Module):
581
- """Residual block used in StyleGAN2 Discriminator.
582
-
583
- Args:
584
- in_channels (int): Channel number of the input.
585
- out_channels (int): Channel number of the output.
586
- """
587
-
588
- def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
589
- super(ResBlock, self).__init__()
590
-
591
- self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
592
- self.conv2 = ConvLayer(
593
- in_channels,
594
- out_channels,
595
- 3,
596
- downsample=True,
597
- interpolation_mode=interpolation_mode,
598
- bias=True,
599
- activate=True)
600
- self.skip = ConvLayer(
601
- in_channels,
602
- out_channels,
603
- 1,
604
- downsample=True,
605
- interpolation_mode=interpolation_mode,
606
- bias=False,
607
- activate=False)
608
-
609
- def forward(self, x):
610
- out = self.conv1(x)
611
- out = self.conv2(out)
612
- skip = self.skip(x)
613
- out = (out + skip) / math.sqrt(2)
614
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/swinir_arch.py DELETED
@@ -1,956 +0,0 @@
1
- # Modified from https://github.com/JingyunLiang/SwinIR
2
- # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
- # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
-
5
- import math
6
- import torch
7
- import torch.nn as nn
8
- import torch.utils.checkpoint as checkpoint
9
-
10
- from basicsr.utils.registry import ARCH_REGISTRY
11
- from .arch_util import to_2tuple, trunc_normal_
12
-
13
-
14
- def drop_path(x, drop_prob: float = 0., training: bool = False):
15
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
16
-
17
- From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
18
- """
19
- if drop_prob == 0. or not training:
20
- return x
21
- keep_prob = 1 - drop_prob
22
- shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
23
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
24
- random_tensor.floor_() # binarize
25
- output = x.div(keep_prob) * random_tensor
26
- return output
27
-
28
-
29
- class DropPath(nn.Module):
30
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
31
-
32
- From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
33
- """
34
-
35
- def __init__(self, drop_prob=None):
36
- super(DropPath, self).__init__()
37
- self.drop_prob = drop_prob
38
-
39
- def forward(self, x):
40
- return drop_path(x, self.drop_prob, self.training)
41
-
42
-
43
- class Mlp(nn.Module):
44
-
45
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
46
- super().__init__()
47
- out_features = out_features or in_features
48
- hidden_features = hidden_features or in_features
49
- self.fc1 = nn.Linear(in_features, hidden_features)
50
- self.act = act_layer()
51
- self.fc2 = nn.Linear(hidden_features, out_features)
52
- self.drop = nn.Dropout(drop)
53
-
54
- def forward(self, x):
55
- x = self.fc1(x)
56
- x = self.act(x)
57
- x = self.drop(x)
58
- x = self.fc2(x)
59
- x = self.drop(x)
60
- return x
61
-
62
-
63
- def window_partition(x, window_size):
64
- """
65
- Args:
66
- x: (b, h, w, c)
67
- window_size (int): window size
68
-
69
- Returns:
70
- windows: (num_windows*b, window_size, window_size, c)
71
- """
72
- b, h, w, c = x.shape
73
- x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
74
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
75
- return windows
76
-
77
-
78
- def window_reverse(windows, window_size, h, w):
79
- """
80
- Args:
81
- windows: (num_windows*b, window_size, window_size, c)
82
- window_size (int): Window size
83
- h (int): Height of image
84
- w (int): Width of image
85
-
86
- Returns:
87
- x: (b, h, w, c)
88
- """
89
- b = int(windows.shape[0] / (h * w / window_size / window_size))
90
- x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
91
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
92
- return x
93
-
94
-
95
- class WindowAttention(nn.Module):
96
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
97
- It supports both of shifted and non-shifted window.
98
-
99
- Args:
100
- dim (int): Number of input channels.
101
- window_size (tuple[int]): The height and width of the window.
102
- num_heads (int): Number of attention heads.
103
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
104
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
105
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
106
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
107
- """
108
-
109
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
110
-
111
- super().__init__()
112
- self.dim = dim
113
- self.window_size = window_size # Wh, Ww
114
- self.num_heads = num_heads
115
- head_dim = dim // num_heads
116
- self.scale = qk_scale or head_dim**-0.5
117
-
118
- # define a parameter table of relative position bias
119
- self.relative_position_bias_table = nn.Parameter(
120
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
121
-
122
- # get pair-wise relative position index for each token inside the window
123
- coords_h = torch.arange(self.window_size[0])
124
- coords_w = torch.arange(self.window_size[1])
125
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
126
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
127
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
128
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
129
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
130
- relative_coords[:, :, 1] += self.window_size[1] - 1
131
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
132
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
133
- self.register_buffer('relative_position_index', relative_position_index)
134
-
135
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
136
- self.attn_drop = nn.Dropout(attn_drop)
137
- self.proj = nn.Linear(dim, dim)
138
-
139
- self.proj_drop = nn.Dropout(proj_drop)
140
-
141
- trunc_normal_(self.relative_position_bias_table, std=.02)
142
- self.softmax = nn.Softmax(dim=-1)
143
-
144
- def forward(self, x, mask=None):
145
- """
146
- Args:
147
- x: input features with shape of (num_windows*b, n, c)
148
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
149
- """
150
- b_, n, c = x.shape
151
- qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
152
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
153
-
154
- q = q * self.scale
155
- attn = (q @ k.transpose(-2, -1))
156
-
157
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
158
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
159
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
160
- attn = attn + relative_position_bias.unsqueeze(0)
161
-
162
- if mask is not None:
163
- nw = mask.shape[0]
164
- attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
165
- attn = attn.view(-1, self.num_heads, n, n)
166
- attn = self.softmax(attn)
167
- else:
168
- attn = self.softmax(attn)
169
-
170
- attn = self.attn_drop(attn)
171
-
172
- x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
173
- x = self.proj(x)
174
- x = self.proj_drop(x)
175
- return x
176
-
177
- def extra_repr(self) -> str:
178
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
179
-
180
- def flops(self, n):
181
- # calculate flops for 1 window with token length of n
182
- flops = 0
183
- # qkv = self.qkv(x)
184
- flops += n * self.dim * 3 * self.dim
185
- # attn = (q @ k.transpose(-2, -1))
186
- flops += self.num_heads * n * (self.dim // self.num_heads) * n
187
- # x = (attn @ v)
188
- flops += self.num_heads * n * n * (self.dim // self.num_heads)
189
- # x = self.proj(x)
190
- flops += n * self.dim * self.dim
191
- return flops
192
-
193
-
194
- class SwinTransformerBlock(nn.Module):
195
- r""" Swin Transformer Block.
196
-
197
- Args:
198
- dim (int): Number of input channels.
199
- input_resolution (tuple[int]): Input resolution.
200
- num_heads (int): Number of attention heads.
201
- window_size (int): Window size.
202
- shift_size (int): Shift size for SW-MSA.
203
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
204
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
205
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
206
- drop (float, optional): Dropout rate. Default: 0.0
207
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
208
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
209
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
210
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
211
- """
212
-
213
- def __init__(self,
214
- dim,
215
- input_resolution,
216
- num_heads,
217
- window_size=7,
218
- shift_size=0,
219
- mlp_ratio=4.,
220
- qkv_bias=True,
221
- qk_scale=None,
222
- drop=0.,
223
- attn_drop=0.,
224
- drop_path=0.,
225
- act_layer=nn.GELU,
226
- norm_layer=nn.LayerNorm):
227
- super().__init__()
228
- self.dim = dim
229
- self.input_resolution = input_resolution
230
- self.num_heads = num_heads
231
- self.window_size = window_size
232
- self.shift_size = shift_size
233
- self.mlp_ratio = mlp_ratio
234
- if min(self.input_resolution) <= self.window_size:
235
- # if window size is larger than input resolution, we don't partition windows
236
- self.shift_size = 0
237
- self.window_size = min(self.input_resolution)
238
- assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
239
-
240
- self.norm1 = norm_layer(dim)
241
- self.attn = WindowAttention(
242
- dim,
243
- window_size=to_2tuple(self.window_size),
244
- num_heads=num_heads,
245
- qkv_bias=qkv_bias,
246
- qk_scale=qk_scale,
247
- attn_drop=attn_drop,
248
- proj_drop=drop)
249
-
250
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
251
- self.norm2 = norm_layer(dim)
252
- mlp_hidden_dim = int(dim * mlp_ratio)
253
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
254
-
255
- if self.shift_size > 0:
256
- attn_mask = self.calculate_mask(self.input_resolution)
257
- else:
258
- attn_mask = None
259
-
260
- self.register_buffer('attn_mask', attn_mask)
261
-
262
- def calculate_mask(self, x_size):
263
- # calculate attention mask for SW-MSA
264
- h, w = x_size
265
- img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
266
- h_slices = (slice(0, -self.window_size), slice(-self.window_size,
267
- -self.shift_size), slice(-self.shift_size, None))
268
- w_slices = (slice(0, -self.window_size), slice(-self.window_size,
269
- -self.shift_size), slice(-self.shift_size, None))
270
- cnt = 0
271
- for h in h_slices:
272
- for w in w_slices:
273
- img_mask[:, h, w, :] = cnt
274
- cnt += 1
275
-
276
- mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
277
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
278
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
279
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
280
-
281
- return attn_mask
282
-
283
- def forward(self, x, x_size):
284
- h, w = x_size
285
- b, _, c = x.shape
286
- # assert seq_len == h * w, "input feature has wrong size"
287
-
288
- shortcut = x
289
- x = self.norm1(x)
290
- x = x.view(b, h, w, c)
291
-
292
- # cyclic shift
293
- if self.shift_size > 0:
294
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
295
- else:
296
- shifted_x = x
297
-
298
- # partition windows
299
- x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
300
- x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
301
-
302
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
303
- if self.input_resolution == x_size:
304
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
305
- else:
306
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
307
-
308
- # merge windows
309
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
310
- shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
311
-
312
- # reverse cyclic shift
313
- if self.shift_size > 0:
314
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
315
- else:
316
- x = shifted_x
317
- x = x.view(b, h * w, c)
318
-
319
- # FFN
320
- x = shortcut + self.drop_path(x)
321
- x = x + self.drop_path(self.mlp(self.norm2(x)))
322
-
323
- return x
324
-
325
- def extra_repr(self) -> str:
326
- return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
327
- f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
328
-
329
- def flops(self):
330
- flops = 0
331
- h, w = self.input_resolution
332
- # norm1
333
- flops += self.dim * h * w
334
- # W-MSA/SW-MSA
335
- nw = h * w / self.window_size / self.window_size
336
- flops += nw * self.attn.flops(self.window_size * self.window_size)
337
- # mlp
338
- flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
339
- # norm2
340
- flops += self.dim * h * w
341
- return flops
342
-
343
-
344
- class PatchMerging(nn.Module):
345
- r""" Patch Merging Layer.
346
-
347
- Args:
348
- input_resolution (tuple[int]): Resolution of input feature.
349
- dim (int): Number of input channels.
350
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
351
- """
352
-
353
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
354
- super().__init__()
355
- self.input_resolution = input_resolution
356
- self.dim = dim
357
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
358
- self.norm = norm_layer(4 * dim)
359
-
360
- def forward(self, x):
361
- """
362
- x: b, h*w, c
363
- """
364
- h, w = self.input_resolution
365
- b, seq_len, c = x.shape
366
- assert seq_len == h * w, 'input feature has wrong size'
367
- assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
368
-
369
- x = x.view(b, h, w, c)
370
-
371
- x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
372
- x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
373
- x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
374
- x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
375
- x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
376
- x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
377
-
378
- x = self.norm(x)
379
- x = self.reduction(x)
380
-
381
- return x
382
-
383
- def extra_repr(self) -> str:
384
- return f'input_resolution={self.input_resolution}, dim={self.dim}'
385
-
386
- def flops(self):
387
- h, w = self.input_resolution
388
- flops = h * w * self.dim
389
- flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
390
- return flops
391
-
392
-
393
- class BasicLayer(nn.Module):
394
- """ A basic Swin Transformer layer for one stage.
395
-
396
- Args:
397
- dim (int): Number of input channels.
398
- input_resolution (tuple[int]): Input resolution.
399
- depth (int): Number of blocks.
400
- num_heads (int): Number of attention heads.
401
- window_size (int): Local window size.
402
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
403
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
404
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
405
- drop (float, optional): Dropout rate. Default: 0.0
406
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
407
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
408
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
409
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
410
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
411
- """
412
-
413
- def __init__(self,
414
- dim,
415
- input_resolution,
416
- depth,
417
- num_heads,
418
- window_size,
419
- mlp_ratio=4.,
420
- qkv_bias=True,
421
- qk_scale=None,
422
- drop=0.,
423
- attn_drop=0.,
424
- drop_path=0.,
425
- norm_layer=nn.LayerNorm,
426
- downsample=None,
427
- use_checkpoint=False):
428
-
429
- super().__init__()
430
- self.dim = dim
431
- self.input_resolution = input_resolution
432
- self.depth = depth
433
- self.use_checkpoint = use_checkpoint
434
-
435
- # build blocks
436
- self.blocks = nn.ModuleList([
437
- SwinTransformerBlock(
438
- dim=dim,
439
- input_resolution=input_resolution,
440
- num_heads=num_heads,
441
- window_size=window_size,
442
- shift_size=0 if (i % 2 == 0) else window_size // 2,
443
- mlp_ratio=mlp_ratio,
444
- qkv_bias=qkv_bias,
445
- qk_scale=qk_scale,
446
- drop=drop,
447
- attn_drop=attn_drop,
448
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
449
- norm_layer=norm_layer) for i in range(depth)
450
- ])
451
-
452
- # patch merging layer
453
- if downsample is not None:
454
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
455
- else:
456
- self.downsample = None
457
-
458
- def forward(self, x, x_size):
459
- for blk in self.blocks:
460
- if self.use_checkpoint:
461
- x = checkpoint.checkpoint(blk, x)
462
- else:
463
- x = blk(x, x_size)
464
- if self.downsample is not None:
465
- x = self.downsample(x)
466
- return x
467
-
468
- def extra_repr(self) -> str:
469
- return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
470
-
471
- def flops(self):
472
- flops = 0
473
- for blk in self.blocks:
474
- flops += blk.flops()
475
- if self.downsample is not None:
476
- flops += self.downsample.flops()
477
- return flops
478
-
479
-
480
- class RSTB(nn.Module):
481
- """Residual Swin Transformer Block (RSTB).
482
-
483
- Args:
484
- dim (int): Number of input channels.
485
- input_resolution (tuple[int]): Input resolution.
486
- depth (int): Number of blocks.
487
- num_heads (int): Number of attention heads.
488
- window_size (int): Local window size.
489
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
490
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
491
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
492
- drop (float, optional): Dropout rate. Default: 0.0
493
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
494
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
495
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
496
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
497
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
498
- img_size: Input image size.
499
- patch_size: Patch size.
500
- resi_connection: The convolutional block before residual connection.
501
- """
502
-
503
- def __init__(self,
504
- dim,
505
- input_resolution,
506
- depth,
507
- num_heads,
508
- window_size,
509
- mlp_ratio=4.,
510
- qkv_bias=True,
511
- qk_scale=None,
512
- drop=0.,
513
- attn_drop=0.,
514
- drop_path=0.,
515
- norm_layer=nn.LayerNorm,
516
- downsample=None,
517
- use_checkpoint=False,
518
- img_size=224,
519
- patch_size=4,
520
- resi_connection='1conv'):
521
- super(RSTB, self).__init__()
522
-
523
- self.dim = dim
524
- self.input_resolution = input_resolution
525
-
526
- self.residual_group = BasicLayer(
527
- dim=dim,
528
- input_resolution=input_resolution,
529
- depth=depth,
530
- num_heads=num_heads,
531
- window_size=window_size,
532
- mlp_ratio=mlp_ratio,
533
- qkv_bias=qkv_bias,
534
- qk_scale=qk_scale,
535
- drop=drop,
536
- attn_drop=attn_drop,
537
- drop_path=drop_path,
538
- norm_layer=norm_layer,
539
- downsample=downsample,
540
- use_checkpoint=use_checkpoint)
541
-
542
- if resi_connection == '1conv':
543
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
544
- elif resi_connection == '3conv':
545
- # to save parameters and memory
546
- self.conv = nn.Sequential(
547
- nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
548
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
549
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
550
-
551
- self.patch_embed = PatchEmbed(
552
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
553
-
554
- self.patch_unembed = PatchUnEmbed(
555
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
556
-
557
- def forward(self, x, x_size):
558
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
559
-
560
- def flops(self):
561
- flops = 0
562
- flops += self.residual_group.flops()
563
- h, w = self.input_resolution
564
- flops += h * w * self.dim * self.dim * 9
565
- flops += self.patch_embed.flops()
566
- flops += self.patch_unembed.flops()
567
-
568
- return flops
569
-
570
-
571
- class PatchEmbed(nn.Module):
572
- r""" Image to Patch Embedding
573
-
574
- Args:
575
- img_size (int): Image size. Default: 224.
576
- patch_size (int): Patch token size. Default: 4.
577
- in_chans (int): Number of input image channels. Default: 3.
578
- embed_dim (int): Number of linear projection output channels. Default: 96.
579
- norm_layer (nn.Module, optional): Normalization layer. Default: None
580
- """
581
-
582
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
583
- super().__init__()
584
- img_size = to_2tuple(img_size)
585
- patch_size = to_2tuple(patch_size)
586
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
587
- self.img_size = img_size
588
- self.patch_size = patch_size
589
- self.patches_resolution = patches_resolution
590
- self.num_patches = patches_resolution[0] * patches_resolution[1]
591
-
592
- self.in_chans = in_chans
593
- self.embed_dim = embed_dim
594
-
595
- if norm_layer is not None:
596
- self.norm = norm_layer(embed_dim)
597
- else:
598
- self.norm = None
599
-
600
- def forward(self, x):
601
- x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
602
- if self.norm is not None:
603
- x = self.norm(x)
604
- return x
605
-
606
- def flops(self):
607
- flops = 0
608
- h, w = self.img_size
609
- if self.norm is not None:
610
- flops += h * w * self.embed_dim
611
- return flops
612
-
613
-
614
- class PatchUnEmbed(nn.Module):
615
- r""" Image to Patch Unembedding
616
-
617
- Args:
618
- img_size (int): Image size. Default: 224.
619
- patch_size (int): Patch token size. Default: 4.
620
- in_chans (int): Number of input image channels. Default: 3.
621
- embed_dim (int): Number of linear projection output channels. Default: 96.
622
- norm_layer (nn.Module, optional): Normalization layer. Default: None
623
- """
624
-
625
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
626
- super().__init__()
627
- img_size = to_2tuple(img_size)
628
- patch_size = to_2tuple(patch_size)
629
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
630
- self.img_size = img_size
631
- self.patch_size = patch_size
632
- self.patches_resolution = patches_resolution
633
- self.num_patches = patches_resolution[0] * patches_resolution[1]
634
-
635
- self.in_chans = in_chans
636
- self.embed_dim = embed_dim
637
-
638
- def forward(self, x, x_size):
639
- x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
640
- return x
641
-
642
- def flops(self):
643
- flops = 0
644
- return flops
645
-
646
-
647
- class Upsample(nn.Sequential):
648
- """Upsample module.
649
-
650
- Args:
651
- scale (int): Scale factor. Supported scales: 2^n and 3.
652
- num_feat (int): Channel number of intermediate features.
653
- """
654
-
655
- def __init__(self, scale, num_feat):
656
- m = []
657
- if (scale & (scale - 1)) == 0: # scale = 2^n
658
- for _ in range(int(math.log(scale, 2))):
659
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
660
- m.append(nn.PixelShuffle(2))
661
- elif scale == 3:
662
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
663
- m.append(nn.PixelShuffle(3))
664
- else:
665
- raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
666
- super(Upsample, self).__init__(*m)
667
-
668
-
669
- class UpsampleOneStep(nn.Sequential):
670
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
671
- Used in lightweight SR to save parameters.
672
-
673
- Args:
674
- scale (int): Scale factor. Supported scales: 2^n and 3.
675
- num_feat (int): Channel number of intermediate features.
676
-
677
- """
678
-
679
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
680
- self.num_feat = num_feat
681
- self.input_resolution = input_resolution
682
- m = []
683
- m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
684
- m.append(nn.PixelShuffle(scale))
685
- super(UpsampleOneStep, self).__init__(*m)
686
-
687
- def flops(self):
688
- h, w = self.input_resolution
689
- flops = h * w * self.num_feat * 3 * 9
690
- return flops
691
-
692
-
693
- @ARCH_REGISTRY.register()
694
- class SwinIR(nn.Module):
695
- r""" SwinIR
696
- A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
697
-
698
- Args:
699
- img_size (int | tuple(int)): Input image size. Default 64
700
- patch_size (int | tuple(int)): Patch size. Default: 1
701
- in_chans (int): Number of input image channels. Default: 3
702
- embed_dim (int): Patch embedding dimension. Default: 96
703
- depths (tuple(int)): Depth of each Swin Transformer layer.
704
- num_heads (tuple(int)): Number of attention heads in different layers.
705
- window_size (int): Window size. Default: 7
706
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
707
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
708
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
709
- drop_rate (float): Dropout rate. Default: 0
710
- attn_drop_rate (float): Attention dropout rate. Default: 0
711
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
712
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
713
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
714
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
715
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
716
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
717
- img_range: Image range. 1. or 255.
718
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
719
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
720
- """
721
-
722
- def __init__(self,
723
- img_size=64,
724
- patch_size=1,
725
- in_chans=3,
726
- embed_dim=96,
727
- depths=(6, 6, 6, 6),
728
- num_heads=(6, 6, 6, 6),
729
- window_size=7,
730
- mlp_ratio=4.,
731
- qkv_bias=True,
732
- qk_scale=None,
733
- drop_rate=0.,
734
- attn_drop_rate=0.,
735
- drop_path_rate=0.1,
736
- norm_layer=nn.LayerNorm,
737
- ape=False,
738
- patch_norm=True,
739
- use_checkpoint=False,
740
- upscale=2,
741
- img_range=1.,
742
- upsampler='',
743
- resi_connection='1conv',
744
- **kwargs):
745
- super(SwinIR, self).__init__()
746
- num_in_ch = in_chans
747
- num_out_ch = in_chans
748
- num_feat = 64
749
- self.img_range = img_range
750
- if in_chans == 3:
751
- rgb_mean = (0.4488, 0.4371, 0.4040)
752
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
753
- else:
754
- self.mean = torch.zeros(1, 1, 1, 1)
755
- self.upscale = upscale
756
- self.upsampler = upsampler
757
-
758
- # ------------------------- 1, shallow feature extraction ------------------------- #
759
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
760
-
761
- # ------------------------- 2, deep feature extraction ------------------------- #
762
- self.num_layers = len(depths)
763
- self.embed_dim = embed_dim
764
- self.ape = ape
765
- self.patch_norm = patch_norm
766
- self.num_features = embed_dim
767
- self.mlp_ratio = mlp_ratio
768
-
769
- # split image into non-overlapping patches
770
- self.patch_embed = PatchEmbed(
771
- img_size=img_size,
772
- patch_size=patch_size,
773
- in_chans=embed_dim,
774
- embed_dim=embed_dim,
775
- norm_layer=norm_layer if self.patch_norm else None)
776
- num_patches = self.patch_embed.num_patches
777
- patches_resolution = self.patch_embed.patches_resolution
778
- self.patches_resolution = patches_resolution
779
-
780
- # merge non-overlapping patches into image
781
- self.patch_unembed = PatchUnEmbed(
782
- img_size=img_size,
783
- patch_size=patch_size,
784
- in_chans=embed_dim,
785
- embed_dim=embed_dim,
786
- norm_layer=norm_layer if self.patch_norm else None)
787
-
788
- # absolute position embedding
789
- if self.ape:
790
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
791
- trunc_normal_(self.absolute_pos_embed, std=.02)
792
-
793
- self.pos_drop = nn.Dropout(p=drop_rate)
794
-
795
- # stochastic depth
796
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
797
-
798
- # build Residual Swin Transformer blocks (RSTB)
799
- self.layers = nn.ModuleList()
800
- for i_layer in range(self.num_layers):
801
- layer = RSTB(
802
- dim=embed_dim,
803
- input_resolution=(patches_resolution[0], patches_resolution[1]),
804
- depth=depths[i_layer],
805
- num_heads=num_heads[i_layer],
806
- window_size=window_size,
807
- mlp_ratio=self.mlp_ratio,
808
- qkv_bias=qkv_bias,
809
- qk_scale=qk_scale,
810
- drop=drop_rate,
811
- attn_drop=attn_drop_rate,
812
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
813
- norm_layer=norm_layer,
814
- downsample=None,
815
- use_checkpoint=use_checkpoint,
816
- img_size=img_size,
817
- patch_size=patch_size,
818
- resi_connection=resi_connection)
819
- self.layers.append(layer)
820
- self.norm = norm_layer(self.num_features)
821
-
822
- # build the last conv layer in deep feature extraction
823
- if resi_connection == '1conv':
824
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
825
- elif resi_connection == '3conv':
826
- # to save parameters and memory
827
- self.conv_after_body = nn.Sequential(
828
- nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
829
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
830
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
831
-
832
- # ------------------------- 3, high quality image reconstruction ------------------------- #
833
- if self.upsampler == 'pixelshuffle':
834
- # for classical SR
835
- self.conv_before_upsample = nn.Sequential(
836
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
837
- self.upsample = Upsample(upscale, num_feat)
838
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
839
- elif self.upsampler == 'pixelshuffledirect':
840
- # for lightweight SR (to save parameters)
841
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
842
- (patches_resolution[0], patches_resolution[1]))
843
- elif self.upsampler == 'nearest+conv':
844
- # for real-world SR (less artifacts)
845
- assert self.upscale == 4, 'only support x4 now.'
846
- self.conv_before_upsample = nn.Sequential(
847
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
848
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
849
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
850
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
851
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
852
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
853
- else:
854
- # for image denoising and JPEG compression artifact reduction
855
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
856
-
857
- self.apply(self._init_weights)
858
-
859
- def _init_weights(self, m):
860
- if isinstance(m, nn.Linear):
861
- trunc_normal_(m.weight, std=.02)
862
- if isinstance(m, nn.Linear) and m.bias is not None:
863
- nn.init.constant_(m.bias, 0)
864
- elif isinstance(m, nn.LayerNorm):
865
- nn.init.constant_(m.bias, 0)
866
- nn.init.constant_(m.weight, 1.0)
867
-
868
- @torch.jit.ignore
869
- def no_weight_decay(self):
870
- return {'absolute_pos_embed'}
871
-
872
- @torch.jit.ignore
873
- def no_weight_decay_keywords(self):
874
- return {'relative_position_bias_table'}
875
-
876
- def forward_features(self, x):
877
- x_size = (x.shape[2], x.shape[3])
878
- x = self.patch_embed(x)
879
- if self.ape:
880
- x = x + self.absolute_pos_embed
881
- x = self.pos_drop(x)
882
-
883
- for layer in self.layers:
884
- x = layer(x, x_size)
885
-
886
- x = self.norm(x) # b seq_len c
887
- x = self.patch_unembed(x, x_size)
888
-
889
- return x
890
-
891
- def forward(self, x):
892
- self.mean = self.mean.type_as(x)
893
- x = (x - self.mean) * self.img_range
894
-
895
- if self.upsampler == 'pixelshuffle':
896
- # for classical SR
897
- x = self.conv_first(x)
898
- x = self.conv_after_body(self.forward_features(x)) + x
899
- x = self.conv_before_upsample(x)
900
- x = self.conv_last(self.upsample(x))
901
- elif self.upsampler == 'pixelshuffledirect':
902
- # for lightweight SR
903
- x = self.conv_first(x)
904
- x = self.conv_after_body(self.forward_features(x)) + x
905
- x = self.upsample(x)
906
- elif self.upsampler == 'nearest+conv':
907
- # for real-world SR
908
- x = self.conv_first(x)
909
- x = self.conv_after_body(self.forward_features(x)) + x
910
- x = self.conv_before_upsample(x)
911
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
912
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
913
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
914
- else:
915
- # for image denoising and JPEG compression artifact reduction
916
- x_first = self.conv_first(x)
917
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
918
- x = x + self.conv_last(res)
919
-
920
- x = x / self.img_range + self.mean
921
-
922
- return x
923
-
924
- def flops(self):
925
- flops = 0
926
- h, w = self.patches_resolution
927
- flops += h * w * 3 * self.embed_dim * 9
928
- flops += self.patch_embed.flops()
929
- for layer in self.layers:
930
- flops += layer.flops()
931
- flops += h * w * 3 * self.embed_dim * self.embed_dim
932
- flops += self.upsample.flops()
933
- return flops
934
-
935
-
936
- if __name__ == '__main__':
937
- upscale = 4
938
- window_size = 8
939
- height = (1024 // upscale // window_size + 1) * window_size
940
- width = (720 // upscale // window_size + 1) * window_size
941
- model = SwinIR(
942
- upscale=2,
943
- img_size=(height, width),
944
- window_size=window_size,
945
- img_range=1.,
946
- depths=[6, 6, 6, 6],
947
- embed_dim=60,
948
- num_heads=[6, 6, 6, 6],
949
- mlp_ratio=2,
950
- upsampler='pixelshuffledirect')
951
- print(model)
952
- print(height, width, model.flops() / 1e9)
953
-
954
- x = torch.randn((1, 3, height, width))
955
- x = model(x)
956
- print(x.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/tof_arch.py DELETED
@@ -1,172 +0,0 @@
1
- import torch
2
- from torch import nn as nn
3
- from torch.nn import functional as F
4
-
5
- from basicsr.utils.registry import ARCH_REGISTRY
6
- from .arch_util import flow_warp
7
-
8
-
9
- class BasicModule(nn.Module):
10
- """Basic module of SPyNet.
11
-
12
- Note that unlike the architecture in spynet_arch.py, the basic module
13
- here contains batch normalization.
14
- """
15
-
16
- def __init__(self):
17
- super(BasicModule, self).__init__()
18
- self.basic_module = nn.Sequential(
19
- nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
20
- nn.BatchNorm2d(32), nn.ReLU(inplace=True),
21
- nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
22
- nn.BatchNorm2d(64), nn.ReLU(inplace=True),
23
- nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
24
- nn.BatchNorm2d(32), nn.ReLU(inplace=True),
25
- nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
26
- nn.BatchNorm2d(16), nn.ReLU(inplace=True),
27
- nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
28
-
29
- def forward(self, tensor_input):
30
- """
31
- Args:
32
- tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
33
- 8 channels contain:
34
- [reference image (3), neighbor image (3), initial flow (2)].
35
-
36
- Returns:
37
- Tensor: Estimated flow with shape (b, 2, h, w)
38
- """
39
- return self.basic_module(tensor_input)
40
-
41
-
42
- class SPyNetTOF(nn.Module):
43
- """SPyNet architecture for TOF.
44
-
45
- Note that this implementation is specifically for TOFlow. Please use :file:`spynet_arch.py` for general use.
46
- They differ in the following aspects:
47
-
48
- 1. The basic modules here contain BatchNorm.
49
- 2. Normalization and denormalization are not done here, as they are done in TOFlow.
50
-
51
- ``Paper: Optical Flow Estimation using a Spatial Pyramid Network``
52
-
53
- Reference: https://github.com/Coldog2333/pytoflow
54
-
55
- Args:
56
- load_path (str): Path for pretrained SPyNet. Default: None.
57
- """
58
-
59
- def __init__(self, load_path=None):
60
- super(SPyNetTOF, self).__init__()
61
-
62
- self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
63
- if load_path:
64
- self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
65
-
66
- def forward(self, ref, supp):
67
- """
68
- Args:
69
- ref (Tensor): Reference image with shape of (b, 3, h, w).
70
- supp: The supporting image to be warped: (b, 3, h, w).
71
-
72
- Returns:
73
- Tensor: Estimated optical flow: (b, 2, h, w).
74
- """
75
- num_batches, _, h, w = ref.size()
76
- ref = [ref]
77
- supp = [supp]
78
-
79
- # generate downsampled frames
80
- for _ in range(3):
81
- ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
82
- supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
83
-
84
- # flow computation
85
- flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
86
- for i in range(4):
87
- flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
88
- flow = flow_up + self.basic_module[i](
89
- torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
90
- return flow
91
-
92
-
93
- @ARCH_REGISTRY.register()
94
- class TOFlow(nn.Module):
95
- """PyTorch implementation of TOFlow.
96
-
97
- In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames.
98
-
99
- ``Paper: Video Enhancement with Task-Oriented Flow``
100
-
101
- Reference: https://github.com/anchen1011/toflow
102
-
103
- Reference: https://github.com/Coldog2333/pytoflow
104
-
105
- Args:
106
- adapt_official_weights (bool): Whether to adapt the weights translated
107
- from the official implementation. Set to false if you want to
108
- train from scratch. Default: False
109
- """
110
-
111
- def __init__(self, adapt_official_weights=False):
112
- super(TOFlow, self).__init__()
113
- self.adapt_official_weights = adapt_official_weights
114
- self.ref_idx = 0 if adapt_official_weights else 3
115
-
116
- self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
117
- self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
118
-
119
- # flow estimation module
120
- self.spynet = SPyNetTOF()
121
-
122
- # reconstruction module
123
- self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
124
- self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4)
125
- self.conv_3 = nn.Conv2d(64, 64, 1)
126
- self.conv_4 = nn.Conv2d(64, 3, 1)
127
-
128
- # activation function
129
- self.relu = nn.ReLU(inplace=True)
130
-
131
- def normalize(self, img):
132
- return (img - self.mean) / self.std
133
-
134
- def denormalize(self, img):
135
- return img * self.std + self.mean
136
-
137
- def forward(self, lrs):
138
- """
139
- Args:
140
- lrs: Input lr frames: (b, 7, 3, h, w).
141
-
142
- Returns:
143
- Tensor: SR frame: (b, 3, h, w).
144
- """
145
- # In the official implementation, the 0-th frame is the reference frame
146
- if self.adapt_official_weights:
147
- lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
148
-
149
- num_batches, num_lrs, _, h, w = lrs.size()
150
-
151
- lrs = self.normalize(lrs.view(-1, 3, h, w))
152
- lrs = lrs.view(num_batches, num_lrs, 3, h, w)
153
-
154
- lr_ref = lrs[:, self.ref_idx, :, :, :]
155
- lr_aligned = []
156
- for i in range(7): # 7 frames
157
- if i == self.ref_idx:
158
- lr_aligned.append(lr_ref)
159
- else:
160
- lr_supp = lrs[:, i, :, :, :]
161
- flow = self.spynet(lr_ref, lr_supp)
162
- lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
163
-
164
- # reconstruction
165
- hr = torch.stack(lr_aligned, dim=1)
166
- hr = hr.view(num_batches, -1, h, w)
167
- hr = self.relu(self.conv_1(hr))
168
- hr = self.relu(self.conv_2(hr))
169
- hr = self.relu(self.conv_3(hr))
170
- hr = self.conv_4(hr) + lr_ref
171
-
172
- return self.denormalize(hr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/archs/vgg_arch.py DELETED
@@ -1,161 +0,0 @@
1
- import os
2
- import torch
3
- from collections import OrderedDict
4
- from torch import nn as nn
5
- from torchvision.models import vgg as vgg
6
-
7
- from basicsr.utils.registry import ARCH_REGISTRY
8
-
9
- VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
10
- NAMES = {
11
- 'vgg11': [
12
- 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
13
- 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14
- 'pool5'
15
- ],
16
- 'vgg13': [
17
- 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
18
- 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
19
- 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
20
- ],
21
- 'vgg16': [
22
- 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
23
- 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
24
- 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
25
- 'pool5'
26
- ],
27
- 'vgg19': [
28
- 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
29
- 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
30
- 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
31
- 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
32
- ]
33
- }
34
-
35
-
36
- def insert_bn(names):
37
- """Insert bn layer after each conv.
38
-
39
- Args:
40
- names (list): The list of layer names.
41
-
42
- Returns:
43
- list: The list of layer names with bn layers.
44
- """
45
- names_bn = []
46
- for name in names:
47
- names_bn.append(name)
48
- if 'conv' in name:
49
- position = name.replace('conv', '')
50
- names_bn.append('bn' + position)
51
- return names_bn
52
-
53
-
54
- @ARCH_REGISTRY.register()
55
- class VGGFeatureExtractor(nn.Module):
56
- """VGG network for feature extraction.
57
-
58
- In this implementation, we allow users to choose whether use normalization
59
- in the input feature and the type of vgg network. Note that the pretrained
60
- path must fit the vgg type.
61
-
62
- Args:
63
- layer_name_list (list[str]): Forward function returns the corresponding
64
- features according to the layer_name_list.
65
- Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
66
- vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
67
- use_input_norm (bool): If True, normalize the input image. Importantly,
68
- the input feature must in the range [0, 1]. Default: True.
69
- range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
70
- Default: False.
71
- requires_grad (bool): If true, the parameters of VGG network will be
72
- optimized. Default: False.
73
- remove_pooling (bool): If true, the max pooling operations in VGG net
74
- will be removed. Default: False.
75
- pooling_stride (int): The stride of max pooling operation. Default: 2.
76
- """
77
-
78
- def __init__(self,
79
- layer_name_list,
80
- vgg_type='vgg19',
81
- use_input_norm=True,
82
- range_norm=False,
83
- requires_grad=False,
84
- remove_pooling=False,
85
- pooling_stride=2):
86
- super(VGGFeatureExtractor, self).__init__()
87
-
88
- self.layer_name_list = layer_name_list
89
- self.use_input_norm = use_input_norm
90
- self.range_norm = range_norm
91
-
92
- self.names = NAMES[vgg_type.replace('_bn', '')]
93
- if 'bn' in vgg_type:
94
- self.names = insert_bn(self.names)
95
-
96
- # only borrow layers that will be used to avoid unused params
97
- max_idx = 0
98
- for v in layer_name_list:
99
- idx = self.names.index(v)
100
- if idx > max_idx:
101
- max_idx = idx
102
-
103
- if os.path.exists(VGG_PRETRAIN_PATH):
104
- vgg_net = getattr(vgg, vgg_type)(pretrained=False)
105
- state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
106
- vgg_net.load_state_dict(state_dict)
107
- else:
108
- vgg_net = getattr(vgg, vgg_type)(pretrained=True)
109
-
110
- features = vgg_net.features[:max_idx + 1]
111
-
112
- modified_net = OrderedDict()
113
- for k, v in zip(self.names, features):
114
- if 'pool' in k:
115
- # if remove_pooling is true, pooling operation will be removed
116
- if remove_pooling:
117
- continue
118
- else:
119
- # in some cases, we may want to change the default stride
120
- modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
121
- else:
122
- modified_net[k] = v
123
-
124
- self.vgg_net = nn.Sequential(modified_net)
125
-
126
- if not requires_grad:
127
- self.vgg_net.eval()
128
- for param in self.parameters():
129
- param.requires_grad = False
130
- else:
131
- self.vgg_net.train()
132
- for param in self.parameters():
133
- param.requires_grad = True
134
-
135
- if self.use_input_norm:
136
- # the mean is for image with range [0, 1]
137
- self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
138
- # the std is for image with range [0, 1]
139
- self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
140
-
141
- def forward(self, x):
142
- """Forward function.
143
-
144
- Args:
145
- x (Tensor): Input tensor with shape (n, c, h, w).
146
-
147
- Returns:
148
- Tensor: Forward results.
149
- """
150
- if self.range_norm:
151
- x = (x + 1) / 2
152
- if self.use_input_norm:
153
- x = (x - self.mean) / self.std
154
-
155
- output = {}
156
- for key, layer in self.vgg_net._modules.items():
157
- x = layer(x)
158
- if key in self.layer_name_list:
159
- output[key] = x.clone()
160
-
161
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/data/__init__.py DELETED
@@ -1,101 +0,0 @@
1
- import importlib
2
- import numpy as np
3
- import random
4
- import torch
5
- import torch.utils.data
6
- from copy import deepcopy
7
- from functools import partial
8
- from os import path as osp
9
-
10
- from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
- from basicsr.utils import get_root_logger, scandir
12
- from basicsr.utils.dist_util import get_dist_info
13
- from basicsr.utils.registry import DATASET_REGISTRY
14
-
15
- __all__ = ['build_dataset', 'build_dataloader']
16
-
17
- # automatically scan and import dataset modules for registry
18
- # scan all the files under the data folder with '_dataset' in file names
19
- data_folder = osp.dirname(osp.abspath(__file__))
20
- dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
- # import all the dataset modules
22
- _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
-
24
-
25
- def build_dataset(dataset_opt):
26
- """Build dataset from options.
27
-
28
- Args:
29
- dataset_opt (dict): Configuration for dataset. It must contain:
30
- name (str): Dataset name.
31
- type (str): Dataset type.
32
- """
33
- dataset_opt = deepcopy(dataset_opt)
34
- dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
- logger = get_root_logger()
36
- logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
37
- return dataset
38
-
39
-
40
- def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
- """Build dataloader.
42
-
43
- Args:
44
- dataset (torch.utils.data.Dataset): Dataset.
45
- dataset_opt (dict): Dataset options. It contains the following keys:
46
- phase (str): 'train' or 'val'.
47
- num_worker_per_gpu (int): Number of workers for each GPU.
48
- batch_size_per_gpu (int): Training batch size for each GPU.
49
- num_gpu (int): Number of GPUs. Used only in the train phase.
50
- Default: 1.
51
- dist (bool): Whether in distributed training. Used only in the train
52
- phase. Default: False.
53
- sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
- seed (int | None): Seed. Default: None
55
- """
56
- phase = dataset_opt['phase']
57
- rank, _ = get_dist_info()
58
- if phase == 'train':
59
- if dist: # distributed training
60
- batch_size = dataset_opt['batch_size_per_gpu']
61
- num_workers = dataset_opt['num_worker_per_gpu']
62
- else: # non-distributed training
63
- multiplier = 1 if num_gpu == 0 else num_gpu
64
- batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
- num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
- dataloader_args = dict(
67
- dataset=dataset,
68
- batch_size=batch_size,
69
- shuffle=False,
70
- num_workers=num_workers,
71
- sampler=sampler,
72
- drop_last=True)
73
- if sampler is None:
74
- dataloader_args['shuffle'] = True
75
- dataloader_args['worker_init_fn'] = partial(
76
- worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
- elif phase in ['val', 'test']: # validation
78
- dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
- else:
80
- raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
81
-
82
- dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
- dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
84
-
85
- prefetch_mode = dataset_opt.get('prefetch_mode')
86
- if prefetch_mode == 'cpu': # CPUPrefetcher
87
- num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
88
- logger = get_root_logger()
89
- logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
90
- return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
91
- else:
92
- # prefetch_mode=None: Normal dataloader
93
- # prefetch_mode='cuda': dataloader for CUDAPrefetcher
94
- return torch.utils.data.DataLoader(**dataloader_args)
95
-
96
-
97
- def worker_init_fn(worker_id, num_workers, rank, seed):
98
- # Set the worker seed to num_workers * rank + worker_id + seed
99
- worker_seed = num_workers * rank + worker_id + seed
100
- np.random.seed(worker_seed)
101
- random.seed(worker_seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/data/data_sampler.py DELETED
@@ -1,48 +0,0 @@
1
- import math
2
- import torch
3
- from torch.utils.data.sampler import Sampler
4
-
5
-
6
- class EnlargedSampler(Sampler):
7
- """Sampler that restricts data loading to a subset of the dataset.
8
-
9
- Modified from torch.utils.data.distributed.DistributedSampler
10
- Support enlarging the dataset for iteration-based training, for saving
11
- time when restart the dataloader after each epoch
12
-
13
- Args:
14
- dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
- num_replicas (int | None): Number of processes participating in
16
- the training. It is usually the world_size.
17
- rank (int | None): Rank of the current process within num_replicas.
18
- ratio (int): Enlarging ratio. Default: 1.
19
- """
20
-
21
- def __init__(self, dataset, num_replicas, rank, ratio=1):
22
- self.dataset = dataset
23
- self.num_replicas = num_replicas
24
- self.rank = rank
25
- self.epoch = 0
26
- self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
- self.total_size = self.num_samples * self.num_replicas
28
-
29
- def __iter__(self):
30
- # deterministically shuffle based on epoch
31
- g = torch.Generator()
32
- g.manual_seed(self.epoch)
33
- indices = torch.randperm(self.total_size, generator=g).tolist()
34
-
35
- dataset_size = len(self.dataset)
36
- indices = [v % dataset_size for v in indices]
37
-
38
- # subsample
39
- indices = indices[self.rank:self.total_size:self.num_replicas]
40
- assert len(indices) == self.num_samples
41
-
42
- return iter(indices)
43
-
44
- def __len__(self):
45
- return self.num_samples
46
-
47
- def set_epoch(self, epoch):
48
- self.epoch = epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/data/data_util.py DELETED
@@ -1,315 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import torch
4
- from os import path as osp
5
- from torch.nn import functional as F
6
-
7
- from basicsr.data.transforms import mod_crop
8
- from basicsr.utils import img2tensor, scandir
9
-
10
-
11
- def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
12
- """Read a sequence of images from a given folder path.
13
-
14
- Args:
15
- path (list[str] | str): List of image paths or image folder path.
16
- require_mod_crop (bool): Require mod crop for each image.
17
- Default: False.
18
- scale (int): Scale factor for mod_crop. Default: 1.
19
- return_imgname(bool): Whether return image names. Default False.
20
-
21
- Returns:
22
- Tensor: size (t, c, h, w), RGB, [0, 1].
23
- list[str]: Returned image name list.
24
- """
25
- if isinstance(path, list):
26
- img_paths = path
27
- else:
28
- img_paths = sorted(list(scandir(path, full_path=True)))
29
- imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
30
-
31
- if require_mod_crop:
32
- imgs = [mod_crop(img, scale) for img in imgs]
33
- imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
34
- imgs = torch.stack(imgs, dim=0)
35
-
36
- if return_imgname:
37
- imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
38
- return imgs, imgnames
39
- else:
40
- return imgs
41
-
42
-
43
- def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
44
- """Generate an index list for reading `num_frames` frames from a sequence
45
- of images.
46
-
47
- Args:
48
- crt_idx (int): Current center index.
49
- max_frame_num (int): Max number of the sequence of images (from 1).
50
- num_frames (int): Reading num_frames frames.
51
- padding (str): Padding mode, one of
52
- 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
53
- Examples: current_idx = 0, num_frames = 5
54
- The generated frame indices under different padding mode:
55
- replicate: [0, 0, 0, 1, 2]
56
- reflection: [2, 1, 0, 1, 2]
57
- reflection_circle: [4, 3, 0, 1, 2]
58
- circle: [3, 4, 0, 1, 2]
59
-
60
- Returns:
61
- list[int]: A list of indices.
62
- """
63
- assert num_frames % 2 == 1, 'num_frames should be an odd number.'
64
- assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
65
-
66
- max_frame_num = max_frame_num - 1 # start from 0
67
- num_pad = num_frames // 2
68
-
69
- indices = []
70
- for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
71
- if i < 0:
72
- if padding == 'replicate':
73
- pad_idx = 0
74
- elif padding == 'reflection':
75
- pad_idx = -i
76
- elif padding == 'reflection_circle':
77
- pad_idx = crt_idx + num_pad - i
78
- else:
79
- pad_idx = num_frames + i
80
- elif i > max_frame_num:
81
- if padding == 'replicate':
82
- pad_idx = max_frame_num
83
- elif padding == 'reflection':
84
- pad_idx = max_frame_num * 2 - i
85
- elif padding == 'reflection_circle':
86
- pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
87
- else:
88
- pad_idx = i - num_frames
89
- else:
90
- pad_idx = i
91
- indices.append(pad_idx)
92
- return indices
93
-
94
-
95
- def paired_paths_from_lmdb(folders, keys):
96
- """Generate paired paths from lmdb files.
97
-
98
- Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
99
-
100
- ::
101
-
102
- lq.lmdb
103
- ├── data.mdb
104
- ├── lock.mdb
105
- ├── meta_info.txt
106
-
107
- The data.mdb and lock.mdb are standard lmdb files and you can refer to
108
- https://lmdb.readthedocs.io/en/release/ for more details.
109
-
110
- The meta_info.txt is a specified txt file to record the meta information
111
- of our datasets. It will be automatically created when preparing
112
- datasets by our provided dataset tools.
113
- Each line in the txt file records
114
- 1)image name (with extension),
115
- 2)image shape,
116
- 3)compression level, separated by a white space.
117
- Example: `baboon.png (120,125,3) 1`
118
-
119
- We use the image name without extension as the lmdb key.
120
- Note that we use the same key for the corresponding lq and gt images.
121
-
122
- Args:
123
- folders (list[str]): A list of folder path. The order of list should
124
- be [input_folder, gt_folder].
125
- keys (list[str]): A list of keys identifying folders. The order should
126
- be in consistent with folders, e.g., ['lq', 'gt'].
127
- Note that this key is different from lmdb keys.
128
-
129
- Returns:
130
- list[str]: Returned path list.
131
- """
132
- assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
133
- f'But got {len(folders)}')
134
- assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
135
- input_folder, gt_folder = folders
136
- input_key, gt_key = keys
137
-
138
- if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
139
- raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
140
- f'formats. But received {input_key}: {input_folder}; '
141
- f'{gt_key}: {gt_folder}')
142
- # ensure that the two meta_info files are the same
143
- with open(osp.join(input_folder, 'meta_info.txt')) as fin:
144
- input_lmdb_keys = [line.split('.')[0] for line in fin]
145
- with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
146
- gt_lmdb_keys = [line.split('.')[0] for line in fin]
147
- if set(input_lmdb_keys) != set(gt_lmdb_keys):
148
- raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
149
- else:
150
- paths = []
151
- for lmdb_key in sorted(input_lmdb_keys):
152
- paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
153
- return paths
154
-
155
-
156
- def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
157
- """Generate paired paths from an meta information file.
158
-
159
- Each line in the meta information file contains the image names and
160
- image shape (usually for gt), separated by a white space.
161
-
162
- Example of an meta information file:
163
- ```
164
- 0001_s001.png (480,480,3)
165
- 0001_s002.png (480,480,3)
166
- ```
167
-
168
- Args:
169
- folders (list[str]): A list of folder path. The order of list should
170
- be [input_folder, gt_folder].
171
- keys (list[str]): A list of keys identifying folders. The order should
172
- be in consistent with folders, e.g., ['lq', 'gt'].
173
- meta_info_file (str): Path to the meta information file.
174
- filename_tmpl (str): Template for each filename. Note that the
175
- template excludes the file extension. Usually the filename_tmpl is
176
- for files in the input folder.
177
-
178
- Returns:
179
- list[str]: Returned path list.
180
- """
181
- assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
182
- f'But got {len(folders)}')
183
- assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
184
- input_folder, gt_folder = folders
185
- input_key, gt_key = keys
186
-
187
- with open(meta_info_file, 'r') as fin:
188
- gt_names = [line.strip().split(' ')[0] for line in fin]
189
-
190
- paths = []
191
- for gt_name in gt_names:
192
- basename, ext = osp.splitext(osp.basename(gt_name))
193
- input_name = f'{filename_tmpl.format(basename)}{ext}'
194
- input_path = osp.join(input_folder, input_name)
195
- gt_path = osp.join(gt_folder, gt_name)
196
- paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
197
- return paths
198
-
199
-
200
- def paired_paths_from_folder(folders, keys, filename_tmpl):
201
- """Generate paired paths from folders.
202
-
203
- Args:
204
- folders (list[str]): A list of folder path. The order of list should
205
- be [input_folder, gt_folder].
206
- keys (list[str]): A list of keys identifying folders. The order should
207
- be in consistent with folders, e.g., ['lq', 'gt'].
208
- filename_tmpl (str): Template for each filename. Note that the
209
- template excludes the file extension. Usually the filename_tmpl is
210
- for files in the input folder.
211
-
212
- Returns:
213
- list[str]: Returned path list.
214
- """
215
- assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
216
- f'But got {len(folders)}')
217
- assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
218
- input_folder, gt_folder = folders
219
- input_key, gt_key = keys
220
-
221
- input_paths = list(scandir(input_folder))
222
- gt_paths = list(scandir(gt_folder))
223
- assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
224
- f'{len(input_paths)}, {len(gt_paths)}.')
225
- paths = []
226
- for gt_path in gt_paths:
227
- basename, ext = osp.splitext(osp.basename(gt_path))
228
- input_name = f'{filename_tmpl.format(basename)}{ext}'
229
- input_path = osp.join(input_folder, input_name)
230
- assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
231
- gt_path = osp.join(gt_folder, gt_path)
232
- paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
233
- return paths
234
-
235
-
236
- def paths_from_folder(folder):
237
- """Generate paths from folder.
238
-
239
- Args:
240
- folder (str): Folder path.
241
-
242
- Returns:
243
- list[str]: Returned path list.
244
- """
245
-
246
- paths = list(scandir(folder))
247
- paths = [osp.join(folder, path) for path in paths]
248
- return paths
249
-
250
-
251
- def paths_from_lmdb(folder):
252
- """Generate paths from lmdb.
253
-
254
- Args:
255
- folder (str): Folder path.
256
-
257
- Returns:
258
- list[str]: Returned path list.
259
- """
260
- if not folder.endswith('.lmdb'):
261
- raise ValueError(f'Folder {folder}folder should in lmdb format.')
262
- with open(osp.join(folder, 'meta_info.txt')) as fin:
263
- paths = [line.split('.')[0] for line in fin]
264
- return paths
265
-
266
-
267
- def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
268
- """Generate Gaussian kernel used in `duf_downsample`.
269
-
270
- Args:
271
- kernel_size (int): Kernel size. Default: 13.
272
- sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
273
-
274
- Returns:
275
- np.array: The Gaussian kernel.
276
- """
277
- from scipy.ndimage import filters as filters
278
- kernel = np.zeros((kernel_size, kernel_size))
279
- # set element at the middle to one, a dirac delta
280
- kernel[kernel_size // 2, kernel_size // 2] = 1
281
- # gaussian-smooth the dirac, resulting in a gaussian filter
282
- return filters.gaussian_filter(kernel, sigma)
283
-
284
-
285
- def duf_downsample(x, kernel_size=13, scale=4):
286
- """Downsamping with Gaussian kernel used in the DUF official code.
287
-
288
- Args:
289
- x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
290
- kernel_size (int): Kernel size. Default: 13.
291
- scale (int): Downsampling factor. Supported scale: (2, 3, 4).
292
- Default: 4.
293
-
294
- Returns:
295
- Tensor: DUF downsampled frames.
296
- """
297
- assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
298
-
299
- squeeze_flag = False
300
- if x.ndim == 4:
301
- squeeze_flag = True
302
- x = x.unsqueeze(0)
303
- b, t, c, h, w = x.size()
304
- x = x.view(-1, 1, h, w)
305
- pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
306
- x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
307
-
308
- gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
309
- gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
310
- x = F.conv2d(x, gaussian_filter, stride=scale)
311
- x = x[:, :, 2:-2, 2:-2]
312
- x = x.view(b, t, c, x.size(2), x.size(3))
313
- if squeeze_flag:
314
- x = x.squeeze(0)
315
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/data/degradations.py DELETED
@@ -1,764 +0,0 @@
1
- import cv2
2
- import math
3
- import numpy as np
4
- import random
5
- import torch
6
- from scipy import special
7
- from scipy.stats import multivariate_normal
8
- from torchvision.transforms.functional import rgb_to_grayscale
9
-
10
- # -------------------------------------------------------------------- #
11
- # --------------------------- blur kernels --------------------------- #
12
- # -------------------------------------------------------------------- #
13
-
14
-
15
- # --------------------------- util functions --------------------------- #
16
- def sigma_matrix2(sig_x, sig_y, theta):
17
- """Calculate the rotated sigma matrix (two dimensional matrix).
18
-
19
- Args:
20
- sig_x (float):
21
- sig_y (float):
22
- theta (float): Radian measurement.
23
-
24
- Returns:
25
- ndarray: Rotated sigma matrix.
26
- """
27
- d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
28
- u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
29
- return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
30
-
31
-
32
- def mesh_grid(kernel_size):
33
- """Generate the mesh grid, centering at zero.
34
-
35
- Args:
36
- kernel_size (int):
37
-
38
- Returns:
39
- xy (ndarray): with the shape (kernel_size, kernel_size, 2)
40
- xx (ndarray): with the shape (kernel_size, kernel_size)
41
- yy (ndarray): with the shape (kernel_size, kernel_size)
42
- """
43
- ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
44
- xx, yy = np.meshgrid(ax, ax)
45
- xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
46
- 1))).reshape(kernel_size, kernel_size, 2)
47
- return xy, xx, yy
48
-
49
-
50
- def pdf2(sigma_matrix, grid):
51
- """Calculate PDF of the bivariate Gaussian distribution.
52
-
53
- Args:
54
- sigma_matrix (ndarray): with the shape (2, 2)
55
- grid (ndarray): generated by :func:`mesh_grid`,
56
- with the shape (K, K, 2), K is the kernel size.
57
-
58
- Returns:
59
- kernel (ndarrray): un-normalized kernel.
60
- """
61
- inverse_sigma = np.linalg.inv(sigma_matrix)
62
- kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
63
- return kernel
64
-
65
-
66
- def cdf2(d_matrix, grid):
67
- """Calculate the CDF of the standard bivariate Gaussian distribution.
68
- Used in skewed Gaussian distribution.
69
-
70
- Args:
71
- d_matrix (ndarrasy): skew matrix.
72
- grid (ndarray): generated by :func:`mesh_grid`,
73
- with the shape (K, K, 2), K is the kernel size.
74
-
75
- Returns:
76
- cdf (ndarray): skewed cdf.
77
- """
78
- rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
79
- grid = np.dot(grid, d_matrix)
80
- cdf = rv.cdf(grid)
81
- return cdf
82
-
83
-
84
- def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
85
- """Generate a bivariate isotropic or anisotropic Gaussian kernel.
86
-
87
- In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
88
-
89
- Args:
90
- kernel_size (int):
91
- sig_x (float):
92
- sig_y (float):
93
- theta (float): Radian measurement.
94
- grid (ndarray, optional): generated by :func:`mesh_grid`,
95
- with the shape (K, K, 2), K is the kernel size. Default: None
96
- isotropic (bool):
97
-
98
- Returns:
99
- kernel (ndarray): normalized kernel.
100
- """
101
- if grid is None:
102
- grid, _, _ = mesh_grid(kernel_size)
103
- if isotropic:
104
- sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
105
- else:
106
- sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
107
- kernel = pdf2(sigma_matrix, grid)
108
- kernel = kernel / np.sum(kernel)
109
- return kernel
110
-
111
-
112
- def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
113
- """Generate a bivariate generalized Gaussian kernel.
114
-
115
- ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
116
-
117
- In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
118
-
119
- Args:
120
- kernel_size (int):
121
- sig_x (float):
122
- sig_y (float):
123
- theta (float): Radian measurement.
124
- beta (float): shape parameter, beta = 1 is the normal distribution.
125
- grid (ndarray, optional): generated by :func:`mesh_grid`,
126
- with the shape (K, K, 2), K is the kernel size. Default: None
127
-
128
- Returns:
129
- kernel (ndarray): normalized kernel.
130
- """
131
- if grid is None:
132
- grid, _, _ = mesh_grid(kernel_size)
133
- if isotropic:
134
- sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
135
- else:
136
- sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
137
- inverse_sigma = np.linalg.inv(sigma_matrix)
138
- kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
139
- kernel = kernel / np.sum(kernel)
140
- return kernel
141
-
142
-
143
- def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
144
- """Generate a plateau-like anisotropic kernel.
145
-
146
- 1 / (1+x^(beta))
147
-
148
- Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
149
-
150
- In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
151
-
152
- Args:
153
- kernel_size (int):
154
- sig_x (float):
155
- sig_y (float):
156
- theta (float): Radian measurement.
157
- beta (float): shape parameter, beta = 1 is the normal distribution.
158
- grid (ndarray, optional): generated by :func:`mesh_grid`,
159
- with the shape (K, K, 2), K is the kernel size. Default: None
160
-
161
- Returns:
162
- kernel (ndarray): normalized kernel.
163
- """
164
- if grid is None:
165
- grid, _, _ = mesh_grid(kernel_size)
166
- if isotropic:
167
- sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
168
- else:
169
- sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
170
- inverse_sigma = np.linalg.inv(sigma_matrix)
171
- kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
172
- kernel = kernel / np.sum(kernel)
173
- return kernel
174
-
175
-
176
- def random_bivariate_Gaussian(kernel_size,
177
- sigma_x_range,
178
- sigma_y_range,
179
- rotation_range,
180
- noise_range=None,
181
- isotropic=True):
182
- """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
183
-
184
- In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
185
-
186
- Args:
187
- kernel_size (int):
188
- sigma_x_range (tuple): [0.6, 5]
189
- sigma_y_range (tuple): [0.6, 5]
190
- rotation range (tuple): [-math.pi, math.pi]
191
- noise_range(tuple, optional): multiplicative kernel noise,
192
- [0.75, 1.25]. Default: None
193
-
194
- Returns:
195
- kernel (ndarray):
196
- """
197
- assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
198
- assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
199
- sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
200
- if isotropic is False:
201
- assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
202
- assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
203
- sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
204
- rotation = np.random.uniform(rotation_range[0], rotation_range[1])
205
- else:
206
- sigma_y = sigma_x
207
- rotation = 0
208
-
209
- kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
210
-
211
- # add multiplicative noise
212
- if noise_range is not None:
213
- assert noise_range[0] < noise_range[1], 'Wrong noise range.'
214
- noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
215
- kernel = kernel * noise
216
- kernel = kernel / np.sum(kernel)
217
- return kernel
218
-
219
-
220
- def random_bivariate_generalized_Gaussian(kernel_size,
221
- sigma_x_range,
222
- sigma_y_range,
223
- rotation_range,
224
- beta_range,
225
- noise_range=None,
226
- isotropic=True):
227
- """Randomly generate bivariate generalized Gaussian kernels.
228
-
229
- In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
230
-
231
- Args:
232
- kernel_size (int):
233
- sigma_x_range (tuple): [0.6, 5]
234
- sigma_y_range (tuple): [0.6, 5]
235
- rotation range (tuple): [-math.pi, math.pi]
236
- beta_range (tuple): [0.5, 8]
237
- noise_range(tuple, optional): multiplicative kernel noise,
238
- [0.75, 1.25]. Default: None
239
-
240
- Returns:
241
- kernel (ndarray):
242
- """
243
- assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
244
- assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
245
- sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
246
- if isotropic is False:
247
- assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
248
- assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
249
- sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
250
- rotation = np.random.uniform(rotation_range[0], rotation_range[1])
251
- else:
252
- sigma_y = sigma_x
253
- rotation = 0
254
-
255
- # assume beta_range[0] < 1 < beta_range[1]
256
- if np.random.uniform() < 0.5:
257
- beta = np.random.uniform(beta_range[0], 1)
258
- else:
259
- beta = np.random.uniform(1, beta_range[1])
260
-
261
- kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
262
-
263
- # add multiplicative noise
264
- if noise_range is not None:
265
- assert noise_range[0] < noise_range[1], 'Wrong noise range.'
266
- noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
267
- kernel = kernel * noise
268
- kernel = kernel / np.sum(kernel)
269
- return kernel
270
-
271
-
272
- def random_bivariate_plateau(kernel_size,
273
- sigma_x_range,
274
- sigma_y_range,
275
- rotation_range,
276
- beta_range,
277
- noise_range=None,
278
- isotropic=True):
279
- """Randomly generate bivariate plateau kernels.
280
-
281
- In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
282
-
283
- Args:
284
- kernel_size (int):
285
- sigma_x_range (tuple): [0.6, 5]
286
- sigma_y_range (tuple): [0.6, 5]
287
- rotation range (tuple): [-math.pi/2, math.pi/2]
288
- beta_range (tuple): [1, 4]
289
- noise_range(tuple, optional): multiplicative kernel noise,
290
- [0.75, 1.25]. Default: None
291
-
292
- Returns:
293
- kernel (ndarray):
294
- """
295
- assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
296
- assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
297
- sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
298
- if isotropic is False:
299
- assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
300
- assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
301
- sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
302
- rotation = np.random.uniform(rotation_range[0], rotation_range[1])
303
- else:
304
- sigma_y = sigma_x
305
- rotation = 0
306
-
307
- # TODO: this may be not proper
308
- if np.random.uniform() < 0.5:
309
- beta = np.random.uniform(beta_range[0], 1)
310
- else:
311
- beta = np.random.uniform(1, beta_range[1])
312
-
313
- kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
314
- # add multiplicative noise
315
- if noise_range is not None:
316
- assert noise_range[0] < noise_range[1], 'Wrong noise range.'
317
- noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
318
- kernel = kernel * noise
319
- kernel = kernel / np.sum(kernel)
320
-
321
- return kernel
322
-
323
-
324
- def random_mixed_kernels(kernel_list,
325
- kernel_prob,
326
- kernel_size=21,
327
- sigma_x_range=(0.6, 5),
328
- sigma_y_range=(0.6, 5),
329
- rotation_range=(-math.pi, math.pi),
330
- betag_range=(0.5, 8),
331
- betap_range=(0.5, 8),
332
- noise_range=None):
333
- """Randomly generate mixed kernels.
334
-
335
- Args:
336
- kernel_list (tuple): a list name of kernel types,
337
- support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
338
- 'plateau_aniso']
339
- kernel_prob (tuple): corresponding kernel probability for each
340
- kernel type
341
- kernel_size (int):
342
- sigma_x_range (tuple): [0.6, 5]
343
- sigma_y_range (tuple): [0.6, 5]
344
- rotation range (tuple): [-math.pi, math.pi]
345
- beta_range (tuple): [0.5, 8]
346
- noise_range(tuple, optional): multiplicative kernel noise,
347
- [0.75, 1.25]. Default: None
348
-
349
- Returns:
350
- kernel (ndarray):
351
- """
352
- kernel_type = random.choices(kernel_list, kernel_prob)[0]
353
- if kernel_type == 'iso':
354
- kernel = random_bivariate_Gaussian(
355
- kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
356
- elif kernel_type == 'aniso':
357
- kernel = random_bivariate_Gaussian(
358
- kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
359
- elif kernel_type == 'generalized_iso':
360
- kernel = random_bivariate_generalized_Gaussian(
361
- kernel_size,
362
- sigma_x_range,
363
- sigma_y_range,
364
- rotation_range,
365
- betag_range,
366
- noise_range=noise_range,
367
- isotropic=True)
368
- elif kernel_type == 'generalized_aniso':
369
- kernel = random_bivariate_generalized_Gaussian(
370
- kernel_size,
371
- sigma_x_range,
372
- sigma_y_range,
373
- rotation_range,
374
- betag_range,
375
- noise_range=noise_range,
376
- isotropic=False)
377
- elif kernel_type == 'plateau_iso':
378
- kernel = random_bivariate_plateau(
379
- kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
380
- elif kernel_type == 'plateau_aniso':
381
- kernel = random_bivariate_plateau(
382
- kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
383
- return kernel
384
-
385
-
386
- np.seterr(divide='ignore', invalid='ignore')
387
-
388
-
389
- def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
390
- """2D sinc filter
391
-
392
- Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
393
-
394
- Args:
395
- cutoff (float): cutoff frequency in radians (pi is max)
396
- kernel_size (int): horizontal and vertical size, must be odd.
397
- pad_to (int): pad kernel size to desired size, must be odd or zero.
398
- """
399
- assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
400
- kernel = np.fromfunction(
401
- lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
402
- (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
403
- (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
404
- kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
405
- kernel = kernel / np.sum(kernel)
406
- if pad_to > kernel_size:
407
- pad_size = (pad_to - kernel_size) // 2
408
- kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
409
- return kernel
410
-
411
-
412
- # ------------------------------------------------------------- #
413
- # --------------------------- noise --------------------------- #
414
- # ------------------------------------------------------------- #
415
-
416
- # ----------------------- Gaussian Noise ----------------------- #
417
-
418
-
419
- def generate_gaussian_noise(img, sigma=10, gray_noise=False):
420
- """Generate Gaussian noise.
421
-
422
- Args:
423
- img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
424
- sigma (float): Noise scale (measured in range 255). Default: 10.
425
-
426
- Returns:
427
- (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
428
- float32.
429
- """
430
- if gray_noise:
431
- noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
432
- noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
433
- else:
434
- noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
435
- return noise
436
-
437
-
438
- def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
439
- """Add Gaussian noise.
440
-
441
- Args:
442
- img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
443
- sigma (float): Noise scale (measured in range 255). Default: 10.
444
-
445
- Returns:
446
- (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
447
- float32.
448
- """
449
- noise = generate_gaussian_noise(img, sigma, gray_noise)
450
- out = img + noise
451
- if clip and rounds:
452
- out = np.clip((out * 255.0).round(), 0, 255) / 255.
453
- elif clip:
454
- out = np.clip(out, 0, 1)
455
- elif rounds:
456
- out = (out * 255.0).round() / 255.
457
- return out
458
-
459
-
460
- def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
461
- """Add Gaussian noise (PyTorch version).
462
-
463
- Args:
464
- img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
465
- scale (float | Tensor): Noise scale. Default: 1.0.
466
-
467
- Returns:
468
- (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
469
- float32.
470
- """
471
- b, _, h, w = img.size()
472
- if not isinstance(sigma, (float, int)):
473
- sigma = sigma.view(img.size(0), 1, 1, 1)
474
- if isinstance(gray_noise, (float, int)):
475
- cal_gray_noise = gray_noise > 0
476
- else:
477
- gray_noise = gray_noise.view(b, 1, 1, 1)
478
- cal_gray_noise = torch.sum(gray_noise) > 0
479
-
480
- if cal_gray_noise:
481
- noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
482
- noise_gray = noise_gray.view(b, 1, h, w)
483
-
484
- # always calculate color noise
485
- noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
486
-
487
- if cal_gray_noise:
488
- noise = noise * (1 - gray_noise) + noise_gray * gray_noise
489
- return noise
490
-
491
-
492
- def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
493
- """Add Gaussian noise (PyTorch version).
494
-
495
- Args:
496
- img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
497
- scale (float | Tensor): Noise scale. Default: 1.0.
498
-
499
- Returns:
500
- (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
501
- float32.
502
- """
503
- noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
504
- out = img + noise
505
- if clip and rounds:
506
- out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
507
- elif clip:
508
- out = torch.clamp(out, 0, 1)
509
- elif rounds:
510
- out = (out * 255.0).round() / 255.
511
- return out
512
-
513
-
514
- # ----------------------- Random Gaussian Noise ----------------------- #
515
- def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
516
- sigma = np.random.uniform(sigma_range[0], sigma_range[1])
517
- if np.random.uniform() < gray_prob:
518
- gray_noise = True
519
- else:
520
- gray_noise = False
521
- return generate_gaussian_noise(img, sigma, gray_noise)
522
-
523
-
524
- def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
525
- noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
526
- out = img + noise
527
- if clip and rounds:
528
- out = np.clip((out * 255.0).round(), 0, 255) / 255.
529
- elif clip:
530
- out = np.clip(out, 0, 1)
531
- elif rounds:
532
- out = (out * 255.0).round() / 255.
533
- return out
534
-
535
-
536
- def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
537
- sigma = torch.rand(
538
- img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
539
- gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
540
- gray_noise = (gray_noise < gray_prob).float()
541
- return generate_gaussian_noise_pt(img, sigma, gray_noise)
542
-
543
-
544
- def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
545
- noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
546
- out = img + noise
547
- if clip and rounds:
548
- out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
549
- elif clip:
550
- out = torch.clamp(out, 0, 1)
551
- elif rounds:
552
- out = (out * 255.0).round() / 255.
553
- return out
554
-
555
-
556
- # ----------------------- Poisson (Shot) Noise ----------------------- #
557
-
558
-
559
- def generate_poisson_noise(img, scale=1.0, gray_noise=False):
560
- """Generate poisson noise.
561
-
562
- Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
563
-
564
- Args:
565
- img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
566
- scale (float): Noise scale. Default: 1.0.
567
- gray_noise (bool): Whether generate gray noise. Default: False.
568
-
569
- Returns:
570
- (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
571
- float32.
572
- """
573
- if gray_noise:
574
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
575
- # round and clip image for counting vals correctly
576
- img = np.clip((img * 255.0).round(), 0, 255) / 255.
577
- vals = len(np.unique(img))
578
- vals = 2**np.ceil(np.log2(vals))
579
- out = np.float32(np.random.poisson(img * vals) / float(vals))
580
- noise = out - img
581
- if gray_noise:
582
- noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
583
- return noise * scale
584
-
585
-
586
- def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
587
- """Add poisson noise.
588
-
589
- Args:
590
- img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
591
- scale (float): Noise scale. Default: 1.0.
592
- gray_noise (bool): Whether generate gray noise. Default: False.
593
-
594
- Returns:
595
- (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
596
- float32.
597
- """
598
- noise = generate_poisson_noise(img, scale, gray_noise)
599
- out = img + noise
600
- if clip and rounds:
601
- out = np.clip((out * 255.0).round(), 0, 255) / 255.
602
- elif clip:
603
- out = np.clip(out, 0, 1)
604
- elif rounds:
605
- out = (out * 255.0).round() / 255.
606
- return out
607
-
608
-
609
- def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
610
- """Generate a batch of poisson noise (PyTorch version)
611
-
612
- Args:
613
- img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
614
- scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
615
- Default: 1.0.
616
- gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
617
- 0 for False, 1 for True. Default: 0.
618
-
619
- Returns:
620
- (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
621
- float32.
622
- """
623
- b, _, h, w = img.size()
624
- if isinstance(gray_noise, (float, int)):
625
- cal_gray_noise = gray_noise > 0
626
- else:
627
- gray_noise = gray_noise.view(b, 1, 1, 1)
628
- cal_gray_noise = torch.sum(gray_noise) > 0
629
- if cal_gray_noise:
630
- img_gray = rgb_to_grayscale(img, num_output_channels=1)
631
- # round and clip image for counting vals correctly
632
- img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
633
- # use for-loop to get the unique values for each sample
634
- vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
635
- vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
636
- vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
637
- out = torch.poisson(img_gray * vals) / vals
638
- noise_gray = out - img_gray
639
- noise_gray = noise_gray.expand(b, 3, h, w)
640
-
641
- # always calculate color noise
642
- # round and clip image for counting vals correctly
643
- img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
644
- # use for-loop to get the unique values for each sample
645
- vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
646
- vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
647
- vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
648
- out = torch.poisson(img * vals) / vals
649
- noise = out - img
650
- if cal_gray_noise:
651
- noise = noise * (1 - gray_noise) + noise_gray * gray_noise
652
- if not isinstance(scale, (float, int)):
653
- scale = scale.view(b, 1, 1, 1)
654
- return noise * scale
655
-
656
-
657
- def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
658
- """Add poisson noise to a batch of images (PyTorch version).
659
-
660
- Args:
661
- img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
662
- scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
663
- Default: 1.0.
664
- gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
665
- 0 for False, 1 for True. Default: 0.
666
-
667
- Returns:
668
- (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
669
- float32.
670
- """
671
- noise = generate_poisson_noise_pt(img, scale, gray_noise)
672
- out = img + noise
673
- if clip and rounds:
674
- out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
675
- elif clip:
676
- out = torch.clamp(out, 0, 1)
677
- elif rounds:
678
- out = (out * 255.0).round() / 255.
679
- return out
680
-
681
-
682
- # ----------------------- Random Poisson (Shot) Noise ----------------------- #
683
-
684
-
685
- def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
686
- scale = np.random.uniform(scale_range[0], scale_range[1])
687
- if np.random.uniform() < gray_prob:
688
- gray_noise = True
689
- else:
690
- gray_noise = False
691
- return generate_poisson_noise(img, scale, gray_noise)
692
-
693
-
694
- def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
695
- noise = random_generate_poisson_noise(img, scale_range, gray_prob)
696
- out = img + noise
697
- if clip and rounds:
698
- out = np.clip((out * 255.0).round(), 0, 255) / 255.
699
- elif clip:
700
- out = np.clip(out, 0, 1)
701
- elif rounds:
702
- out = (out * 255.0).round() / 255.
703
- return out
704
-
705
-
706
- def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
707
- scale = torch.rand(
708
- img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
709
- gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
710
- gray_noise = (gray_noise < gray_prob).float()
711
- return generate_poisson_noise_pt(img, scale, gray_noise)
712
-
713
-
714
- def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
715
- noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
716
- out = img + noise
717
- if clip and rounds:
718
- out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
719
- elif clip:
720
- out = torch.clamp(out, 0, 1)
721
- elif rounds:
722
- out = (out * 255.0).round() / 255.
723
- return out
724
-
725
-
726
- # ------------------------------------------------------------------------ #
727
- # --------------------------- JPEG compression --------------------------- #
728
- # ------------------------------------------------------------------------ #
729
-
730
-
731
- def add_jpg_compression(img, quality=90):
732
- """Add JPG compression artifacts.
733
-
734
- Args:
735
- img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
736
- quality (float): JPG compression quality. 0 for lowest quality, 100 for
737
- best quality. Default: 90.
738
-
739
- Returns:
740
- (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
741
- float32.
742
- """
743
- img = np.clip(img, 0, 1)
744
- encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
745
- _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
746
- img = np.float32(cv2.imdecode(encimg, 1)) / 255.
747
- return img
748
-
749
-
750
- def random_add_jpg_compression(img, quality_range=(90, 100)):
751
- """Randomly add JPG compression artifacts.
752
-
753
- Args:
754
- img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
755
- quality_range (tuple[float] | list[float]): JPG compression quality
756
- range. 0 for lowest quality, 100 for best quality.
757
- Default: (90, 100).
758
-
759
- Returns:
760
- (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
761
- float32.
762
- """
763
- quality = np.random.uniform(quality_range[0], quality_range[1])
764
- return add_jpg_compression(img, quality)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/data/ffhq_dataset.py DELETED
@@ -1,80 +0,0 @@
1
- import random
2
- import time
3
- from os import path as osp
4
- from torch.utils import data as data
5
- from torchvision.transforms.functional import normalize
6
-
7
- from basicsr.data.transforms import augment
8
- from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
9
- from basicsr.utils.registry import DATASET_REGISTRY
10
-
11
-
12
- @DATASET_REGISTRY.register()
13
- class FFHQDataset(data.Dataset):
14
- """FFHQ dataset for StyleGAN.
15
-
16
- Args:
17
- opt (dict): Config for train datasets. It contains the following keys:
18
- dataroot_gt (str): Data root path for gt.
19
- io_backend (dict): IO backend type and other kwarg.
20
- mean (list | tuple): Image mean.
21
- std (list | tuple): Image std.
22
- use_hflip (bool): Whether to horizontally flip.
23
-
24
- """
25
-
26
- def __init__(self, opt):
27
- super(FFHQDataset, self).__init__()
28
- self.opt = opt
29
- # file client (io backend)
30
- self.file_client = None
31
- self.io_backend_opt = opt['io_backend']
32
-
33
- self.gt_folder = opt['dataroot_gt']
34
- self.mean = opt['mean']
35
- self.std = opt['std']
36
-
37
- if self.io_backend_opt['type'] == 'lmdb':
38
- self.io_backend_opt['db_paths'] = self.gt_folder
39
- if not self.gt_folder.endswith('.lmdb'):
40
- raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
41
- with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
42
- self.paths = [line.split('.')[0] for line in fin]
43
- else:
44
- # FFHQ has 70000 images in total
45
- self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
46
-
47
- def __getitem__(self, index):
48
- if self.file_client is None:
49
- self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
50
-
51
- # load gt image
52
- gt_path = self.paths[index]
53
- # avoid errors caused by high latency in reading files
54
- retry = 3
55
- while retry > 0:
56
- try:
57
- img_bytes = self.file_client.get(gt_path)
58
- except Exception as e:
59
- logger = get_root_logger()
60
- logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
61
- # change another file to read
62
- index = random.randint(0, self.__len__())
63
- gt_path = self.paths[index]
64
- time.sleep(1) # sleep 1s for occasional server congestion
65
- else:
66
- break
67
- finally:
68
- retry -= 1
69
- img_gt = imfrombytes(img_bytes, float32=True)
70
-
71
- # random horizontal flip
72
- img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
73
- # BGR to RGB, HWC to CHW, numpy to tensor
74
- img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
75
- # normalize
76
- normalize(img_gt, self.mean, self.std, inplace=True)
77
- return {'gt': img_gt, 'gt_path': gt_path}
78
-
79
- def __len__(self):
80
- return len(self.paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt DELETED
The diff for this file is too large to render. See raw diff