Shokoufehhh
commited on
Commit
•
b427b58
1
Parent(s):
9e1402a
Upload 27 files
Browse filesAdding sgmse folder
- sgmse/backbones/__init__.py +6 -0
- sgmse/backbones/dcunet.py +627 -0
- sgmse/backbones/ncsnpp.py +419 -0
- sgmse/backbones/ncsnpp_48k.py +424 -0
- sgmse/backbones/ncsnpp_utils/layers.py +662 -0
- sgmse/backbones/ncsnpp_utils/layerspp.py +274 -0
- sgmse/backbones/ncsnpp_utils/normalization.py +215 -0
- sgmse/backbones/ncsnpp_utils/op/__init__.py +1 -0
- sgmse/backbones/ncsnpp_utils/op/fused_act.py +97 -0
- sgmse/backbones/ncsnpp_utils/op/fused_bias_act.cpp +21 -0
- sgmse/backbones/ncsnpp_utils/op/fused_bias_act_kernel.cu +99 -0
- sgmse/backbones/ncsnpp_utils/op/upfirdn2d.cpp +23 -0
- sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py +203 -0
- sgmse/backbones/ncsnpp_utils/op/upfirdn2d_kernel.cu +369 -0
- sgmse/backbones/ncsnpp_utils/up_or_down_sampling.py +257 -0
- sgmse/backbones/ncsnpp_utils/utils.py +189 -0
- sgmse/backbones/shared.py +123 -0
- sgmse/data_module.py +236 -0
- sgmse/model.py +253 -0
- sgmse/sampling/__init__.py +143 -0
- sgmse/sampling/correctors.py +96 -0
- sgmse/sampling/predictors.py +76 -0
- sgmse/sdes.py +310 -0
- sgmse/util/inference.py +64 -0
- sgmse/util/other.py +141 -0
- sgmse/util/registry.py +34 -0
- sgmse/util/tensors.py +16 -0
sgmse/backbones/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .shared import BackboneRegistry
|
2 |
+
from .ncsnpp import NCSNpp
|
3 |
+
from .ncsnpp_48k import NCSNpp_48k
|
4 |
+
from .dcunet import DCUNet
|
5 |
+
|
6 |
+
__all__ = ['BackboneRegistry', 'NCSNpp', 'NCSNpp_48k', 'DCUNet']
|
sgmse/backbones/dcunet.py
ADDED
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
7 |
+
|
8 |
+
from .shared import BackboneRegistry, ComplexConv2d, ComplexConvTranspose2d, ComplexLinear, \
|
9 |
+
DiffusionStepEmbedding, GaussianFourierProjection, FeatureMapDense, torch_complex_from_reim
|
10 |
+
|
11 |
+
|
12 |
+
def get_activation(name):
|
13 |
+
if name == "silu":
|
14 |
+
return nn.SiLU
|
15 |
+
elif name == "relu":
|
16 |
+
return nn.ReLU
|
17 |
+
elif name == "leaky_relu":
|
18 |
+
return nn.LeakyReLU
|
19 |
+
else:
|
20 |
+
raise NotImplementedError(f"Unknown activation: {name}")
|
21 |
+
|
22 |
+
|
23 |
+
class BatchNorm(_BatchNorm):
|
24 |
+
def _check_input_dim(self, input):
|
25 |
+
if input.dim() < 2 or input.dim() > 4:
|
26 |
+
raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
|
27 |
+
|
28 |
+
|
29 |
+
class OnReIm(nn.Module):
|
30 |
+
def __init__(self, module_cls, *args, **kwargs):
|
31 |
+
super().__init__()
|
32 |
+
self.re_module = module_cls(*args, **kwargs)
|
33 |
+
self.im_module = module_cls(*args, **kwargs)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))
|
37 |
+
|
38 |
+
|
39 |
+
# Code for DCUNet largely copied from Danilo's `informedenh` repo, cheers!
|
40 |
+
|
41 |
+
def unet_decoder_args(encoders, *, skip_connections):
|
42 |
+
"""Get list of decoder arguments for upsampling (right) side of a symmetric u-net,
|
43 |
+
given the arguments used to construct the encoder.
|
44 |
+
Args:
|
45 |
+
encoders (tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding)):
|
46 |
+
List of arguments used to construct the encoders
|
47 |
+
skip_connections (bool): Whether to include skip connections in the
|
48 |
+
calculation of decoder input channels.
|
49 |
+
Return:
|
50 |
+
tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding):
|
51 |
+
Arguments to be used to construct decoders
|
52 |
+
"""
|
53 |
+
decoder_args = []
|
54 |
+
for enc_in_chan, enc_out_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation in reversed(encoders):
|
55 |
+
if skip_connections and decoder_args:
|
56 |
+
skip_in_chan = enc_out_chan
|
57 |
+
else:
|
58 |
+
skip_in_chan = 0
|
59 |
+
decoder_args.append(
|
60 |
+
(enc_out_chan + skip_in_chan, enc_in_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation)
|
61 |
+
)
|
62 |
+
return tuple(decoder_args)
|
63 |
+
|
64 |
+
|
65 |
+
def make_unet_encoder_decoder_args(encoder_args, decoder_args):
|
66 |
+
encoder_args = tuple(
|
67 |
+
(
|
68 |
+
in_chan,
|
69 |
+
out_chan,
|
70 |
+
tuple(kernel_size),
|
71 |
+
tuple(stride),
|
72 |
+
tuple([n // 2 for n in kernel_size]) if padding == "auto" else tuple(padding),
|
73 |
+
tuple(dilation)
|
74 |
+
)
|
75 |
+
for in_chan, out_chan, kernel_size, stride, padding, dilation in encoder_args
|
76 |
+
)
|
77 |
+
|
78 |
+
if decoder_args == "auto":
|
79 |
+
decoder_args = unet_decoder_args(
|
80 |
+
encoder_args,
|
81 |
+
skip_connections=True,
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
decoder_args = tuple(
|
85 |
+
(
|
86 |
+
in_chan,
|
87 |
+
out_chan,
|
88 |
+
tuple(kernel_size),
|
89 |
+
tuple(stride),
|
90 |
+
tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding,
|
91 |
+
tuple(dilation),
|
92 |
+
output_padding,
|
93 |
+
)
|
94 |
+
for in_chan, out_chan, kernel_size, stride, padding, dilation, output_padding in decoder_args
|
95 |
+
)
|
96 |
+
|
97 |
+
return encoder_args, decoder_args
|
98 |
+
|
99 |
+
|
100 |
+
DCUNET_ARCHITECTURES = {
|
101 |
+
"DCUNet-10": make_unet_encoder_decoder_args(
|
102 |
+
# Encoders:
|
103 |
+
# (in_chan, out_chan, kernel_size, stride, padding, dilation)
|
104 |
+
(
|
105 |
+
(1, 32, (7, 5), (2, 2), "auto", (1,1)),
|
106 |
+
(32, 64, (7, 5), (2, 2), "auto", (1,1)),
|
107 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
108 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
109 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
110 |
+
),
|
111 |
+
# Decoders: automatic inverse
|
112 |
+
"auto",
|
113 |
+
),
|
114 |
+
"DCUNet-16": make_unet_encoder_decoder_args(
|
115 |
+
# Encoders:
|
116 |
+
# (in_chan, out_chan, kernel_size, stride, padding, dilation)
|
117 |
+
(
|
118 |
+
(1, 32, (7, 5), (2, 2), "auto", (1,1)),
|
119 |
+
(32, 32, (7, 5), (2, 1), "auto", (1,1)),
|
120 |
+
(32, 64, (7, 5), (2, 2), "auto", (1,1)),
|
121 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
122 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
123 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
124 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
125 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
126 |
+
),
|
127 |
+
# Decoders: automatic inverse
|
128 |
+
"auto",
|
129 |
+
),
|
130 |
+
"DCUNet-20": make_unet_encoder_decoder_args(
|
131 |
+
# Encoders:
|
132 |
+
# (in_chan, out_chan, kernel_size, stride, padding, dilation)
|
133 |
+
(
|
134 |
+
(1, 32, (7, 1), (1, 1), "auto", (1,1)),
|
135 |
+
(32, 32, (1, 7), (1, 1), "auto", (1,1)),
|
136 |
+
(32, 64, (7, 5), (2, 2), "auto", (1,1)),
|
137 |
+
(64, 64, (7, 5), (2, 1), "auto", (1,1)),
|
138 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
139 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
140 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
141 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
142 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
143 |
+
(64, 90, (5, 3), (2, 1), "auto", (1,1)),
|
144 |
+
),
|
145 |
+
# Decoders: automatic inverse
|
146 |
+
"auto",
|
147 |
+
),
|
148 |
+
"DilDCUNet-v2": make_unet_encoder_decoder_args( # architecture used in SGMSE / Interspeech paper
|
149 |
+
# Encoders:
|
150 |
+
# (in_chan, out_chan, kernel_size, stride, padding, dilation)
|
151 |
+
(
|
152 |
+
(1, 32, (4, 4), (1, 1), "auto", (1, 1)),
|
153 |
+
(32, 32, (4, 4), (1, 1), "auto", (1, 1)),
|
154 |
+
(32, 32, (4, 4), (1, 1), "auto", (1, 1)),
|
155 |
+
(32, 64, (4, 4), (2, 1), "auto", (2, 1)),
|
156 |
+
(64, 128, (4, 4), (2, 2), "auto", (4, 1)),
|
157 |
+
(128, 256, (4, 4), (2, 2), "auto", (8, 1)),
|
158 |
+
),
|
159 |
+
# Decoders: automatic inverse
|
160 |
+
"auto",
|
161 |
+
),
|
162 |
+
}
|
163 |
+
|
164 |
+
|
165 |
+
@BackboneRegistry.register("dcunet")
|
166 |
+
class DCUNet(nn.Module):
|
167 |
+
@staticmethod
|
168 |
+
def add_argparse_args(parser):
|
169 |
+
parser.add_argument("--dcunet-architecture", type=str, default="DilDCUNet-v2", choices=DCUNET_ARCHITECTURES.keys(), help="The concrete DCUNet architecture. 'DilDCUNet-v2' by default.")
|
170 |
+
parser.add_argument("--dcunet-time-embedding", type=str, choices=("gfp", "ds", "none"), default="gfp", help="Timestep embedding style. 'gfp' (Gaussian Fourier Projections) by default.")
|
171 |
+
parser.add_argument("--dcunet-temb-layers-global", type=int, default=1, help="Number of global linear+activation layers for the time embedding. 1 by default.")
|
172 |
+
parser.add_argument("--dcunet-temb-layers-local", type=int, default=1, help="Number of local (per-encoder/per-decoder) linear+activation layers for the time embedding. 1 by default.")
|
173 |
+
parser.add_argument("--dcunet-temb-activation", type=str, default="silu", help="The (complex) activation to use between all (global&local) time embedding layers.")
|
174 |
+
parser.add_argument("--dcunet-time-embedding-complex", action="store_true", help="Use complex-valued timestep embedding. Compatible with 'gfp' and 'ds' embeddings.")
|
175 |
+
parser.add_argument("--dcunet-fix-length", type=str, default="pad", choices=("pad", "trim", "none"), help="DCUNet strategy to 'fix' mismatched input timespan. 'pad' by default.")
|
176 |
+
parser.add_argument("--dcunet-mask-bound", type=str, choices=("tanh", "sigmoid", "none"), default="none", help="DCUNet output bounding strategy. 'none' by default.")
|
177 |
+
parser.add_argument("--dcunet-norm-type", type=str, choices=("bN", "CbN"), default="bN", help="The type of norm to use within each encoder and decoder layer. 'bN' (real/imaginary separate batch norm) by default.")
|
178 |
+
parser.add_argument("--dcunet-activation", type=str, choices=("leaky_relu", "relu", "silu"), default="leaky_relu", help="The activation to use within each encoder and decoder layer. 'leaky_relu' by default.")
|
179 |
+
return parser
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
dcunet_architecture: str = "DilDCUNet-v2",
|
184 |
+
dcunet_time_embedding: str = "gfp",
|
185 |
+
dcunet_temb_layers_global: int = 2,
|
186 |
+
dcunet_temb_layers_local: int = 1,
|
187 |
+
dcunet_temb_activation: str = "silu",
|
188 |
+
dcunet_time_embedding_complex: bool = False,
|
189 |
+
dcunet_fix_length: str = "pad",
|
190 |
+
dcunet_mask_bound: str = "none",
|
191 |
+
dcunet_norm_type: str = "bN",
|
192 |
+
dcunet_activation: str = "relu",
|
193 |
+
embed_dim: int = 128,
|
194 |
+
**kwargs
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
|
198 |
+
self.architecture = dcunet_architecture
|
199 |
+
self.fix_length_mode = (dcunet_fix_length if dcunet_fix_length != "none" else None)
|
200 |
+
self.norm_type = dcunet_norm_type
|
201 |
+
self.activation = dcunet_activation
|
202 |
+
self.input_channels = 2 # for x_t and y -- note that this is 2 rather than 4, because we directly treat complex channels in this DNN
|
203 |
+
self.time_embedding = (dcunet_time_embedding if dcunet_time_embedding != "none" else None)
|
204 |
+
self.time_embedding_complex = dcunet_time_embedding_complex
|
205 |
+
self.temb_layers_global = dcunet_temb_layers_global
|
206 |
+
self.temb_layers_local = dcunet_temb_layers_local
|
207 |
+
self.temb_activation = dcunet_temb_activation
|
208 |
+
conf_encoders, conf_decoders = DCUNET_ARCHITECTURES[dcunet_architecture]
|
209 |
+
|
210 |
+
# Replace `input_channels` in encoders config
|
211 |
+
_replaced_input_channels, *rest = conf_encoders[0]
|
212 |
+
encoders = ((self.input_channels, *rest), *conf_encoders[1:])
|
213 |
+
decoders = conf_decoders
|
214 |
+
self.encoders_stride_product = np.prod(
|
215 |
+
[enc_stride for _, _, _, enc_stride, _, _ in encoders], axis=0
|
216 |
+
)
|
217 |
+
|
218 |
+
# Prepare kwargs for encoder and decoder (to potentially be modified before layer instantiation)
|
219 |
+
encoder_decoder_kwargs = dict(
|
220 |
+
norm_type=self.norm_type, activation=self.activation,
|
221 |
+
temb_layers=self.temb_layers_local, temb_activation=self.temb_activation)
|
222 |
+
|
223 |
+
# Instantiate (global) time embedding layer
|
224 |
+
embed_ops = []
|
225 |
+
if self.time_embedding is not None:
|
226 |
+
complex_valued = self.time_embedding_complex
|
227 |
+
if self.time_embedding == "gfp":
|
228 |
+
embed_ops += [GaussianFourierProjection(embed_dim=embed_dim, complex_valued=complex_valued)]
|
229 |
+
encoder_decoder_kwargs["embed_dim"] = embed_dim
|
230 |
+
elif self.time_embedding == "ds":
|
231 |
+
embed_ops += [DiffusionStepEmbedding(embed_dim=embed_dim, complex_valued=complex_valued)]
|
232 |
+
encoder_decoder_kwargs["embed_dim"] = embed_dim
|
233 |
+
|
234 |
+
if self.time_embedding_complex:
|
235 |
+
assert self.time_embedding in ("gfp", "ds"), "Complex timestep embedding only available for gfp and ds"
|
236 |
+
encoder_decoder_kwargs["complex_time_embedding"] = True
|
237 |
+
for _ in range(self.temb_layers_global):
|
238 |
+
embed_ops += [
|
239 |
+
ComplexLinear(embed_dim, embed_dim, complex_valued=True),
|
240 |
+
OnReIm(get_activation(dcunet_temb_activation))
|
241 |
+
]
|
242 |
+
self.embed = nn.Sequential(*embed_ops)
|
243 |
+
|
244 |
+
### Instantiate DCUNet layers ###
|
245 |
+
output_layer = ComplexConvTranspose2d(*decoders[-1])
|
246 |
+
encoders = [DCUNetComplexEncoderBlock(*args, **encoder_decoder_kwargs) for args in encoders]
|
247 |
+
decoders = [DCUNetComplexDecoderBlock(*args, **encoder_decoder_kwargs) for args in decoders[:-1]]
|
248 |
+
|
249 |
+
self.mask_bound = (dcunet_mask_bound if dcunet_mask_bound != "none" else None)
|
250 |
+
if self.mask_bound is not None:
|
251 |
+
raise NotImplementedError("sorry, mask bounding not implemented at the moment")
|
252 |
+
# TODO we can't use nn.Sequential since the ComplexConvTranspose2d needs a second `output_size` argument
|
253 |
+
#operations = (output_layer, complex_nn.BoundComplexMask(self.mask_bound))
|
254 |
+
#output_layer = nn.Sequential(*[x for x in operations if x is not None])
|
255 |
+
|
256 |
+
assert len(encoders) == len(decoders) + 1
|
257 |
+
self.encoders = nn.ModuleList(encoders)
|
258 |
+
self.decoders = nn.ModuleList(decoders)
|
259 |
+
self.output_layer = output_layer or nn.Identity()
|
260 |
+
|
261 |
+
def forward(self, spec, t) -> Tensor:
|
262 |
+
"""
|
263 |
+
Input shape is expected to be $(batch, nfreqs, time)$, with $nfreqs - 1$ divisible
|
264 |
+
by $f_0 * f_1 * ... * f_N$ where $f_k$ are the frequency strides of the encoders,
|
265 |
+
and $time - 1$ is divisible by $t_0 * t_1 * ... * t_N$ where $t_N$ are the time
|
266 |
+
strides of the encoders.
|
267 |
+
Args:
|
268 |
+
spec (Tensor): complex spectrogram tensor. 1D, 2D or 3D tensor, time last.
|
269 |
+
Returns:
|
270 |
+
Tensor, of shape (batch, time) or (time).
|
271 |
+
"""
|
272 |
+
# TF-rep shape: (batch, self.input_channels, n_fft, frames)
|
273 |
+
# Estimate mask from time-frequency representation.
|
274 |
+
x_in = self.fix_input_dims(spec)
|
275 |
+
x = x_in
|
276 |
+
t_embed = self.embed(t+0j) if self.time_embedding is not None else None
|
277 |
+
|
278 |
+
enc_outs = []
|
279 |
+
for idx, enc in enumerate(self.encoders):
|
280 |
+
x = enc(x, t_embed)
|
281 |
+
# UNet skip connection
|
282 |
+
enc_outs.append(x)
|
283 |
+
for (enc_out, dec) in zip(reversed(enc_outs[:-1]), self.decoders):
|
284 |
+
x = dec(x, t_embed, output_size=enc_out.shape)
|
285 |
+
x = torch.cat([x, enc_out], dim=1)
|
286 |
+
|
287 |
+
output = self.output_layer(x, output_size=x_in.shape)
|
288 |
+
# output shape: (batch, 1, n_fft, frames)
|
289 |
+
output = self.fix_output_dims(output, spec)
|
290 |
+
return output
|
291 |
+
|
292 |
+
def fix_input_dims(self, x):
|
293 |
+
return _fix_dcu_input_dims(
|
294 |
+
self.fix_length_mode, x, torch.from_numpy(self.encoders_stride_product)
|
295 |
+
)
|
296 |
+
|
297 |
+
def fix_output_dims(self, out, x):
|
298 |
+
return _fix_dcu_output_dims(self.fix_length_mode, out, x)
|
299 |
+
|
300 |
+
|
301 |
+
def _fix_dcu_input_dims(fix_length_mode, x, encoders_stride_product):
|
302 |
+
"""Pad or trim `x` to a length compatible with DCUNet."""
|
303 |
+
freq_prod = int(encoders_stride_product[0])
|
304 |
+
time_prod = int(encoders_stride_product[1])
|
305 |
+
if (x.shape[2] - 1) % freq_prod:
|
306 |
+
raise TypeError(
|
307 |
+
f"Input shape must be [batch, ch, freq + 1, time + 1] with freq divisible by "
|
308 |
+
f"{freq_prod}, got {x.shape} instead"
|
309 |
+
)
|
310 |
+
time_remainder = (x.shape[3] - 1) % time_prod
|
311 |
+
if time_remainder:
|
312 |
+
if fix_length_mode is None:
|
313 |
+
raise TypeError(
|
314 |
+
f"Input shape must be [batch, ch, freq + 1, time + 1] with time divisible by "
|
315 |
+
f"{time_prod}, got {x.shape} instead. Set the 'fix_length_mode' argument "
|
316 |
+
f"in 'DCUNet' to 'pad' or 'trim' to fix shapes automatically."
|
317 |
+
)
|
318 |
+
elif fix_length_mode == "pad":
|
319 |
+
pad_shape = [0, time_prod - time_remainder]
|
320 |
+
x = nn.functional.pad(x, pad_shape, mode="constant")
|
321 |
+
elif fix_length_mode == "trim":
|
322 |
+
pad_shape = [0, -time_remainder]
|
323 |
+
x = nn.functional.pad(x, pad_shape, mode="constant")
|
324 |
+
else:
|
325 |
+
raise ValueError(f"Unknown fix_length mode '{fix_length_mode}'")
|
326 |
+
return x
|
327 |
+
|
328 |
+
|
329 |
+
def _fix_dcu_output_dims(fix_length_mode, out, x):
|
330 |
+
"""Fix shape of `out` to the original shape of `x` by padding/cropping."""
|
331 |
+
inp_len = x.shape[-1]
|
332 |
+
output_len = out.shape[-1]
|
333 |
+
return nn.functional.pad(out, [0, inp_len - output_len])
|
334 |
+
|
335 |
+
|
336 |
+
def _get_norm(norm_type):
|
337 |
+
if norm_type == "CbN":
|
338 |
+
return ComplexBatchNorm
|
339 |
+
elif norm_type == "bN":
|
340 |
+
return partial(OnReIm, BatchNorm)
|
341 |
+
else:
|
342 |
+
raise NotImplementedError(f"Unknown norm type: {norm_type}")
|
343 |
+
|
344 |
+
|
345 |
+
class DCUNetComplexEncoderBlock(nn.Module):
|
346 |
+
def __init__(
|
347 |
+
self,
|
348 |
+
in_chan,
|
349 |
+
out_chan,
|
350 |
+
kernel_size,
|
351 |
+
stride,
|
352 |
+
padding,
|
353 |
+
dilation,
|
354 |
+
norm_type="bN",
|
355 |
+
activation="leaky_relu",
|
356 |
+
embed_dim=None,
|
357 |
+
complex_time_embedding=False,
|
358 |
+
temb_layers=1,
|
359 |
+
temb_activation="silu"
|
360 |
+
):
|
361 |
+
super().__init__()
|
362 |
+
|
363 |
+
self.in_chan = in_chan
|
364 |
+
self.out_chan = out_chan
|
365 |
+
self.kernel_size = kernel_size
|
366 |
+
self.stride = stride
|
367 |
+
self.padding = padding
|
368 |
+
self.dilation = dilation
|
369 |
+
self.temb_layers = temb_layers
|
370 |
+
self.temb_activation = temb_activation
|
371 |
+
self.complex_time_embedding = complex_time_embedding
|
372 |
+
|
373 |
+
self.conv = ComplexConv2d(
|
374 |
+
in_chan, out_chan, kernel_size, stride, padding, bias=norm_type is None, dilation=dilation
|
375 |
+
)
|
376 |
+
self.norm = _get_norm(norm_type)(out_chan)
|
377 |
+
self.activation = OnReIm(get_activation(activation))
|
378 |
+
self.embed_dim = embed_dim
|
379 |
+
if self.embed_dim is not None:
|
380 |
+
ops = []
|
381 |
+
for _ in range(max(0, self.temb_layers - 1)):
|
382 |
+
ops += [
|
383 |
+
ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
|
384 |
+
OnReIm(get_activation(self.temb_activation))
|
385 |
+
]
|
386 |
+
ops += [
|
387 |
+
FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
|
388 |
+
OnReIm(get_activation(self.temb_activation))
|
389 |
+
]
|
390 |
+
self.embed_layer = nn.Sequential(*ops)
|
391 |
+
|
392 |
+
def forward(self, x, t_embed):
|
393 |
+
y = self.conv(x)
|
394 |
+
if self.embed_dim is not None:
|
395 |
+
y = y + self.embed_layer(t_embed)
|
396 |
+
return self.activation(self.norm(y))
|
397 |
+
|
398 |
+
|
399 |
+
class DCUNetComplexDecoderBlock(nn.Module):
|
400 |
+
def __init__(
|
401 |
+
self,
|
402 |
+
in_chan,
|
403 |
+
out_chan,
|
404 |
+
kernel_size,
|
405 |
+
stride,
|
406 |
+
padding,
|
407 |
+
dilation,
|
408 |
+
output_padding=(0, 0),
|
409 |
+
norm_type="bN",
|
410 |
+
activation="leaky_relu",
|
411 |
+
embed_dim=None,
|
412 |
+
temb_layers=1,
|
413 |
+
temb_activation='swish',
|
414 |
+
complex_time_embedding=False,
|
415 |
+
):
|
416 |
+
super().__init__()
|
417 |
+
|
418 |
+
self.in_chan = in_chan
|
419 |
+
self.out_chan = out_chan
|
420 |
+
self.kernel_size = kernel_size
|
421 |
+
self.stride = stride
|
422 |
+
self.padding = padding
|
423 |
+
self.dilation = dilation
|
424 |
+
self.output_padding = output_padding
|
425 |
+
self.complex_time_embedding = complex_time_embedding
|
426 |
+
self.temb_layers = temb_layers
|
427 |
+
self.temb_activation = temb_activation
|
428 |
+
|
429 |
+
self.deconv = ComplexConvTranspose2d(
|
430 |
+
in_chan, out_chan, kernel_size, stride, padding, output_padding, dilation=dilation, bias=norm_type is None
|
431 |
+
)
|
432 |
+
self.norm = _get_norm(norm_type)(out_chan)
|
433 |
+
self.activation = OnReIm(get_activation(activation))
|
434 |
+
self.embed_dim = embed_dim
|
435 |
+
if self.embed_dim is not None:
|
436 |
+
ops = []
|
437 |
+
for _ in range(max(0, self.temb_layers - 1)):
|
438 |
+
ops += [
|
439 |
+
ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
|
440 |
+
OnReIm(get_activation(self.temb_activation))
|
441 |
+
]
|
442 |
+
ops += [
|
443 |
+
FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
|
444 |
+
OnReIm(get_activation(self.temb_activation))
|
445 |
+
]
|
446 |
+
self.embed_layer = nn.Sequential(*ops)
|
447 |
+
|
448 |
+
def forward(self, x, t_embed, output_size=None):
|
449 |
+
y = self.deconv(x, output_size=output_size)
|
450 |
+
if self.embed_dim is not None:
|
451 |
+
y = y + self.embed_layer(t_embed)
|
452 |
+
return self.activation(self.norm(y))
|
453 |
+
|
454 |
+
|
455 |
+
# From https://github.com/chanil1218/DCUnet.pytorch/blob/2dcdd30804be47a866fde6435cbb7e2f81585213/models/layers/complexnn.py
|
456 |
+
class ComplexBatchNorm(torch.nn.Module):
|
457 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=False):
|
458 |
+
super(ComplexBatchNorm, self).__init__()
|
459 |
+
self.num_features = num_features
|
460 |
+
self.eps = eps
|
461 |
+
self.momentum = momentum
|
462 |
+
self.affine = affine
|
463 |
+
self.track_running_stats = track_running_stats
|
464 |
+
if self.affine:
|
465 |
+
self.Wrr = torch.nn.Parameter(torch.Tensor(num_features))
|
466 |
+
self.Wri = torch.nn.Parameter(torch.Tensor(num_features))
|
467 |
+
self.Wii = torch.nn.Parameter(torch.Tensor(num_features))
|
468 |
+
self.Br = torch.nn.Parameter(torch.Tensor(num_features))
|
469 |
+
self.Bi = torch.nn.Parameter(torch.Tensor(num_features))
|
470 |
+
else:
|
471 |
+
self.register_parameter('Wrr', None)
|
472 |
+
self.register_parameter('Wri', None)
|
473 |
+
self.register_parameter('Wii', None)
|
474 |
+
self.register_parameter('Br', None)
|
475 |
+
self.register_parameter('Bi', None)
|
476 |
+
if self.track_running_stats:
|
477 |
+
self.register_buffer('RMr', torch.zeros(num_features))
|
478 |
+
self.register_buffer('RMi', torch.zeros(num_features))
|
479 |
+
self.register_buffer('RVrr', torch.ones (num_features))
|
480 |
+
self.register_buffer('RVri', torch.zeros(num_features))
|
481 |
+
self.register_buffer('RVii', torch.ones (num_features))
|
482 |
+
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
|
483 |
+
else:
|
484 |
+
self.register_parameter('RMr', None)
|
485 |
+
self.register_parameter('RMi', None)
|
486 |
+
self.register_parameter('RVrr', None)
|
487 |
+
self.register_parameter('RVri', None)
|
488 |
+
self.register_parameter('RVii', None)
|
489 |
+
self.register_parameter('num_batches_tracked', None)
|
490 |
+
self.reset_parameters()
|
491 |
+
|
492 |
+
def reset_running_stats(self):
|
493 |
+
if self.track_running_stats:
|
494 |
+
self.RMr.zero_()
|
495 |
+
self.RMi.zero_()
|
496 |
+
self.RVrr.fill_(1)
|
497 |
+
self.RVri.zero_()
|
498 |
+
self.RVii.fill_(1)
|
499 |
+
self.num_batches_tracked.zero_()
|
500 |
+
|
501 |
+
def reset_parameters(self):
|
502 |
+
self.reset_running_stats()
|
503 |
+
if self.affine:
|
504 |
+
self.Br.data.zero_()
|
505 |
+
self.Bi.data.zero_()
|
506 |
+
self.Wrr.data.fill_(1)
|
507 |
+
self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
|
508 |
+
self.Wii.data.fill_(1)
|
509 |
+
|
510 |
+
def _check_input_dim(self, xr, xi):
|
511 |
+
assert(xr.shape == xi.shape)
|
512 |
+
assert(xr.size(1) == self.num_features)
|
513 |
+
|
514 |
+
def forward(self, x):
|
515 |
+
xr, xi = x.real, x.imag
|
516 |
+
self._check_input_dim(xr, xi)
|
517 |
+
|
518 |
+
exponential_average_factor = 0.0
|
519 |
+
|
520 |
+
if self.training and self.track_running_stats:
|
521 |
+
self.num_batches_tracked += 1
|
522 |
+
if self.momentum is None: # use cumulative moving average
|
523 |
+
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
|
524 |
+
else: # use exponential moving average
|
525 |
+
exponential_average_factor = self.momentum
|
526 |
+
|
527 |
+
#
|
528 |
+
# NOTE: The precise meaning of the "training flag" is:
|
529 |
+
# True: Normalize using batch statistics, update running statistics
|
530 |
+
# if they are being collected.
|
531 |
+
# False: Normalize using running statistics, ignore batch statistics.
|
532 |
+
#
|
533 |
+
training = self.training or not self.track_running_stats
|
534 |
+
redux = [i for i in reversed(range(xr.dim())) if i!=1]
|
535 |
+
vdim = [1] * xr.dim()
|
536 |
+
vdim[1] = xr.size(1)
|
537 |
+
|
538 |
+
#
|
539 |
+
# Mean M Computation and Centering
|
540 |
+
#
|
541 |
+
# Includes running mean update if training and running.
|
542 |
+
#
|
543 |
+
if training:
|
544 |
+
Mr, Mi = xr, xi
|
545 |
+
for d in redux:
|
546 |
+
Mr = Mr.mean(d, keepdim=True)
|
547 |
+
Mi = Mi.mean(d, keepdim=True)
|
548 |
+
if self.track_running_stats:
|
549 |
+
self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
|
550 |
+
self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
|
551 |
+
else:
|
552 |
+
Mr = self.RMr.view(vdim)
|
553 |
+
Mi = self.RMi.view(vdim)
|
554 |
+
xr, xi = xr-Mr, xi-Mi
|
555 |
+
|
556 |
+
#
|
557 |
+
# Variance Matrix V Computation
|
558 |
+
#
|
559 |
+
# Includes epsilon numerical stabilizer/Tikhonov regularizer.
|
560 |
+
# Includes running variance update if training and running.
|
561 |
+
#
|
562 |
+
if training:
|
563 |
+
Vrr = xr * xr
|
564 |
+
Vri = xr * xi
|
565 |
+
Vii = xi * xi
|
566 |
+
for d in redux:
|
567 |
+
Vrr = Vrr.mean(d, keepdim=True)
|
568 |
+
Vri = Vri.mean(d, keepdim=True)
|
569 |
+
Vii = Vii.mean(d, keepdim=True)
|
570 |
+
if self.track_running_stats:
|
571 |
+
self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
|
572 |
+
self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
|
573 |
+
self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
|
574 |
+
else:
|
575 |
+
Vrr = self.RVrr.view(vdim)
|
576 |
+
Vri = self.RVri.view(vdim)
|
577 |
+
Vii = self.RVii.view(vdim)
|
578 |
+
Vrr = Vrr + self.eps
|
579 |
+
Vri = Vri
|
580 |
+
Vii = Vii + self.eps
|
581 |
+
|
582 |
+
#
|
583 |
+
# Matrix Inverse Square Root U = V^-0.5
|
584 |
+
#
|
585 |
+
# sqrt of a 2x2 matrix,
|
586 |
+
# - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
|
587 |
+
tau = Vrr + Vii
|
588 |
+
delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1)
|
589 |
+
s = delta.sqrt()
|
590 |
+
t = (tau + 2*s).sqrt()
|
591 |
+
|
592 |
+
# matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
|
593 |
+
rst = (s * t).reciprocal()
|
594 |
+
Urr = (s + Vii) * rst
|
595 |
+
Uii = (s + Vrr) * rst
|
596 |
+
Uri = ( - Vri) * rst
|
597 |
+
|
598 |
+
#
|
599 |
+
# Optionally left-multiply U by affine weights W to produce combined
|
600 |
+
# weights Z, left-multiply the inputs by Z, then optionally bias them.
|
601 |
+
#
|
602 |
+
# y = Zx + B
|
603 |
+
# y = WUx + B
|
604 |
+
# y = [Wrr Wri][Urr Uri] [xr] + [Br]
|
605 |
+
# [Wir Wii][Uir Uii] [xi] [Bi]
|
606 |
+
#
|
607 |
+
if self.affine:
|
608 |
+
Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
|
609 |
+
Zrr = (Wrr * Urr) + (Wri * Uri)
|
610 |
+
Zri = (Wrr * Uri) + (Wri * Uii)
|
611 |
+
Zir = (Wri * Urr) + (Wii * Uri)
|
612 |
+
Zii = (Wri * Uri) + (Wii * Uii)
|
613 |
+
else:
|
614 |
+
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
|
615 |
+
|
616 |
+
yr = (Zrr * xr) + (Zri * xi)
|
617 |
+
yi = (Zir * xr) + (Zii * xi)
|
618 |
+
|
619 |
+
if self.affine:
|
620 |
+
yr = yr + self.Br.view(vdim)
|
621 |
+
yi = yi + self.Bi.view(vdim)
|
622 |
+
|
623 |
+
return torch.view_as_complex(torch.stack([yr, yi], dim=-1))
|
624 |
+
|
625 |
+
def extra_repr(self):
|
626 |
+
return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
|
627 |
+
'track_running_stats={track_running_stats}'.format(**self.__dict__)
|
sgmse/backbones/ncsnpp.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
|
18 |
+
from .ncsnpp_utils import layers, layerspp, normalization
|
19 |
+
import torch.nn as nn
|
20 |
+
import functools
|
21 |
+
import torch
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from .shared import BackboneRegistry
|
25 |
+
|
26 |
+
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
|
27 |
+
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
|
28 |
+
Combine = layerspp.Combine
|
29 |
+
conv3x3 = layerspp.conv3x3
|
30 |
+
conv1x1 = layerspp.conv1x1
|
31 |
+
get_act = layers.get_act
|
32 |
+
get_normalization = normalization.get_normalization
|
33 |
+
default_initializer = layers.default_init
|
34 |
+
|
35 |
+
|
36 |
+
@BackboneRegistry.register("ncsnpp")
|
37 |
+
class NCSNpp(nn.Module):
|
38 |
+
"""NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def add_argparse_args(parser):
|
42 |
+
parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
|
43 |
+
parser.add_argument("--num_res_blocks", type=int, default=2)
|
44 |
+
parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16])
|
45 |
+
parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
|
46 |
+
parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
|
47 |
+
parser.set_defaults(centered=True)
|
48 |
+
return parser
|
49 |
+
|
50 |
+
def __init__(self,
|
51 |
+
scale_by_sigma = True,
|
52 |
+
nonlinearity = 'swish',
|
53 |
+
nf = 128,
|
54 |
+
ch_mult = (1, 1, 2, 2, 2, 2, 2),
|
55 |
+
num_res_blocks = 2,
|
56 |
+
attn_resolutions = (16,),
|
57 |
+
resamp_with_conv = True,
|
58 |
+
conditional = True,
|
59 |
+
fir = True,
|
60 |
+
fir_kernel = [1, 3, 3, 1],
|
61 |
+
skip_rescale = True,
|
62 |
+
resblock_type = 'biggan',
|
63 |
+
progressive = 'output_skip',
|
64 |
+
progressive_input = 'input_skip',
|
65 |
+
progressive_combine = 'sum',
|
66 |
+
init_scale = 0.,
|
67 |
+
fourier_scale = 16,
|
68 |
+
image_size = 256,
|
69 |
+
embedding_type = 'fourier',
|
70 |
+
dropout = .0,
|
71 |
+
centered = True,
|
72 |
+
**unused_kwargs
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
self.act = act = get_act(nonlinearity)
|
76 |
+
|
77 |
+
self.nf = nf = nf
|
78 |
+
ch_mult = ch_mult
|
79 |
+
self.num_res_blocks = num_res_blocks = num_res_blocks
|
80 |
+
self.attn_resolutions = attn_resolutions = attn_resolutions
|
81 |
+
dropout = dropout
|
82 |
+
resamp_with_conv = resamp_with_conv
|
83 |
+
self.num_resolutions = num_resolutions = len(ch_mult)
|
84 |
+
self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
|
85 |
+
|
86 |
+
self.conditional = conditional = conditional # noise-conditional
|
87 |
+
self.centered = centered
|
88 |
+
self.scale_by_sigma = scale_by_sigma
|
89 |
+
|
90 |
+
fir = fir
|
91 |
+
fir_kernel = fir_kernel
|
92 |
+
self.skip_rescale = skip_rescale = skip_rescale
|
93 |
+
self.resblock_type = resblock_type = resblock_type.lower()
|
94 |
+
self.progressive = progressive = progressive.lower()
|
95 |
+
self.progressive_input = progressive_input = progressive_input.lower()
|
96 |
+
self.embedding_type = embedding_type = embedding_type.lower()
|
97 |
+
init_scale = init_scale
|
98 |
+
assert progressive in ['none', 'output_skip', 'residual']
|
99 |
+
assert progressive_input in ['none', 'input_skip', 'residual']
|
100 |
+
assert embedding_type in ['fourier', 'positional']
|
101 |
+
combine_method = progressive_combine.lower()
|
102 |
+
combiner = functools.partial(Combine, method=combine_method)
|
103 |
+
|
104 |
+
num_channels = 4 # x.real, x.imag, y.real, y.imag
|
105 |
+
self.output_layer = nn.Conv2d(num_channels, 2, 1)
|
106 |
+
|
107 |
+
modules = []
|
108 |
+
# timestep/noise_level embedding
|
109 |
+
if embedding_type == 'fourier':
|
110 |
+
# Gaussian Fourier features embeddings.
|
111 |
+
modules.append(layerspp.GaussianFourierProjection(
|
112 |
+
embedding_size=nf, scale=fourier_scale
|
113 |
+
))
|
114 |
+
embed_dim = 2 * nf
|
115 |
+
elif embedding_type == 'positional':
|
116 |
+
embed_dim = nf
|
117 |
+
else:
|
118 |
+
raise ValueError(f'embedding type {embedding_type} unknown.')
|
119 |
+
|
120 |
+
if conditional:
|
121 |
+
modules.append(nn.Linear(embed_dim, nf * 4))
|
122 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
123 |
+
nn.init.zeros_(modules[-1].bias)
|
124 |
+
modules.append(nn.Linear(nf * 4, nf * 4))
|
125 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
126 |
+
nn.init.zeros_(modules[-1].bias)
|
127 |
+
|
128 |
+
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
129 |
+
init_scale=init_scale, skip_rescale=skip_rescale)
|
130 |
+
|
131 |
+
Upsample = functools.partial(layerspp.Upsample,
|
132 |
+
with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
133 |
+
|
134 |
+
if progressive == 'output_skip':
|
135 |
+
self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
136 |
+
elif progressive == 'residual':
|
137 |
+
pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
|
138 |
+
fir_kernel=fir_kernel, with_conv=True)
|
139 |
+
|
140 |
+
Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
141 |
+
|
142 |
+
if progressive_input == 'input_skip':
|
143 |
+
self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
144 |
+
elif progressive_input == 'residual':
|
145 |
+
pyramid_downsample = functools.partial(layerspp.Downsample,
|
146 |
+
fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
147 |
+
|
148 |
+
if resblock_type == 'ddpm':
|
149 |
+
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
|
150 |
+
dropout=dropout, init_scale=init_scale,
|
151 |
+
skip_rescale=skip_rescale, temb_dim=nf * 4)
|
152 |
+
|
153 |
+
elif resblock_type == 'biggan':
|
154 |
+
ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
|
155 |
+
dropout=dropout, fir=fir, fir_kernel=fir_kernel,
|
156 |
+
init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
|
157 |
+
|
158 |
+
else:
|
159 |
+
raise ValueError(f'resblock type {resblock_type} unrecognized.')
|
160 |
+
|
161 |
+
# Downsampling block
|
162 |
+
|
163 |
+
channels = num_channels
|
164 |
+
if progressive_input != 'none':
|
165 |
+
input_pyramid_ch = channels
|
166 |
+
|
167 |
+
modules.append(conv3x3(channels, nf))
|
168 |
+
hs_c = [nf]
|
169 |
+
|
170 |
+
in_ch = nf
|
171 |
+
for i_level in range(num_resolutions):
|
172 |
+
# Residual blocks for this resolution
|
173 |
+
for i_block in range(num_res_blocks):
|
174 |
+
out_ch = nf * ch_mult[i_level]
|
175 |
+
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
176 |
+
in_ch = out_ch
|
177 |
+
|
178 |
+
if all_resolutions[i_level] in attn_resolutions:
|
179 |
+
modules.append(AttnBlock(channels=in_ch))
|
180 |
+
hs_c.append(in_ch)
|
181 |
+
|
182 |
+
if i_level != num_resolutions - 1:
|
183 |
+
if resblock_type == 'ddpm':
|
184 |
+
modules.append(Downsample(in_ch=in_ch))
|
185 |
+
else:
|
186 |
+
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
187 |
+
|
188 |
+
if progressive_input == 'input_skip':
|
189 |
+
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
190 |
+
if combine_method == 'cat':
|
191 |
+
in_ch *= 2
|
192 |
+
|
193 |
+
elif progressive_input == 'residual':
|
194 |
+
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
|
195 |
+
input_pyramid_ch = in_ch
|
196 |
+
|
197 |
+
hs_c.append(in_ch)
|
198 |
+
|
199 |
+
in_ch = hs_c[-1]
|
200 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
201 |
+
modules.append(AttnBlock(channels=in_ch))
|
202 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
203 |
+
|
204 |
+
pyramid_ch = 0
|
205 |
+
# Upsampling block
|
206 |
+
for i_level in reversed(range(num_resolutions)):
|
207 |
+
for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
|
208 |
+
out_ch = nf * ch_mult[i_level]
|
209 |
+
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
210 |
+
in_ch = out_ch
|
211 |
+
|
212 |
+
if all_resolutions[i_level] in attn_resolutions:
|
213 |
+
modules.append(AttnBlock(channels=in_ch))
|
214 |
+
|
215 |
+
if progressive != 'none':
|
216 |
+
if i_level == num_resolutions - 1:
|
217 |
+
if progressive == 'output_skip':
|
218 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
219 |
+
num_channels=in_ch, eps=1e-6))
|
220 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
221 |
+
pyramid_ch = channels
|
222 |
+
elif progressive == 'residual':
|
223 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
224 |
+
modules.append(conv3x3(in_ch, in_ch, bias=True))
|
225 |
+
pyramid_ch = in_ch
|
226 |
+
else:
|
227 |
+
raise ValueError(f'{progressive} is not a valid name.')
|
228 |
+
else:
|
229 |
+
if progressive == 'output_skip':
|
230 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
231 |
+
num_channels=in_ch, eps=1e-6))
|
232 |
+
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
|
233 |
+
pyramid_ch = channels
|
234 |
+
elif progressive == 'residual':
|
235 |
+
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
236 |
+
pyramid_ch = in_ch
|
237 |
+
else:
|
238 |
+
raise ValueError(f'{progressive} is not a valid name')
|
239 |
+
|
240 |
+
if i_level != 0:
|
241 |
+
if resblock_type == 'ddpm':
|
242 |
+
modules.append(Upsample(in_ch=in_ch))
|
243 |
+
else:
|
244 |
+
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
245 |
+
|
246 |
+
assert not hs_c
|
247 |
+
|
248 |
+
if progressive != 'output_skip':
|
249 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
250 |
+
num_channels=in_ch, eps=1e-6))
|
251 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
252 |
+
|
253 |
+
self.all_modules = nn.ModuleList(modules)
|
254 |
+
|
255 |
+
|
256 |
+
def forward(self, x, time_cond):
|
257 |
+
# timestep/noise_level embedding; only for continuous training
|
258 |
+
modules = self.all_modules
|
259 |
+
m_idx = 0
|
260 |
+
|
261 |
+
# Convert real and imaginary parts of (x,y) into four channel dimensions
|
262 |
+
x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
|
263 |
+
x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
|
264 |
+
|
265 |
+
if self.embedding_type == 'fourier':
|
266 |
+
# Gaussian Fourier features embeddings.
|
267 |
+
used_sigmas = time_cond
|
268 |
+
temb = modules[m_idx](torch.log(used_sigmas))
|
269 |
+
m_idx += 1
|
270 |
+
|
271 |
+
elif self.embedding_type == 'positional':
|
272 |
+
# Sinusoidal positional embeddings.
|
273 |
+
timesteps = time_cond
|
274 |
+
used_sigmas = self.sigmas[time_cond.long()]
|
275 |
+
temb = layers.get_timestep_embedding(timesteps, self.nf)
|
276 |
+
|
277 |
+
else:
|
278 |
+
raise ValueError(f'embedding type {self.embedding_type} unknown.')
|
279 |
+
|
280 |
+
if self.conditional:
|
281 |
+
temb = modules[m_idx](temb)
|
282 |
+
m_idx += 1
|
283 |
+
temb = modules[m_idx](self.act(temb))
|
284 |
+
m_idx += 1
|
285 |
+
else:
|
286 |
+
temb = None
|
287 |
+
|
288 |
+
if not self.centered:
|
289 |
+
# If input data is in [0, 1]
|
290 |
+
x = 2 * x - 1.
|
291 |
+
|
292 |
+
# Downsampling block
|
293 |
+
input_pyramid = None
|
294 |
+
if self.progressive_input != 'none':
|
295 |
+
input_pyramid = x
|
296 |
+
|
297 |
+
# Input layer: Conv2d: 4ch -> 128ch
|
298 |
+
hs = [modules[m_idx](x)]
|
299 |
+
m_idx += 1
|
300 |
+
|
301 |
+
# Down path in U-Net
|
302 |
+
for i_level in range(self.num_resolutions):
|
303 |
+
# Residual blocks for this resolution
|
304 |
+
for i_block in range(self.num_res_blocks):
|
305 |
+
h = modules[m_idx](hs[-1], temb)
|
306 |
+
m_idx += 1
|
307 |
+
# Attention layer (optional)
|
308 |
+
if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
|
309 |
+
h = modules[m_idx](h)
|
310 |
+
m_idx += 1
|
311 |
+
hs.append(h)
|
312 |
+
|
313 |
+
# Downsampling
|
314 |
+
if i_level != self.num_resolutions - 1:
|
315 |
+
if self.resblock_type == 'ddpm':
|
316 |
+
h = modules[m_idx](hs[-1])
|
317 |
+
m_idx += 1
|
318 |
+
else:
|
319 |
+
h = modules[m_idx](hs[-1], temb)
|
320 |
+
m_idx += 1
|
321 |
+
|
322 |
+
if self.progressive_input == 'input_skip': # Combine h with x
|
323 |
+
input_pyramid = self.pyramid_downsample(input_pyramid)
|
324 |
+
h = modules[m_idx](input_pyramid, h)
|
325 |
+
m_idx += 1
|
326 |
+
|
327 |
+
elif self.progressive_input == 'residual':
|
328 |
+
input_pyramid = modules[m_idx](input_pyramid)
|
329 |
+
m_idx += 1
|
330 |
+
if self.skip_rescale:
|
331 |
+
input_pyramid = (input_pyramid + h) / np.sqrt(2.)
|
332 |
+
else:
|
333 |
+
input_pyramid = input_pyramid + h
|
334 |
+
h = input_pyramid
|
335 |
+
hs.append(h)
|
336 |
+
|
337 |
+
h = hs[-1] # actualy equal to: h = h
|
338 |
+
h = modules[m_idx](h, temb) # ResNet block
|
339 |
+
m_idx += 1
|
340 |
+
h = modules[m_idx](h) # Attention block
|
341 |
+
m_idx += 1
|
342 |
+
h = modules[m_idx](h, temb) # ResNet block
|
343 |
+
m_idx += 1
|
344 |
+
|
345 |
+
pyramid = None
|
346 |
+
|
347 |
+
# Upsampling block
|
348 |
+
for i_level in reversed(range(self.num_resolutions)):
|
349 |
+
for i_block in range(self.num_res_blocks + 1):
|
350 |
+
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
|
351 |
+
m_idx += 1
|
352 |
+
|
353 |
+
# edit: from -1 to -2
|
354 |
+
if h.shape[-2] in self.attn_resolutions:
|
355 |
+
h = modules[m_idx](h)
|
356 |
+
m_idx += 1
|
357 |
+
|
358 |
+
if self.progressive != 'none':
|
359 |
+
if i_level == self.num_resolutions - 1:
|
360 |
+
if self.progressive == 'output_skip':
|
361 |
+
pyramid = self.act(modules[m_idx](h)) # GroupNorm
|
362 |
+
m_idx += 1
|
363 |
+
pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
|
364 |
+
m_idx += 1
|
365 |
+
elif self.progressive == 'residual':
|
366 |
+
pyramid = self.act(modules[m_idx](h))
|
367 |
+
m_idx += 1
|
368 |
+
pyramid = modules[m_idx](pyramid)
|
369 |
+
m_idx += 1
|
370 |
+
else:
|
371 |
+
raise ValueError(f'{self.progressive} is not a valid name.')
|
372 |
+
else:
|
373 |
+
if self.progressive == 'output_skip':
|
374 |
+
pyramid = self.pyramid_upsample(pyramid) # Upsample
|
375 |
+
pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
|
376 |
+
m_idx += 1
|
377 |
+
pyramid_h = modules[m_idx](pyramid_h)
|
378 |
+
m_idx += 1
|
379 |
+
pyramid = pyramid + pyramid_h
|
380 |
+
elif self.progressive == 'residual':
|
381 |
+
pyramid = modules[m_idx](pyramid)
|
382 |
+
m_idx += 1
|
383 |
+
if self.skip_rescale:
|
384 |
+
pyramid = (pyramid + h) / np.sqrt(2.)
|
385 |
+
else:
|
386 |
+
pyramid = pyramid + h
|
387 |
+
h = pyramid
|
388 |
+
else:
|
389 |
+
raise ValueError(f'{self.progressive} is not a valid name')
|
390 |
+
|
391 |
+
# Upsampling Layer
|
392 |
+
if i_level != 0:
|
393 |
+
if self.resblock_type == 'ddpm':
|
394 |
+
h = modules[m_idx](h)
|
395 |
+
m_idx += 1
|
396 |
+
else:
|
397 |
+
h = modules[m_idx](h, temb) # Upspampling
|
398 |
+
m_idx += 1
|
399 |
+
|
400 |
+
assert not hs
|
401 |
+
|
402 |
+
if self.progressive == 'output_skip':
|
403 |
+
h = pyramid
|
404 |
+
else:
|
405 |
+
h = self.act(modules[m_idx](h))
|
406 |
+
m_idx += 1
|
407 |
+
h = modules[m_idx](h)
|
408 |
+
m_idx += 1
|
409 |
+
|
410 |
+
assert m_idx == len(modules), "Implementation error"
|
411 |
+
if self.scale_by_sigma:
|
412 |
+
used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
|
413 |
+
h = h / used_sigmas
|
414 |
+
|
415 |
+
# Convert back to complex number
|
416 |
+
h = self.output_layer(h)
|
417 |
+
h = torch.permute(h, (0, 2, 3, 1)).contiguous()
|
418 |
+
h = torch.view_as_complex(h)[:,None, :, :]
|
419 |
+
return h
|
sgmse/backbones/ncsnpp_48k.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
|
18 |
+
from .ncsnpp_utils import layers, layerspp, normalization
|
19 |
+
import torch.nn as nn
|
20 |
+
import functools
|
21 |
+
import torch
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from .shared import BackboneRegistry
|
25 |
+
|
26 |
+
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
|
27 |
+
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
|
28 |
+
Combine = layerspp.Combine
|
29 |
+
conv3x3 = layerspp.conv3x3
|
30 |
+
conv1x1 = layerspp.conv1x1
|
31 |
+
get_act = layers.get_act
|
32 |
+
get_normalization = normalization.get_normalization
|
33 |
+
default_initializer = layers.default_init
|
34 |
+
|
35 |
+
|
36 |
+
@BackboneRegistry.register("ncsnpp_48k")
|
37 |
+
class NCSNpp_48k(nn.Module):
|
38 |
+
"""NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def add_argparse_args(parser):
|
42 |
+
parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
|
43 |
+
parser.add_argument("--num_res_blocks", type=int, default=2)
|
44 |
+
parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
|
45 |
+
parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model")
|
46 |
+
parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
|
47 |
+
parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
|
48 |
+
parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method")
|
49 |
+
parser.add_argument("--progressive_input", type=str, default='none', help="Progressive upsampling method")
|
50 |
+
parser.set_defaults(centered=True)
|
51 |
+
return parser
|
52 |
+
|
53 |
+
def __init__(self,
|
54 |
+
scale_by_sigma = True,
|
55 |
+
nonlinearity = 'swish',
|
56 |
+
nf = 128,
|
57 |
+
ch_mult = (1, 1, 2, 2, 2, 2, 2),
|
58 |
+
num_res_blocks = 2,
|
59 |
+
attn_resolutions = (),
|
60 |
+
resamp_with_conv = True,
|
61 |
+
conditional = True,
|
62 |
+
fir = True,
|
63 |
+
fir_kernel = [1, 3, 3, 1],
|
64 |
+
skip_rescale = True,
|
65 |
+
resblock_type = 'biggan',
|
66 |
+
progressive = 'none',
|
67 |
+
progressive_input = 'none',
|
68 |
+
progressive_combine = 'sum',
|
69 |
+
init_scale = 0.,
|
70 |
+
fourier_scale = 16,
|
71 |
+
image_size = 256,
|
72 |
+
embedding_type = 'fourier',
|
73 |
+
dropout = .0,
|
74 |
+
centered = True,
|
75 |
+
**unused_kwargs
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.act = act = get_act(nonlinearity)
|
79 |
+
|
80 |
+
self.nf = nf = nf
|
81 |
+
ch_mult = ch_mult
|
82 |
+
self.num_res_blocks = num_res_blocks = num_res_blocks
|
83 |
+
self.attn_resolutions = attn_resolutions
|
84 |
+
dropout = dropout
|
85 |
+
resamp_with_conv = resamp_with_conv
|
86 |
+
self.num_resolutions = num_resolutions = len(ch_mult)
|
87 |
+
self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
|
88 |
+
|
89 |
+
self.conditional = conditional = conditional # noise-conditional
|
90 |
+
self.centered = centered
|
91 |
+
self.scale_by_sigma = scale_by_sigma
|
92 |
+
|
93 |
+
fir = fir
|
94 |
+
fir_kernel = fir_kernel
|
95 |
+
self.skip_rescale = skip_rescale = skip_rescale
|
96 |
+
self.resblock_type = resblock_type = resblock_type.lower()
|
97 |
+
self.progressive = progressive = progressive.lower()
|
98 |
+
self.progressive_input = progressive_input = progressive_input.lower()
|
99 |
+
self.embedding_type = embedding_type = embedding_type.lower()
|
100 |
+
init_scale = init_scale
|
101 |
+
assert progressive in ['none', 'output_skip', 'residual']
|
102 |
+
assert progressive_input in ['none', 'input_skip', 'residual']
|
103 |
+
assert embedding_type in ['fourier', 'positional']
|
104 |
+
combine_method = progressive_combine.lower()
|
105 |
+
combiner = functools.partial(Combine, method=combine_method)
|
106 |
+
|
107 |
+
num_channels = 4 # x.real, x.imag, y.real, y.imag
|
108 |
+
self.output_layer = nn.Conv2d(num_channels, 2, 1)
|
109 |
+
|
110 |
+
modules = []
|
111 |
+
# timestep/noise_level embedding
|
112 |
+
if embedding_type == 'fourier':
|
113 |
+
# Gaussian Fourier features embeddings.
|
114 |
+
modules.append(layerspp.GaussianFourierProjection(
|
115 |
+
embedding_size=nf, scale=fourier_scale
|
116 |
+
))
|
117 |
+
embed_dim = 2 * nf
|
118 |
+
elif embedding_type == 'positional':
|
119 |
+
embed_dim = nf
|
120 |
+
else:
|
121 |
+
raise ValueError(f'embedding type {embedding_type} unknown.')
|
122 |
+
|
123 |
+
if conditional:
|
124 |
+
modules.append(nn.Linear(embed_dim, nf * 4))
|
125 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
126 |
+
nn.init.zeros_(modules[-1].bias)
|
127 |
+
modules.append(nn.Linear(nf * 4, nf * 4))
|
128 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
129 |
+
nn.init.zeros_(modules[-1].bias)
|
130 |
+
|
131 |
+
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
132 |
+
init_scale=init_scale, skip_rescale=skip_rescale)
|
133 |
+
|
134 |
+
Upsample = functools.partial(layerspp.Upsample,
|
135 |
+
with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
136 |
+
|
137 |
+
if progressive == 'output_skip':
|
138 |
+
self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
139 |
+
elif progressive == 'residual':
|
140 |
+
pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
|
141 |
+
fir_kernel=fir_kernel, with_conv=True)
|
142 |
+
|
143 |
+
Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
144 |
+
|
145 |
+
if progressive_input == 'input_skip':
|
146 |
+
self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
147 |
+
elif progressive_input == 'residual':
|
148 |
+
pyramid_downsample = functools.partial(layerspp.Downsample,
|
149 |
+
fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
150 |
+
|
151 |
+
if resblock_type == 'ddpm':
|
152 |
+
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
|
153 |
+
dropout=dropout, init_scale=init_scale,
|
154 |
+
skip_rescale=skip_rescale, temb_dim=nf * 4)
|
155 |
+
|
156 |
+
elif resblock_type == 'biggan':
|
157 |
+
ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
|
158 |
+
dropout=dropout, fir=fir, fir_kernel=fir_kernel,
|
159 |
+
init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
|
160 |
+
|
161 |
+
else:
|
162 |
+
raise ValueError(f'resblock type {resblock_type} unrecognized.')
|
163 |
+
|
164 |
+
# Downsampling block
|
165 |
+
|
166 |
+
channels = num_channels
|
167 |
+
if progressive_input != 'none':
|
168 |
+
input_pyramid_ch = channels
|
169 |
+
|
170 |
+
modules.append(conv3x3(channels, nf))
|
171 |
+
hs_c = [nf]
|
172 |
+
|
173 |
+
in_ch = nf
|
174 |
+
for i_level in range(num_resolutions):
|
175 |
+
# Residual blocks for this resolution
|
176 |
+
for i_block in range(num_res_blocks):
|
177 |
+
out_ch = nf * ch_mult[i_level]
|
178 |
+
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
179 |
+
in_ch = out_ch
|
180 |
+
|
181 |
+
if all_resolutions[i_level] in attn_resolutions:
|
182 |
+
modules.append(AttnBlock(channels=in_ch))
|
183 |
+
hs_c.append(in_ch)
|
184 |
+
|
185 |
+
if i_level != num_resolutions - 1:
|
186 |
+
if resblock_type == 'ddpm':
|
187 |
+
modules.append(Downsample(in_ch=in_ch))
|
188 |
+
else:
|
189 |
+
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
190 |
+
|
191 |
+
if progressive_input == 'input_skip':
|
192 |
+
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
193 |
+
if combine_method == 'cat':
|
194 |
+
in_ch *= 2
|
195 |
+
|
196 |
+
elif progressive_input == 'residual':
|
197 |
+
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
|
198 |
+
input_pyramid_ch = in_ch
|
199 |
+
|
200 |
+
hs_c.append(in_ch)
|
201 |
+
|
202 |
+
in_ch = hs_c[-1]
|
203 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
204 |
+
modules.append(AttnBlock(channels=in_ch))
|
205 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
206 |
+
|
207 |
+
pyramid_ch = 0
|
208 |
+
# Upsampling block
|
209 |
+
for i_level in reversed(range(num_resolutions)):
|
210 |
+
for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
|
211 |
+
out_ch = nf * ch_mult[i_level]
|
212 |
+
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
213 |
+
in_ch = out_ch
|
214 |
+
|
215 |
+
if all_resolutions[i_level] in attn_resolutions:
|
216 |
+
modules.append(AttnBlock(channels=in_ch))
|
217 |
+
|
218 |
+
if progressive != 'none':
|
219 |
+
if i_level == num_resolutions - 1:
|
220 |
+
if progressive == 'output_skip':
|
221 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
222 |
+
num_channels=in_ch, eps=1e-6))
|
223 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
224 |
+
pyramid_ch = channels
|
225 |
+
elif progressive == 'residual':
|
226 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
227 |
+
modules.append(conv3x3(in_ch, in_ch, bias=True))
|
228 |
+
pyramid_ch = in_ch
|
229 |
+
else:
|
230 |
+
raise ValueError(f'{progressive} is not a valid name.')
|
231 |
+
else:
|
232 |
+
if progressive == 'output_skip':
|
233 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
234 |
+
num_channels=in_ch, eps=1e-6))
|
235 |
+
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
|
236 |
+
pyramid_ch = channels
|
237 |
+
elif progressive == 'residual':
|
238 |
+
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
239 |
+
pyramid_ch = in_ch
|
240 |
+
else:
|
241 |
+
raise ValueError(f'{progressive} is not a valid name')
|
242 |
+
|
243 |
+
if i_level != 0:
|
244 |
+
if resblock_type == 'ddpm':
|
245 |
+
modules.append(Upsample(in_ch=in_ch))
|
246 |
+
else:
|
247 |
+
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
248 |
+
|
249 |
+
assert not hs_c
|
250 |
+
|
251 |
+
if progressive != 'output_skip':
|
252 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
253 |
+
num_channels=in_ch, eps=1e-6))
|
254 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
255 |
+
|
256 |
+
self.all_modules = nn.ModuleList(modules)
|
257 |
+
|
258 |
+
|
259 |
+
def forward(self, x, time_cond):
|
260 |
+
# timestep/noise_level embedding; only for continuous training
|
261 |
+
modules = self.all_modules
|
262 |
+
m_idx = 0
|
263 |
+
|
264 |
+
# Convert real and imaginary parts of (x,y) into four channel dimensions
|
265 |
+
x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
|
266 |
+
x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
|
267 |
+
|
268 |
+
if self.embedding_type == 'fourier':
|
269 |
+
# Gaussian Fourier features embeddings.
|
270 |
+
used_sigmas = time_cond
|
271 |
+
temb = modules[m_idx](torch.log(used_sigmas))
|
272 |
+
m_idx += 1
|
273 |
+
|
274 |
+
elif self.embedding_type == 'positional':
|
275 |
+
# Sinusoidal positional embeddings.
|
276 |
+
timesteps = time_cond
|
277 |
+
used_sigmas = self.sigmas[time_cond.long()]
|
278 |
+
temb = layers.get_timestep_embedding(timesteps, self.nf)
|
279 |
+
|
280 |
+
else:
|
281 |
+
raise ValueError(f'embedding type {self.embedding_type} unknown.')
|
282 |
+
|
283 |
+
if self.conditional:
|
284 |
+
temb = modules[m_idx](temb)
|
285 |
+
m_idx += 1
|
286 |
+
temb = modules[m_idx](self.act(temb))
|
287 |
+
m_idx += 1
|
288 |
+
else:
|
289 |
+
temb = None
|
290 |
+
|
291 |
+
if not self.centered:
|
292 |
+
# If input data is in [0, 1]
|
293 |
+
x = 2 * x - 1.
|
294 |
+
|
295 |
+
# Downsampling block
|
296 |
+
input_pyramid = None
|
297 |
+
if self.progressive_input != 'none':
|
298 |
+
input_pyramid = x
|
299 |
+
|
300 |
+
# Input layer: Conv2d: 4ch -> 128ch
|
301 |
+
hs = [modules[m_idx](x)]
|
302 |
+
m_idx += 1
|
303 |
+
|
304 |
+
# Down path in U-Net
|
305 |
+
for i_level in range(self.num_resolutions):
|
306 |
+
# Residual blocks for this resolution
|
307 |
+
for i_block in range(self.num_res_blocks):
|
308 |
+
h = modules[m_idx](hs[-1], temb)
|
309 |
+
m_idx += 1
|
310 |
+
# Attention layer (optional)
|
311 |
+
if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
|
312 |
+
h = modules[m_idx](h)
|
313 |
+
m_idx += 1
|
314 |
+
hs.append(h)
|
315 |
+
|
316 |
+
# Downsampling
|
317 |
+
if i_level != self.num_resolutions - 1:
|
318 |
+
if self.resblock_type == 'ddpm':
|
319 |
+
h = modules[m_idx](hs[-1])
|
320 |
+
m_idx += 1
|
321 |
+
else:
|
322 |
+
h = modules[m_idx](hs[-1], temb)
|
323 |
+
m_idx += 1
|
324 |
+
|
325 |
+
if self.progressive_input == 'input_skip': # Combine h with x
|
326 |
+
input_pyramid = self.pyramid_downsample(input_pyramid)
|
327 |
+
h = modules[m_idx](input_pyramid, h)
|
328 |
+
m_idx += 1
|
329 |
+
|
330 |
+
elif self.progressive_input == 'residual':
|
331 |
+
input_pyramid = modules[m_idx](input_pyramid)
|
332 |
+
m_idx += 1
|
333 |
+
if self.skip_rescale:
|
334 |
+
input_pyramid = (input_pyramid + h) / np.sqrt(2.)
|
335 |
+
else:
|
336 |
+
input_pyramid = input_pyramid + h
|
337 |
+
h = input_pyramid
|
338 |
+
hs.append(h)
|
339 |
+
|
340 |
+
h = hs[-1] # actualy equal to: h = h
|
341 |
+
h = modules[m_idx](h, temb) # ResNet block
|
342 |
+
m_idx += 1
|
343 |
+
h = modules[m_idx](h) # Attention block
|
344 |
+
m_idx += 1
|
345 |
+
h = modules[m_idx](h, temb) # ResNet block
|
346 |
+
m_idx += 1
|
347 |
+
|
348 |
+
pyramid = None
|
349 |
+
|
350 |
+
# Upsampling block
|
351 |
+
for i_level in reversed(range(self.num_resolutions)):
|
352 |
+
for i_block in range(self.num_res_blocks + 1):
|
353 |
+
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
|
354 |
+
m_idx += 1
|
355 |
+
|
356 |
+
# edit: from -1 to -2
|
357 |
+
if h.shape[-2] in self.attn_resolutions:
|
358 |
+
h = modules[m_idx](h)
|
359 |
+
m_idx += 1
|
360 |
+
|
361 |
+
if self.progressive != 'none':
|
362 |
+
if i_level == self.num_resolutions - 1:
|
363 |
+
if self.progressive == 'output_skip':
|
364 |
+
pyramid = self.act(modules[m_idx](h)) # GroupNorm
|
365 |
+
m_idx += 1
|
366 |
+
pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
|
367 |
+
m_idx += 1
|
368 |
+
elif self.progressive == 'residual':
|
369 |
+
pyramid = self.act(modules[m_idx](h))
|
370 |
+
m_idx += 1
|
371 |
+
pyramid = modules[m_idx](pyramid)
|
372 |
+
m_idx += 1
|
373 |
+
else:
|
374 |
+
raise ValueError(f'{self.progressive} is not a valid name.')
|
375 |
+
else:
|
376 |
+
if self.progressive == 'output_skip':
|
377 |
+
pyramid = self.pyramid_upsample(pyramid) # Upsample
|
378 |
+
pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
|
379 |
+
m_idx += 1
|
380 |
+
pyramid_h = modules[m_idx](pyramid_h)
|
381 |
+
m_idx += 1
|
382 |
+
pyramid = pyramid + pyramid_h
|
383 |
+
elif self.progressive == 'residual':
|
384 |
+
pyramid = modules[m_idx](pyramid)
|
385 |
+
m_idx += 1
|
386 |
+
if self.skip_rescale:
|
387 |
+
pyramid = (pyramid + h) / np.sqrt(2.)
|
388 |
+
else:
|
389 |
+
pyramid = pyramid + h
|
390 |
+
h = pyramid
|
391 |
+
else:
|
392 |
+
raise ValueError(f'{self.progressive} is not a valid name')
|
393 |
+
|
394 |
+
# Upsampling Layer
|
395 |
+
if i_level != 0:
|
396 |
+
if self.resblock_type == 'ddpm':
|
397 |
+
h = modules[m_idx](h)
|
398 |
+
m_idx += 1
|
399 |
+
else:
|
400 |
+
h = modules[m_idx](h, temb) # Upspampling
|
401 |
+
m_idx += 1
|
402 |
+
|
403 |
+
assert not hs
|
404 |
+
|
405 |
+
if self.progressive == 'output_skip':
|
406 |
+
h = pyramid
|
407 |
+
else:
|
408 |
+
h = self.act(modules[m_idx](h))
|
409 |
+
m_idx += 1
|
410 |
+
h = modules[m_idx](h)
|
411 |
+
m_idx += 1
|
412 |
+
|
413 |
+
assert m_idx == len(modules), "Implementation error"
|
414 |
+
|
415 |
+
# Convert back to complex number
|
416 |
+
h = self.output_layer(h)
|
417 |
+
|
418 |
+
if self.scale_by_sigma:
|
419 |
+
used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
|
420 |
+
h = h / used_sigmas
|
421 |
+
|
422 |
+
h = torch.permute(h, (0, 2, 3, 1)).contiguous()
|
423 |
+
h = torch.view_as_complex(h)[:,None, :, :]
|
424 |
+
return h
|
sgmse/backbones/ncsnpp_utils/layers.py
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
"""Common layers for defining score networks.
|
18 |
+
"""
|
19 |
+
import math
|
20 |
+
import string
|
21 |
+
from functools import partial
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import numpy as np
|
26 |
+
from .normalization import ConditionalInstanceNorm2dPlus
|
27 |
+
|
28 |
+
|
29 |
+
def get_act(config):
|
30 |
+
"""Get activation functions from the config file."""
|
31 |
+
|
32 |
+
if config == 'elu':
|
33 |
+
return nn.ELU()
|
34 |
+
elif config == 'relu':
|
35 |
+
return nn.ReLU()
|
36 |
+
elif config == 'lrelu':
|
37 |
+
return nn.LeakyReLU(negative_slope=0.2)
|
38 |
+
elif config == 'swish':
|
39 |
+
return nn.SiLU()
|
40 |
+
else:
|
41 |
+
raise NotImplementedError('activation function does not exist!')
|
42 |
+
|
43 |
+
|
44 |
+
def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
|
45 |
+
"""1x1 convolution. Same as NCSNv1/v2."""
|
46 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
|
47 |
+
padding=padding)
|
48 |
+
init_scale = 1e-10 if init_scale == 0 else init_scale
|
49 |
+
conv.weight.data *= init_scale
|
50 |
+
conv.bias.data *= init_scale
|
51 |
+
return conv
|
52 |
+
|
53 |
+
|
54 |
+
def variance_scaling(scale, mode, distribution,
|
55 |
+
in_axis=1, out_axis=0,
|
56 |
+
dtype=torch.float32,
|
57 |
+
device='cpu'):
|
58 |
+
"""Ported from JAX. """
|
59 |
+
|
60 |
+
def _compute_fans(shape, in_axis=1, out_axis=0):
|
61 |
+
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
62 |
+
fan_in = shape[in_axis] * receptive_field_size
|
63 |
+
fan_out = shape[out_axis] * receptive_field_size
|
64 |
+
return fan_in, fan_out
|
65 |
+
|
66 |
+
def init(shape, dtype=dtype, device=device):
|
67 |
+
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
68 |
+
if mode == "fan_in":
|
69 |
+
denominator = fan_in
|
70 |
+
elif mode == "fan_out":
|
71 |
+
denominator = fan_out
|
72 |
+
elif mode == "fan_avg":
|
73 |
+
denominator = (fan_in + fan_out) / 2
|
74 |
+
else:
|
75 |
+
raise ValueError(
|
76 |
+
"invalid mode for variance scaling initializer: {}".format(mode))
|
77 |
+
variance = scale / denominator
|
78 |
+
if distribution == "normal":
|
79 |
+
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
80 |
+
elif distribution == "uniform":
|
81 |
+
return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
|
82 |
+
else:
|
83 |
+
raise ValueError("invalid distribution for variance scaling initializer")
|
84 |
+
|
85 |
+
return init
|
86 |
+
|
87 |
+
|
88 |
+
def default_init(scale=1.):
|
89 |
+
"""The same initialization used in DDPM."""
|
90 |
+
scale = 1e-10 if scale == 0 else scale
|
91 |
+
return variance_scaling(scale, 'fan_avg', 'uniform')
|
92 |
+
|
93 |
+
|
94 |
+
class Dense(nn.Module):
|
95 |
+
"""Linear layer with `default_init`."""
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
|
100 |
+
def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
|
101 |
+
"""1x1 convolution with DDPM initialization."""
|
102 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
103 |
+
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
104 |
+
nn.init.zeros_(conv.bias)
|
105 |
+
return conv
|
106 |
+
|
107 |
+
|
108 |
+
def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
|
109 |
+
"""3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
|
110 |
+
init_scale = 1e-10 if init_scale == 0 else init_scale
|
111 |
+
conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
|
112 |
+
dilation=dilation, padding=padding, kernel_size=3)
|
113 |
+
conv.weight.data *= init_scale
|
114 |
+
conv.bias.data *= init_scale
|
115 |
+
return conv
|
116 |
+
|
117 |
+
|
118 |
+
def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
|
119 |
+
"""3x3 convolution with DDPM initialization."""
|
120 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
|
121 |
+
dilation=dilation, bias=bias)
|
122 |
+
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
123 |
+
nn.init.zeros_(conv.bias)
|
124 |
+
return conv
|
125 |
+
|
126 |
+
###########################################################################
|
127 |
+
# Functions below are ported over from the NCSNv1/NCSNv2 codebase:
|
128 |
+
# https://github.com/ermongroup/ncsn
|
129 |
+
# https://github.com/ermongroup/ncsnv2
|
130 |
+
###########################################################################
|
131 |
+
|
132 |
+
|
133 |
+
class CRPBlock(nn.Module):
|
134 |
+
def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
|
135 |
+
super().__init__()
|
136 |
+
self.convs = nn.ModuleList()
|
137 |
+
for i in range(n_stages):
|
138 |
+
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
|
139 |
+
self.n_stages = n_stages
|
140 |
+
if maxpool:
|
141 |
+
self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
|
142 |
+
else:
|
143 |
+
self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
|
144 |
+
|
145 |
+
self.act = act
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
x = self.act(x)
|
149 |
+
path = x
|
150 |
+
for i in range(self.n_stages):
|
151 |
+
path = self.pool(path)
|
152 |
+
path = self.convs[i](path)
|
153 |
+
x = path + x
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class CondCRPBlock(nn.Module):
|
158 |
+
def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
|
159 |
+
super().__init__()
|
160 |
+
self.convs = nn.ModuleList()
|
161 |
+
self.norms = nn.ModuleList()
|
162 |
+
self.normalizer = normalizer
|
163 |
+
for i in range(n_stages):
|
164 |
+
self.norms.append(normalizer(features, num_classes, bias=True))
|
165 |
+
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
|
166 |
+
|
167 |
+
self.n_stages = n_stages
|
168 |
+
self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
|
169 |
+
self.act = act
|
170 |
+
|
171 |
+
def forward(self, x, y):
|
172 |
+
x = self.act(x)
|
173 |
+
path = x
|
174 |
+
for i in range(self.n_stages):
|
175 |
+
path = self.norms[i](path, y)
|
176 |
+
path = self.pool(path)
|
177 |
+
path = self.convs[i](path)
|
178 |
+
|
179 |
+
x = path + x
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class RCUBlock(nn.Module):
|
184 |
+
def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
for i in range(n_blocks):
|
188 |
+
for j in range(n_stages):
|
189 |
+
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
|
190 |
+
|
191 |
+
self.stride = 1
|
192 |
+
self.n_blocks = n_blocks
|
193 |
+
self.n_stages = n_stages
|
194 |
+
self.act = act
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
for i in range(self.n_blocks):
|
198 |
+
residual = x
|
199 |
+
for j in range(self.n_stages):
|
200 |
+
x = self.act(x)
|
201 |
+
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
|
202 |
+
|
203 |
+
x += residual
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
class CondRCUBlock(nn.Module):
|
208 |
+
def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
for i in range(n_blocks):
|
212 |
+
for j in range(n_stages):
|
213 |
+
setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
|
214 |
+
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
|
215 |
+
|
216 |
+
self.stride = 1
|
217 |
+
self.n_blocks = n_blocks
|
218 |
+
self.n_stages = n_stages
|
219 |
+
self.act = act
|
220 |
+
self.normalizer = normalizer
|
221 |
+
|
222 |
+
def forward(self, x, y):
|
223 |
+
for i in range(self.n_blocks):
|
224 |
+
residual = x
|
225 |
+
for j in range(self.n_stages):
|
226 |
+
x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
|
227 |
+
x = self.act(x)
|
228 |
+
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
|
229 |
+
|
230 |
+
x += residual
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
class MSFBlock(nn.Module):
|
235 |
+
def __init__(self, in_planes, features):
|
236 |
+
super().__init__()
|
237 |
+
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
|
238 |
+
self.convs = nn.ModuleList()
|
239 |
+
self.features = features
|
240 |
+
|
241 |
+
for i in range(len(in_planes)):
|
242 |
+
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
|
243 |
+
|
244 |
+
def forward(self, xs, shape):
|
245 |
+
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
|
246 |
+
for i in range(len(self.convs)):
|
247 |
+
h = self.convs[i](xs[i])
|
248 |
+
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
|
249 |
+
sums += h
|
250 |
+
return sums
|
251 |
+
|
252 |
+
|
253 |
+
class CondMSFBlock(nn.Module):
|
254 |
+
def __init__(self, in_planes, features, num_classes, normalizer):
|
255 |
+
super().__init__()
|
256 |
+
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
|
257 |
+
|
258 |
+
self.convs = nn.ModuleList()
|
259 |
+
self.norms = nn.ModuleList()
|
260 |
+
self.features = features
|
261 |
+
self.normalizer = normalizer
|
262 |
+
|
263 |
+
for i in range(len(in_planes)):
|
264 |
+
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
|
265 |
+
self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
|
266 |
+
|
267 |
+
def forward(self, xs, y, shape):
|
268 |
+
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
|
269 |
+
for i in range(len(self.convs)):
|
270 |
+
h = self.norms[i](xs[i], y)
|
271 |
+
h = self.convs[i](h)
|
272 |
+
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
|
273 |
+
sums += h
|
274 |
+
return sums
|
275 |
+
|
276 |
+
|
277 |
+
class RefineBlock(nn.Module):
|
278 |
+
def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
|
279 |
+
super().__init__()
|
280 |
+
|
281 |
+
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
|
282 |
+
self.n_blocks = n_blocks = len(in_planes)
|
283 |
+
|
284 |
+
self.adapt_convs = nn.ModuleList()
|
285 |
+
for i in range(n_blocks):
|
286 |
+
self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
|
287 |
+
|
288 |
+
self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
|
289 |
+
|
290 |
+
if not start:
|
291 |
+
self.msf = MSFBlock(in_planes, features)
|
292 |
+
|
293 |
+
self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
|
294 |
+
|
295 |
+
def forward(self, xs, output_shape):
|
296 |
+
assert isinstance(xs, tuple) or isinstance(xs, list)
|
297 |
+
hs = []
|
298 |
+
for i in range(len(xs)):
|
299 |
+
h = self.adapt_convs[i](xs[i])
|
300 |
+
hs.append(h)
|
301 |
+
|
302 |
+
if self.n_blocks > 1:
|
303 |
+
h = self.msf(hs, output_shape)
|
304 |
+
else:
|
305 |
+
h = hs[0]
|
306 |
+
|
307 |
+
h = self.crp(h)
|
308 |
+
h = self.output_convs(h)
|
309 |
+
|
310 |
+
return h
|
311 |
+
|
312 |
+
|
313 |
+
class CondRefineBlock(nn.Module):
|
314 |
+
def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
|
315 |
+
super().__init__()
|
316 |
+
|
317 |
+
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
|
318 |
+
self.n_blocks = n_blocks = len(in_planes)
|
319 |
+
|
320 |
+
self.adapt_convs = nn.ModuleList()
|
321 |
+
for i in range(n_blocks):
|
322 |
+
self.adapt_convs.append(
|
323 |
+
CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
|
324 |
+
)
|
325 |
+
|
326 |
+
self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
|
327 |
+
|
328 |
+
if not start:
|
329 |
+
self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
|
330 |
+
|
331 |
+
self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
|
332 |
+
|
333 |
+
def forward(self, xs, y, output_shape):
|
334 |
+
assert isinstance(xs, tuple) or isinstance(xs, list)
|
335 |
+
hs = []
|
336 |
+
for i in range(len(xs)):
|
337 |
+
h = self.adapt_convs[i](xs[i], y)
|
338 |
+
hs.append(h)
|
339 |
+
|
340 |
+
if self.n_blocks > 1:
|
341 |
+
h = self.msf(hs, y, output_shape)
|
342 |
+
else:
|
343 |
+
h = hs[0]
|
344 |
+
|
345 |
+
h = self.crp(h, y)
|
346 |
+
h = self.output_convs(h, y)
|
347 |
+
|
348 |
+
return h
|
349 |
+
|
350 |
+
|
351 |
+
class ConvMeanPool(nn.Module):
|
352 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
|
353 |
+
super().__init__()
|
354 |
+
if not adjust_padding:
|
355 |
+
conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
356 |
+
self.conv = conv
|
357 |
+
else:
|
358 |
+
conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
359 |
+
|
360 |
+
self.conv = nn.Sequential(
|
361 |
+
nn.ZeroPad2d((1, 0, 1, 0)),
|
362 |
+
conv
|
363 |
+
)
|
364 |
+
|
365 |
+
def forward(self, inputs):
|
366 |
+
output = self.conv(inputs)
|
367 |
+
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
|
368 |
+
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
|
369 |
+
return output
|
370 |
+
|
371 |
+
|
372 |
+
class MeanPoolConv(nn.Module):
|
373 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
|
374 |
+
super().__init__()
|
375 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
376 |
+
|
377 |
+
def forward(self, inputs):
|
378 |
+
output = inputs
|
379 |
+
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
|
380 |
+
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
|
381 |
+
return self.conv(output)
|
382 |
+
|
383 |
+
|
384 |
+
class UpsampleConv(nn.Module):
|
385 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
|
386 |
+
super().__init__()
|
387 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
388 |
+
self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
|
389 |
+
|
390 |
+
def forward(self, inputs):
|
391 |
+
output = inputs
|
392 |
+
output = torch.cat([output, output, output, output], dim=1)
|
393 |
+
output = self.pixelshuffle(output)
|
394 |
+
return self.conv(output)
|
395 |
+
|
396 |
+
|
397 |
+
class ConditionalResidualBlock(nn.Module):
|
398 |
+
def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
|
399 |
+
normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
|
400 |
+
super().__init__()
|
401 |
+
self.non_linearity = act
|
402 |
+
self.input_dim = input_dim
|
403 |
+
self.output_dim = output_dim
|
404 |
+
self.resample = resample
|
405 |
+
self.normalization = normalization
|
406 |
+
if resample == 'down':
|
407 |
+
if dilation > 1:
|
408 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
|
409 |
+
self.normalize2 = normalization(input_dim, num_classes)
|
410 |
+
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
411 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
412 |
+
else:
|
413 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
|
414 |
+
self.normalize2 = normalization(input_dim, num_classes)
|
415 |
+
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
|
416 |
+
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
|
417 |
+
|
418 |
+
elif resample is None:
|
419 |
+
if dilation > 1:
|
420 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
421 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
422 |
+
self.normalize2 = normalization(output_dim, num_classes)
|
423 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
|
424 |
+
else:
|
425 |
+
conv_shortcut = nn.Conv2d
|
426 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
|
427 |
+
self.normalize2 = normalization(output_dim, num_classes)
|
428 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
|
429 |
+
else:
|
430 |
+
raise Exception('invalid resample value')
|
431 |
+
|
432 |
+
if output_dim != input_dim or resample is not None:
|
433 |
+
self.shortcut = conv_shortcut(input_dim, output_dim)
|
434 |
+
|
435 |
+
self.normalize1 = normalization(input_dim, num_classes)
|
436 |
+
|
437 |
+
def forward(self, x, y):
|
438 |
+
output = self.normalize1(x, y)
|
439 |
+
output = self.non_linearity(output)
|
440 |
+
output = self.conv1(output)
|
441 |
+
output = self.normalize2(output, y)
|
442 |
+
output = self.non_linearity(output)
|
443 |
+
output = self.conv2(output)
|
444 |
+
|
445 |
+
if self.output_dim == self.input_dim and self.resample is None:
|
446 |
+
shortcut = x
|
447 |
+
else:
|
448 |
+
shortcut = self.shortcut(x)
|
449 |
+
|
450 |
+
return shortcut + output
|
451 |
+
|
452 |
+
|
453 |
+
class ResidualBlock(nn.Module):
|
454 |
+
def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
|
455 |
+
normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
|
456 |
+
super().__init__()
|
457 |
+
self.non_linearity = act
|
458 |
+
self.input_dim = input_dim
|
459 |
+
self.output_dim = output_dim
|
460 |
+
self.resample = resample
|
461 |
+
self.normalization = normalization
|
462 |
+
if resample == 'down':
|
463 |
+
if dilation > 1:
|
464 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
|
465 |
+
self.normalize2 = normalization(input_dim)
|
466 |
+
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
467 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
468 |
+
else:
|
469 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
|
470 |
+
self.normalize2 = normalization(input_dim)
|
471 |
+
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
|
472 |
+
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
|
473 |
+
|
474 |
+
elif resample is None:
|
475 |
+
if dilation > 1:
|
476 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
477 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
478 |
+
self.normalize2 = normalization(output_dim)
|
479 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
|
480 |
+
else:
|
481 |
+
# conv_shortcut = nn.Conv2d ### Something wierd here.
|
482 |
+
conv_shortcut = partial(ncsn_conv1x1)
|
483 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
|
484 |
+
self.normalize2 = normalization(output_dim)
|
485 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
|
486 |
+
else:
|
487 |
+
raise Exception('invalid resample value')
|
488 |
+
|
489 |
+
if output_dim != input_dim or resample is not None:
|
490 |
+
self.shortcut = conv_shortcut(input_dim, output_dim)
|
491 |
+
|
492 |
+
self.normalize1 = normalization(input_dim)
|
493 |
+
|
494 |
+
def forward(self, x):
|
495 |
+
output = self.normalize1(x)
|
496 |
+
output = self.non_linearity(output)
|
497 |
+
output = self.conv1(output)
|
498 |
+
output = self.normalize2(output)
|
499 |
+
output = self.non_linearity(output)
|
500 |
+
output = self.conv2(output)
|
501 |
+
|
502 |
+
if self.output_dim == self.input_dim and self.resample is None:
|
503 |
+
shortcut = x
|
504 |
+
else:
|
505 |
+
shortcut = self.shortcut(x)
|
506 |
+
|
507 |
+
return shortcut + output
|
508 |
+
|
509 |
+
|
510 |
+
###########################################################################
|
511 |
+
# Functions below are ported over from the DDPM codebase:
|
512 |
+
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
|
513 |
+
###########################################################################
|
514 |
+
|
515 |
+
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
516 |
+
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
517 |
+
half_dim = embedding_dim // 2
|
518 |
+
# magic number 10000 is from transformers
|
519 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
520 |
+
# emb = math.log(2.) / (half_dim - 1)
|
521 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
522 |
+
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
|
523 |
+
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
|
524 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
525 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
526 |
+
if embedding_dim % 2 == 1: # zero pad
|
527 |
+
emb = F.pad(emb, (0, 1), mode='constant')
|
528 |
+
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
529 |
+
return emb
|
530 |
+
|
531 |
+
|
532 |
+
def _einsum(a, b, c, x, y):
|
533 |
+
einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
|
534 |
+
return torch.einsum(einsum_str, x, y)
|
535 |
+
|
536 |
+
|
537 |
+
def contract_inner(x, y):
|
538 |
+
"""tensordot(x, y, 1)."""
|
539 |
+
x_chars = list(string.ascii_lowercase[:len(x.shape)])
|
540 |
+
y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
|
541 |
+
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
542 |
+
out_chars = x_chars[:-1] + y_chars[1:]
|
543 |
+
return _einsum(x_chars, y_chars, out_chars, x, y)
|
544 |
+
|
545 |
+
|
546 |
+
class NIN(nn.Module):
|
547 |
+
def __init__(self, in_dim, num_units, init_scale=0.1):
|
548 |
+
super().__init__()
|
549 |
+
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
550 |
+
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
551 |
+
|
552 |
+
def forward(self, x):
|
553 |
+
x = x.permute(0, 2, 3, 1)
|
554 |
+
y = contract_inner(x, self.W) + self.b
|
555 |
+
return y.permute(0, 3, 1, 2)
|
556 |
+
|
557 |
+
|
558 |
+
class AttnBlock(nn.Module):
|
559 |
+
"""Channel-wise self-attention block."""
|
560 |
+
def __init__(self, channels):
|
561 |
+
super().__init__()
|
562 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
|
563 |
+
self.NIN_0 = NIN(channels, channels)
|
564 |
+
self.NIN_1 = NIN(channels, channels)
|
565 |
+
self.NIN_2 = NIN(channels, channels)
|
566 |
+
self.NIN_3 = NIN(channels, channels, init_scale=0.)
|
567 |
+
|
568 |
+
def forward(self, x):
|
569 |
+
B, C, H, W = x.shape
|
570 |
+
h = self.GroupNorm_0(x)
|
571 |
+
q = self.NIN_0(h)
|
572 |
+
k = self.NIN_1(h)
|
573 |
+
v = self.NIN_2(h)
|
574 |
+
|
575 |
+
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
|
576 |
+
w = torch.reshape(w, (B, H, W, H * W))
|
577 |
+
w = F.softmax(w, dim=-1)
|
578 |
+
w = torch.reshape(w, (B, H, W, H, W))
|
579 |
+
h = torch.einsum('bhwij,bcij->bchw', w, v)
|
580 |
+
h = self.NIN_3(h)
|
581 |
+
return x + h
|
582 |
+
|
583 |
+
|
584 |
+
class Upsample(nn.Module):
|
585 |
+
def __init__(self, channels, with_conv=False):
|
586 |
+
super().__init__()
|
587 |
+
if with_conv:
|
588 |
+
self.Conv_0 = ddpm_conv3x3(channels, channels)
|
589 |
+
self.with_conv = with_conv
|
590 |
+
|
591 |
+
def forward(self, x):
|
592 |
+
B, C, H, W = x.shape
|
593 |
+
h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
|
594 |
+
if self.with_conv:
|
595 |
+
h = self.Conv_0(h)
|
596 |
+
return h
|
597 |
+
|
598 |
+
|
599 |
+
class Downsample(nn.Module):
|
600 |
+
def __init__(self, channels, with_conv=False):
|
601 |
+
super().__init__()
|
602 |
+
if with_conv:
|
603 |
+
self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
|
604 |
+
self.with_conv = with_conv
|
605 |
+
|
606 |
+
def forward(self, x):
|
607 |
+
B, C, H, W = x.shape
|
608 |
+
# Emulate 'SAME' padding
|
609 |
+
if self.with_conv:
|
610 |
+
x = F.pad(x, (0, 1, 0, 1))
|
611 |
+
x = self.Conv_0(x)
|
612 |
+
else:
|
613 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
|
614 |
+
|
615 |
+
assert x.shape == (B, C, H // 2, W // 2)
|
616 |
+
return x
|
617 |
+
|
618 |
+
|
619 |
+
class ResnetBlockDDPM(nn.Module):
|
620 |
+
"""The ResNet Blocks used in DDPM."""
|
621 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
|
622 |
+
super().__init__()
|
623 |
+
if out_ch is None:
|
624 |
+
out_ch = in_ch
|
625 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
|
626 |
+
self.act = act
|
627 |
+
self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
|
628 |
+
if temb_dim is not None:
|
629 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
630 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
631 |
+
nn.init.zeros_(self.Dense_0.bias)
|
632 |
+
|
633 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
|
634 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
635 |
+
self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
|
636 |
+
if in_ch != out_ch:
|
637 |
+
if conv_shortcut:
|
638 |
+
self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
|
639 |
+
else:
|
640 |
+
self.NIN_0 = NIN(in_ch, out_ch)
|
641 |
+
self.out_ch = out_ch
|
642 |
+
self.in_ch = in_ch
|
643 |
+
self.conv_shortcut = conv_shortcut
|
644 |
+
|
645 |
+
def forward(self, x, temb=None):
|
646 |
+
B, C, H, W = x.shape
|
647 |
+
assert C == self.in_ch
|
648 |
+
out_ch = self.out_ch if self.out_ch else self.in_ch
|
649 |
+
h = self.act(self.GroupNorm_0(x))
|
650 |
+
h = self.Conv_0(h)
|
651 |
+
# Add bias to each feature map conditioned on the time embedding
|
652 |
+
if temb is not None:
|
653 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
654 |
+
h = self.act(self.GroupNorm_1(h))
|
655 |
+
h = self.Dropout_0(h)
|
656 |
+
h = self.Conv_1(h)
|
657 |
+
if C != out_ch:
|
658 |
+
if self.conv_shortcut:
|
659 |
+
x = self.Conv_2(x)
|
660 |
+
else:
|
661 |
+
x = self.NIN_0(x)
|
662 |
+
return x + h
|
sgmse/backbones/ncsnpp_utils/layerspp.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
"""Layers for defining NCSN++.
|
18 |
+
"""
|
19 |
+
from . import layers
|
20 |
+
from . import up_or_down_sampling
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
conv1x1 = layers.ddpm_conv1x1
|
27 |
+
conv3x3 = layers.ddpm_conv3x3
|
28 |
+
NIN = layers.NIN
|
29 |
+
default_init = layers.default_init
|
30 |
+
|
31 |
+
|
32 |
+
class GaussianFourierProjection(nn.Module):
|
33 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
34 |
+
|
35 |
+
def __init__(self, embedding_size=256, scale=1.0):
|
36 |
+
super().__init__()
|
37 |
+
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
41 |
+
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
42 |
+
|
43 |
+
|
44 |
+
class Combine(nn.Module):
|
45 |
+
"""Combine information from skip connections."""
|
46 |
+
|
47 |
+
def __init__(self, dim1, dim2, method='cat'):
|
48 |
+
super().__init__()
|
49 |
+
self.Conv_0 = conv1x1(dim1, dim2)
|
50 |
+
self.method = method
|
51 |
+
|
52 |
+
def forward(self, x, y):
|
53 |
+
h = self.Conv_0(x)
|
54 |
+
if self.method == 'cat':
|
55 |
+
return torch.cat([h, y], dim=1)
|
56 |
+
elif self.method == 'sum':
|
57 |
+
return h + y
|
58 |
+
else:
|
59 |
+
raise ValueError(f'Method {self.method} not recognized.')
|
60 |
+
|
61 |
+
|
62 |
+
class AttnBlockpp(nn.Module):
|
63 |
+
"""Channel-wise self-attention block. Modified from DDPM."""
|
64 |
+
|
65 |
+
def __init__(self, channels, skip_rescale=False, init_scale=0.):
|
66 |
+
super().__init__()
|
67 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
|
68 |
+
eps=1e-6)
|
69 |
+
self.NIN_0 = NIN(channels, channels)
|
70 |
+
self.NIN_1 = NIN(channels, channels)
|
71 |
+
self.NIN_2 = NIN(channels, channels)
|
72 |
+
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
|
73 |
+
self.skip_rescale = skip_rescale
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
B, C, H, W = x.shape
|
77 |
+
h = self.GroupNorm_0(x)
|
78 |
+
q = self.NIN_0(h)
|
79 |
+
k = self.NIN_1(h)
|
80 |
+
v = self.NIN_2(h)
|
81 |
+
|
82 |
+
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
|
83 |
+
w = torch.reshape(w, (B, H, W, H * W))
|
84 |
+
w = F.softmax(w, dim=-1)
|
85 |
+
w = torch.reshape(w, (B, H, W, H, W))
|
86 |
+
h = torch.einsum('bhwij,bcij->bchw', w, v)
|
87 |
+
h = self.NIN_3(h)
|
88 |
+
if not self.skip_rescale:
|
89 |
+
return x + h
|
90 |
+
else:
|
91 |
+
return (x + h) / np.sqrt(2.)
|
92 |
+
|
93 |
+
|
94 |
+
class Upsample(nn.Module):
|
95 |
+
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
|
96 |
+
fir_kernel=(1, 3, 3, 1)):
|
97 |
+
super().__init__()
|
98 |
+
out_ch = out_ch if out_ch else in_ch
|
99 |
+
if not fir:
|
100 |
+
if with_conv:
|
101 |
+
self.Conv_0 = conv3x3(in_ch, out_ch)
|
102 |
+
else:
|
103 |
+
if with_conv:
|
104 |
+
self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
|
105 |
+
kernel=3, up=True,
|
106 |
+
resample_kernel=fir_kernel,
|
107 |
+
use_bias=True,
|
108 |
+
kernel_init=default_init())
|
109 |
+
self.fir = fir
|
110 |
+
self.with_conv = with_conv
|
111 |
+
self.fir_kernel = fir_kernel
|
112 |
+
self.out_ch = out_ch
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
B, C, H, W = x.shape
|
116 |
+
if not self.fir:
|
117 |
+
h = F.interpolate(x, (H * 2, W * 2), 'nearest')
|
118 |
+
if self.with_conv:
|
119 |
+
h = self.Conv_0(h)
|
120 |
+
else:
|
121 |
+
if not self.with_conv:
|
122 |
+
h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
|
123 |
+
else:
|
124 |
+
h = self.Conv2d_0(x)
|
125 |
+
|
126 |
+
return h
|
127 |
+
|
128 |
+
|
129 |
+
class Downsample(nn.Module):
|
130 |
+
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
|
131 |
+
fir_kernel=(1, 3, 3, 1)):
|
132 |
+
super().__init__()
|
133 |
+
out_ch = out_ch if out_ch else in_ch
|
134 |
+
if not fir:
|
135 |
+
if with_conv:
|
136 |
+
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
|
137 |
+
else:
|
138 |
+
if with_conv:
|
139 |
+
self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
|
140 |
+
kernel=3, down=True,
|
141 |
+
resample_kernel=fir_kernel,
|
142 |
+
use_bias=True,
|
143 |
+
kernel_init=default_init())
|
144 |
+
self.fir = fir
|
145 |
+
self.fir_kernel = fir_kernel
|
146 |
+
self.with_conv = with_conv
|
147 |
+
self.out_ch = out_ch
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
B, C, H, W = x.shape
|
151 |
+
if not self.fir:
|
152 |
+
if self.with_conv:
|
153 |
+
x = F.pad(x, (0, 1, 0, 1))
|
154 |
+
x = self.Conv_0(x)
|
155 |
+
else:
|
156 |
+
x = F.avg_pool2d(x, 2, stride=2)
|
157 |
+
else:
|
158 |
+
if not self.with_conv:
|
159 |
+
x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
|
160 |
+
else:
|
161 |
+
x = self.Conv2d_0(x)
|
162 |
+
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
class ResnetBlockDDPMpp(nn.Module):
|
167 |
+
"""ResBlock adapted from DDPM."""
|
168 |
+
|
169 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
|
170 |
+
dropout=0.1, skip_rescale=False, init_scale=0.):
|
171 |
+
super().__init__()
|
172 |
+
out_ch = out_ch if out_ch else in_ch
|
173 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
174 |
+
self.Conv_0 = conv3x3(in_ch, out_ch)
|
175 |
+
if temb_dim is not None:
|
176 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
177 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
178 |
+
nn.init.zeros_(self.Dense_0.bias)
|
179 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
180 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
181 |
+
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
182 |
+
if in_ch != out_ch:
|
183 |
+
if conv_shortcut:
|
184 |
+
self.Conv_2 = conv3x3(in_ch, out_ch)
|
185 |
+
else:
|
186 |
+
self.NIN_0 = NIN(in_ch, out_ch)
|
187 |
+
|
188 |
+
self.skip_rescale = skip_rescale
|
189 |
+
self.act = act
|
190 |
+
self.out_ch = out_ch
|
191 |
+
self.conv_shortcut = conv_shortcut
|
192 |
+
|
193 |
+
def forward(self, x, temb=None):
|
194 |
+
h = self.act(self.GroupNorm_0(x))
|
195 |
+
h = self.Conv_0(h)
|
196 |
+
if temb is not None:
|
197 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
198 |
+
h = self.act(self.GroupNorm_1(h))
|
199 |
+
h = self.Dropout_0(h)
|
200 |
+
h = self.Conv_1(h)
|
201 |
+
if x.shape[1] != self.out_ch:
|
202 |
+
if self.conv_shortcut:
|
203 |
+
x = self.Conv_2(x)
|
204 |
+
else:
|
205 |
+
x = self.NIN_0(x)
|
206 |
+
if not self.skip_rescale:
|
207 |
+
return x + h
|
208 |
+
else:
|
209 |
+
return (x + h) / np.sqrt(2.)
|
210 |
+
|
211 |
+
|
212 |
+
class ResnetBlockBigGANpp(nn.Module):
|
213 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
|
214 |
+
dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
|
215 |
+
skip_rescale=True, init_scale=0.):
|
216 |
+
super().__init__()
|
217 |
+
|
218 |
+
out_ch = out_ch if out_ch else in_ch
|
219 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
220 |
+
self.up = up
|
221 |
+
self.down = down
|
222 |
+
self.fir = fir
|
223 |
+
self.fir_kernel = fir_kernel
|
224 |
+
|
225 |
+
self.Conv_0 = conv3x3(in_ch, out_ch)
|
226 |
+
if temb_dim is not None:
|
227 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
228 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
|
229 |
+
nn.init.zeros_(self.Dense_0.bias)
|
230 |
+
|
231 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
232 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
233 |
+
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
234 |
+
if in_ch != out_ch or up or down:
|
235 |
+
self.Conv_2 = conv1x1(in_ch, out_ch)
|
236 |
+
|
237 |
+
self.skip_rescale = skip_rescale
|
238 |
+
self.act = act
|
239 |
+
self.in_ch = in_ch
|
240 |
+
self.out_ch = out_ch
|
241 |
+
|
242 |
+
def forward(self, x, temb=None):
|
243 |
+
h = self.act(self.GroupNorm_0(x))
|
244 |
+
|
245 |
+
if self.up:
|
246 |
+
if self.fir:
|
247 |
+
h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
|
248 |
+
x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
|
249 |
+
else:
|
250 |
+
h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
|
251 |
+
x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
|
252 |
+
elif self.down:
|
253 |
+
if self.fir:
|
254 |
+
h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
|
255 |
+
x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
|
256 |
+
else:
|
257 |
+
h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
|
258 |
+
x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
|
259 |
+
|
260 |
+
h = self.Conv_0(h)
|
261 |
+
# Add bias to each feature map conditioned on the time embedding
|
262 |
+
if temb is not None:
|
263 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
264 |
+
h = self.act(self.GroupNorm_1(h))
|
265 |
+
h = self.Dropout_0(h)
|
266 |
+
h = self.Conv_1(h)
|
267 |
+
|
268 |
+
if self.in_ch != self.out_ch or self.up or self.down:
|
269 |
+
x = self.Conv_2(x)
|
270 |
+
|
271 |
+
if not self.skip_rescale:
|
272 |
+
return x + h
|
273 |
+
else:
|
274 |
+
return (x + h) / np.sqrt(2.)
|
sgmse/backbones/ncsnpp_utils/normalization.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Normalization layers."""
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch
|
19 |
+
import functools
|
20 |
+
|
21 |
+
|
22 |
+
def get_normalization(config, conditional=False):
|
23 |
+
"""Obtain normalization modules from the config file."""
|
24 |
+
norm = config.model.normalization
|
25 |
+
if conditional:
|
26 |
+
if norm == 'InstanceNorm++':
|
27 |
+
return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError(f'{norm} not implemented yet.')
|
30 |
+
else:
|
31 |
+
if norm == 'InstanceNorm':
|
32 |
+
return nn.InstanceNorm2d
|
33 |
+
elif norm == 'InstanceNorm++':
|
34 |
+
return InstanceNorm2dPlus
|
35 |
+
elif norm == 'VarianceNorm':
|
36 |
+
return VarianceNorm2d
|
37 |
+
elif norm == 'GroupNorm':
|
38 |
+
return nn.GroupNorm
|
39 |
+
else:
|
40 |
+
raise ValueError('Unknown normalization: %s' % norm)
|
41 |
+
|
42 |
+
|
43 |
+
class ConditionalBatchNorm2d(nn.Module):
|
44 |
+
def __init__(self, num_features, num_classes, bias=True):
|
45 |
+
super().__init__()
|
46 |
+
self.num_features = num_features
|
47 |
+
self.bias = bias
|
48 |
+
self.bn = nn.BatchNorm2d(num_features, affine=False)
|
49 |
+
if self.bias:
|
50 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
51 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
52 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
53 |
+
else:
|
54 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
55 |
+
self.embed.weight.data.uniform_()
|
56 |
+
|
57 |
+
def forward(self, x, y):
|
58 |
+
out = self.bn(x)
|
59 |
+
if self.bias:
|
60 |
+
gamma, beta = self.embed(y).chunk(2, dim=1)
|
61 |
+
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
|
62 |
+
else:
|
63 |
+
gamma = self.embed(y)
|
64 |
+
out = gamma.view(-1, self.num_features, 1, 1) * out
|
65 |
+
return out
|
66 |
+
|
67 |
+
|
68 |
+
class ConditionalInstanceNorm2d(nn.Module):
|
69 |
+
def __init__(self, num_features, num_classes, bias=True):
|
70 |
+
super().__init__()
|
71 |
+
self.num_features = num_features
|
72 |
+
self.bias = bias
|
73 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
74 |
+
if bias:
|
75 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
76 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
77 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
78 |
+
else:
|
79 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
80 |
+
self.embed.weight.data.uniform_()
|
81 |
+
|
82 |
+
def forward(self, x, y):
|
83 |
+
h = self.instance_norm(x)
|
84 |
+
if self.bias:
|
85 |
+
gamma, beta = self.embed(y).chunk(2, dim=-1)
|
86 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
|
87 |
+
else:
|
88 |
+
gamma = self.embed(y)
|
89 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
90 |
+
return out
|
91 |
+
|
92 |
+
|
93 |
+
class ConditionalVarianceNorm2d(nn.Module):
|
94 |
+
def __init__(self, num_features, num_classes, bias=False):
|
95 |
+
super().__init__()
|
96 |
+
self.num_features = num_features
|
97 |
+
self.bias = bias
|
98 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
99 |
+
self.embed.weight.data.normal_(1, 0.02)
|
100 |
+
|
101 |
+
def forward(self, x, y):
|
102 |
+
vars = torch.var(x, dim=(2, 3), keepdim=True)
|
103 |
+
h = x / torch.sqrt(vars + 1e-5)
|
104 |
+
|
105 |
+
gamma = self.embed(y)
|
106 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
class VarianceNorm2d(nn.Module):
|
111 |
+
def __init__(self, num_features, bias=False):
|
112 |
+
super().__init__()
|
113 |
+
self.num_features = num_features
|
114 |
+
self.bias = bias
|
115 |
+
self.alpha = nn.Parameter(torch.zeros(num_features))
|
116 |
+
self.alpha.data.normal_(1, 0.02)
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
vars = torch.var(x, dim=(2, 3), keepdim=True)
|
120 |
+
h = x / torch.sqrt(vars + 1e-5)
|
121 |
+
|
122 |
+
out = self.alpha.view(-1, self.num_features, 1, 1) * h
|
123 |
+
return out
|
124 |
+
|
125 |
+
|
126 |
+
class ConditionalNoneNorm2d(nn.Module):
|
127 |
+
def __init__(self, num_features, num_classes, bias=True):
|
128 |
+
super().__init__()
|
129 |
+
self.num_features = num_features
|
130 |
+
self.bias = bias
|
131 |
+
if bias:
|
132 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
133 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
134 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
135 |
+
else:
|
136 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
137 |
+
self.embed.weight.data.uniform_()
|
138 |
+
|
139 |
+
def forward(self, x, y):
|
140 |
+
if self.bias:
|
141 |
+
gamma, beta = self.embed(y).chunk(2, dim=-1)
|
142 |
+
out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
|
143 |
+
else:
|
144 |
+
gamma = self.embed(y)
|
145 |
+
out = gamma.view(-1, self.num_features, 1, 1) * x
|
146 |
+
return out
|
147 |
+
|
148 |
+
|
149 |
+
class NoneNorm2d(nn.Module):
|
150 |
+
def __init__(self, num_features, bias=True):
|
151 |
+
super().__init__()
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class InstanceNorm2dPlus(nn.Module):
|
158 |
+
def __init__(self, num_features, bias=True):
|
159 |
+
super().__init__()
|
160 |
+
self.num_features = num_features
|
161 |
+
self.bias = bias
|
162 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
163 |
+
self.alpha = nn.Parameter(torch.zeros(num_features))
|
164 |
+
self.gamma = nn.Parameter(torch.zeros(num_features))
|
165 |
+
self.alpha.data.normal_(1, 0.02)
|
166 |
+
self.gamma.data.normal_(1, 0.02)
|
167 |
+
if bias:
|
168 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
means = torch.mean(x, dim=(2, 3))
|
172 |
+
m = torch.mean(means, dim=-1, keepdim=True)
|
173 |
+
v = torch.var(means, dim=-1, keepdim=True)
|
174 |
+
means = (means - m) / (torch.sqrt(v + 1e-5))
|
175 |
+
h = self.instance_norm(x)
|
176 |
+
|
177 |
+
if self.bias:
|
178 |
+
h = h + means[..., None, None] * self.alpha[..., None, None]
|
179 |
+
out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
|
180 |
+
else:
|
181 |
+
h = h + means[..., None, None] * self.alpha[..., None, None]
|
182 |
+
out = self.gamma.view(-1, self.num_features, 1, 1) * h
|
183 |
+
return out
|
184 |
+
|
185 |
+
|
186 |
+
class ConditionalInstanceNorm2dPlus(nn.Module):
|
187 |
+
def __init__(self, num_features, num_classes, bias=True):
|
188 |
+
super().__init__()
|
189 |
+
self.num_features = num_features
|
190 |
+
self.bias = bias
|
191 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
192 |
+
if bias:
|
193 |
+
self.embed = nn.Embedding(num_classes, num_features * 3)
|
194 |
+
self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
|
195 |
+
self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
|
196 |
+
else:
|
197 |
+
self.embed = nn.Embedding(num_classes, 2 * num_features)
|
198 |
+
self.embed.weight.data.normal_(1, 0.02)
|
199 |
+
|
200 |
+
def forward(self, x, y):
|
201 |
+
means = torch.mean(x, dim=(2, 3))
|
202 |
+
m = torch.mean(means, dim=-1, keepdim=True)
|
203 |
+
v = torch.var(means, dim=-1, keepdim=True)
|
204 |
+
means = (means - m) / (torch.sqrt(v + 1e-5))
|
205 |
+
h = self.instance_norm(x)
|
206 |
+
|
207 |
+
if self.bias:
|
208 |
+
gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
|
209 |
+
h = h + means[..., None, None] * alpha[..., None, None]
|
210 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
|
211 |
+
else:
|
212 |
+
gamma, alpha = self.embed(y).chunk(2, dim=-1)
|
213 |
+
h = h + means[..., None, None] * alpha[..., None, None]
|
214 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
215 |
+
return out
|
sgmse/backbones/ncsnpp_utils/op/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .upfirdn2d import upfirdn2d
|
sgmse/backbones/ncsnpp_utils/op/fused_act.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
|
9 |
+
|
10 |
+
module_path = os.path.dirname(__file__)
|
11 |
+
fused = load(
|
12 |
+
"fused",
|
13 |
+
sources=[
|
14 |
+
os.path.join(module_path, "fused_bias_act.cpp"),
|
15 |
+
os.path.join(module_path, "fused_bias_act_kernel.cu"),
|
16 |
+
],
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
23 |
+
ctx.save_for_backward(out)
|
24 |
+
ctx.negative_slope = negative_slope
|
25 |
+
ctx.scale = scale
|
26 |
+
|
27 |
+
empty = grad_output.new_empty(0)
|
28 |
+
|
29 |
+
grad_input = fused.fused_bias_act(
|
30 |
+
grad_output, empty, out, 3, 1, negative_slope, scale
|
31 |
+
)
|
32 |
+
|
33 |
+
dim = [0]
|
34 |
+
|
35 |
+
if grad_input.ndim > 2:
|
36 |
+
dim += list(range(2, grad_input.ndim))
|
37 |
+
|
38 |
+
grad_bias = grad_input.sum(dim).detach()
|
39 |
+
|
40 |
+
return grad_input, grad_bias
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
44 |
+
out, = ctx.saved_tensors
|
45 |
+
gradgrad_out = fused.fused_bias_act(
|
46 |
+
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
47 |
+
)
|
48 |
+
|
49 |
+
return gradgrad_out, None, None, None
|
50 |
+
|
51 |
+
|
52 |
+
class FusedLeakyReLUFunction(Function):
|
53 |
+
@staticmethod
|
54 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
55 |
+
empty = input.new_empty(0)
|
56 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
57 |
+
ctx.save_for_backward(out)
|
58 |
+
ctx.negative_slope = negative_slope
|
59 |
+
ctx.scale = scale
|
60 |
+
|
61 |
+
return out
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def backward(ctx, grad_output):
|
65 |
+
out, = ctx.saved_tensors
|
66 |
+
|
67 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
68 |
+
grad_output, out, ctx.negative_slope, ctx.scale
|
69 |
+
)
|
70 |
+
|
71 |
+
return grad_input, grad_bias, None, None
|
72 |
+
|
73 |
+
|
74 |
+
class FusedLeakyReLU(nn.Module):
|
75 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
79 |
+
self.negative_slope = negative_slope
|
80 |
+
self.scale = scale
|
81 |
+
|
82 |
+
def forward(self, input):
|
83 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
84 |
+
|
85 |
+
|
86 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
87 |
+
if input.device.type == "cpu":
|
88 |
+
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
89 |
+
return (
|
90 |
+
F.leaky_relu(
|
91 |
+
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
92 |
+
)
|
93 |
+
* scale
|
94 |
+
)
|
95 |
+
|
96 |
+
else:
|
97 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
sgmse/backbones/ncsnpp_utils/op/fused_bias_act.cpp
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
5 |
+
int act, int grad, float alpha, float scale);
|
6 |
+
|
7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
9 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
10 |
+
|
11 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
12 |
+
int act, int grad, float alpha, float scale) {
|
13 |
+
CHECK_CUDA(input);
|
14 |
+
CHECK_CUDA(bias);
|
15 |
+
|
16 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
17 |
+
}
|
18 |
+
|
19 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
20 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
21 |
+
}
|
sgmse/backbones/ncsnpp_utils/op/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
20 |
+
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
21 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
22 |
+
|
23 |
+
scalar_t zero = 0.0;
|
24 |
+
|
25 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
26 |
+
scalar_t x = p_x[xi];
|
27 |
+
|
28 |
+
if (use_bias) {
|
29 |
+
x += p_b[(xi / step_b) % size_b];
|
30 |
+
}
|
31 |
+
|
32 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
33 |
+
|
34 |
+
scalar_t y;
|
35 |
+
|
36 |
+
switch (act * 10 + grad) {
|
37 |
+
default:
|
38 |
+
case 10: y = x; break;
|
39 |
+
case 11: y = x; break;
|
40 |
+
case 12: y = 0.0; break;
|
41 |
+
|
42 |
+
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
43 |
+
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
44 |
+
case 32: y = 0.0; break;
|
45 |
+
}
|
46 |
+
|
47 |
+
out[xi] = y * scale;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
53 |
+
int act, int grad, float alpha, float scale) {
|
54 |
+
int curDevice = -1;
|
55 |
+
cudaGetDevice(&curDevice);
|
56 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
57 |
+
|
58 |
+
auto x = input.contiguous();
|
59 |
+
auto b = bias.contiguous();
|
60 |
+
auto ref = refer.contiguous();
|
61 |
+
|
62 |
+
int use_bias = b.numel() ? 1 : 0;
|
63 |
+
int use_ref = ref.numel() ? 1 : 0;
|
64 |
+
|
65 |
+
int size_x = x.numel();
|
66 |
+
int size_b = b.numel();
|
67 |
+
int step_b = 1;
|
68 |
+
|
69 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
70 |
+
step_b *= x.size(i);
|
71 |
+
}
|
72 |
+
|
73 |
+
int loop_x = 4;
|
74 |
+
int block_size = 4 * 32;
|
75 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
76 |
+
|
77 |
+
auto y = torch::empty_like(x);
|
78 |
+
|
79 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
80 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
81 |
+
y.data_ptr<scalar_t>(),
|
82 |
+
x.data_ptr<scalar_t>(),
|
83 |
+
b.data_ptr<scalar_t>(),
|
84 |
+
ref.data_ptr<scalar_t>(),
|
85 |
+
act,
|
86 |
+
grad,
|
87 |
+
alpha,
|
88 |
+
scale,
|
89 |
+
loop_x,
|
90 |
+
size_x,
|
91 |
+
step_b,
|
92 |
+
size_b,
|
93 |
+
use_bias,
|
94 |
+
use_ref
|
95 |
+
);
|
96 |
+
});
|
97 |
+
|
98 |
+
return y;
|
99 |
+
}
|
sgmse/backbones/ncsnpp_utils/op/upfirdn2d.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
5 |
+
int up_x, int up_y, int down_x, int down_y,
|
6 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
7 |
+
|
8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
9 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
10 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
11 |
+
|
12 |
+
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
13 |
+
int up_x, int up_y, int down_x, int down_y,
|
14 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
15 |
+
CHECK_CUDA(input);
|
16 |
+
CHECK_CUDA(kernel);
|
17 |
+
|
18 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
19 |
+
}
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
23 |
+
}
|
sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.utils.cpp_extension import load
|
7 |
+
|
8 |
+
|
9 |
+
module_path = os.path.dirname(__file__)
|
10 |
+
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
upfirdn2d_op = load(
|
13 |
+
"upfirdn2d",
|
14 |
+
sources=[
|
15 |
+
os.path.join(module_path, "upfirdn2d.cpp"),
|
16 |
+
os.path.join(module_path, "upfirdn2d_kernel.cu"),
|
17 |
+
],
|
18 |
+
)
|
19 |
+
else:
|
20 |
+
upfirdn2d_op = None
|
21 |
+
|
22 |
+
class UpFirDn2dBackward(Function):
|
23 |
+
@staticmethod
|
24 |
+
def forward(
|
25 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
26 |
+
):
|
27 |
+
|
28 |
+
up_x, up_y = up
|
29 |
+
down_x, down_y = down
|
30 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
31 |
+
|
32 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
33 |
+
|
34 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
35 |
+
grad_output,
|
36 |
+
grad_kernel,
|
37 |
+
down_x,
|
38 |
+
down_y,
|
39 |
+
up_x,
|
40 |
+
up_y,
|
41 |
+
g_pad_x0,
|
42 |
+
g_pad_x1,
|
43 |
+
g_pad_y0,
|
44 |
+
g_pad_y1,
|
45 |
+
)
|
46 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
47 |
+
|
48 |
+
ctx.save_for_backward(kernel)
|
49 |
+
|
50 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
51 |
+
|
52 |
+
ctx.up_x = up_x
|
53 |
+
ctx.up_y = up_y
|
54 |
+
ctx.down_x = down_x
|
55 |
+
ctx.down_y = down_y
|
56 |
+
ctx.pad_x0 = pad_x0
|
57 |
+
ctx.pad_x1 = pad_x1
|
58 |
+
ctx.pad_y0 = pad_y0
|
59 |
+
ctx.pad_y1 = pad_y1
|
60 |
+
ctx.in_size = in_size
|
61 |
+
ctx.out_size = out_size
|
62 |
+
|
63 |
+
return grad_input
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def backward(ctx, gradgrad_input):
|
67 |
+
kernel, = ctx.saved_tensors
|
68 |
+
|
69 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
70 |
+
|
71 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
72 |
+
gradgrad_input,
|
73 |
+
kernel,
|
74 |
+
ctx.up_x,
|
75 |
+
ctx.up_y,
|
76 |
+
ctx.down_x,
|
77 |
+
ctx.down_y,
|
78 |
+
ctx.pad_x0,
|
79 |
+
ctx.pad_x1,
|
80 |
+
ctx.pad_y0,
|
81 |
+
ctx.pad_y1,
|
82 |
+
)
|
83 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
84 |
+
gradgrad_out = gradgrad_out.view(
|
85 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
86 |
+
)
|
87 |
+
|
88 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
89 |
+
|
90 |
+
|
91 |
+
class UpFirDn2d(Function):
|
92 |
+
@staticmethod
|
93 |
+
def forward(ctx, input, kernel, up, down, pad):
|
94 |
+
up_x, up_y = up
|
95 |
+
down_x, down_y = down
|
96 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
97 |
+
|
98 |
+
kernel_h, kernel_w = kernel.shape
|
99 |
+
batch, channel, in_h, in_w = input.shape
|
100 |
+
ctx.in_size = input.shape
|
101 |
+
|
102 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
103 |
+
|
104 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
105 |
+
|
106 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
107 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
108 |
+
ctx.out_size = (out_h, out_w)
|
109 |
+
|
110 |
+
ctx.up = (up_x, up_y)
|
111 |
+
ctx.down = (down_x, down_y)
|
112 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
113 |
+
|
114 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
115 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
116 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
117 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
118 |
+
|
119 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
120 |
+
|
121 |
+
out = upfirdn2d_op.upfirdn2d(
|
122 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
123 |
+
)
|
124 |
+
# out = out.view(major, out_h, out_w, minor)
|
125 |
+
out = out.view(-1, channel, out_h, out_w)
|
126 |
+
|
127 |
+
return out
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def backward(ctx, grad_output):
|
131 |
+
kernel, grad_kernel = ctx.saved_tensors
|
132 |
+
|
133 |
+
grad_input = UpFirDn2dBackward.apply(
|
134 |
+
grad_output,
|
135 |
+
kernel,
|
136 |
+
grad_kernel,
|
137 |
+
ctx.up,
|
138 |
+
ctx.down,
|
139 |
+
ctx.pad,
|
140 |
+
ctx.g_pad,
|
141 |
+
ctx.in_size,
|
142 |
+
ctx.out_size,
|
143 |
+
)
|
144 |
+
|
145 |
+
return grad_input, None, None, None, None
|
146 |
+
|
147 |
+
|
148 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
149 |
+
if input.device.type == "cpu":
|
150 |
+
out = upfirdn2d_native(
|
151 |
+
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
|
152 |
+
)
|
153 |
+
|
154 |
+
else:
|
155 |
+
out = UpFirDn2d.apply(
|
156 |
+
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
157 |
+
)
|
158 |
+
|
159 |
+
return out
|
160 |
+
|
161 |
+
|
162 |
+
def upfirdn2d_native(
|
163 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
164 |
+
):
|
165 |
+
_, channel, in_h, in_w = input.shape
|
166 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
167 |
+
|
168 |
+
_, in_h, in_w, minor = input.shape
|
169 |
+
kernel_h, kernel_w = kernel.shape
|
170 |
+
|
171 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
172 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
173 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
174 |
+
|
175 |
+
out = F.pad(
|
176 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
177 |
+
)
|
178 |
+
out = out[
|
179 |
+
:,
|
180 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
181 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
182 |
+
:,
|
183 |
+
]
|
184 |
+
|
185 |
+
out = out.permute(0, 3, 1, 2)
|
186 |
+
out = out.reshape(
|
187 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
188 |
+
)
|
189 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
190 |
+
out = F.conv2d(out, w)
|
191 |
+
out = out.reshape(
|
192 |
+
-1,
|
193 |
+
minor,
|
194 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
195 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
196 |
+
)
|
197 |
+
out = out.permute(0, 2, 3, 1)
|
198 |
+
out = out[:, ::down_y, ::down_x, :]
|
199 |
+
|
200 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
201 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
202 |
+
|
203 |
+
return out.view(-1, channel, out_h, out_w)
|
sgmse/backbones/ncsnpp_utils/op/upfirdn2d_kernel.cu
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
18 |
+
int c = a / b;
|
19 |
+
|
20 |
+
if (c * b > a) {
|
21 |
+
c--;
|
22 |
+
}
|
23 |
+
|
24 |
+
return c;
|
25 |
+
}
|
26 |
+
|
27 |
+
struct UpFirDn2DKernelParams {
|
28 |
+
int up_x;
|
29 |
+
int up_y;
|
30 |
+
int down_x;
|
31 |
+
int down_y;
|
32 |
+
int pad_x0;
|
33 |
+
int pad_x1;
|
34 |
+
int pad_y0;
|
35 |
+
int pad_y1;
|
36 |
+
|
37 |
+
int major_dim;
|
38 |
+
int in_h;
|
39 |
+
int in_w;
|
40 |
+
int minor_dim;
|
41 |
+
int kernel_h;
|
42 |
+
int kernel_w;
|
43 |
+
int out_h;
|
44 |
+
int out_w;
|
45 |
+
int loop_major;
|
46 |
+
int loop_x;
|
47 |
+
};
|
48 |
+
|
49 |
+
template <typename scalar_t>
|
50 |
+
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
51 |
+
const scalar_t *kernel,
|
52 |
+
const UpFirDn2DKernelParams p) {
|
53 |
+
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
54 |
+
int out_y = minor_idx / p.minor_dim;
|
55 |
+
minor_idx -= out_y * p.minor_dim;
|
56 |
+
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
57 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
58 |
+
|
59 |
+
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
60 |
+
major_idx_base >= p.major_dim) {
|
61 |
+
return;
|
62 |
+
}
|
63 |
+
|
64 |
+
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
65 |
+
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
66 |
+
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
67 |
+
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
68 |
+
|
69 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
70 |
+
loop_major < p.loop_major && major_idx < p.major_dim;
|
71 |
+
loop_major++, major_idx++) {
|
72 |
+
for (int loop_x = 0, out_x = out_x_base;
|
73 |
+
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
74 |
+
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
75 |
+
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
76 |
+
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
77 |
+
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
78 |
+
|
79 |
+
const scalar_t *x_p =
|
80 |
+
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
81 |
+
minor_idx];
|
82 |
+
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
83 |
+
int x_px = p.minor_dim;
|
84 |
+
int k_px = -p.up_x;
|
85 |
+
int x_py = p.in_w * p.minor_dim;
|
86 |
+
int k_py = -p.up_y * p.kernel_w;
|
87 |
+
|
88 |
+
scalar_t v = 0.0f;
|
89 |
+
|
90 |
+
for (int y = 0; y < h; y++) {
|
91 |
+
for (int x = 0; x < w; x++) {
|
92 |
+
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
93 |
+
x_p += x_px;
|
94 |
+
k_p += k_px;
|
95 |
+
}
|
96 |
+
|
97 |
+
x_p += x_py - w * x_px;
|
98 |
+
k_p += k_py - w * k_px;
|
99 |
+
}
|
100 |
+
|
101 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
102 |
+
minor_idx] = v;
|
103 |
+
}
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
108 |
+
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
109 |
+
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
110 |
+
const scalar_t *kernel,
|
111 |
+
const UpFirDn2DKernelParams p) {
|
112 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
113 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
114 |
+
|
115 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
116 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
117 |
+
|
118 |
+
int minor_idx = blockIdx.x;
|
119 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
120 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
121 |
+
tile_out_y *= tile_out_h;
|
122 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
123 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
124 |
+
|
125 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
126 |
+
major_idx_base >= p.major_dim) {
|
127 |
+
return;
|
128 |
+
}
|
129 |
+
|
130 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
131 |
+
tap_idx += blockDim.x) {
|
132 |
+
int ky = tap_idx / kernel_w;
|
133 |
+
int kx = tap_idx - ky * kernel_w;
|
134 |
+
scalar_t v = 0.0;
|
135 |
+
|
136 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
137 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
138 |
+
}
|
139 |
+
|
140 |
+
sk[ky][kx] = v;
|
141 |
+
}
|
142 |
+
|
143 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
144 |
+
loop_major < p.loop_major & major_idx < p.major_dim;
|
145 |
+
loop_major++, major_idx++) {
|
146 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
147 |
+
loop_x < p.loop_x & tile_out_x < p.out_w;
|
148 |
+
loop_x++, tile_out_x += tile_out_w) {
|
149 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
150 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
151 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
152 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
153 |
+
|
154 |
+
__syncthreads();
|
155 |
+
|
156 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
157 |
+
in_idx += blockDim.x) {
|
158 |
+
int rel_in_y = in_idx / tile_in_w;
|
159 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
160 |
+
int in_x = rel_in_x + tile_in_x;
|
161 |
+
int in_y = rel_in_y + tile_in_y;
|
162 |
+
|
163 |
+
scalar_t v = 0.0;
|
164 |
+
|
165 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
166 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
167 |
+
p.minor_dim +
|
168 |
+
minor_idx];
|
169 |
+
}
|
170 |
+
|
171 |
+
sx[rel_in_y][rel_in_x] = v;
|
172 |
+
}
|
173 |
+
|
174 |
+
__syncthreads();
|
175 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
176 |
+
out_idx += blockDim.x) {
|
177 |
+
int rel_out_y = out_idx / tile_out_w;
|
178 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
179 |
+
int out_x = rel_out_x + tile_out_x;
|
180 |
+
int out_y = rel_out_y + tile_out_y;
|
181 |
+
|
182 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
183 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
184 |
+
int in_x = floor_div(mid_x, up_x);
|
185 |
+
int in_y = floor_div(mid_y, up_y);
|
186 |
+
int rel_in_x = in_x - tile_in_x;
|
187 |
+
int rel_in_y = in_y - tile_in_y;
|
188 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
189 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
190 |
+
|
191 |
+
scalar_t v = 0.0;
|
192 |
+
|
193 |
+
#pragma unroll
|
194 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
195 |
+
#pragma unroll
|
196 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
197 |
+
v += sx[rel_in_y + y][rel_in_x + x] *
|
198 |
+
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
199 |
+
|
200 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
201 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
202 |
+
minor_idx] = v;
|
203 |
+
}
|
204 |
+
}
|
205 |
+
}
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
210 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
211 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
212 |
+
int pad_y0, int pad_y1) {
|
213 |
+
int curDevice = -1;
|
214 |
+
cudaGetDevice(&curDevice);
|
215 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
216 |
+
|
217 |
+
UpFirDn2DKernelParams p;
|
218 |
+
|
219 |
+
auto x = input.contiguous();
|
220 |
+
auto k = kernel.contiguous();
|
221 |
+
|
222 |
+
p.major_dim = x.size(0);
|
223 |
+
p.in_h = x.size(1);
|
224 |
+
p.in_w = x.size(2);
|
225 |
+
p.minor_dim = x.size(3);
|
226 |
+
p.kernel_h = k.size(0);
|
227 |
+
p.kernel_w = k.size(1);
|
228 |
+
p.up_x = up_x;
|
229 |
+
p.up_y = up_y;
|
230 |
+
p.down_x = down_x;
|
231 |
+
p.down_y = down_y;
|
232 |
+
p.pad_x0 = pad_x0;
|
233 |
+
p.pad_x1 = pad_x1;
|
234 |
+
p.pad_y0 = pad_y0;
|
235 |
+
p.pad_y1 = pad_y1;
|
236 |
+
|
237 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
238 |
+
p.down_y;
|
239 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
240 |
+
p.down_x;
|
241 |
+
|
242 |
+
auto out =
|
243 |
+
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
244 |
+
|
245 |
+
int mode = -1;
|
246 |
+
|
247 |
+
int tile_out_h = -1;
|
248 |
+
int tile_out_w = -1;
|
249 |
+
|
250 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
251 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
252 |
+
mode = 1;
|
253 |
+
tile_out_h = 16;
|
254 |
+
tile_out_w = 64;
|
255 |
+
}
|
256 |
+
|
257 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
258 |
+
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
259 |
+
mode = 2;
|
260 |
+
tile_out_h = 16;
|
261 |
+
tile_out_w = 64;
|
262 |
+
}
|
263 |
+
|
264 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
265 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
266 |
+
mode = 3;
|
267 |
+
tile_out_h = 16;
|
268 |
+
tile_out_w = 64;
|
269 |
+
}
|
270 |
+
|
271 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
272 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
273 |
+
mode = 4;
|
274 |
+
tile_out_h = 16;
|
275 |
+
tile_out_w = 64;
|
276 |
+
}
|
277 |
+
|
278 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
279 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
280 |
+
mode = 5;
|
281 |
+
tile_out_h = 8;
|
282 |
+
tile_out_w = 32;
|
283 |
+
}
|
284 |
+
|
285 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
286 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
287 |
+
mode = 6;
|
288 |
+
tile_out_h = 8;
|
289 |
+
tile_out_w = 32;
|
290 |
+
}
|
291 |
+
|
292 |
+
dim3 block_size;
|
293 |
+
dim3 grid_size;
|
294 |
+
|
295 |
+
if (tile_out_h > 0 && tile_out_w > 0) {
|
296 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
297 |
+
p.loop_x = 1;
|
298 |
+
block_size = dim3(32 * 8, 1, 1);
|
299 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
300 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
301 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
302 |
+
} else {
|
303 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
304 |
+
p.loop_x = 4;
|
305 |
+
block_size = dim3(4, 32, 1);
|
306 |
+
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
307 |
+
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
308 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
309 |
+
}
|
310 |
+
|
311 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
312 |
+
switch (mode) {
|
313 |
+
case 1:
|
314 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
315 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
316 |
+
x.data_ptr<scalar_t>(),
|
317 |
+
k.data_ptr<scalar_t>(), p);
|
318 |
+
|
319 |
+
break;
|
320 |
+
|
321 |
+
case 2:
|
322 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
323 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
324 |
+
x.data_ptr<scalar_t>(),
|
325 |
+
k.data_ptr<scalar_t>(), p);
|
326 |
+
|
327 |
+
break;
|
328 |
+
|
329 |
+
case 3:
|
330 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
331 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
332 |
+
x.data_ptr<scalar_t>(),
|
333 |
+
k.data_ptr<scalar_t>(), p);
|
334 |
+
|
335 |
+
break;
|
336 |
+
|
337 |
+
case 4:
|
338 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
339 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
340 |
+
x.data_ptr<scalar_t>(),
|
341 |
+
k.data_ptr<scalar_t>(), p);
|
342 |
+
|
343 |
+
break;
|
344 |
+
|
345 |
+
case 5:
|
346 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
347 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
348 |
+
x.data_ptr<scalar_t>(),
|
349 |
+
k.data_ptr<scalar_t>(), p);
|
350 |
+
|
351 |
+
break;
|
352 |
+
|
353 |
+
case 6:
|
354 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
355 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
356 |
+
x.data_ptr<scalar_t>(),
|
357 |
+
k.data_ptr<scalar_t>(), p);
|
358 |
+
|
359 |
+
break;
|
360 |
+
|
361 |
+
default:
|
362 |
+
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
363 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
364 |
+
k.data_ptr<scalar_t>(), p);
|
365 |
+
}
|
366 |
+
});
|
367 |
+
|
368 |
+
return out;
|
369 |
+
}
|
sgmse/backbones/ncsnpp_utils/up_or_down_sampling.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Layers used for up-sampling or down-sampling images.
|
2 |
+
|
3 |
+
Many functions are ported from https://github.com/NVlabs/stylegan2.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
from .op import upfirdn2d
|
11 |
+
|
12 |
+
|
13 |
+
# Function ported from StyleGAN2
|
14 |
+
def get_weight(module,
|
15 |
+
shape,
|
16 |
+
weight_var='weight',
|
17 |
+
kernel_init=None):
|
18 |
+
"""Get/create weight tensor for a convolution or fully-connected layer."""
|
19 |
+
|
20 |
+
return module.param(weight_var, kernel_init, shape)
|
21 |
+
|
22 |
+
|
23 |
+
class Conv2d(nn.Module):
|
24 |
+
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
|
25 |
+
|
26 |
+
def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
|
27 |
+
resample_kernel=(1, 3, 3, 1),
|
28 |
+
use_bias=True,
|
29 |
+
kernel_init=None):
|
30 |
+
super().__init__()
|
31 |
+
assert not (up and down)
|
32 |
+
assert kernel >= 1 and kernel % 2 == 1
|
33 |
+
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
|
34 |
+
if kernel_init is not None:
|
35 |
+
self.weight.data = kernel_init(self.weight.data.shape)
|
36 |
+
if use_bias:
|
37 |
+
self.bias = nn.Parameter(torch.zeros(out_ch))
|
38 |
+
|
39 |
+
self.up = up
|
40 |
+
self.down = down
|
41 |
+
self.resample_kernel = resample_kernel
|
42 |
+
self.kernel = kernel
|
43 |
+
self.use_bias = use_bias
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
if self.up:
|
47 |
+
x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
|
48 |
+
elif self.down:
|
49 |
+
x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
|
50 |
+
else:
|
51 |
+
x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
|
52 |
+
|
53 |
+
if self.use_bias:
|
54 |
+
x = x + self.bias.reshape(1, -1, 1, 1)
|
55 |
+
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def naive_upsample_2d(x, factor=2):
|
60 |
+
_N, C, H, W = x.shape
|
61 |
+
x = torch.reshape(x, (-1, C, H, 1, W, 1))
|
62 |
+
x = x.repeat(1, 1, 1, factor, 1, factor)
|
63 |
+
return torch.reshape(x, (-1, C, H * factor, W * factor))
|
64 |
+
|
65 |
+
|
66 |
+
def naive_downsample_2d(x, factor=2):
|
67 |
+
_N, C, H, W = x.shape
|
68 |
+
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
|
69 |
+
return torch.mean(x, dim=(3, 5))
|
70 |
+
|
71 |
+
|
72 |
+
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
73 |
+
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
|
74 |
+
|
75 |
+
Padding is performed only once at the beginning, not between the
|
76 |
+
operations.
|
77 |
+
The fused op is considerably more efficient than performing the same
|
78 |
+
calculation
|
79 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
80 |
+
Args:
|
81 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
82 |
+
C]`.
|
83 |
+
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
84 |
+
outChannels]`. Grouped convolution can be performed by `inChannels =
|
85 |
+
x.shape[0] // numGroups`.
|
86 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
87 |
+
(separable). The default is `[1] * factor`, which corresponds to
|
88 |
+
nearest-neighbor upsampling.
|
89 |
+
factor: Integer upsampling factor (default: 2).
|
90 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tensor of the shape `[N, C, H * factor, W * factor]` or
|
94 |
+
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
|
95 |
+
"""
|
96 |
+
|
97 |
+
assert isinstance(factor, int) and factor >= 1
|
98 |
+
|
99 |
+
# Check weight shape.
|
100 |
+
assert len(w.shape) == 4
|
101 |
+
convH = w.shape[2]
|
102 |
+
convW = w.shape[3]
|
103 |
+
inC = w.shape[1]
|
104 |
+
outC = w.shape[0]
|
105 |
+
|
106 |
+
assert convW == convH
|
107 |
+
|
108 |
+
# Setup filter kernel.
|
109 |
+
if k is None:
|
110 |
+
k = [1] * factor
|
111 |
+
k = _setup_kernel(k) * (gain * (factor ** 2))
|
112 |
+
p = (k.shape[0] - factor) - (convW - 1)
|
113 |
+
|
114 |
+
stride = (factor, factor)
|
115 |
+
|
116 |
+
# Determine data dimensions.
|
117 |
+
stride = [1, 1, factor, factor]
|
118 |
+
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
|
119 |
+
output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
|
120 |
+
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
|
121 |
+
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
122 |
+
num_groups = _shape(x, 1) // inC
|
123 |
+
|
124 |
+
# Transpose weights.
|
125 |
+
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
126 |
+
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
127 |
+
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
128 |
+
|
129 |
+
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
130 |
+
## Original TF code.
|
131 |
+
# x = tf.nn.conv2d_transpose(
|
132 |
+
# x,
|
133 |
+
# w,
|
134 |
+
# output_shape=output_shape,
|
135 |
+
# strides=stride,
|
136 |
+
# padding='VALID',
|
137 |
+
# data_format=data_format)
|
138 |
+
## JAX equivalent
|
139 |
+
|
140 |
+
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
141 |
+
pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
142 |
+
|
143 |
+
|
144 |
+
def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
145 |
+
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
|
146 |
+
|
147 |
+
Padding is performed only once at the beginning, not between the operations.
|
148 |
+
The fused op is considerably more efficient than performing the same
|
149 |
+
calculation
|
150 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
151 |
+
Args:
|
152 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
153 |
+
C]`.
|
154 |
+
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
155 |
+
outChannels]`. Grouped convolution can be performed by `inChannels =
|
156 |
+
x.shape[0] // numGroups`.
|
157 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
158 |
+
(separable). The default is `[1] * factor`, which corresponds to
|
159 |
+
average pooling.
|
160 |
+
factor: Integer downsampling factor (default: 2).
|
161 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
Tensor of the shape `[N, C, H // factor, W // factor]` or
|
165 |
+
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
|
166 |
+
"""
|
167 |
+
|
168 |
+
assert isinstance(factor, int) and factor >= 1
|
169 |
+
_outC, _inC, convH, convW = w.shape
|
170 |
+
assert convW == convH
|
171 |
+
if k is None:
|
172 |
+
k = [1] * factor
|
173 |
+
k = _setup_kernel(k) * gain
|
174 |
+
p = (k.shape[0] - factor) + (convW - 1)
|
175 |
+
s = [factor, factor]
|
176 |
+
x = upfirdn2d(x, torch.tensor(k, device=x.device),
|
177 |
+
pad=((p + 1) // 2, p // 2))
|
178 |
+
return F.conv2d(x, w, stride=s, padding=0)
|
179 |
+
|
180 |
+
|
181 |
+
def _setup_kernel(k):
|
182 |
+
k = np.asarray(k, dtype=np.float32)
|
183 |
+
if k.ndim == 1:
|
184 |
+
k = np.outer(k, k)
|
185 |
+
k /= np.sum(k)
|
186 |
+
assert k.ndim == 2
|
187 |
+
assert k.shape[0] == k.shape[1]
|
188 |
+
return k
|
189 |
+
|
190 |
+
|
191 |
+
def _shape(x, dim):
|
192 |
+
return x.shape[dim]
|
193 |
+
|
194 |
+
|
195 |
+
def upsample_2d(x, k=None, factor=2, gain=1):
|
196 |
+
r"""Upsample a batch of 2D images with the given filter.
|
197 |
+
|
198 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
199 |
+
and upsamples each image with the given filter. The filter is normalized so
|
200 |
+
that
|
201 |
+
if the input pixels are constant, they will be scaled by the specified
|
202 |
+
`gain`.
|
203 |
+
Pixels outside the image are assumed to be zero, and the filter is padded
|
204 |
+
with
|
205 |
+
zeros so that its shape is a multiple of the upsampling factor.
|
206 |
+
Args:
|
207 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
208 |
+
C]`.
|
209 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
210 |
+
(separable). The default is `[1] * factor`, which corresponds to
|
211 |
+
nearest-neighbor upsampling.
|
212 |
+
factor: Integer upsampling factor (default: 2).
|
213 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
Tensor of the shape `[N, C, H * factor, W * factor]`
|
217 |
+
"""
|
218 |
+
assert isinstance(factor, int) and factor >= 1
|
219 |
+
if k is None:
|
220 |
+
k = [1] * factor
|
221 |
+
k = _setup_kernel(k) * (gain * (factor ** 2))
|
222 |
+
p = k.shape[0] - factor
|
223 |
+
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
224 |
+
up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
|
225 |
+
|
226 |
+
|
227 |
+
def downsample_2d(x, k=None, factor=2, gain=1):
|
228 |
+
r"""Downsample a batch of 2D images with the given filter.
|
229 |
+
|
230 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
231 |
+
and downsamples each image with the given filter. The filter is normalized
|
232 |
+
so that
|
233 |
+
if the input pixels are constant, they will be scaled by the specified
|
234 |
+
`gain`.
|
235 |
+
Pixels outside the image are assumed to be zero, and the filter is padded
|
236 |
+
with
|
237 |
+
zeros so that its shape is a multiple of the downsampling factor.
|
238 |
+
Args:
|
239 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
240 |
+
C]`.
|
241 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
242 |
+
(separable). The default is `[1] * factor`, which corresponds to
|
243 |
+
average pooling.
|
244 |
+
factor: Integer downsampling factor (default: 2).
|
245 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Tensor of the shape `[N, C, H // factor, W // factor]`
|
249 |
+
"""
|
250 |
+
|
251 |
+
assert isinstance(factor, int) and factor >= 1
|
252 |
+
if k is None:
|
253 |
+
k = [1] * factor
|
254 |
+
k = _setup_kernel(k) * gain
|
255 |
+
p = k.shape[0] - factor
|
256 |
+
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
257 |
+
down=factor, pad=((p + 1) // 2, p // 2))
|
sgmse/backbones/ncsnpp_utils/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""All functions and modules related to model definition.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
from ...sdes import OUVESDE, OUVPSDE
|
23 |
+
|
24 |
+
|
25 |
+
_MODELS = {}
|
26 |
+
|
27 |
+
|
28 |
+
def register_model(cls=None, *, name=None):
|
29 |
+
"""A decorator for registering model classes."""
|
30 |
+
|
31 |
+
def _register(cls):
|
32 |
+
if name is None:
|
33 |
+
local_name = cls.__name__
|
34 |
+
else:
|
35 |
+
local_name = name
|
36 |
+
if local_name in _MODELS:
|
37 |
+
raise ValueError(f'Already registered model with name: {local_name}')
|
38 |
+
_MODELS[local_name] = cls
|
39 |
+
return cls
|
40 |
+
|
41 |
+
if cls is None:
|
42 |
+
return _register
|
43 |
+
else:
|
44 |
+
return _register(cls)
|
45 |
+
|
46 |
+
|
47 |
+
def get_model(name):
|
48 |
+
return _MODELS[name]
|
49 |
+
|
50 |
+
|
51 |
+
def get_sigmas(sigma_min, sigma_max, num_scales):
|
52 |
+
"""Get sigmas --- the set of noise levels for SMLD from config files.
|
53 |
+
Args:
|
54 |
+
config: A ConfigDict object parsed from the config file
|
55 |
+
Returns:
|
56 |
+
sigmas: a jax numpy arrary of noise levels
|
57 |
+
"""
|
58 |
+
sigmas = np.exp(
|
59 |
+
np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
|
60 |
+
|
61 |
+
return sigmas
|
62 |
+
|
63 |
+
|
64 |
+
def get_ddpm_params(config):
|
65 |
+
"""Get betas and alphas --- parameters used in the original DDPM paper."""
|
66 |
+
num_diffusion_timesteps = 1000
|
67 |
+
# parameters need to be adapted if number of time steps differs from 1000
|
68 |
+
beta_start = config.model.beta_min / config.model.num_scales
|
69 |
+
beta_end = config.model.beta_max / config.model.num_scales
|
70 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
71 |
+
|
72 |
+
alphas = 1. - betas
|
73 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
74 |
+
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
|
75 |
+
sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
|
76 |
+
|
77 |
+
return {
|
78 |
+
'betas': betas,
|
79 |
+
'alphas': alphas,
|
80 |
+
'alphas_cumprod': alphas_cumprod,
|
81 |
+
'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
|
82 |
+
'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
|
83 |
+
'beta_min': beta_start * (num_diffusion_timesteps - 1),
|
84 |
+
'beta_max': beta_end * (num_diffusion_timesteps - 1),
|
85 |
+
'num_diffusion_timesteps': num_diffusion_timesteps
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
def create_model(config):
|
90 |
+
"""Create the score model."""
|
91 |
+
model_name = config.model.name
|
92 |
+
score_model = get_model(model_name)(config)
|
93 |
+
score_model = score_model.to(config.device)
|
94 |
+
score_model = torch.nn.DataParallel(score_model)
|
95 |
+
return score_model
|
96 |
+
|
97 |
+
|
98 |
+
def get_model_fn(model, train=False):
|
99 |
+
"""Create a function to give the output of the score-based model.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
model: The score model.
|
103 |
+
train: `True` for training and `False` for evaluation.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
A model function.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def model_fn(x, labels):
|
110 |
+
"""Compute the output of the score-based model.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x: A mini-batch of input data.
|
114 |
+
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
115 |
+
for different models.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
A tuple of (model output, new mutable states)
|
119 |
+
"""
|
120 |
+
if not train:
|
121 |
+
model.eval()
|
122 |
+
return model(x, labels)
|
123 |
+
else:
|
124 |
+
model.train()
|
125 |
+
return model(x, labels)
|
126 |
+
|
127 |
+
return model_fn
|
128 |
+
|
129 |
+
|
130 |
+
def get_score_fn(sde, model, train=False, continuous=False):
|
131 |
+
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
135 |
+
model: A score model.
|
136 |
+
train: `True` for training and `False` for evaluation.
|
137 |
+
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
A score function.
|
141 |
+
"""
|
142 |
+
model_fn = get_model_fn(model, train=train)
|
143 |
+
|
144 |
+
if isinstance(sde, OUVPSDE):
|
145 |
+
def score_fn(x, t):
|
146 |
+
# Scale neural network output by standard deviation and flip sign
|
147 |
+
if continuous:
|
148 |
+
# For VP-trained models, t=0 corresponds to the lowest noise level
|
149 |
+
# The maximum value of time embedding is assumed to 999 for
|
150 |
+
# continuously-trained models.
|
151 |
+
labels = t * 999
|
152 |
+
score = model_fn(x, labels)
|
153 |
+
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
154 |
+
else:
|
155 |
+
# For VP-trained models, t=0 corresponds to the lowest noise level
|
156 |
+
labels = t * (sde.N - 1)
|
157 |
+
score = model_fn(x, labels)
|
158 |
+
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
|
159 |
+
|
160 |
+
score = -score / std[:, None, None, None]
|
161 |
+
return score
|
162 |
+
|
163 |
+
elif isinstance(sde, OUVESDE):
|
164 |
+
def score_fn(x, t):
|
165 |
+
if continuous:
|
166 |
+
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
167 |
+
else:
|
168 |
+
# For VE-trained models, t=0 corresponds to the highest noise level
|
169 |
+
labels = sde.T - t
|
170 |
+
labels *= sde.N - 1
|
171 |
+
labels = torch.round(labels).long()
|
172 |
+
|
173 |
+
score = model_fn(x, labels)
|
174 |
+
return score
|
175 |
+
|
176 |
+
else:
|
177 |
+
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
178 |
+
|
179 |
+
return score_fn
|
180 |
+
|
181 |
+
|
182 |
+
def to_flattened_numpy(x):
|
183 |
+
"""Flatten a torch tensor `x` and convert it to numpy."""
|
184 |
+
return x.detach().cpu().numpy().reshape((-1,))
|
185 |
+
|
186 |
+
|
187 |
+
def from_flattened_numpy(x, shape):
|
188 |
+
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
189 |
+
return torch.from_numpy(x.reshape(shape))
|
sgmse/backbones/shared.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from sgmse.util.registry import Registry
|
8 |
+
|
9 |
+
|
10 |
+
BackboneRegistry = Registry("Backbone")
|
11 |
+
|
12 |
+
|
13 |
+
class GaussianFourierProjection(nn.Module):
|
14 |
+
"""Gaussian random features for encoding time steps."""
|
15 |
+
|
16 |
+
def __init__(self, embed_dim, scale=16, complex_valued=False):
|
17 |
+
super().__init__()
|
18 |
+
self.complex_valued = complex_valued
|
19 |
+
if not complex_valued:
|
20 |
+
# If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
|
21 |
+
# Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
|
22 |
+
# we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
|
23 |
+
# and this halving is not necessary.
|
24 |
+
embed_dim = embed_dim // 2
|
25 |
+
# Randomly sample weights during initialization. These weights are fixed
|
26 |
+
# during optimization and are not trainable.
|
27 |
+
self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
|
28 |
+
|
29 |
+
def forward(self, t):
|
30 |
+
t_proj = t[:, None] * self.W[None, :] * 2*np.pi
|
31 |
+
if self.complex_valued:
|
32 |
+
return torch.exp(1j * t_proj)
|
33 |
+
else:
|
34 |
+
return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
|
35 |
+
|
36 |
+
|
37 |
+
class DiffusionStepEmbedding(nn.Module):
|
38 |
+
"""Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
|
39 |
+
|
40 |
+
def __init__(self, embed_dim, complex_valued=False):
|
41 |
+
super().__init__()
|
42 |
+
self.complex_valued = complex_valued
|
43 |
+
if not complex_valued:
|
44 |
+
# If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
|
45 |
+
# Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
|
46 |
+
# we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
|
47 |
+
# and this halving is not necessary.
|
48 |
+
embed_dim = embed_dim // 2
|
49 |
+
self.embed_dim = embed_dim
|
50 |
+
|
51 |
+
def forward(self, t):
|
52 |
+
fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
|
53 |
+
inner = t[:, None] * fac[None, :]
|
54 |
+
if self.complex_valued:
|
55 |
+
return torch.exp(1j * inner)
|
56 |
+
else:
|
57 |
+
return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
|
58 |
+
|
59 |
+
|
60 |
+
class ComplexLinear(nn.Module):
|
61 |
+
"""A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
|
62 |
+
def __init__(self, input_dim, output_dim, complex_valued):
|
63 |
+
super().__init__()
|
64 |
+
self.complex_valued = complex_valued
|
65 |
+
if self.complex_valued:
|
66 |
+
self.re = nn.Linear(input_dim, output_dim)
|
67 |
+
self.im = nn.Linear(input_dim, output_dim)
|
68 |
+
else:
|
69 |
+
self.lin = nn.Linear(input_dim, output_dim)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.complex_valued:
|
73 |
+
return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
|
74 |
+
else:
|
75 |
+
return self.lin(x)
|
76 |
+
|
77 |
+
|
78 |
+
class FeatureMapDense(nn.Module):
|
79 |
+
"""A fully connected layer that reshapes outputs to feature maps."""
|
80 |
+
|
81 |
+
def __init__(self, input_dim, output_dim, complex_valued=False):
|
82 |
+
super().__init__()
|
83 |
+
self.complex_valued = complex_valued
|
84 |
+
self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
return self.dense(x)[..., None, None]
|
88 |
+
|
89 |
+
|
90 |
+
def torch_complex_from_reim(re, im):
|
91 |
+
return torch.view_as_complex(torch.stack([re, im], dim=-1))
|
92 |
+
|
93 |
+
|
94 |
+
class ArgsComplexMultiplicationWrapper(nn.Module):
|
95 |
+
"""Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
|
96 |
+
|
97 |
+
Make a complex-valued module `F` from a real-valued module `f` by applying
|
98 |
+
complex multiplication rules:
|
99 |
+
|
100 |
+
F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
|
101 |
+
|
102 |
+
where `f1`, `f2` are instances of `f` that do *not* share weights.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
module_cls (callable): A class or function that returns a Torch module/functional.
|
106 |
+
Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
|
107 |
+
to construct the real and imaginary component modules.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, module_cls, *args, **kwargs):
|
111 |
+
super().__init__()
|
112 |
+
self.re_module = module_cls(*args, **kwargs)
|
113 |
+
self.im_module = module_cls(*args, **kwargs)
|
114 |
+
|
115 |
+
def forward(self, x, *args, **kwargs):
|
116 |
+
return torch_complex_from_reim(
|
117 |
+
self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
|
118 |
+
self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
|
123 |
+
ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
|
sgmse/data_module.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from os.path import join
|
3 |
+
import torch
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from glob import glob
|
8 |
+
from torchaudio import load
|
9 |
+
import numpy as np
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
def get_window(window_type, window_length):
|
14 |
+
if window_type == 'sqrthann':
|
15 |
+
return torch.sqrt(torch.hann_window(window_length, periodic=True))
|
16 |
+
elif window_type == 'hann':
|
17 |
+
return torch.hann_window(window_length, periodic=True)
|
18 |
+
else:
|
19 |
+
raise NotImplementedError(f"Window type {window_type} not implemented!")
|
20 |
+
|
21 |
+
|
22 |
+
class Specs(Dataset):
|
23 |
+
def __init__(self, data_dir, subset, dummy, shuffle_spec, num_frames,
|
24 |
+
format='default', normalize="noisy", spec_transform=None,
|
25 |
+
stft_kwargs=None, **ignored_kwargs):
|
26 |
+
|
27 |
+
# Read file paths according to file naming format.
|
28 |
+
if format == "default":
|
29 |
+
self.clean_files = []
|
30 |
+
self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav")))
|
31 |
+
self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav")))
|
32 |
+
self.noisy_files = []
|
33 |
+
self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav")))
|
34 |
+
self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav")))
|
35 |
+
elif format == "reverb":
|
36 |
+
self.clean_files = []
|
37 |
+
self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav")))
|
38 |
+
self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav")))
|
39 |
+
self.noisy_files = []
|
40 |
+
self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav")))
|
41 |
+
self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav")))
|
42 |
+
else:
|
43 |
+
# Feel free to add your own directory format
|
44 |
+
raise NotImplementedError(f"Directory format {format} unknown!")
|
45 |
+
|
46 |
+
self.dummy = dummy
|
47 |
+
self.num_frames = num_frames
|
48 |
+
self.shuffle_spec = shuffle_spec
|
49 |
+
self.normalize = normalize
|
50 |
+
self.spec_transform = spec_transform
|
51 |
+
|
52 |
+
assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs"
|
53 |
+
self.stft_kwargs = stft_kwargs
|
54 |
+
self.hop_length = self.stft_kwargs["hop_length"]
|
55 |
+
assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation"
|
56 |
+
|
57 |
+
def __getitem__(self, i):
|
58 |
+
x, _ = load(self.clean_files[i])
|
59 |
+
y, _ = load(self.noisy_files[i])
|
60 |
+
|
61 |
+
# formula applies for center=True
|
62 |
+
target_len = (self.num_frames - 1) * self.hop_length
|
63 |
+
current_len = x.size(-1)
|
64 |
+
pad = max(target_len - current_len, 0)
|
65 |
+
if pad == 0:
|
66 |
+
# extract random part of the audio file
|
67 |
+
if self.shuffle_spec:
|
68 |
+
start = int(np.random.uniform(0, current_len-target_len))
|
69 |
+
else:
|
70 |
+
start = int((current_len-target_len)/2)
|
71 |
+
x = x[..., start:start+target_len]
|
72 |
+
y = y[..., start:start+target_len]
|
73 |
+
else:
|
74 |
+
# pad audio if the length T is smaller than num_frames
|
75 |
+
x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant')
|
76 |
+
y = F.pad(y, (pad//2, pad//2+(pad%2)), mode='constant')
|
77 |
+
|
78 |
+
# normalize w.r.t to the noisy or the clean signal or not at all
|
79 |
+
# to ensure same clean signal power in x and y.
|
80 |
+
if self.normalize == "noisy":
|
81 |
+
normfac = y.abs().max()
|
82 |
+
elif self.normalize == "clean":
|
83 |
+
normfac = x.abs().max()
|
84 |
+
elif self.normalize == "not":
|
85 |
+
normfac = 1.0
|
86 |
+
x = x / normfac
|
87 |
+
y = y / normfac
|
88 |
+
|
89 |
+
X = torch.stft(x, **self.stft_kwargs)
|
90 |
+
Y = torch.stft(y, **self.stft_kwargs)
|
91 |
+
|
92 |
+
X, Y = self.spec_transform(X), self.spec_transform(Y)
|
93 |
+
return X, Y
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
if self.dummy:
|
97 |
+
# for debugging shrink the data set size
|
98 |
+
return int(len(self.clean_files)/200)
|
99 |
+
else:
|
100 |
+
return len(self.clean_files)
|
101 |
+
|
102 |
+
|
103 |
+
class SpecsDataModule(pl.LightningDataModule):
|
104 |
+
@staticmethod
|
105 |
+
def add_argparse_args(parser):
|
106 |
+
parser.add_argument("--base_dir", type=str, required=True, help="The base directory of the dataset. Should contain `train`, `valid` and `test` subdirectories, each of which contain `clean` and `noisy` subdirectories.")
|
107 |
+
parser.add_argument("--format", type=str, choices=("default", "reverb"), default="default", help="Read file paths according to file naming format.")
|
108 |
+
parser.add_argument("--batch_size", type=int, default=8, help="The batch size. 8 by default.")
|
109 |
+
parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 256 freq bins
|
110 |
+
parser.add_argument("--hop_length", type=int, default=128, help="Window hop length. 128 by default.")
|
111 |
+
parser.add_argument("--num_frames", type=int, default=256, help="Number of frames for the dataset. 256 by default.")
|
112 |
+
parser.add_argument("--window", type=str, choices=("sqrthann", "hann"), default="hann", help="The window function to use for the STFT. 'hann' by default.")
|
113 |
+
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to use for DataLoaders. 4 by default.")
|
114 |
+
parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.")
|
115 |
+
parser.add_argument("--spec_factor", type=float, default=0.15, help="Factor to multiply complex STFT coefficients by. 0.15 by default.")
|
116 |
+
parser.add_argument("--spec_abs_exponent", type=float, default=0.5, help="Exponent e for the transformation abs(z)**e * exp(1j*angle(z)). 0.5 by default.")
|
117 |
+
parser.add_argument("--normalize", type=str, choices=("clean", "noisy", "not"), default="noisy", help="Normalize the input waveforms by the clean signal, the noisy signal, or not at all.")
|
118 |
+
parser.add_argument("--transform_type", type=str, choices=("exponent", "log", "none"), default="exponent", help="Spectogram transformation for input representation.")
|
119 |
+
return parser
|
120 |
+
|
121 |
+
def __init__(
|
122 |
+
self, base_dir, format='default', batch_size=8,
|
123 |
+
n_fft=510, hop_length=128, num_frames=256, window='hann',
|
124 |
+
num_workers=4, dummy=False, spec_factor=0.15, spec_abs_exponent=0.5,
|
125 |
+
gpu=True, normalize='noisy', transform_type="exponent", **kwargs
|
126 |
+
):
|
127 |
+
super().__init__()
|
128 |
+
self.base_dir = base_dir
|
129 |
+
self.format = format
|
130 |
+
self.batch_size = batch_size
|
131 |
+
self.n_fft = n_fft
|
132 |
+
self.hop_length = hop_length
|
133 |
+
self.num_frames = num_frames
|
134 |
+
self.window = get_window(window, self.n_fft)
|
135 |
+
self.windows = {}
|
136 |
+
self.num_workers = num_workers
|
137 |
+
self.dummy = dummy
|
138 |
+
self.spec_factor = spec_factor
|
139 |
+
self.spec_abs_exponent = spec_abs_exponent
|
140 |
+
self.gpu = gpu
|
141 |
+
self.normalize = normalize
|
142 |
+
self.transform_type = transform_type
|
143 |
+
self.kwargs = kwargs
|
144 |
+
|
145 |
+
def setup(self, stage=None):
|
146 |
+
specs_kwargs = dict(
|
147 |
+
stft_kwargs=self.stft_kwargs, num_frames=self.num_frames,
|
148 |
+
spec_transform=self.spec_fwd, **self.kwargs
|
149 |
+
)
|
150 |
+
if stage == 'fit' or stage is None:
|
151 |
+
self.train_set = Specs(data_dir=self.base_dir, subset='train',
|
152 |
+
dummy=self.dummy, shuffle_spec=True, format=self.format,
|
153 |
+
normalize=self.normalize, **specs_kwargs)
|
154 |
+
self.valid_set = Specs(data_dir=self.base_dir, subset='valid',
|
155 |
+
dummy=self.dummy, shuffle_spec=False, format=self.format,
|
156 |
+
normalize=self.normalize, **specs_kwargs)
|
157 |
+
if stage == 'test' or stage is None:
|
158 |
+
self.test_set = Specs(data_dir=self.base_dir, subset='test',
|
159 |
+
dummy=self.dummy, shuffle_spec=False, format=self.format,
|
160 |
+
normalize=self.normalize, **specs_kwargs)
|
161 |
+
|
162 |
+
def spec_fwd(self, spec):
|
163 |
+
if self.transform_type == "exponent":
|
164 |
+
if self.spec_abs_exponent != 1:
|
165 |
+
# only do this calculation if spec_exponent != 1, otherwise it's quite a bit of wasted computation
|
166 |
+
# and introduced numerical error
|
167 |
+
e = self.spec_abs_exponent
|
168 |
+
spec = spec.abs()**e * torch.exp(1j * spec.angle())
|
169 |
+
spec = spec * self.spec_factor
|
170 |
+
elif self.transform_type == "log":
|
171 |
+
spec = torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle())
|
172 |
+
spec = spec * self.spec_factor
|
173 |
+
elif self.transform_type == "none":
|
174 |
+
spec = spec
|
175 |
+
return spec
|
176 |
+
|
177 |
+
def spec_back(self, spec):
|
178 |
+
if self.transform_type == "exponent":
|
179 |
+
spec = spec / self.spec_factor
|
180 |
+
if self.spec_abs_exponent != 1:
|
181 |
+
e = self.spec_abs_exponent
|
182 |
+
spec = spec.abs()**(1/e) * torch.exp(1j * spec.angle())
|
183 |
+
elif self.transform_type == "log":
|
184 |
+
spec = spec / self.spec_factor
|
185 |
+
spec = (torch.exp(spec.abs()) - 1) * torch.exp(1j * spec.angle())
|
186 |
+
elif self.transform_type == "none":
|
187 |
+
spec = spec
|
188 |
+
return spec
|
189 |
+
|
190 |
+
@property
|
191 |
+
def stft_kwargs(self):
|
192 |
+
return {**self.istft_kwargs, "return_complex": True}
|
193 |
+
|
194 |
+
@property
|
195 |
+
def istft_kwargs(self):
|
196 |
+
return dict(
|
197 |
+
n_fft=self.n_fft, hop_length=self.hop_length,
|
198 |
+
window=self.window, center=True
|
199 |
+
)
|
200 |
+
|
201 |
+
def _get_window(self, x):
|
202 |
+
"""
|
203 |
+
Retrieve an appropriate window for the given tensor x, matching the device.
|
204 |
+
Caches the retrieved windows so that only one window tensor will be allocated per device.
|
205 |
+
"""
|
206 |
+
window = self.windows.get(x.device, None)
|
207 |
+
if window is None:
|
208 |
+
window = self.window.to(x.device)
|
209 |
+
self.windows[x.device] = window
|
210 |
+
return window
|
211 |
+
|
212 |
+
def stft(self, sig):
|
213 |
+
window = self._get_window(sig)
|
214 |
+
return torch.stft(sig, **{**self.stft_kwargs, "window": window})
|
215 |
+
|
216 |
+
def istft(self, spec, length=None):
|
217 |
+
window = self._get_window(spec)
|
218 |
+
return torch.istft(spec, **{**self.istft_kwargs, "window": window, "length": length})
|
219 |
+
|
220 |
+
def train_dataloader(self):
|
221 |
+
return DataLoader(
|
222 |
+
self.train_set, batch_size=self.batch_size,
|
223 |
+
num_workers=self.num_workers, pin_memory=self.gpu, shuffle=True
|
224 |
+
)
|
225 |
+
|
226 |
+
def val_dataloader(self):
|
227 |
+
return DataLoader(
|
228 |
+
self.valid_set, batch_size=self.batch_size,
|
229 |
+
num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
|
230 |
+
)
|
231 |
+
|
232 |
+
def test_dataloader(self):
|
233 |
+
return DataLoader(
|
234 |
+
self.test_set, batch_size=self.batch_size,
|
235 |
+
num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
|
236 |
+
)
|
sgmse/model.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from math import ceil
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from torch_ema import ExponentialMovingAverage
|
8 |
+
|
9 |
+
from sgmse import sampling
|
10 |
+
from sgmse.sdes import SDERegistry
|
11 |
+
from sgmse.backbones import BackboneRegistry
|
12 |
+
from sgmse.util.inference import evaluate_model
|
13 |
+
from sgmse.util.other import pad_spec
|
14 |
+
|
15 |
+
|
16 |
+
class ScoreModel(pl.LightningModule):
|
17 |
+
@staticmethod
|
18 |
+
def add_argparse_args(parser):
|
19 |
+
parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
|
20 |
+
parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
|
21 |
+
parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
|
22 |
+
parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
|
23 |
+
parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.")
|
24 |
+
return parser
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03,
|
28 |
+
num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Create a new ScoreModel.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
backbone: Backbone DNN that serves as a score-based model.
|
35 |
+
sde: The SDE that defines the diffusion process.
|
36 |
+
lr: The learning rate of the optimizer. (1e-4 by default).
|
37 |
+
ema_decay: The decay constant of the parameter EMA (0.999 by default).
|
38 |
+
t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
|
39 |
+
loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
# Initialize Backbone DNN
|
43 |
+
self.backbone = backbone
|
44 |
+
dnn_cls = BackboneRegistry.get_by_name(backbone)
|
45 |
+
self.dnn = dnn_cls(**kwargs)
|
46 |
+
# Initialize SDE
|
47 |
+
sde_cls = SDERegistry.get_by_name(sde)
|
48 |
+
self.sde = sde_cls(**kwargs)
|
49 |
+
# Store hyperparams and save them
|
50 |
+
self.lr = lr
|
51 |
+
self.ema_decay = ema_decay
|
52 |
+
self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
|
53 |
+
self._error_loading_ema = False
|
54 |
+
self.t_eps = t_eps
|
55 |
+
self.loss_type = loss_type
|
56 |
+
self.num_eval_files = num_eval_files
|
57 |
+
|
58 |
+
self.save_hyperparameters(ignore=['no_wandb'])
|
59 |
+
self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
|
60 |
+
|
61 |
+
def configure_optimizers(self):
|
62 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
63 |
+
return optimizer
|
64 |
+
|
65 |
+
def optimizer_step(self, *args, **kwargs):
|
66 |
+
# Method overridden so that the EMA params are updated after each optimizer step
|
67 |
+
super().optimizer_step(*args, **kwargs)
|
68 |
+
self.ema.update(self.parameters())
|
69 |
+
|
70 |
+
# on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
|
71 |
+
def on_load_checkpoint(self, checkpoint):
|
72 |
+
ema = checkpoint.get('ema', None)
|
73 |
+
if ema is not None:
|
74 |
+
self.ema.load_state_dict(checkpoint['ema'])
|
75 |
+
else:
|
76 |
+
self._error_loading_ema = True
|
77 |
+
warnings.warn("EMA state_dict not found in checkpoint!")
|
78 |
+
|
79 |
+
def on_save_checkpoint(self, checkpoint):
|
80 |
+
checkpoint['ema'] = self.ema.state_dict()
|
81 |
+
|
82 |
+
def train(self, mode, no_ema=False):
|
83 |
+
res = super().train(mode) # call the standard `train` method with the given mode
|
84 |
+
if not self._error_loading_ema:
|
85 |
+
if mode == False and not no_ema:
|
86 |
+
# eval
|
87 |
+
self.ema.store(self.parameters()) # store current params in EMA
|
88 |
+
self.ema.copy_to(self.parameters()) # copy EMA parameters over current params for evaluation
|
89 |
+
else:
|
90 |
+
# train
|
91 |
+
if self.ema.collected_params is not None:
|
92 |
+
self.ema.restore(self.parameters()) # restore the EMA weights (if stored)
|
93 |
+
return res
|
94 |
+
|
95 |
+
def eval(self, no_ema=False):
|
96 |
+
return self.train(False, no_ema=no_ema)
|
97 |
+
|
98 |
+
def _loss(self, err):
|
99 |
+
if self.loss_type == 'mse':
|
100 |
+
losses = torch.square(err.abs())
|
101 |
+
elif self.loss_type == 'mae':
|
102 |
+
losses = err.abs()
|
103 |
+
# taken from reduce_op function: sum over channels and position and mean over batch dim
|
104 |
+
# presumably only important for absolute loss number, not for gradients
|
105 |
+
loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
|
106 |
+
return loss
|
107 |
+
|
108 |
+
def _step(self, batch, batch_idx):
|
109 |
+
x, y = batch
|
110 |
+
t = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps
|
111 |
+
mean, std = self.sde.marginal_prob(x, t, y)
|
112 |
+
z = torch.randn_like(x) # i.i.d. normal distributed with var=0.5
|
113 |
+
sigmas = std[:, None, None, None]
|
114 |
+
perturbed_data = mean + sigmas * z
|
115 |
+
score = self(perturbed_data, t, y)
|
116 |
+
err = score * sigmas + z
|
117 |
+
loss = self._loss(err)
|
118 |
+
return loss
|
119 |
+
|
120 |
+
def training_step(self, batch, batch_idx):
|
121 |
+
loss = self._step(batch, batch_idx)
|
122 |
+
self.log('train_loss', loss, on_step=True, on_epoch=True)
|
123 |
+
return loss
|
124 |
+
|
125 |
+
def validation_step(self, batch, batch_idx):
|
126 |
+
loss = self._step(batch, batch_idx)
|
127 |
+
self.log('valid_loss', loss, on_step=False, on_epoch=True)
|
128 |
+
|
129 |
+
# Evaluate speech enhancement performance
|
130 |
+
if batch_idx == 0 and self.num_eval_files != 0:
|
131 |
+
pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files)
|
132 |
+
self.log('pesq', pesq, on_step=False, on_epoch=True)
|
133 |
+
self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
|
134 |
+
self.log('estoi', estoi, on_step=False, on_epoch=True)
|
135 |
+
|
136 |
+
return loss
|
137 |
+
|
138 |
+
def forward(self, x, t, y):
|
139 |
+
# Concatenate y as an extra channel
|
140 |
+
dnn_input = torch.cat([x, y], dim=1)
|
141 |
+
|
142 |
+
# the minus is most likely unimportant here - taken from Song's repo
|
143 |
+
score = -self.dnn(dnn_input, t)
|
144 |
+
return score
|
145 |
+
|
146 |
+
def to(self, *args, **kwargs):
|
147 |
+
"""Override PyTorch .to() to also transfer the EMA of the model weights"""
|
148 |
+
self.ema.to(*args, **kwargs)
|
149 |
+
return super().to(*args, **kwargs)
|
150 |
+
|
151 |
+
def get_pc_sampler(self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs):
|
152 |
+
N = self.sde.N if N is None else N
|
153 |
+
sde = self.sde.copy()
|
154 |
+
sde.N = N
|
155 |
+
|
156 |
+
kwargs = {"eps": self.t_eps, **kwargs}
|
157 |
+
if minibatch is None:
|
158 |
+
return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y, **kwargs)
|
159 |
+
else:
|
160 |
+
M = y.shape[0]
|
161 |
+
def batched_sampling_fn():
|
162 |
+
samples, ns = [], []
|
163 |
+
for i in range(int(ceil(M / minibatch))):
|
164 |
+
y_mini = y[i*minibatch:(i+1)*minibatch]
|
165 |
+
sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y_mini, **kwargs)
|
166 |
+
sample, n = sampler()
|
167 |
+
samples.append(sample)
|
168 |
+
ns.append(n)
|
169 |
+
samples = torch.cat(samples, dim=0)
|
170 |
+
return samples, ns
|
171 |
+
return batched_sampling_fn
|
172 |
+
|
173 |
+
def get_ode_sampler(self, y, N=None, minibatch=None, **kwargs):
|
174 |
+
N = self.sde.N if N is None else N
|
175 |
+
sde = self.sde.copy()
|
176 |
+
sde.N = N
|
177 |
+
|
178 |
+
kwargs = {"eps": self.t_eps, **kwargs}
|
179 |
+
if minibatch is None:
|
180 |
+
return sampling.get_ode_sampler(sde, self, y=y, **kwargs)
|
181 |
+
else:
|
182 |
+
M = y.shape[0]
|
183 |
+
def batched_sampling_fn():
|
184 |
+
samples, ns = [], []
|
185 |
+
for i in range(int(ceil(M / minibatch))):
|
186 |
+
y_mini = y[i*minibatch:(i+1)*minibatch]
|
187 |
+
sampler = sampling.get_ode_sampler(sde, self, y=y_mini, **kwargs)
|
188 |
+
sample, n = sampler()
|
189 |
+
samples.append(sample)
|
190 |
+
ns.append(n)
|
191 |
+
samples = torch.cat(samples, dim=0)
|
192 |
+
return sample, ns
|
193 |
+
return batched_sampling_fn
|
194 |
+
|
195 |
+
def train_dataloader(self):
|
196 |
+
return self.data_module.train_dataloader()
|
197 |
+
|
198 |
+
def val_dataloader(self):
|
199 |
+
return self.data_module.val_dataloader()
|
200 |
+
|
201 |
+
def test_dataloader(self):
|
202 |
+
return self.data_module.test_dataloader()
|
203 |
+
|
204 |
+
def setup(self, stage=None):
|
205 |
+
return self.data_module.setup(stage=stage)
|
206 |
+
|
207 |
+
def to_audio(self, spec, length=None):
|
208 |
+
return self._istft(self._backward_transform(spec), length)
|
209 |
+
|
210 |
+
def _forward_transform(self, spec):
|
211 |
+
return self.data_module.spec_fwd(spec)
|
212 |
+
|
213 |
+
def _backward_transform(self, spec):
|
214 |
+
return self.data_module.spec_back(spec)
|
215 |
+
|
216 |
+
def _stft(self, sig):
|
217 |
+
return self.data_module.stft(sig)
|
218 |
+
|
219 |
+
def _istft(self, spec, length=None):
|
220 |
+
return self.data_module.istft(spec, length)
|
221 |
+
|
222 |
+
def enhance(self, y, sampler_type="pc", predictor="reverse_diffusion",
|
223 |
+
corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False,
|
224 |
+
**kwargs
|
225 |
+
):
|
226 |
+
"""
|
227 |
+
One-call speech enhancement of noisy speech `y`, for convenience.
|
228 |
+
"""
|
229 |
+
sr=16000
|
230 |
+
start = time.time()
|
231 |
+
T_orig = y.size(1)
|
232 |
+
norm_factor = y.abs().max().item()
|
233 |
+
y = y / norm_factor
|
234 |
+
Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0)
|
235 |
+
Y = pad_spec(Y)
|
236 |
+
if sampler_type == "pc":
|
237 |
+
sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), N=N,
|
238 |
+
corrector_steps=corrector_steps, snr=snr, intermediate=False,
|
239 |
+
**kwargs)
|
240 |
+
elif sampler_type == "ode":
|
241 |
+
sampler = self.get_ode_sampler(Y.cuda(), N=N, **kwargs)
|
242 |
+
else:
|
243 |
+
print("{} is not a valid sampler type!".format(sampler_type))
|
244 |
+
sample, nfe = sampler()
|
245 |
+
x_hat = self.to_audio(sample.squeeze(), T_orig)
|
246 |
+
x_hat = x_hat * norm_factor
|
247 |
+
x_hat = x_hat.squeeze().cpu().numpy()
|
248 |
+
end = time.time()
|
249 |
+
if timeit:
|
250 |
+
rtf = (end-start)/(len(x_hat)/sr)
|
251 |
+
return x_hat, nfe, rtf
|
252 |
+
else:
|
253 |
+
return x_hat
|
sgmse/sampling/__init__.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py
|
2 |
+
"""Various sampling methods."""
|
3 |
+
from scipy import integrate
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
|
7 |
+
from .correctors import Corrector, CorrectorRegistry
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
|
12 |
+
'get_sampler'
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
def to_flattened_numpy(x):
|
17 |
+
"""Flatten a torch tensor `x` and convert it to numpy."""
|
18 |
+
return x.detach().cpu().numpy().reshape((-1,))
|
19 |
+
|
20 |
+
|
21 |
+
def from_flattened_numpy(x, shape):
|
22 |
+
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
23 |
+
return torch.from_numpy(x.reshape(shape))
|
24 |
+
|
25 |
+
|
26 |
+
def get_pc_sampler(
|
27 |
+
predictor_name, corrector_name, sde, score_fn, y,
|
28 |
+
denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
|
29 |
+
intermediate=False, **kwargs
|
30 |
+
):
|
31 |
+
"""Create a Predictor-Corrector (PC) sampler.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
predictor_name: The name of a registered `sampling.Predictor`.
|
35 |
+
corrector_name: The name of a registered `sampling.Corrector`.
|
36 |
+
sde: An `sdes.SDE` object representing the forward SDE.
|
37 |
+
score_fn: A function (typically learned model) that predicts the score.
|
38 |
+
y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
|
39 |
+
denoise: If `True`, add one-step denoising to the final samples.
|
40 |
+
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
|
41 |
+
snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
|
42 |
+
N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
A sampling function that returns samples and the number of function evaluations during sampling.
|
46 |
+
"""
|
47 |
+
predictor_cls = PredictorRegistry.get_by_name(predictor_name)
|
48 |
+
corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
|
49 |
+
predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
|
50 |
+
corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
|
51 |
+
|
52 |
+
def pc_sampler():
|
53 |
+
"""The PC sampler function."""
|
54 |
+
with torch.no_grad():
|
55 |
+
xt = sde.prior_sampling(y.shape, y).to(y.device)
|
56 |
+
timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
|
57 |
+
for i in range(sde.N):
|
58 |
+
t = timesteps[i]
|
59 |
+
if i != len(timesteps) - 1:
|
60 |
+
stepsize = t - timesteps[i+1]
|
61 |
+
else:
|
62 |
+
stepsize = timesteps[-1] # from eps to 0
|
63 |
+
vec_t = torch.ones(y.shape[0], device=y.device) * t
|
64 |
+
xt, xt_mean = corrector.update_fn(xt, vec_t, y)
|
65 |
+
xt, xt_mean = predictor.update_fn(xt, vec_t, y, stepsize)
|
66 |
+
x_result = xt_mean if denoise else xt
|
67 |
+
ns = sde.N * (corrector.n_steps + 1)
|
68 |
+
return x_result, ns
|
69 |
+
|
70 |
+
return pc_sampler
|
71 |
+
|
72 |
+
|
73 |
+
def get_ode_sampler(
|
74 |
+
sde, score_fn, y, inverse_scaler=None,
|
75 |
+
denoise=True, rtol=1e-5, atol=1e-5,
|
76 |
+
method='RK45', eps=3e-2, device='cuda', **kwargs
|
77 |
+
):
|
78 |
+
"""Probability flow ODE sampler with the black-box ODE solver.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
sde: An `sdes.SDE` object representing the forward SDE.
|
82 |
+
score_fn: A function (typically learned model) that predicts the score.
|
83 |
+
y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
|
84 |
+
inverse_scaler: The inverse data normalizer.
|
85 |
+
denoise: If `True`, add one-step denoising to final samples.
|
86 |
+
rtol: A `float` number. The relative tolerance level of the ODE solver.
|
87 |
+
atol: A `float` number. The absolute tolerance level of the ODE solver.
|
88 |
+
method: A `str`. The algorithm used for the black-box ODE solver.
|
89 |
+
See the documentation of `scipy.integrate.solve_ivp`.
|
90 |
+
eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
|
91 |
+
device: PyTorch device.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
A sampling function that returns samples and the number of function evaluations during sampling.
|
95 |
+
"""
|
96 |
+
predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
|
97 |
+
rsde = sde.reverse(score_fn, probability_flow=True)
|
98 |
+
|
99 |
+
def denoise_update_fn(x):
|
100 |
+
vec_eps = torch.ones(x.shape[0], device=x.device) * eps
|
101 |
+
_, x = predictor.update_fn(x, vec_eps, y)
|
102 |
+
return x
|
103 |
+
|
104 |
+
def drift_fn(x, t, y):
|
105 |
+
"""Get the drift function of the reverse-time SDE."""
|
106 |
+
return rsde.sde(x, t, y)[0]
|
107 |
+
|
108 |
+
def ode_sampler(z=None, **kwargs):
|
109 |
+
"""The probability flow ODE sampler with black-box ODE solver.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
model: A score model.
|
113 |
+
z: If present, generate samples from latent code `z`.
|
114 |
+
Returns:
|
115 |
+
samples, number of function evaluations.
|
116 |
+
"""
|
117 |
+
with torch.no_grad():
|
118 |
+
# If not represent, sample the latent code from the prior distibution of the SDE.
|
119 |
+
x = sde.prior_sampling(y.shape, y).to(device)
|
120 |
+
|
121 |
+
def ode_func(t, x):
|
122 |
+
x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64)
|
123 |
+
vec_t = torch.ones(y.shape[0], device=x.device) * t
|
124 |
+
drift = drift_fn(x, vec_t, y)
|
125 |
+
return to_flattened_numpy(drift)
|
126 |
+
|
127 |
+
# Black-box ODE solver for the probability flow ODE
|
128 |
+
solution = integrate.solve_ivp(
|
129 |
+
ode_func, (sde.T, eps), to_flattened_numpy(x),
|
130 |
+
rtol=rtol, atol=atol, method=method, **kwargs
|
131 |
+
)
|
132 |
+
nfe = solution.nfev
|
133 |
+
x = torch.tensor(solution.y[:, -1]).reshape(y.shape).to(device).type(torch.complex64)
|
134 |
+
|
135 |
+
# Denoising is equivalent to running one predictor step without adding noise
|
136 |
+
if denoise:
|
137 |
+
x = denoise_update_fn(x)
|
138 |
+
|
139 |
+
if inverse_scaler is not None:
|
140 |
+
x = inverse_scaler(x)
|
141 |
+
return x, nfe
|
142 |
+
|
143 |
+
return ode_sampler
|
sgmse/sampling/correctors.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from sgmse import sdes
|
5 |
+
from sgmse.util.registry import Registry
|
6 |
+
|
7 |
+
|
8 |
+
CorrectorRegistry = Registry("Corrector")
|
9 |
+
|
10 |
+
|
11 |
+
class Corrector(abc.ABC):
|
12 |
+
"""The abstract class for a corrector algorithm."""
|
13 |
+
|
14 |
+
def __init__(self, sde, score_fn, snr, n_steps):
|
15 |
+
super().__init__()
|
16 |
+
self.rsde = sde.reverse(score_fn)
|
17 |
+
self.score_fn = score_fn
|
18 |
+
self.snr = snr
|
19 |
+
self.n_steps = n_steps
|
20 |
+
|
21 |
+
@abc.abstractmethod
|
22 |
+
def update_fn(self, x, t, *args):
|
23 |
+
"""One update of the corrector.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
x: A PyTorch tensor representing the current state
|
27 |
+
t: A PyTorch tensor representing the current time step.
|
28 |
+
*args: Possibly additional arguments, in particular `y` for OU processes
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
x: A PyTorch tensor of the next state.
|
32 |
+
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
33 |
+
"""
|
34 |
+
pass
|
35 |
+
|
36 |
+
|
37 |
+
@CorrectorRegistry.register(name='langevin')
|
38 |
+
class LangevinCorrector(Corrector):
|
39 |
+
def __init__(self, sde, score_fn, snr, n_steps):
|
40 |
+
super().__init__(sde, score_fn, snr, n_steps)
|
41 |
+
self.score_fn = score_fn
|
42 |
+
self.n_steps = n_steps
|
43 |
+
self.snr = snr
|
44 |
+
|
45 |
+
def update_fn(self, x, t, *args):
|
46 |
+
target_snr = self.snr
|
47 |
+
for _ in range(self.n_steps):
|
48 |
+
grad = self.score_fn(x, t, *args)
|
49 |
+
noise = torch.randn_like(x)
|
50 |
+
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
51 |
+
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
52 |
+
step_size = ((target_snr * noise_norm / grad_norm) ** 2 * 2).unsqueeze(0)
|
53 |
+
x_mean = x + step_size[:, None, None, None] * grad
|
54 |
+
x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
|
55 |
+
|
56 |
+
return x, x_mean
|
57 |
+
|
58 |
+
|
59 |
+
@CorrectorRegistry.register(name='ald')
|
60 |
+
class AnnealedLangevinDynamics(Corrector):
|
61 |
+
"""The original annealed Langevin dynamics predictor in NCSN/NCSNv2."""
|
62 |
+
def __init__(self, sde, score_fn, snr, n_steps):
|
63 |
+
super().__init__(sde, score_fn, snr, n_steps)
|
64 |
+
if not isinstance(sde, (sdes.OUVESDE,)):
|
65 |
+
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
66 |
+
self.sde = sde
|
67 |
+
self.score_fn = score_fn
|
68 |
+
self.snr = snr
|
69 |
+
self.n_steps = n_steps
|
70 |
+
|
71 |
+
def update_fn(self, x, t, *args):
|
72 |
+
n_steps = self.n_steps
|
73 |
+
target_snr = self.snr
|
74 |
+
std = self.sde.marginal_prob(x, t, *args)[1]
|
75 |
+
|
76 |
+
for _ in range(n_steps):
|
77 |
+
grad = self.score_fn(x, t, *args)
|
78 |
+
noise = torch.randn_like(x)
|
79 |
+
step_size = (target_snr * std) ** 2 * 2
|
80 |
+
x_mean = x + step_size[:, None, None, None] * grad
|
81 |
+
x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
|
82 |
+
|
83 |
+
return x, x_mean
|
84 |
+
|
85 |
+
|
86 |
+
@CorrectorRegistry.register(name='none')
|
87 |
+
class NoneCorrector(Corrector):
|
88 |
+
"""An empty corrector that does nothing."""
|
89 |
+
|
90 |
+
def __init__(self, *args, **kwargs):
|
91 |
+
self.snr = 0
|
92 |
+
self.n_steps = 0
|
93 |
+
pass
|
94 |
+
|
95 |
+
def update_fn(self, x, t, *args):
|
96 |
+
return x, x
|
sgmse/sampling/predictors.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from sgmse.util.registry import Registry
|
7 |
+
|
8 |
+
|
9 |
+
PredictorRegistry = Registry("Predictor")
|
10 |
+
|
11 |
+
|
12 |
+
class Predictor(abc.ABC):
|
13 |
+
"""The abstract class for a predictor algorithm."""
|
14 |
+
|
15 |
+
def __init__(self, sde, score_fn, probability_flow=False):
|
16 |
+
super().__init__()
|
17 |
+
self.sde = sde
|
18 |
+
self.rsde = sde.reverse(score_fn)
|
19 |
+
self.score_fn = score_fn
|
20 |
+
self.probability_flow = probability_flow
|
21 |
+
|
22 |
+
@abc.abstractmethod
|
23 |
+
def update_fn(self, x, t, *args):
|
24 |
+
"""One update of the predictor.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
x: A PyTorch tensor representing the current state
|
28 |
+
t: A Pytorch tensor representing the current time step.
|
29 |
+
*args: Possibly additional arguments, in particular `y` for OU processes
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
x: A PyTorch tensor of the next state.
|
33 |
+
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
34 |
+
"""
|
35 |
+
pass
|
36 |
+
|
37 |
+
def debug_update_fn(self, x, t, *args):
|
38 |
+
raise NotImplementedError(f"Debug update function not implemented for predictor {self}.")
|
39 |
+
|
40 |
+
|
41 |
+
@PredictorRegistry.register('euler_maruyama')
|
42 |
+
class EulerMaruyamaPredictor(Predictor):
|
43 |
+
def __init__(self, sde, score_fn, probability_flow=False):
|
44 |
+
super().__init__(sde, score_fn, probability_flow=probability_flow)
|
45 |
+
|
46 |
+
def update_fn(self, x, t, *args):
|
47 |
+
dt = -1. / self.rsde.N
|
48 |
+
z = torch.randn_like(x)
|
49 |
+
f, g = self.rsde.sde(x, t, *args)
|
50 |
+
x_mean = x + f * dt
|
51 |
+
x = x_mean + g[:, None, None, None] * np.sqrt(-dt) * z
|
52 |
+
return x, x_mean
|
53 |
+
|
54 |
+
|
55 |
+
@PredictorRegistry.register('reverse_diffusion')
|
56 |
+
class ReverseDiffusionPredictor(Predictor):
|
57 |
+
def __init__(self, sde, score_fn, probability_flow=False):
|
58 |
+
super().__init__(sde, score_fn, probability_flow=probability_flow)
|
59 |
+
|
60 |
+
def update_fn(self, x, t, y, stepsize):
|
61 |
+
f, g = self.rsde.discretize(x, t, y, stepsize)
|
62 |
+
z = torch.randn_like(x)
|
63 |
+
x_mean = x - f
|
64 |
+
x = x_mean + g[:, None, None, None] * z
|
65 |
+
return x, x_mean
|
66 |
+
|
67 |
+
|
68 |
+
@PredictorRegistry.register('none')
|
69 |
+
class NonePredictor(Predictor):
|
70 |
+
"""An empty predictor that does nothing."""
|
71 |
+
|
72 |
+
def __init__(self, *args, **kwargs):
|
73 |
+
pass
|
74 |
+
|
75 |
+
def update_fn(self, x, t, *args):
|
76 |
+
return x, x
|
sgmse/sdes.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
|
3 |
+
|
4 |
+
Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
|
5 |
+
"""
|
6 |
+
import abc
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from sgmse.util.tensors import batch_broadcast
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from sgmse.util.registry import Registry
|
14 |
+
|
15 |
+
|
16 |
+
SDERegistry = Registry("SDE")
|
17 |
+
|
18 |
+
|
19 |
+
class SDE(abc.ABC):
|
20 |
+
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
|
21 |
+
|
22 |
+
def __init__(self, N):
|
23 |
+
"""Construct an SDE.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
N: number of discretization time steps.
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
self.N = N
|
30 |
+
|
31 |
+
@property
|
32 |
+
@abc.abstractmethod
|
33 |
+
def T(self):
|
34 |
+
"""End time of the SDE."""
|
35 |
+
pass
|
36 |
+
|
37 |
+
@abc.abstractmethod
|
38 |
+
def sde(self, x, t, *args):
|
39 |
+
pass
|
40 |
+
|
41 |
+
@abc.abstractmethod
|
42 |
+
def marginal_prob(self, x, t, *args):
|
43 |
+
"""Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
|
44 |
+
pass
|
45 |
+
|
46 |
+
@abc.abstractmethod
|
47 |
+
def prior_sampling(self, shape, *args):
|
48 |
+
"""Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
|
49 |
+
pass
|
50 |
+
|
51 |
+
@abc.abstractmethod
|
52 |
+
def prior_logp(self, z):
|
53 |
+
"""Compute log-density of the prior distribution.
|
54 |
+
|
55 |
+
Useful for computing the log-likelihood via probability flow ODE.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
z: latent code
|
59 |
+
Returns:
|
60 |
+
log probability density
|
61 |
+
"""
|
62 |
+
pass
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
@abc.abstractmethod
|
66 |
+
def add_argparse_args(parent_parser):
|
67 |
+
"""
|
68 |
+
Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
|
69 |
+
"""
|
70 |
+
pass
|
71 |
+
|
72 |
+
def discretize(self, x, t, y, stepsize):
|
73 |
+
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
|
74 |
+
|
75 |
+
Useful for reverse diffusion sampling and probabiliy flow sampling.
|
76 |
+
Defaults to Euler-Maruyama discretization.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
x: a torch tensor
|
80 |
+
t: a torch float representing the time step (from 0 to `self.T`)
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
f, G
|
84 |
+
"""
|
85 |
+
dt = stepsize
|
86 |
+
drift, diffusion = self.sde(x, t, y)
|
87 |
+
f = drift * dt
|
88 |
+
G = diffusion * torch.sqrt(dt)
|
89 |
+
return f, G
|
90 |
+
|
91 |
+
def reverse(oself, score_model, probability_flow=False):
|
92 |
+
"""Create the reverse-time SDE/ODE.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
score_model: A function that takes x, t and y and returns the score.
|
96 |
+
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
|
97 |
+
"""
|
98 |
+
N = oself.N
|
99 |
+
T = oself.T
|
100 |
+
sde_fn = oself.sde
|
101 |
+
discretize_fn = oself.discretize
|
102 |
+
|
103 |
+
# Build the class for reverse-time SDE.
|
104 |
+
class RSDE(oself.__class__):
|
105 |
+
def __init__(self):
|
106 |
+
self.N = N
|
107 |
+
self.probability_flow = probability_flow
|
108 |
+
|
109 |
+
@property
|
110 |
+
def T(self):
|
111 |
+
return T
|
112 |
+
|
113 |
+
def sde(self, x, t, *args):
|
114 |
+
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
|
115 |
+
rsde_parts = self.rsde_parts(x, t, *args)
|
116 |
+
total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
|
117 |
+
return total_drift, diffusion
|
118 |
+
|
119 |
+
def rsde_parts(self, x, t, *args):
|
120 |
+
sde_drift, sde_diffusion = sde_fn(x, t, *args)
|
121 |
+
score = score_model(x, t, *args)
|
122 |
+
score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.)
|
123 |
+
diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion
|
124 |
+
total_drift = sde_drift + score_drift
|
125 |
+
return {
|
126 |
+
'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift,
|
127 |
+
'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
|
128 |
+
}
|
129 |
+
|
130 |
+
def discretize(self, x, t, y, stepsize):
|
131 |
+
"""Create discretized iteration rules for the reverse diffusion sampler."""
|
132 |
+
f, G = discretize_fn(x, t, y, stepsize)
|
133 |
+
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.)
|
134 |
+
rev_G = torch.zeros_like(G) if self.probability_flow else G
|
135 |
+
return rev_f, rev_G
|
136 |
+
|
137 |
+
return RSDE()
|
138 |
+
|
139 |
+
@abc.abstractmethod
|
140 |
+
def copy(self):
|
141 |
+
pass
|
142 |
+
|
143 |
+
|
144 |
+
@SDERegistry.register("ouve")
|
145 |
+
class OUVESDE(SDE):
|
146 |
+
@staticmethod
|
147 |
+
def add_argparse_args(parser):
|
148 |
+
parser.add_argument("--sde-n", type=int, default=1000, help="The number of timesteps in the SDE discretization. 30 by default")
|
149 |
+
parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.")
|
150 |
+
parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.")
|
151 |
+
parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.")
|
152 |
+
return parser
|
153 |
+
|
154 |
+
def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs):
|
155 |
+
"""Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
|
156 |
+
|
157 |
+
Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
|
158 |
+
to the methods which require it (e.g., `sde` or `marginal_prob`).
|
159 |
+
|
160 |
+
dx = -theta (y-x) dt + sigma(t) dw
|
161 |
+
|
162 |
+
with
|
163 |
+
|
164 |
+
sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
|
165 |
+
|
166 |
+
Args:
|
167 |
+
theta: stiffness parameter.
|
168 |
+
sigma_min: smallest sigma.
|
169 |
+
sigma_max: largest sigma.
|
170 |
+
N: number of discretization steps
|
171 |
+
"""
|
172 |
+
super().__init__(N)
|
173 |
+
self.theta = theta
|
174 |
+
self.sigma_min = sigma_min
|
175 |
+
self.sigma_max = sigma_max
|
176 |
+
self.logsig = np.log(self.sigma_max / self.sigma_min)
|
177 |
+
self.N = N
|
178 |
+
|
179 |
+
def copy(self):
|
180 |
+
return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N)
|
181 |
+
|
182 |
+
@property
|
183 |
+
def T(self):
|
184 |
+
return 1
|
185 |
+
|
186 |
+
def sde(self, x, t, y):
|
187 |
+
drift = self.theta * (y - x)
|
188 |
+
# the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
|
189 |
+
# standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
|
190 |
+
# with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
|
191 |
+
# unless this sqrt(2*logsig) factor is included.
|
192 |
+
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
193 |
+
diffusion = sigma * np.sqrt(2 * self.logsig)
|
194 |
+
return drift, diffusion
|
195 |
+
|
196 |
+
def _mean(self, x0, t, y):
|
197 |
+
theta = self.theta
|
198 |
+
exp_interp = torch.exp(-theta * t)[:, None, None, None]
|
199 |
+
return exp_interp * x0 + (1 - exp_interp) * y
|
200 |
+
|
201 |
+
def alpha(self, t):
|
202 |
+
return torch.exp(-self.theta * t)
|
203 |
+
|
204 |
+
def _std(self, t):
|
205 |
+
# This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
|
206 |
+
sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
|
207 |
+
# could maybe replace the two torch.exp(... * t) terms here by cached values **t
|
208 |
+
return torch.sqrt(
|
209 |
+
(
|
210 |
+
sigma_min**2
|
211 |
+
* torch.exp(-2 * theta * t)
|
212 |
+
* (torch.exp(2 * (theta + logsig) * t) - 1)
|
213 |
+
* logsig
|
214 |
+
)
|
215 |
+
/
|
216 |
+
(theta + logsig)
|
217 |
+
)
|
218 |
+
|
219 |
+
def marginal_prob(self, x0, t, y):
|
220 |
+
return self._mean(x0, t, y), self._std(t)
|
221 |
+
|
222 |
+
def prior_sampling(self, shape, y):
|
223 |
+
if shape != y.shape:
|
224 |
+
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
|
225 |
+
std = self._std(torch.ones((y.shape[0],), device=y.device))
|
226 |
+
x_T = y + torch.randn_like(y) * std[:, None, None, None]
|
227 |
+
return x_T
|
228 |
+
|
229 |
+
def prior_logp(self, z):
|
230 |
+
raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
|
231 |
+
|
232 |
+
|
233 |
+
@SDERegistry.register("ouvp")
|
234 |
+
class OUVPSDE(SDE):
|
235 |
+
# !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
|
236 |
+
@staticmethod
|
237 |
+
def add_argparse_args(parser):
|
238 |
+
parser.add_argument("--sde-n", type=int, default=1000,
|
239 |
+
help="The number of timesteps in the SDE discretization. 1000 by default")
|
240 |
+
parser.add_argument("--beta-min", type=float, required=True,
|
241 |
+
help="The minimum beta to use.")
|
242 |
+
parser.add_argument("--beta-max", type=float, required=True,
|
243 |
+
help="The maximum beta to use.")
|
244 |
+
parser.add_argument("--stiffness", type=float, default=1,
|
245 |
+
help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.")
|
246 |
+
return parser
|
247 |
+
|
248 |
+
def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs):
|
249 |
+
"""
|
250 |
+
!!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
|
251 |
+
|
252 |
+
Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
|
253 |
+
|
254 |
+
dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
|
255 |
+
|
256 |
+
with
|
257 |
+
|
258 |
+
beta(t) = beta_min + t(beta_max - beta_min)
|
259 |
+
|
260 |
+
Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
|
261 |
+
to the methods which require it (e.g., `sde` or `marginal_prob`).
|
262 |
+
|
263 |
+
Args:
|
264 |
+
beta_min: smallest sigma.
|
265 |
+
beta_max: largest sigma.
|
266 |
+
stiffness: stiffness factor of the drift. 1 by default.
|
267 |
+
N: number of discretization steps
|
268 |
+
"""
|
269 |
+
super().__init__(N)
|
270 |
+
self.beta_min = beta_min
|
271 |
+
self.beta_max = beta_max
|
272 |
+
self.stiffness = stiffness
|
273 |
+
self.N = N
|
274 |
+
|
275 |
+
def copy(self):
|
276 |
+
return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)
|
277 |
+
|
278 |
+
@property
|
279 |
+
def T(self):
|
280 |
+
return 1
|
281 |
+
|
282 |
+
def _beta(self, t):
|
283 |
+
return self.beta_min + t * (self.beta_max - self.beta_min)
|
284 |
+
|
285 |
+
def sde(self, x, t, y):
|
286 |
+
drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x)
|
287 |
+
diffusion = torch.sqrt(self._beta(t))
|
288 |
+
return drift, diffusion
|
289 |
+
|
290 |
+
def _mean(self, x0, t, y):
|
291 |
+
b0, b1, s = self.beta_min, self.beta_max, self.stiffness
|
292 |
+
x0y_fac = torch.exp(-0.25 * s * t * (t * (b1-b0) + 2 * b0))[:, None, None, None]
|
293 |
+
return y + x0y_fac * (x0 - y)
|
294 |
+
|
295 |
+
def _std(self, t):
|
296 |
+
b0, b1, s = self.beta_min, self.beta_max, self.stiffness
|
297 |
+
return (1 - torch.exp(-0.5 * s * t * (t * (b1-b0) + 2 * b0))) / s
|
298 |
+
|
299 |
+
def marginal_prob(self, x0, t, y):
|
300 |
+
return self._mean(x0, t, y), self._std(t)
|
301 |
+
|
302 |
+
def prior_sampling(self, shape, y):
|
303 |
+
if shape != y.shape:
|
304 |
+
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
|
305 |
+
std = self._std(torch.ones((y.shape[0],), device=y.device))
|
306 |
+
x_T = y + torch.randn_like(y) * std[:, None, None, None]
|
307 |
+
return x_T
|
308 |
+
|
309 |
+
def prior_logp(self, z):
|
310 |
+
raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
|
sgmse/util/inference.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchaudio import load
|
3 |
+
|
4 |
+
from pesq import pesq
|
5 |
+
from pystoi import stoi
|
6 |
+
|
7 |
+
from .other import si_sdr, pad_spec
|
8 |
+
|
9 |
+
# Settings
|
10 |
+
sr = 16000
|
11 |
+
snr = 0.5
|
12 |
+
N = 30
|
13 |
+
corrector_steps = 1
|
14 |
+
|
15 |
+
|
16 |
+
def evaluate_model(model, num_eval_files):
|
17 |
+
|
18 |
+
clean_files = model.data_module.valid_set.clean_files
|
19 |
+
noisy_files = model.data_module.valid_set.noisy_files
|
20 |
+
|
21 |
+
# Select test files uniformly accros validation files
|
22 |
+
total_num_files = len(clean_files)
|
23 |
+
indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
|
24 |
+
clean_files = list(clean_files[i] for i in indices)
|
25 |
+
noisy_files = list(noisy_files[i] for i in indices)
|
26 |
+
|
27 |
+
_pesq = 0
|
28 |
+
_si_sdr = 0
|
29 |
+
_estoi = 0
|
30 |
+
# iterate over files
|
31 |
+
for (clean_file, noisy_file) in zip(clean_files, noisy_files):
|
32 |
+
# Load wavs
|
33 |
+
x, _ = load(clean_file)
|
34 |
+
y, _ = load(noisy_file)
|
35 |
+
T_orig = x.size(1)
|
36 |
+
|
37 |
+
# Normalize per utterance
|
38 |
+
norm_factor = y.abs().max()
|
39 |
+
y = y / norm_factor
|
40 |
+
|
41 |
+
# Prepare DNN input
|
42 |
+
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
|
43 |
+
Y = pad_spec(Y)
|
44 |
+
y = y * norm_factor
|
45 |
+
|
46 |
+
# Reverse sampling
|
47 |
+
sampler = model.get_pc_sampler(
|
48 |
+
'reverse_diffusion', 'ald', Y.cuda(), N=N,
|
49 |
+
corrector_steps=corrector_steps, snr=snr)
|
50 |
+
sample, _ = sampler()
|
51 |
+
|
52 |
+
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
53 |
+
x_hat = x_hat * norm_factor
|
54 |
+
|
55 |
+
x_hat = x_hat.squeeze().cpu().numpy()
|
56 |
+
x = x.squeeze().cpu().numpy()
|
57 |
+
y = y.squeeze().cpu().numpy()
|
58 |
+
|
59 |
+
_si_sdr += si_sdr(x, x_hat)
|
60 |
+
_pesq += pesq(sr, x, x_hat, 'wb')
|
61 |
+
_estoi += stoi(x, x_hat, sr, extended=True)
|
62 |
+
|
63 |
+
return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
|
64 |
+
|
sgmse/util/other.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import scipy.stats
|
5 |
+
from scipy.signal import butter, sosfilt
|
6 |
+
|
7 |
+
from pesq import pesq
|
8 |
+
from pystoi import stoi
|
9 |
+
|
10 |
+
|
11 |
+
def si_sdr_components(s_hat, s, n):
|
12 |
+
# s_target
|
13 |
+
alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
|
14 |
+
s_target = alpha_s * s
|
15 |
+
|
16 |
+
# e_noise
|
17 |
+
alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
|
18 |
+
e_noise = alpha_n * n
|
19 |
+
|
20 |
+
# e_art
|
21 |
+
e_art = s_hat - s_target - e_noise
|
22 |
+
|
23 |
+
return s_target, e_noise, e_art
|
24 |
+
|
25 |
+
def energy_ratios(s_hat, s, n):
|
26 |
+
s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
|
27 |
+
|
28 |
+
si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
|
29 |
+
si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
|
30 |
+
si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
|
31 |
+
|
32 |
+
return si_sdr, si_sir, si_sar
|
33 |
+
|
34 |
+
def mean_conf_int(data, confidence=0.95):
|
35 |
+
a = 1.0 * np.array(data)
|
36 |
+
n = len(a)
|
37 |
+
m, se = np.mean(a), scipy.stats.sem(a)
|
38 |
+
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
|
39 |
+
return m, h
|
40 |
+
|
41 |
+
class Method():
|
42 |
+
def __init__(self, name, base_dir, metrics):
|
43 |
+
self.name = name
|
44 |
+
self.base_dir = base_dir
|
45 |
+
self.metrics = {}
|
46 |
+
|
47 |
+
for i in range(len(metrics)):
|
48 |
+
metric = metrics[i]
|
49 |
+
value = []
|
50 |
+
self.metrics[metric] = value
|
51 |
+
|
52 |
+
def append(self, matric, value):
|
53 |
+
self.metrics[matric].append(value)
|
54 |
+
|
55 |
+
def get_mean_ci(self, metric):
|
56 |
+
return mean_conf_int(np.array(self.metrics[metric]))
|
57 |
+
|
58 |
+
def hp_filter(signal, cut_off=80, order=10, sr=16000):
|
59 |
+
factor = cut_off /sr * 2
|
60 |
+
sos = butter(order, factor, 'hp', output='sos')
|
61 |
+
filtered = sosfilt(sos, signal)
|
62 |
+
return filtered
|
63 |
+
|
64 |
+
def si_sdr(s, s_hat):
|
65 |
+
alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
|
66 |
+
sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
|
67 |
+
alpha*s - s_hat)**2)
|
68 |
+
return sdr
|
69 |
+
|
70 |
+
def snr_dB(s,n):
|
71 |
+
s_power = 1/len(s)*np.sum(s**2)
|
72 |
+
n_power = 1/len(n)*np.sum(n**2)
|
73 |
+
snr_dB = 10*np.log10(s_power/n_power)
|
74 |
+
return snr_dB
|
75 |
+
|
76 |
+
def pad_spec(Y, mode="zero_pad"):
|
77 |
+
T = Y.size(3)
|
78 |
+
if T%64 !=0:
|
79 |
+
num_pad = 64-T%64
|
80 |
+
else:
|
81 |
+
num_pad = 0
|
82 |
+
if mode == "zero_pad":
|
83 |
+
pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
|
84 |
+
elif mode == "reflection":
|
85 |
+
pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
|
86 |
+
elif mode == "replication":
|
87 |
+
pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
|
88 |
+
else:
|
89 |
+
raise NotImplementedError("This function hasn't been implemented yet.")
|
90 |
+
return pad2d(Y)
|
91 |
+
|
92 |
+
def ensure_dir(file_path):
|
93 |
+
directory = file_path
|
94 |
+
if not os.path.exists(directory):
|
95 |
+
os.makedirs(directory)
|
96 |
+
|
97 |
+
|
98 |
+
def print_metrics(x, y, x_hat_list, labels, sr=16000):
|
99 |
+
_si_sdr_mix = si_sdr(x, y)
|
100 |
+
_pesq_mix = pesq(sr, x, y, 'wb')
|
101 |
+
_estoi_mix = stoi(x, y, sr, extended=True)
|
102 |
+
print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
|
103 |
+
for i, x_hat in enumerate(x_hat_list):
|
104 |
+
_si_sdr = si_sdr(x, x_hat)
|
105 |
+
_pesq = pesq(sr, x, x_hat, 'wb')
|
106 |
+
_estoi = stoi(x, x_hat, sr, extended=True)
|
107 |
+
print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
|
108 |
+
|
109 |
+
def mean_std(data):
|
110 |
+
data = data[~np.isnan(data)]
|
111 |
+
mean = np.mean(data)
|
112 |
+
std = np.std(data)
|
113 |
+
return mean, std
|
114 |
+
|
115 |
+
def print_mean_std(data, decimal=2):
|
116 |
+
data = np.array(data)
|
117 |
+
data = data[~np.isnan(data)]
|
118 |
+
mean = np.mean(data)
|
119 |
+
std = np.std(data)
|
120 |
+
if decimal == 2:
|
121 |
+
string = f'{mean:.2f} ± {std:.2f}'
|
122 |
+
elif decimal == 1:
|
123 |
+
string = f'{mean:.1f} ± {std:.1f}'
|
124 |
+
return string
|
125 |
+
|
126 |
+
def set_torch_cuda_arch_list():
|
127 |
+
if not torch.cuda.is_available():
|
128 |
+
print("CUDA is not available. No GPUs found.")
|
129 |
+
return
|
130 |
+
|
131 |
+
num_gpus = torch.cuda.device_count()
|
132 |
+
compute_capabilities = []
|
133 |
+
|
134 |
+
for i in range(num_gpus):
|
135 |
+
cc_major, cc_minor = torch.cuda.get_device_capability(i)
|
136 |
+
cc = f"{cc_major}.{cc_minor}"
|
137 |
+
compute_capabilities.append(cc)
|
138 |
+
|
139 |
+
cc_string = ";".join(compute_capabilities)
|
140 |
+
os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
|
141 |
+
print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}")
|
sgmse/util/registry.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
|
5 |
+
class Registry:
|
6 |
+
def __init__(self, managed_thing: str):
|
7 |
+
"""
|
8 |
+
Create a new registry.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
managed_thing: A string describing what type of thing is managed by this registry. Will be used for
|
12 |
+
warnings and errors, so it's a good idea to keep this string globally unique and easily understood.
|
13 |
+
"""
|
14 |
+
self.managed_thing = managed_thing
|
15 |
+
self._registry = {}
|
16 |
+
|
17 |
+
def register(self, name: str) -> Callable:
|
18 |
+
def inner_wrapper(wrapped_class) -> Callable:
|
19 |
+
if name in self._registry:
|
20 |
+
warnings.warn(f"{self.managed_thing} with name '{name}' doubly registered, old class will be replaced.")
|
21 |
+
self._registry[name] = wrapped_class
|
22 |
+
return wrapped_class
|
23 |
+
return inner_wrapper
|
24 |
+
|
25 |
+
def get_by_name(self, name: str):
|
26 |
+
"""Get a managed thing by name."""
|
27 |
+
if name in self._registry:
|
28 |
+
return self._registry[name]
|
29 |
+
else:
|
30 |
+
raise ValueError(f"{self.managed_thing} with name '{name}' unknown.")
|
31 |
+
|
32 |
+
def get_all_names(self):
|
33 |
+
"""Get the list of things' names registered to this registry."""
|
34 |
+
return list(self._registry.keys())
|
sgmse/util/tensors.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def batch_broadcast(a, x):
|
2 |
+
"""Broadcasts a over all dimensions of x, except the batch dimension, which must match."""
|
3 |
+
|
4 |
+
if len(a.shape) != 1:
|
5 |
+
a = a.squeeze()
|
6 |
+
if len(a.shape) != 1:
|
7 |
+
raise ValueError(
|
8 |
+
f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})"
|
9 |
+
)
|
10 |
+
|
11 |
+
if a.shape[0] != x.shape[0] and a.shape[0] != 1:
|
12 |
+
raise ValueError(
|
13 |
+
f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching")
|
14 |
+
|
15 |
+
out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1))))
|
16 |
+
return out
|