Shokoufehhh commited on
Commit
b427b58
1 Parent(s): 9e1402a

Upload 27 files

Browse files

Adding sgmse folder

sgmse/backbones/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .shared import BackboneRegistry
2
+ from .ncsnpp import NCSNpp
3
+ from .ncsnpp_48k import NCSNpp_48k
4
+ from .dcunet import DCUNet
5
+
6
+ __all__ = ['BackboneRegistry', 'NCSNpp', 'NCSNpp_48k', 'DCUNet']
sgmse/backbones/dcunet.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import numpy as np
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ from .shared import BackboneRegistry, ComplexConv2d, ComplexConvTranspose2d, ComplexLinear, \
9
+ DiffusionStepEmbedding, GaussianFourierProjection, FeatureMapDense, torch_complex_from_reim
10
+
11
+
12
+ def get_activation(name):
13
+ if name == "silu":
14
+ return nn.SiLU
15
+ elif name == "relu":
16
+ return nn.ReLU
17
+ elif name == "leaky_relu":
18
+ return nn.LeakyReLU
19
+ else:
20
+ raise NotImplementedError(f"Unknown activation: {name}")
21
+
22
+
23
+ class BatchNorm(_BatchNorm):
24
+ def _check_input_dim(self, input):
25
+ if input.dim() < 2 or input.dim() > 4:
26
+ raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
27
+
28
+
29
+ class OnReIm(nn.Module):
30
+ def __init__(self, module_cls, *args, **kwargs):
31
+ super().__init__()
32
+ self.re_module = module_cls(*args, **kwargs)
33
+ self.im_module = module_cls(*args, **kwargs)
34
+
35
+ def forward(self, x):
36
+ return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))
37
+
38
+
39
+ # Code for DCUNet largely copied from Danilo's `informedenh` repo, cheers!
40
+
41
+ def unet_decoder_args(encoders, *, skip_connections):
42
+ """Get list of decoder arguments for upsampling (right) side of a symmetric u-net,
43
+ given the arguments used to construct the encoder.
44
+ Args:
45
+ encoders (tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding)):
46
+ List of arguments used to construct the encoders
47
+ skip_connections (bool): Whether to include skip connections in the
48
+ calculation of decoder input channels.
49
+ Return:
50
+ tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding):
51
+ Arguments to be used to construct decoders
52
+ """
53
+ decoder_args = []
54
+ for enc_in_chan, enc_out_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation in reversed(encoders):
55
+ if skip_connections and decoder_args:
56
+ skip_in_chan = enc_out_chan
57
+ else:
58
+ skip_in_chan = 0
59
+ decoder_args.append(
60
+ (enc_out_chan + skip_in_chan, enc_in_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation)
61
+ )
62
+ return tuple(decoder_args)
63
+
64
+
65
+ def make_unet_encoder_decoder_args(encoder_args, decoder_args):
66
+ encoder_args = tuple(
67
+ (
68
+ in_chan,
69
+ out_chan,
70
+ tuple(kernel_size),
71
+ tuple(stride),
72
+ tuple([n // 2 for n in kernel_size]) if padding == "auto" else tuple(padding),
73
+ tuple(dilation)
74
+ )
75
+ for in_chan, out_chan, kernel_size, stride, padding, dilation in encoder_args
76
+ )
77
+
78
+ if decoder_args == "auto":
79
+ decoder_args = unet_decoder_args(
80
+ encoder_args,
81
+ skip_connections=True,
82
+ )
83
+ else:
84
+ decoder_args = tuple(
85
+ (
86
+ in_chan,
87
+ out_chan,
88
+ tuple(kernel_size),
89
+ tuple(stride),
90
+ tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding,
91
+ tuple(dilation),
92
+ output_padding,
93
+ )
94
+ for in_chan, out_chan, kernel_size, stride, padding, dilation, output_padding in decoder_args
95
+ )
96
+
97
+ return encoder_args, decoder_args
98
+
99
+
100
+ DCUNET_ARCHITECTURES = {
101
+ "DCUNet-10": make_unet_encoder_decoder_args(
102
+ # Encoders:
103
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
104
+ (
105
+ (1, 32, (7, 5), (2, 2), "auto", (1,1)),
106
+ (32, 64, (7, 5), (2, 2), "auto", (1,1)),
107
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
108
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
109
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
110
+ ),
111
+ # Decoders: automatic inverse
112
+ "auto",
113
+ ),
114
+ "DCUNet-16": make_unet_encoder_decoder_args(
115
+ # Encoders:
116
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
117
+ (
118
+ (1, 32, (7, 5), (2, 2), "auto", (1,1)),
119
+ (32, 32, (7, 5), (2, 1), "auto", (1,1)),
120
+ (32, 64, (7, 5), (2, 2), "auto", (1,1)),
121
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
122
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
123
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
124
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
125
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
126
+ ),
127
+ # Decoders: automatic inverse
128
+ "auto",
129
+ ),
130
+ "DCUNet-20": make_unet_encoder_decoder_args(
131
+ # Encoders:
132
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
133
+ (
134
+ (1, 32, (7, 1), (1, 1), "auto", (1,1)),
135
+ (32, 32, (1, 7), (1, 1), "auto", (1,1)),
136
+ (32, 64, (7, 5), (2, 2), "auto", (1,1)),
137
+ (64, 64, (7, 5), (2, 1), "auto", (1,1)),
138
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
139
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
140
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
141
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
142
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
143
+ (64, 90, (5, 3), (2, 1), "auto", (1,1)),
144
+ ),
145
+ # Decoders: automatic inverse
146
+ "auto",
147
+ ),
148
+ "DilDCUNet-v2": make_unet_encoder_decoder_args( # architecture used in SGMSE / Interspeech paper
149
+ # Encoders:
150
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
151
+ (
152
+ (1, 32, (4, 4), (1, 1), "auto", (1, 1)),
153
+ (32, 32, (4, 4), (1, 1), "auto", (1, 1)),
154
+ (32, 32, (4, 4), (1, 1), "auto", (1, 1)),
155
+ (32, 64, (4, 4), (2, 1), "auto", (2, 1)),
156
+ (64, 128, (4, 4), (2, 2), "auto", (4, 1)),
157
+ (128, 256, (4, 4), (2, 2), "auto", (8, 1)),
158
+ ),
159
+ # Decoders: automatic inverse
160
+ "auto",
161
+ ),
162
+ }
163
+
164
+
165
+ @BackboneRegistry.register("dcunet")
166
+ class DCUNet(nn.Module):
167
+ @staticmethod
168
+ def add_argparse_args(parser):
169
+ parser.add_argument("--dcunet-architecture", type=str, default="DilDCUNet-v2", choices=DCUNET_ARCHITECTURES.keys(), help="The concrete DCUNet architecture. 'DilDCUNet-v2' by default.")
170
+ parser.add_argument("--dcunet-time-embedding", type=str, choices=("gfp", "ds", "none"), default="gfp", help="Timestep embedding style. 'gfp' (Gaussian Fourier Projections) by default.")
171
+ parser.add_argument("--dcunet-temb-layers-global", type=int, default=1, help="Number of global linear+activation layers for the time embedding. 1 by default.")
172
+ parser.add_argument("--dcunet-temb-layers-local", type=int, default=1, help="Number of local (per-encoder/per-decoder) linear+activation layers for the time embedding. 1 by default.")
173
+ parser.add_argument("--dcunet-temb-activation", type=str, default="silu", help="The (complex) activation to use between all (global&local) time embedding layers.")
174
+ parser.add_argument("--dcunet-time-embedding-complex", action="store_true", help="Use complex-valued timestep embedding. Compatible with 'gfp' and 'ds' embeddings.")
175
+ parser.add_argument("--dcunet-fix-length", type=str, default="pad", choices=("pad", "trim", "none"), help="DCUNet strategy to 'fix' mismatched input timespan. 'pad' by default.")
176
+ parser.add_argument("--dcunet-mask-bound", type=str, choices=("tanh", "sigmoid", "none"), default="none", help="DCUNet output bounding strategy. 'none' by default.")
177
+ parser.add_argument("--dcunet-norm-type", type=str, choices=("bN", "CbN"), default="bN", help="The type of norm to use within each encoder and decoder layer. 'bN' (real/imaginary separate batch norm) by default.")
178
+ parser.add_argument("--dcunet-activation", type=str, choices=("leaky_relu", "relu", "silu"), default="leaky_relu", help="The activation to use within each encoder and decoder layer. 'leaky_relu' by default.")
179
+ return parser
180
+
181
+ def __init__(
182
+ self,
183
+ dcunet_architecture: str = "DilDCUNet-v2",
184
+ dcunet_time_embedding: str = "gfp",
185
+ dcunet_temb_layers_global: int = 2,
186
+ dcunet_temb_layers_local: int = 1,
187
+ dcunet_temb_activation: str = "silu",
188
+ dcunet_time_embedding_complex: bool = False,
189
+ dcunet_fix_length: str = "pad",
190
+ dcunet_mask_bound: str = "none",
191
+ dcunet_norm_type: str = "bN",
192
+ dcunet_activation: str = "relu",
193
+ embed_dim: int = 128,
194
+ **kwargs
195
+ ):
196
+ super().__init__()
197
+
198
+ self.architecture = dcunet_architecture
199
+ self.fix_length_mode = (dcunet_fix_length if dcunet_fix_length != "none" else None)
200
+ self.norm_type = dcunet_norm_type
201
+ self.activation = dcunet_activation
202
+ self.input_channels = 2 # for x_t and y -- note that this is 2 rather than 4, because we directly treat complex channels in this DNN
203
+ self.time_embedding = (dcunet_time_embedding if dcunet_time_embedding != "none" else None)
204
+ self.time_embedding_complex = dcunet_time_embedding_complex
205
+ self.temb_layers_global = dcunet_temb_layers_global
206
+ self.temb_layers_local = dcunet_temb_layers_local
207
+ self.temb_activation = dcunet_temb_activation
208
+ conf_encoders, conf_decoders = DCUNET_ARCHITECTURES[dcunet_architecture]
209
+
210
+ # Replace `input_channels` in encoders config
211
+ _replaced_input_channels, *rest = conf_encoders[0]
212
+ encoders = ((self.input_channels, *rest), *conf_encoders[1:])
213
+ decoders = conf_decoders
214
+ self.encoders_stride_product = np.prod(
215
+ [enc_stride for _, _, _, enc_stride, _, _ in encoders], axis=0
216
+ )
217
+
218
+ # Prepare kwargs for encoder and decoder (to potentially be modified before layer instantiation)
219
+ encoder_decoder_kwargs = dict(
220
+ norm_type=self.norm_type, activation=self.activation,
221
+ temb_layers=self.temb_layers_local, temb_activation=self.temb_activation)
222
+
223
+ # Instantiate (global) time embedding layer
224
+ embed_ops = []
225
+ if self.time_embedding is not None:
226
+ complex_valued = self.time_embedding_complex
227
+ if self.time_embedding == "gfp":
228
+ embed_ops += [GaussianFourierProjection(embed_dim=embed_dim, complex_valued=complex_valued)]
229
+ encoder_decoder_kwargs["embed_dim"] = embed_dim
230
+ elif self.time_embedding == "ds":
231
+ embed_ops += [DiffusionStepEmbedding(embed_dim=embed_dim, complex_valued=complex_valued)]
232
+ encoder_decoder_kwargs["embed_dim"] = embed_dim
233
+
234
+ if self.time_embedding_complex:
235
+ assert self.time_embedding in ("gfp", "ds"), "Complex timestep embedding only available for gfp and ds"
236
+ encoder_decoder_kwargs["complex_time_embedding"] = True
237
+ for _ in range(self.temb_layers_global):
238
+ embed_ops += [
239
+ ComplexLinear(embed_dim, embed_dim, complex_valued=True),
240
+ OnReIm(get_activation(dcunet_temb_activation))
241
+ ]
242
+ self.embed = nn.Sequential(*embed_ops)
243
+
244
+ ### Instantiate DCUNet layers ###
245
+ output_layer = ComplexConvTranspose2d(*decoders[-1])
246
+ encoders = [DCUNetComplexEncoderBlock(*args, **encoder_decoder_kwargs) for args in encoders]
247
+ decoders = [DCUNetComplexDecoderBlock(*args, **encoder_decoder_kwargs) for args in decoders[:-1]]
248
+
249
+ self.mask_bound = (dcunet_mask_bound if dcunet_mask_bound != "none" else None)
250
+ if self.mask_bound is not None:
251
+ raise NotImplementedError("sorry, mask bounding not implemented at the moment")
252
+ # TODO we can't use nn.Sequential since the ComplexConvTranspose2d needs a second `output_size` argument
253
+ #operations = (output_layer, complex_nn.BoundComplexMask(self.mask_bound))
254
+ #output_layer = nn.Sequential(*[x for x in operations if x is not None])
255
+
256
+ assert len(encoders) == len(decoders) + 1
257
+ self.encoders = nn.ModuleList(encoders)
258
+ self.decoders = nn.ModuleList(decoders)
259
+ self.output_layer = output_layer or nn.Identity()
260
+
261
+ def forward(self, spec, t) -> Tensor:
262
+ """
263
+ Input shape is expected to be $(batch, nfreqs, time)$, with $nfreqs - 1$ divisible
264
+ by $f_0 * f_1 * ... * f_N$ where $f_k$ are the frequency strides of the encoders,
265
+ and $time - 1$ is divisible by $t_0 * t_1 * ... * t_N$ where $t_N$ are the time
266
+ strides of the encoders.
267
+ Args:
268
+ spec (Tensor): complex spectrogram tensor. 1D, 2D or 3D tensor, time last.
269
+ Returns:
270
+ Tensor, of shape (batch, time) or (time).
271
+ """
272
+ # TF-rep shape: (batch, self.input_channels, n_fft, frames)
273
+ # Estimate mask from time-frequency representation.
274
+ x_in = self.fix_input_dims(spec)
275
+ x = x_in
276
+ t_embed = self.embed(t+0j) if self.time_embedding is not None else None
277
+
278
+ enc_outs = []
279
+ for idx, enc in enumerate(self.encoders):
280
+ x = enc(x, t_embed)
281
+ # UNet skip connection
282
+ enc_outs.append(x)
283
+ for (enc_out, dec) in zip(reversed(enc_outs[:-1]), self.decoders):
284
+ x = dec(x, t_embed, output_size=enc_out.shape)
285
+ x = torch.cat([x, enc_out], dim=1)
286
+
287
+ output = self.output_layer(x, output_size=x_in.shape)
288
+ # output shape: (batch, 1, n_fft, frames)
289
+ output = self.fix_output_dims(output, spec)
290
+ return output
291
+
292
+ def fix_input_dims(self, x):
293
+ return _fix_dcu_input_dims(
294
+ self.fix_length_mode, x, torch.from_numpy(self.encoders_stride_product)
295
+ )
296
+
297
+ def fix_output_dims(self, out, x):
298
+ return _fix_dcu_output_dims(self.fix_length_mode, out, x)
299
+
300
+
301
+ def _fix_dcu_input_dims(fix_length_mode, x, encoders_stride_product):
302
+ """Pad or trim `x` to a length compatible with DCUNet."""
303
+ freq_prod = int(encoders_stride_product[0])
304
+ time_prod = int(encoders_stride_product[1])
305
+ if (x.shape[2] - 1) % freq_prod:
306
+ raise TypeError(
307
+ f"Input shape must be [batch, ch, freq + 1, time + 1] with freq divisible by "
308
+ f"{freq_prod}, got {x.shape} instead"
309
+ )
310
+ time_remainder = (x.shape[3] - 1) % time_prod
311
+ if time_remainder:
312
+ if fix_length_mode is None:
313
+ raise TypeError(
314
+ f"Input shape must be [batch, ch, freq + 1, time + 1] with time divisible by "
315
+ f"{time_prod}, got {x.shape} instead. Set the 'fix_length_mode' argument "
316
+ f"in 'DCUNet' to 'pad' or 'trim' to fix shapes automatically."
317
+ )
318
+ elif fix_length_mode == "pad":
319
+ pad_shape = [0, time_prod - time_remainder]
320
+ x = nn.functional.pad(x, pad_shape, mode="constant")
321
+ elif fix_length_mode == "trim":
322
+ pad_shape = [0, -time_remainder]
323
+ x = nn.functional.pad(x, pad_shape, mode="constant")
324
+ else:
325
+ raise ValueError(f"Unknown fix_length mode '{fix_length_mode}'")
326
+ return x
327
+
328
+
329
+ def _fix_dcu_output_dims(fix_length_mode, out, x):
330
+ """Fix shape of `out` to the original shape of `x` by padding/cropping."""
331
+ inp_len = x.shape[-1]
332
+ output_len = out.shape[-1]
333
+ return nn.functional.pad(out, [0, inp_len - output_len])
334
+
335
+
336
+ def _get_norm(norm_type):
337
+ if norm_type == "CbN":
338
+ return ComplexBatchNorm
339
+ elif norm_type == "bN":
340
+ return partial(OnReIm, BatchNorm)
341
+ else:
342
+ raise NotImplementedError(f"Unknown norm type: {norm_type}")
343
+
344
+
345
+ class DCUNetComplexEncoderBlock(nn.Module):
346
+ def __init__(
347
+ self,
348
+ in_chan,
349
+ out_chan,
350
+ kernel_size,
351
+ stride,
352
+ padding,
353
+ dilation,
354
+ norm_type="bN",
355
+ activation="leaky_relu",
356
+ embed_dim=None,
357
+ complex_time_embedding=False,
358
+ temb_layers=1,
359
+ temb_activation="silu"
360
+ ):
361
+ super().__init__()
362
+
363
+ self.in_chan = in_chan
364
+ self.out_chan = out_chan
365
+ self.kernel_size = kernel_size
366
+ self.stride = stride
367
+ self.padding = padding
368
+ self.dilation = dilation
369
+ self.temb_layers = temb_layers
370
+ self.temb_activation = temb_activation
371
+ self.complex_time_embedding = complex_time_embedding
372
+
373
+ self.conv = ComplexConv2d(
374
+ in_chan, out_chan, kernel_size, stride, padding, bias=norm_type is None, dilation=dilation
375
+ )
376
+ self.norm = _get_norm(norm_type)(out_chan)
377
+ self.activation = OnReIm(get_activation(activation))
378
+ self.embed_dim = embed_dim
379
+ if self.embed_dim is not None:
380
+ ops = []
381
+ for _ in range(max(0, self.temb_layers - 1)):
382
+ ops += [
383
+ ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
384
+ OnReIm(get_activation(self.temb_activation))
385
+ ]
386
+ ops += [
387
+ FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
388
+ OnReIm(get_activation(self.temb_activation))
389
+ ]
390
+ self.embed_layer = nn.Sequential(*ops)
391
+
392
+ def forward(self, x, t_embed):
393
+ y = self.conv(x)
394
+ if self.embed_dim is not None:
395
+ y = y + self.embed_layer(t_embed)
396
+ return self.activation(self.norm(y))
397
+
398
+
399
+ class DCUNetComplexDecoderBlock(nn.Module):
400
+ def __init__(
401
+ self,
402
+ in_chan,
403
+ out_chan,
404
+ kernel_size,
405
+ stride,
406
+ padding,
407
+ dilation,
408
+ output_padding=(0, 0),
409
+ norm_type="bN",
410
+ activation="leaky_relu",
411
+ embed_dim=None,
412
+ temb_layers=1,
413
+ temb_activation='swish',
414
+ complex_time_embedding=False,
415
+ ):
416
+ super().__init__()
417
+
418
+ self.in_chan = in_chan
419
+ self.out_chan = out_chan
420
+ self.kernel_size = kernel_size
421
+ self.stride = stride
422
+ self.padding = padding
423
+ self.dilation = dilation
424
+ self.output_padding = output_padding
425
+ self.complex_time_embedding = complex_time_embedding
426
+ self.temb_layers = temb_layers
427
+ self.temb_activation = temb_activation
428
+
429
+ self.deconv = ComplexConvTranspose2d(
430
+ in_chan, out_chan, kernel_size, stride, padding, output_padding, dilation=dilation, bias=norm_type is None
431
+ )
432
+ self.norm = _get_norm(norm_type)(out_chan)
433
+ self.activation = OnReIm(get_activation(activation))
434
+ self.embed_dim = embed_dim
435
+ if self.embed_dim is not None:
436
+ ops = []
437
+ for _ in range(max(0, self.temb_layers - 1)):
438
+ ops += [
439
+ ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
440
+ OnReIm(get_activation(self.temb_activation))
441
+ ]
442
+ ops += [
443
+ FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
444
+ OnReIm(get_activation(self.temb_activation))
445
+ ]
446
+ self.embed_layer = nn.Sequential(*ops)
447
+
448
+ def forward(self, x, t_embed, output_size=None):
449
+ y = self.deconv(x, output_size=output_size)
450
+ if self.embed_dim is not None:
451
+ y = y + self.embed_layer(t_embed)
452
+ return self.activation(self.norm(y))
453
+
454
+
455
+ # From https://github.com/chanil1218/DCUnet.pytorch/blob/2dcdd30804be47a866fde6435cbb7e2f81585213/models/layers/complexnn.py
456
+ class ComplexBatchNorm(torch.nn.Module):
457
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=False):
458
+ super(ComplexBatchNorm, self).__init__()
459
+ self.num_features = num_features
460
+ self.eps = eps
461
+ self.momentum = momentum
462
+ self.affine = affine
463
+ self.track_running_stats = track_running_stats
464
+ if self.affine:
465
+ self.Wrr = torch.nn.Parameter(torch.Tensor(num_features))
466
+ self.Wri = torch.nn.Parameter(torch.Tensor(num_features))
467
+ self.Wii = torch.nn.Parameter(torch.Tensor(num_features))
468
+ self.Br = torch.nn.Parameter(torch.Tensor(num_features))
469
+ self.Bi = torch.nn.Parameter(torch.Tensor(num_features))
470
+ else:
471
+ self.register_parameter('Wrr', None)
472
+ self.register_parameter('Wri', None)
473
+ self.register_parameter('Wii', None)
474
+ self.register_parameter('Br', None)
475
+ self.register_parameter('Bi', None)
476
+ if self.track_running_stats:
477
+ self.register_buffer('RMr', torch.zeros(num_features))
478
+ self.register_buffer('RMi', torch.zeros(num_features))
479
+ self.register_buffer('RVrr', torch.ones (num_features))
480
+ self.register_buffer('RVri', torch.zeros(num_features))
481
+ self.register_buffer('RVii', torch.ones (num_features))
482
+ self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
483
+ else:
484
+ self.register_parameter('RMr', None)
485
+ self.register_parameter('RMi', None)
486
+ self.register_parameter('RVrr', None)
487
+ self.register_parameter('RVri', None)
488
+ self.register_parameter('RVii', None)
489
+ self.register_parameter('num_batches_tracked', None)
490
+ self.reset_parameters()
491
+
492
+ def reset_running_stats(self):
493
+ if self.track_running_stats:
494
+ self.RMr.zero_()
495
+ self.RMi.zero_()
496
+ self.RVrr.fill_(1)
497
+ self.RVri.zero_()
498
+ self.RVii.fill_(1)
499
+ self.num_batches_tracked.zero_()
500
+
501
+ def reset_parameters(self):
502
+ self.reset_running_stats()
503
+ if self.affine:
504
+ self.Br.data.zero_()
505
+ self.Bi.data.zero_()
506
+ self.Wrr.data.fill_(1)
507
+ self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
508
+ self.Wii.data.fill_(1)
509
+
510
+ def _check_input_dim(self, xr, xi):
511
+ assert(xr.shape == xi.shape)
512
+ assert(xr.size(1) == self.num_features)
513
+
514
+ def forward(self, x):
515
+ xr, xi = x.real, x.imag
516
+ self._check_input_dim(xr, xi)
517
+
518
+ exponential_average_factor = 0.0
519
+
520
+ if self.training and self.track_running_stats:
521
+ self.num_batches_tracked += 1
522
+ if self.momentum is None: # use cumulative moving average
523
+ exponential_average_factor = 1.0 / self.num_batches_tracked.item()
524
+ else: # use exponential moving average
525
+ exponential_average_factor = self.momentum
526
+
527
+ #
528
+ # NOTE: The precise meaning of the "training flag" is:
529
+ # True: Normalize using batch statistics, update running statistics
530
+ # if they are being collected.
531
+ # False: Normalize using running statistics, ignore batch statistics.
532
+ #
533
+ training = self.training or not self.track_running_stats
534
+ redux = [i for i in reversed(range(xr.dim())) if i!=1]
535
+ vdim = [1] * xr.dim()
536
+ vdim[1] = xr.size(1)
537
+
538
+ #
539
+ # Mean M Computation and Centering
540
+ #
541
+ # Includes running mean update if training and running.
542
+ #
543
+ if training:
544
+ Mr, Mi = xr, xi
545
+ for d in redux:
546
+ Mr = Mr.mean(d, keepdim=True)
547
+ Mi = Mi.mean(d, keepdim=True)
548
+ if self.track_running_stats:
549
+ self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
550
+ self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
551
+ else:
552
+ Mr = self.RMr.view(vdim)
553
+ Mi = self.RMi.view(vdim)
554
+ xr, xi = xr-Mr, xi-Mi
555
+
556
+ #
557
+ # Variance Matrix V Computation
558
+ #
559
+ # Includes epsilon numerical stabilizer/Tikhonov regularizer.
560
+ # Includes running variance update if training and running.
561
+ #
562
+ if training:
563
+ Vrr = xr * xr
564
+ Vri = xr * xi
565
+ Vii = xi * xi
566
+ for d in redux:
567
+ Vrr = Vrr.mean(d, keepdim=True)
568
+ Vri = Vri.mean(d, keepdim=True)
569
+ Vii = Vii.mean(d, keepdim=True)
570
+ if self.track_running_stats:
571
+ self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
572
+ self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
573
+ self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
574
+ else:
575
+ Vrr = self.RVrr.view(vdim)
576
+ Vri = self.RVri.view(vdim)
577
+ Vii = self.RVii.view(vdim)
578
+ Vrr = Vrr + self.eps
579
+ Vri = Vri
580
+ Vii = Vii + self.eps
581
+
582
+ #
583
+ # Matrix Inverse Square Root U = V^-0.5
584
+ #
585
+ # sqrt of a 2x2 matrix,
586
+ # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
587
+ tau = Vrr + Vii
588
+ delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1)
589
+ s = delta.sqrt()
590
+ t = (tau + 2*s).sqrt()
591
+
592
+ # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
593
+ rst = (s * t).reciprocal()
594
+ Urr = (s + Vii) * rst
595
+ Uii = (s + Vrr) * rst
596
+ Uri = ( - Vri) * rst
597
+
598
+ #
599
+ # Optionally left-multiply U by affine weights W to produce combined
600
+ # weights Z, left-multiply the inputs by Z, then optionally bias them.
601
+ #
602
+ # y = Zx + B
603
+ # y = WUx + B
604
+ # y = [Wrr Wri][Urr Uri] [xr] + [Br]
605
+ # [Wir Wii][Uir Uii] [xi] [Bi]
606
+ #
607
+ if self.affine:
608
+ Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
609
+ Zrr = (Wrr * Urr) + (Wri * Uri)
610
+ Zri = (Wrr * Uri) + (Wri * Uii)
611
+ Zir = (Wri * Urr) + (Wii * Uri)
612
+ Zii = (Wri * Uri) + (Wii * Uii)
613
+ else:
614
+ Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
615
+
616
+ yr = (Zrr * xr) + (Zri * xi)
617
+ yi = (Zir * xr) + (Zii * xi)
618
+
619
+ if self.affine:
620
+ yr = yr + self.Br.view(vdim)
621
+ yi = yi + self.Bi.view(vdim)
622
+
623
+ return torch.view_as_complex(torch.stack([yr, yi], dim=-1))
624
+
625
+ def extra_repr(self):
626
+ return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
627
+ 'track_running_stats={track_running_stats}'.format(**self.__dict__)
sgmse/backbones/ncsnpp.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+
18
+ from .ncsnpp_utils import layers, layerspp, normalization
19
+ import torch.nn as nn
20
+ import functools
21
+ import torch
22
+ import numpy as np
23
+
24
+ from .shared import BackboneRegistry
25
+
26
+ ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
+ ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
+ Combine = layerspp.Combine
29
+ conv3x3 = layerspp.conv3x3
30
+ conv1x1 = layerspp.conv1x1
31
+ get_act = layers.get_act
32
+ get_normalization = normalization.get_normalization
33
+ default_initializer = layers.default_init
34
+
35
+
36
+ @BackboneRegistry.register("ncsnpp")
37
+ class NCSNpp(nn.Module):
38
+ """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
+
40
+ @staticmethod
41
+ def add_argparse_args(parser):
42
+ parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
43
+ parser.add_argument("--num_res_blocks", type=int, default=2)
44
+ parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16])
45
+ parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
46
+ parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
47
+ parser.set_defaults(centered=True)
48
+ return parser
49
+
50
+ def __init__(self,
51
+ scale_by_sigma = True,
52
+ nonlinearity = 'swish',
53
+ nf = 128,
54
+ ch_mult = (1, 1, 2, 2, 2, 2, 2),
55
+ num_res_blocks = 2,
56
+ attn_resolutions = (16,),
57
+ resamp_with_conv = True,
58
+ conditional = True,
59
+ fir = True,
60
+ fir_kernel = [1, 3, 3, 1],
61
+ skip_rescale = True,
62
+ resblock_type = 'biggan',
63
+ progressive = 'output_skip',
64
+ progressive_input = 'input_skip',
65
+ progressive_combine = 'sum',
66
+ init_scale = 0.,
67
+ fourier_scale = 16,
68
+ image_size = 256,
69
+ embedding_type = 'fourier',
70
+ dropout = .0,
71
+ centered = True,
72
+ **unused_kwargs
73
+ ):
74
+ super().__init__()
75
+ self.act = act = get_act(nonlinearity)
76
+
77
+ self.nf = nf = nf
78
+ ch_mult = ch_mult
79
+ self.num_res_blocks = num_res_blocks = num_res_blocks
80
+ self.attn_resolutions = attn_resolutions = attn_resolutions
81
+ dropout = dropout
82
+ resamp_with_conv = resamp_with_conv
83
+ self.num_resolutions = num_resolutions = len(ch_mult)
84
+ self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
85
+
86
+ self.conditional = conditional = conditional # noise-conditional
87
+ self.centered = centered
88
+ self.scale_by_sigma = scale_by_sigma
89
+
90
+ fir = fir
91
+ fir_kernel = fir_kernel
92
+ self.skip_rescale = skip_rescale = skip_rescale
93
+ self.resblock_type = resblock_type = resblock_type.lower()
94
+ self.progressive = progressive = progressive.lower()
95
+ self.progressive_input = progressive_input = progressive_input.lower()
96
+ self.embedding_type = embedding_type = embedding_type.lower()
97
+ init_scale = init_scale
98
+ assert progressive in ['none', 'output_skip', 'residual']
99
+ assert progressive_input in ['none', 'input_skip', 'residual']
100
+ assert embedding_type in ['fourier', 'positional']
101
+ combine_method = progressive_combine.lower()
102
+ combiner = functools.partial(Combine, method=combine_method)
103
+
104
+ num_channels = 4 # x.real, x.imag, y.real, y.imag
105
+ self.output_layer = nn.Conv2d(num_channels, 2, 1)
106
+
107
+ modules = []
108
+ # timestep/noise_level embedding
109
+ if embedding_type == 'fourier':
110
+ # Gaussian Fourier features embeddings.
111
+ modules.append(layerspp.GaussianFourierProjection(
112
+ embedding_size=nf, scale=fourier_scale
113
+ ))
114
+ embed_dim = 2 * nf
115
+ elif embedding_type == 'positional':
116
+ embed_dim = nf
117
+ else:
118
+ raise ValueError(f'embedding type {embedding_type} unknown.')
119
+
120
+ if conditional:
121
+ modules.append(nn.Linear(embed_dim, nf * 4))
122
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
123
+ nn.init.zeros_(modules[-1].bias)
124
+ modules.append(nn.Linear(nf * 4, nf * 4))
125
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
126
+ nn.init.zeros_(modules[-1].bias)
127
+
128
+ AttnBlock = functools.partial(layerspp.AttnBlockpp,
129
+ init_scale=init_scale, skip_rescale=skip_rescale)
130
+
131
+ Upsample = functools.partial(layerspp.Upsample,
132
+ with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
133
+
134
+ if progressive == 'output_skip':
135
+ self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
136
+ elif progressive == 'residual':
137
+ pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
138
+ fir_kernel=fir_kernel, with_conv=True)
139
+
140
+ Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
141
+
142
+ if progressive_input == 'input_skip':
143
+ self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
144
+ elif progressive_input == 'residual':
145
+ pyramid_downsample = functools.partial(layerspp.Downsample,
146
+ fir=fir, fir_kernel=fir_kernel, with_conv=True)
147
+
148
+ if resblock_type == 'ddpm':
149
+ ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
150
+ dropout=dropout, init_scale=init_scale,
151
+ skip_rescale=skip_rescale, temb_dim=nf * 4)
152
+
153
+ elif resblock_type == 'biggan':
154
+ ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
155
+ dropout=dropout, fir=fir, fir_kernel=fir_kernel,
156
+ init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
157
+
158
+ else:
159
+ raise ValueError(f'resblock type {resblock_type} unrecognized.')
160
+
161
+ # Downsampling block
162
+
163
+ channels = num_channels
164
+ if progressive_input != 'none':
165
+ input_pyramid_ch = channels
166
+
167
+ modules.append(conv3x3(channels, nf))
168
+ hs_c = [nf]
169
+
170
+ in_ch = nf
171
+ for i_level in range(num_resolutions):
172
+ # Residual blocks for this resolution
173
+ for i_block in range(num_res_blocks):
174
+ out_ch = nf * ch_mult[i_level]
175
+ modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
176
+ in_ch = out_ch
177
+
178
+ if all_resolutions[i_level] in attn_resolutions:
179
+ modules.append(AttnBlock(channels=in_ch))
180
+ hs_c.append(in_ch)
181
+
182
+ if i_level != num_resolutions - 1:
183
+ if resblock_type == 'ddpm':
184
+ modules.append(Downsample(in_ch=in_ch))
185
+ else:
186
+ modules.append(ResnetBlock(down=True, in_ch=in_ch))
187
+
188
+ if progressive_input == 'input_skip':
189
+ modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
190
+ if combine_method == 'cat':
191
+ in_ch *= 2
192
+
193
+ elif progressive_input == 'residual':
194
+ modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
195
+ input_pyramid_ch = in_ch
196
+
197
+ hs_c.append(in_ch)
198
+
199
+ in_ch = hs_c[-1]
200
+ modules.append(ResnetBlock(in_ch=in_ch))
201
+ modules.append(AttnBlock(channels=in_ch))
202
+ modules.append(ResnetBlock(in_ch=in_ch))
203
+
204
+ pyramid_ch = 0
205
+ # Upsampling block
206
+ for i_level in reversed(range(num_resolutions)):
207
+ for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
208
+ out_ch = nf * ch_mult[i_level]
209
+ modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
210
+ in_ch = out_ch
211
+
212
+ if all_resolutions[i_level] in attn_resolutions:
213
+ modules.append(AttnBlock(channels=in_ch))
214
+
215
+ if progressive != 'none':
216
+ if i_level == num_resolutions - 1:
217
+ if progressive == 'output_skip':
218
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
219
+ num_channels=in_ch, eps=1e-6))
220
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
221
+ pyramid_ch = channels
222
+ elif progressive == 'residual':
223
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
224
+ modules.append(conv3x3(in_ch, in_ch, bias=True))
225
+ pyramid_ch = in_ch
226
+ else:
227
+ raise ValueError(f'{progressive} is not a valid name.')
228
+ else:
229
+ if progressive == 'output_skip':
230
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
231
+ num_channels=in_ch, eps=1e-6))
232
+ modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
233
+ pyramid_ch = channels
234
+ elif progressive == 'residual':
235
+ modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
236
+ pyramid_ch = in_ch
237
+ else:
238
+ raise ValueError(f'{progressive} is not a valid name')
239
+
240
+ if i_level != 0:
241
+ if resblock_type == 'ddpm':
242
+ modules.append(Upsample(in_ch=in_ch))
243
+ else:
244
+ modules.append(ResnetBlock(in_ch=in_ch, up=True))
245
+
246
+ assert not hs_c
247
+
248
+ if progressive != 'output_skip':
249
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
250
+ num_channels=in_ch, eps=1e-6))
251
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
252
+
253
+ self.all_modules = nn.ModuleList(modules)
254
+
255
+
256
+ def forward(self, x, time_cond):
257
+ # timestep/noise_level embedding; only for continuous training
258
+ modules = self.all_modules
259
+ m_idx = 0
260
+
261
+ # Convert real and imaginary parts of (x,y) into four channel dimensions
262
+ x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
263
+ x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
264
+
265
+ if self.embedding_type == 'fourier':
266
+ # Gaussian Fourier features embeddings.
267
+ used_sigmas = time_cond
268
+ temb = modules[m_idx](torch.log(used_sigmas))
269
+ m_idx += 1
270
+
271
+ elif self.embedding_type == 'positional':
272
+ # Sinusoidal positional embeddings.
273
+ timesteps = time_cond
274
+ used_sigmas = self.sigmas[time_cond.long()]
275
+ temb = layers.get_timestep_embedding(timesteps, self.nf)
276
+
277
+ else:
278
+ raise ValueError(f'embedding type {self.embedding_type} unknown.')
279
+
280
+ if self.conditional:
281
+ temb = modules[m_idx](temb)
282
+ m_idx += 1
283
+ temb = modules[m_idx](self.act(temb))
284
+ m_idx += 1
285
+ else:
286
+ temb = None
287
+
288
+ if not self.centered:
289
+ # If input data is in [0, 1]
290
+ x = 2 * x - 1.
291
+
292
+ # Downsampling block
293
+ input_pyramid = None
294
+ if self.progressive_input != 'none':
295
+ input_pyramid = x
296
+
297
+ # Input layer: Conv2d: 4ch -> 128ch
298
+ hs = [modules[m_idx](x)]
299
+ m_idx += 1
300
+
301
+ # Down path in U-Net
302
+ for i_level in range(self.num_resolutions):
303
+ # Residual blocks for this resolution
304
+ for i_block in range(self.num_res_blocks):
305
+ h = modules[m_idx](hs[-1], temb)
306
+ m_idx += 1
307
+ # Attention layer (optional)
308
+ if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
309
+ h = modules[m_idx](h)
310
+ m_idx += 1
311
+ hs.append(h)
312
+
313
+ # Downsampling
314
+ if i_level != self.num_resolutions - 1:
315
+ if self.resblock_type == 'ddpm':
316
+ h = modules[m_idx](hs[-1])
317
+ m_idx += 1
318
+ else:
319
+ h = modules[m_idx](hs[-1], temb)
320
+ m_idx += 1
321
+
322
+ if self.progressive_input == 'input_skip': # Combine h with x
323
+ input_pyramid = self.pyramid_downsample(input_pyramid)
324
+ h = modules[m_idx](input_pyramid, h)
325
+ m_idx += 1
326
+
327
+ elif self.progressive_input == 'residual':
328
+ input_pyramid = modules[m_idx](input_pyramid)
329
+ m_idx += 1
330
+ if self.skip_rescale:
331
+ input_pyramid = (input_pyramid + h) / np.sqrt(2.)
332
+ else:
333
+ input_pyramid = input_pyramid + h
334
+ h = input_pyramid
335
+ hs.append(h)
336
+
337
+ h = hs[-1] # actualy equal to: h = h
338
+ h = modules[m_idx](h, temb) # ResNet block
339
+ m_idx += 1
340
+ h = modules[m_idx](h) # Attention block
341
+ m_idx += 1
342
+ h = modules[m_idx](h, temb) # ResNet block
343
+ m_idx += 1
344
+
345
+ pyramid = None
346
+
347
+ # Upsampling block
348
+ for i_level in reversed(range(self.num_resolutions)):
349
+ for i_block in range(self.num_res_blocks + 1):
350
+ h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
351
+ m_idx += 1
352
+
353
+ # edit: from -1 to -2
354
+ if h.shape[-2] in self.attn_resolutions:
355
+ h = modules[m_idx](h)
356
+ m_idx += 1
357
+
358
+ if self.progressive != 'none':
359
+ if i_level == self.num_resolutions - 1:
360
+ if self.progressive == 'output_skip':
361
+ pyramid = self.act(modules[m_idx](h)) # GroupNorm
362
+ m_idx += 1
363
+ pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
364
+ m_idx += 1
365
+ elif self.progressive == 'residual':
366
+ pyramid = self.act(modules[m_idx](h))
367
+ m_idx += 1
368
+ pyramid = modules[m_idx](pyramid)
369
+ m_idx += 1
370
+ else:
371
+ raise ValueError(f'{self.progressive} is not a valid name.')
372
+ else:
373
+ if self.progressive == 'output_skip':
374
+ pyramid = self.pyramid_upsample(pyramid) # Upsample
375
+ pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
376
+ m_idx += 1
377
+ pyramid_h = modules[m_idx](pyramid_h)
378
+ m_idx += 1
379
+ pyramid = pyramid + pyramid_h
380
+ elif self.progressive == 'residual':
381
+ pyramid = modules[m_idx](pyramid)
382
+ m_idx += 1
383
+ if self.skip_rescale:
384
+ pyramid = (pyramid + h) / np.sqrt(2.)
385
+ else:
386
+ pyramid = pyramid + h
387
+ h = pyramid
388
+ else:
389
+ raise ValueError(f'{self.progressive} is not a valid name')
390
+
391
+ # Upsampling Layer
392
+ if i_level != 0:
393
+ if self.resblock_type == 'ddpm':
394
+ h = modules[m_idx](h)
395
+ m_idx += 1
396
+ else:
397
+ h = modules[m_idx](h, temb) # Upspampling
398
+ m_idx += 1
399
+
400
+ assert not hs
401
+
402
+ if self.progressive == 'output_skip':
403
+ h = pyramid
404
+ else:
405
+ h = self.act(modules[m_idx](h))
406
+ m_idx += 1
407
+ h = modules[m_idx](h)
408
+ m_idx += 1
409
+
410
+ assert m_idx == len(modules), "Implementation error"
411
+ if self.scale_by_sigma:
412
+ used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
413
+ h = h / used_sigmas
414
+
415
+ # Convert back to complex number
416
+ h = self.output_layer(h)
417
+ h = torch.permute(h, (0, 2, 3, 1)).contiguous()
418
+ h = torch.view_as_complex(h)[:,None, :, :]
419
+ return h
sgmse/backbones/ncsnpp_48k.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+
18
+ from .ncsnpp_utils import layers, layerspp, normalization
19
+ import torch.nn as nn
20
+ import functools
21
+ import torch
22
+ import numpy as np
23
+
24
+ from .shared import BackboneRegistry
25
+
26
+ ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
+ ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
+ Combine = layerspp.Combine
29
+ conv3x3 = layerspp.conv3x3
30
+ conv1x1 = layerspp.conv1x1
31
+ get_act = layers.get_act
32
+ get_normalization = normalization.get_normalization
33
+ default_initializer = layers.default_init
34
+
35
+
36
+ @BackboneRegistry.register("ncsnpp_48k")
37
+ class NCSNpp_48k(nn.Module):
38
+ """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
+
40
+ @staticmethod
41
+ def add_argparse_args(parser):
42
+ parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
43
+ parser.add_argument("--num_res_blocks", type=int, default=2)
44
+ parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
45
+ parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model")
46
+ parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
47
+ parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
48
+ parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method")
49
+ parser.add_argument("--progressive_input", type=str, default='none', help="Progressive upsampling method")
50
+ parser.set_defaults(centered=True)
51
+ return parser
52
+
53
+ def __init__(self,
54
+ scale_by_sigma = True,
55
+ nonlinearity = 'swish',
56
+ nf = 128,
57
+ ch_mult = (1, 1, 2, 2, 2, 2, 2),
58
+ num_res_blocks = 2,
59
+ attn_resolutions = (),
60
+ resamp_with_conv = True,
61
+ conditional = True,
62
+ fir = True,
63
+ fir_kernel = [1, 3, 3, 1],
64
+ skip_rescale = True,
65
+ resblock_type = 'biggan',
66
+ progressive = 'none',
67
+ progressive_input = 'none',
68
+ progressive_combine = 'sum',
69
+ init_scale = 0.,
70
+ fourier_scale = 16,
71
+ image_size = 256,
72
+ embedding_type = 'fourier',
73
+ dropout = .0,
74
+ centered = True,
75
+ **unused_kwargs
76
+ ):
77
+ super().__init__()
78
+ self.act = act = get_act(nonlinearity)
79
+
80
+ self.nf = nf = nf
81
+ ch_mult = ch_mult
82
+ self.num_res_blocks = num_res_blocks = num_res_blocks
83
+ self.attn_resolutions = attn_resolutions
84
+ dropout = dropout
85
+ resamp_with_conv = resamp_with_conv
86
+ self.num_resolutions = num_resolutions = len(ch_mult)
87
+ self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
88
+
89
+ self.conditional = conditional = conditional # noise-conditional
90
+ self.centered = centered
91
+ self.scale_by_sigma = scale_by_sigma
92
+
93
+ fir = fir
94
+ fir_kernel = fir_kernel
95
+ self.skip_rescale = skip_rescale = skip_rescale
96
+ self.resblock_type = resblock_type = resblock_type.lower()
97
+ self.progressive = progressive = progressive.lower()
98
+ self.progressive_input = progressive_input = progressive_input.lower()
99
+ self.embedding_type = embedding_type = embedding_type.lower()
100
+ init_scale = init_scale
101
+ assert progressive in ['none', 'output_skip', 'residual']
102
+ assert progressive_input in ['none', 'input_skip', 'residual']
103
+ assert embedding_type in ['fourier', 'positional']
104
+ combine_method = progressive_combine.lower()
105
+ combiner = functools.partial(Combine, method=combine_method)
106
+
107
+ num_channels = 4 # x.real, x.imag, y.real, y.imag
108
+ self.output_layer = nn.Conv2d(num_channels, 2, 1)
109
+
110
+ modules = []
111
+ # timestep/noise_level embedding
112
+ if embedding_type == 'fourier':
113
+ # Gaussian Fourier features embeddings.
114
+ modules.append(layerspp.GaussianFourierProjection(
115
+ embedding_size=nf, scale=fourier_scale
116
+ ))
117
+ embed_dim = 2 * nf
118
+ elif embedding_type == 'positional':
119
+ embed_dim = nf
120
+ else:
121
+ raise ValueError(f'embedding type {embedding_type} unknown.')
122
+
123
+ if conditional:
124
+ modules.append(nn.Linear(embed_dim, nf * 4))
125
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
126
+ nn.init.zeros_(modules[-1].bias)
127
+ modules.append(nn.Linear(nf * 4, nf * 4))
128
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
129
+ nn.init.zeros_(modules[-1].bias)
130
+
131
+ AttnBlock = functools.partial(layerspp.AttnBlockpp,
132
+ init_scale=init_scale, skip_rescale=skip_rescale)
133
+
134
+ Upsample = functools.partial(layerspp.Upsample,
135
+ with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
136
+
137
+ if progressive == 'output_skip':
138
+ self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
139
+ elif progressive == 'residual':
140
+ pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
141
+ fir_kernel=fir_kernel, with_conv=True)
142
+
143
+ Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
144
+
145
+ if progressive_input == 'input_skip':
146
+ self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
147
+ elif progressive_input == 'residual':
148
+ pyramid_downsample = functools.partial(layerspp.Downsample,
149
+ fir=fir, fir_kernel=fir_kernel, with_conv=True)
150
+
151
+ if resblock_type == 'ddpm':
152
+ ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
153
+ dropout=dropout, init_scale=init_scale,
154
+ skip_rescale=skip_rescale, temb_dim=nf * 4)
155
+
156
+ elif resblock_type == 'biggan':
157
+ ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
158
+ dropout=dropout, fir=fir, fir_kernel=fir_kernel,
159
+ init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
160
+
161
+ else:
162
+ raise ValueError(f'resblock type {resblock_type} unrecognized.')
163
+
164
+ # Downsampling block
165
+
166
+ channels = num_channels
167
+ if progressive_input != 'none':
168
+ input_pyramid_ch = channels
169
+
170
+ modules.append(conv3x3(channels, nf))
171
+ hs_c = [nf]
172
+
173
+ in_ch = nf
174
+ for i_level in range(num_resolutions):
175
+ # Residual blocks for this resolution
176
+ for i_block in range(num_res_blocks):
177
+ out_ch = nf * ch_mult[i_level]
178
+ modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
179
+ in_ch = out_ch
180
+
181
+ if all_resolutions[i_level] in attn_resolutions:
182
+ modules.append(AttnBlock(channels=in_ch))
183
+ hs_c.append(in_ch)
184
+
185
+ if i_level != num_resolutions - 1:
186
+ if resblock_type == 'ddpm':
187
+ modules.append(Downsample(in_ch=in_ch))
188
+ else:
189
+ modules.append(ResnetBlock(down=True, in_ch=in_ch))
190
+
191
+ if progressive_input == 'input_skip':
192
+ modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
193
+ if combine_method == 'cat':
194
+ in_ch *= 2
195
+
196
+ elif progressive_input == 'residual':
197
+ modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
198
+ input_pyramid_ch = in_ch
199
+
200
+ hs_c.append(in_ch)
201
+
202
+ in_ch = hs_c[-1]
203
+ modules.append(ResnetBlock(in_ch=in_ch))
204
+ modules.append(AttnBlock(channels=in_ch))
205
+ modules.append(ResnetBlock(in_ch=in_ch))
206
+
207
+ pyramid_ch = 0
208
+ # Upsampling block
209
+ for i_level in reversed(range(num_resolutions)):
210
+ for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
211
+ out_ch = nf * ch_mult[i_level]
212
+ modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
213
+ in_ch = out_ch
214
+
215
+ if all_resolutions[i_level] in attn_resolutions:
216
+ modules.append(AttnBlock(channels=in_ch))
217
+
218
+ if progressive != 'none':
219
+ if i_level == num_resolutions - 1:
220
+ if progressive == 'output_skip':
221
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
222
+ num_channels=in_ch, eps=1e-6))
223
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
224
+ pyramid_ch = channels
225
+ elif progressive == 'residual':
226
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
227
+ modules.append(conv3x3(in_ch, in_ch, bias=True))
228
+ pyramid_ch = in_ch
229
+ else:
230
+ raise ValueError(f'{progressive} is not a valid name.')
231
+ else:
232
+ if progressive == 'output_skip':
233
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
234
+ num_channels=in_ch, eps=1e-6))
235
+ modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
236
+ pyramid_ch = channels
237
+ elif progressive == 'residual':
238
+ modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
239
+ pyramid_ch = in_ch
240
+ else:
241
+ raise ValueError(f'{progressive} is not a valid name')
242
+
243
+ if i_level != 0:
244
+ if resblock_type == 'ddpm':
245
+ modules.append(Upsample(in_ch=in_ch))
246
+ else:
247
+ modules.append(ResnetBlock(in_ch=in_ch, up=True))
248
+
249
+ assert not hs_c
250
+
251
+ if progressive != 'output_skip':
252
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
253
+ num_channels=in_ch, eps=1e-6))
254
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
255
+
256
+ self.all_modules = nn.ModuleList(modules)
257
+
258
+
259
+ def forward(self, x, time_cond):
260
+ # timestep/noise_level embedding; only for continuous training
261
+ modules = self.all_modules
262
+ m_idx = 0
263
+
264
+ # Convert real and imaginary parts of (x,y) into four channel dimensions
265
+ x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
266
+ x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
267
+
268
+ if self.embedding_type == 'fourier':
269
+ # Gaussian Fourier features embeddings.
270
+ used_sigmas = time_cond
271
+ temb = modules[m_idx](torch.log(used_sigmas))
272
+ m_idx += 1
273
+
274
+ elif self.embedding_type == 'positional':
275
+ # Sinusoidal positional embeddings.
276
+ timesteps = time_cond
277
+ used_sigmas = self.sigmas[time_cond.long()]
278
+ temb = layers.get_timestep_embedding(timesteps, self.nf)
279
+
280
+ else:
281
+ raise ValueError(f'embedding type {self.embedding_type} unknown.')
282
+
283
+ if self.conditional:
284
+ temb = modules[m_idx](temb)
285
+ m_idx += 1
286
+ temb = modules[m_idx](self.act(temb))
287
+ m_idx += 1
288
+ else:
289
+ temb = None
290
+
291
+ if not self.centered:
292
+ # If input data is in [0, 1]
293
+ x = 2 * x - 1.
294
+
295
+ # Downsampling block
296
+ input_pyramid = None
297
+ if self.progressive_input != 'none':
298
+ input_pyramid = x
299
+
300
+ # Input layer: Conv2d: 4ch -> 128ch
301
+ hs = [modules[m_idx](x)]
302
+ m_idx += 1
303
+
304
+ # Down path in U-Net
305
+ for i_level in range(self.num_resolutions):
306
+ # Residual blocks for this resolution
307
+ for i_block in range(self.num_res_blocks):
308
+ h = modules[m_idx](hs[-1], temb)
309
+ m_idx += 1
310
+ # Attention layer (optional)
311
+ if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
312
+ h = modules[m_idx](h)
313
+ m_idx += 1
314
+ hs.append(h)
315
+
316
+ # Downsampling
317
+ if i_level != self.num_resolutions - 1:
318
+ if self.resblock_type == 'ddpm':
319
+ h = modules[m_idx](hs[-1])
320
+ m_idx += 1
321
+ else:
322
+ h = modules[m_idx](hs[-1], temb)
323
+ m_idx += 1
324
+
325
+ if self.progressive_input == 'input_skip': # Combine h with x
326
+ input_pyramid = self.pyramid_downsample(input_pyramid)
327
+ h = modules[m_idx](input_pyramid, h)
328
+ m_idx += 1
329
+
330
+ elif self.progressive_input == 'residual':
331
+ input_pyramid = modules[m_idx](input_pyramid)
332
+ m_idx += 1
333
+ if self.skip_rescale:
334
+ input_pyramid = (input_pyramid + h) / np.sqrt(2.)
335
+ else:
336
+ input_pyramid = input_pyramid + h
337
+ h = input_pyramid
338
+ hs.append(h)
339
+
340
+ h = hs[-1] # actualy equal to: h = h
341
+ h = modules[m_idx](h, temb) # ResNet block
342
+ m_idx += 1
343
+ h = modules[m_idx](h) # Attention block
344
+ m_idx += 1
345
+ h = modules[m_idx](h, temb) # ResNet block
346
+ m_idx += 1
347
+
348
+ pyramid = None
349
+
350
+ # Upsampling block
351
+ for i_level in reversed(range(self.num_resolutions)):
352
+ for i_block in range(self.num_res_blocks + 1):
353
+ h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
354
+ m_idx += 1
355
+
356
+ # edit: from -1 to -2
357
+ if h.shape[-2] in self.attn_resolutions:
358
+ h = modules[m_idx](h)
359
+ m_idx += 1
360
+
361
+ if self.progressive != 'none':
362
+ if i_level == self.num_resolutions - 1:
363
+ if self.progressive == 'output_skip':
364
+ pyramid = self.act(modules[m_idx](h)) # GroupNorm
365
+ m_idx += 1
366
+ pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
367
+ m_idx += 1
368
+ elif self.progressive == 'residual':
369
+ pyramid = self.act(modules[m_idx](h))
370
+ m_idx += 1
371
+ pyramid = modules[m_idx](pyramid)
372
+ m_idx += 1
373
+ else:
374
+ raise ValueError(f'{self.progressive} is not a valid name.')
375
+ else:
376
+ if self.progressive == 'output_skip':
377
+ pyramid = self.pyramid_upsample(pyramid) # Upsample
378
+ pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
379
+ m_idx += 1
380
+ pyramid_h = modules[m_idx](pyramid_h)
381
+ m_idx += 1
382
+ pyramid = pyramid + pyramid_h
383
+ elif self.progressive == 'residual':
384
+ pyramid = modules[m_idx](pyramid)
385
+ m_idx += 1
386
+ if self.skip_rescale:
387
+ pyramid = (pyramid + h) / np.sqrt(2.)
388
+ else:
389
+ pyramid = pyramid + h
390
+ h = pyramid
391
+ else:
392
+ raise ValueError(f'{self.progressive} is not a valid name')
393
+
394
+ # Upsampling Layer
395
+ if i_level != 0:
396
+ if self.resblock_type == 'ddpm':
397
+ h = modules[m_idx](h)
398
+ m_idx += 1
399
+ else:
400
+ h = modules[m_idx](h, temb) # Upspampling
401
+ m_idx += 1
402
+
403
+ assert not hs
404
+
405
+ if self.progressive == 'output_skip':
406
+ h = pyramid
407
+ else:
408
+ h = self.act(modules[m_idx](h))
409
+ m_idx += 1
410
+ h = modules[m_idx](h)
411
+ m_idx += 1
412
+
413
+ assert m_idx == len(modules), "Implementation error"
414
+
415
+ # Convert back to complex number
416
+ h = self.output_layer(h)
417
+
418
+ if self.scale_by_sigma:
419
+ used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
420
+ h = h / used_sigmas
421
+
422
+ h = torch.permute(h, (0, 2, 3, 1)).contiguous()
423
+ h = torch.view_as_complex(h)[:,None, :, :]
424
+ return h
sgmse/backbones/ncsnpp_utils/layers.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ """Common layers for defining score networks.
18
+ """
19
+ import math
20
+ import string
21
+ from functools import partial
22
+ import torch.nn as nn
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import numpy as np
26
+ from .normalization import ConditionalInstanceNorm2dPlus
27
+
28
+
29
+ def get_act(config):
30
+ """Get activation functions from the config file."""
31
+
32
+ if config == 'elu':
33
+ return nn.ELU()
34
+ elif config == 'relu':
35
+ return nn.ReLU()
36
+ elif config == 'lrelu':
37
+ return nn.LeakyReLU(negative_slope=0.2)
38
+ elif config == 'swish':
39
+ return nn.SiLU()
40
+ else:
41
+ raise NotImplementedError('activation function does not exist!')
42
+
43
+
44
+ def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
45
+ """1x1 convolution. Same as NCSNv1/v2."""
46
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
47
+ padding=padding)
48
+ init_scale = 1e-10 if init_scale == 0 else init_scale
49
+ conv.weight.data *= init_scale
50
+ conv.bias.data *= init_scale
51
+ return conv
52
+
53
+
54
+ def variance_scaling(scale, mode, distribution,
55
+ in_axis=1, out_axis=0,
56
+ dtype=torch.float32,
57
+ device='cpu'):
58
+ """Ported from JAX. """
59
+
60
+ def _compute_fans(shape, in_axis=1, out_axis=0):
61
+ receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
62
+ fan_in = shape[in_axis] * receptive_field_size
63
+ fan_out = shape[out_axis] * receptive_field_size
64
+ return fan_in, fan_out
65
+
66
+ def init(shape, dtype=dtype, device=device):
67
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
68
+ if mode == "fan_in":
69
+ denominator = fan_in
70
+ elif mode == "fan_out":
71
+ denominator = fan_out
72
+ elif mode == "fan_avg":
73
+ denominator = (fan_in + fan_out) / 2
74
+ else:
75
+ raise ValueError(
76
+ "invalid mode for variance scaling initializer: {}".format(mode))
77
+ variance = scale / denominator
78
+ if distribution == "normal":
79
+ return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
80
+ elif distribution == "uniform":
81
+ return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
82
+ else:
83
+ raise ValueError("invalid distribution for variance scaling initializer")
84
+
85
+ return init
86
+
87
+
88
+ def default_init(scale=1.):
89
+ """The same initialization used in DDPM."""
90
+ scale = 1e-10 if scale == 0 else scale
91
+ return variance_scaling(scale, 'fan_avg', 'uniform')
92
+
93
+
94
+ class Dense(nn.Module):
95
+ """Linear layer with `default_init`."""
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+
100
+ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
101
+ """1x1 convolution with DDPM initialization."""
102
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
103
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
104
+ nn.init.zeros_(conv.bias)
105
+ return conv
106
+
107
+
108
+ def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
109
+ """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
110
+ init_scale = 1e-10 if init_scale == 0 else init_scale
111
+ conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
112
+ dilation=dilation, padding=padding, kernel_size=3)
113
+ conv.weight.data *= init_scale
114
+ conv.bias.data *= init_scale
115
+ return conv
116
+
117
+
118
+ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
119
+ """3x3 convolution with DDPM initialization."""
120
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
121
+ dilation=dilation, bias=bias)
122
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
123
+ nn.init.zeros_(conv.bias)
124
+ return conv
125
+
126
+ ###########################################################################
127
+ # Functions below are ported over from the NCSNv1/NCSNv2 codebase:
128
+ # https://github.com/ermongroup/ncsn
129
+ # https://github.com/ermongroup/ncsnv2
130
+ ###########################################################################
131
+
132
+
133
+ class CRPBlock(nn.Module):
134
+ def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
135
+ super().__init__()
136
+ self.convs = nn.ModuleList()
137
+ for i in range(n_stages):
138
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
139
+ self.n_stages = n_stages
140
+ if maxpool:
141
+ self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
142
+ else:
143
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
144
+
145
+ self.act = act
146
+
147
+ def forward(self, x):
148
+ x = self.act(x)
149
+ path = x
150
+ for i in range(self.n_stages):
151
+ path = self.pool(path)
152
+ path = self.convs[i](path)
153
+ x = path + x
154
+ return x
155
+
156
+
157
+ class CondCRPBlock(nn.Module):
158
+ def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
159
+ super().__init__()
160
+ self.convs = nn.ModuleList()
161
+ self.norms = nn.ModuleList()
162
+ self.normalizer = normalizer
163
+ for i in range(n_stages):
164
+ self.norms.append(normalizer(features, num_classes, bias=True))
165
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
166
+
167
+ self.n_stages = n_stages
168
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
169
+ self.act = act
170
+
171
+ def forward(self, x, y):
172
+ x = self.act(x)
173
+ path = x
174
+ for i in range(self.n_stages):
175
+ path = self.norms[i](path, y)
176
+ path = self.pool(path)
177
+ path = self.convs[i](path)
178
+
179
+ x = path + x
180
+ return x
181
+
182
+
183
+ class RCUBlock(nn.Module):
184
+ def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
185
+ super().__init__()
186
+
187
+ for i in range(n_blocks):
188
+ for j in range(n_stages):
189
+ setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
190
+
191
+ self.stride = 1
192
+ self.n_blocks = n_blocks
193
+ self.n_stages = n_stages
194
+ self.act = act
195
+
196
+ def forward(self, x):
197
+ for i in range(self.n_blocks):
198
+ residual = x
199
+ for j in range(self.n_stages):
200
+ x = self.act(x)
201
+ x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
202
+
203
+ x += residual
204
+ return x
205
+
206
+
207
+ class CondRCUBlock(nn.Module):
208
+ def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
209
+ super().__init__()
210
+
211
+ for i in range(n_blocks):
212
+ for j in range(n_stages):
213
+ setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
214
+ setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
215
+
216
+ self.stride = 1
217
+ self.n_blocks = n_blocks
218
+ self.n_stages = n_stages
219
+ self.act = act
220
+ self.normalizer = normalizer
221
+
222
+ def forward(self, x, y):
223
+ for i in range(self.n_blocks):
224
+ residual = x
225
+ for j in range(self.n_stages):
226
+ x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
227
+ x = self.act(x)
228
+ x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
229
+
230
+ x += residual
231
+ return x
232
+
233
+
234
+ class MSFBlock(nn.Module):
235
+ def __init__(self, in_planes, features):
236
+ super().__init__()
237
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
238
+ self.convs = nn.ModuleList()
239
+ self.features = features
240
+
241
+ for i in range(len(in_planes)):
242
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
243
+
244
+ def forward(self, xs, shape):
245
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
246
+ for i in range(len(self.convs)):
247
+ h = self.convs[i](xs[i])
248
+ h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
249
+ sums += h
250
+ return sums
251
+
252
+
253
+ class CondMSFBlock(nn.Module):
254
+ def __init__(self, in_planes, features, num_classes, normalizer):
255
+ super().__init__()
256
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
257
+
258
+ self.convs = nn.ModuleList()
259
+ self.norms = nn.ModuleList()
260
+ self.features = features
261
+ self.normalizer = normalizer
262
+
263
+ for i in range(len(in_planes)):
264
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
265
+ self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
266
+
267
+ def forward(self, xs, y, shape):
268
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
269
+ for i in range(len(self.convs)):
270
+ h = self.norms[i](xs[i], y)
271
+ h = self.convs[i](h)
272
+ h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
273
+ sums += h
274
+ return sums
275
+
276
+
277
+ class RefineBlock(nn.Module):
278
+ def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
279
+ super().__init__()
280
+
281
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
282
+ self.n_blocks = n_blocks = len(in_planes)
283
+
284
+ self.adapt_convs = nn.ModuleList()
285
+ for i in range(n_blocks):
286
+ self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
287
+
288
+ self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
289
+
290
+ if not start:
291
+ self.msf = MSFBlock(in_planes, features)
292
+
293
+ self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
294
+
295
+ def forward(self, xs, output_shape):
296
+ assert isinstance(xs, tuple) or isinstance(xs, list)
297
+ hs = []
298
+ for i in range(len(xs)):
299
+ h = self.adapt_convs[i](xs[i])
300
+ hs.append(h)
301
+
302
+ if self.n_blocks > 1:
303
+ h = self.msf(hs, output_shape)
304
+ else:
305
+ h = hs[0]
306
+
307
+ h = self.crp(h)
308
+ h = self.output_convs(h)
309
+
310
+ return h
311
+
312
+
313
+ class CondRefineBlock(nn.Module):
314
+ def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
315
+ super().__init__()
316
+
317
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
318
+ self.n_blocks = n_blocks = len(in_planes)
319
+
320
+ self.adapt_convs = nn.ModuleList()
321
+ for i in range(n_blocks):
322
+ self.adapt_convs.append(
323
+ CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
324
+ )
325
+
326
+ self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
327
+
328
+ if not start:
329
+ self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
330
+
331
+ self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
332
+
333
+ def forward(self, xs, y, output_shape):
334
+ assert isinstance(xs, tuple) or isinstance(xs, list)
335
+ hs = []
336
+ for i in range(len(xs)):
337
+ h = self.adapt_convs[i](xs[i], y)
338
+ hs.append(h)
339
+
340
+ if self.n_blocks > 1:
341
+ h = self.msf(hs, y, output_shape)
342
+ else:
343
+ h = hs[0]
344
+
345
+ h = self.crp(h, y)
346
+ h = self.output_convs(h, y)
347
+
348
+ return h
349
+
350
+
351
+ class ConvMeanPool(nn.Module):
352
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
353
+ super().__init__()
354
+ if not adjust_padding:
355
+ conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
356
+ self.conv = conv
357
+ else:
358
+ conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
359
+
360
+ self.conv = nn.Sequential(
361
+ nn.ZeroPad2d((1, 0, 1, 0)),
362
+ conv
363
+ )
364
+
365
+ def forward(self, inputs):
366
+ output = self.conv(inputs)
367
+ output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
368
+ output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
369
+ return output
370
+
371
+
372
+ class MeanPoolConv(nn.Module):
373
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
374
+ super().__init__()
375
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
376
+
377
+ def forward(self, inputs):
378
+ output = inputs
379
+ output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
380
+ output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
381
+ return self.conv(output)
382
+
383
+
384
+ class UpsampleConv(nn.Module):
385
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
386
+ super().__init__()
387
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
388
+ self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
389
+
390
+ def forward(self, inputs):
391
+ output = inputs
392
+ output = torch.cat([output, output, output, output], dim=1)
393
+ output = self.pixelshuffle(output)
394
+ return self.conv(output)
395
+
396
+
397
+ class ConditionalResidualBlock(nn.Module):
398
+ def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
399
+ normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
400
+ super().__init__()
401
+ self.non_linearity = act
402
+ self.input_dim = input_dim
403
+ self.output_dim = output_dim
404
+ self.resample = resample
405
+ self.normalization = normalization
406
+ if resample == 'down':
407
+ if dilation > 1:
408
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
409
+ self.normalize2 = normalization(input_dim, num_classes)
410
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
411
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
412
+ else:
413
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
414
+ self.normalize2 = normalization(input_dim, num_classes)
415
+ self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
416
+ conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
417
+
418
+ elif resample is None:
419
+ if dilation > 1:
420
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
421
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
422
+ self.normalize2 = normalization(output_dim, num_classes)
423
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
424
+ else:
425
+ conv_shortcut = nn.Conv2d
426
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
427
+ self.normalize2 = normalization(output_dim, num_classes)
428
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
429
+ else:
430
+ raise Exception('invalid resample value')
431
+
432
+ if output_dim != input_dim or resample is not None:
433
+ self.shortcut = conv_shortcut(input_dim, output_dim)
434
+
435
+ self.normalize1 = normalization(input_dim, num_classes)
436
+
437
+ def forward(self, x, y):
438
+ output = self.normalize1(x, y)
439
+ output = self.non_linearity(output)
440
+ output = self.conv1(output)
441
+ output = self.normalize2(output, y)
442
+ output = self.non_linearity(output)
443
+ output = self.conv2(output)
444
+
445
+ if self.output_dim == self.input_dim and self.resample is None:
446
+ shortcut = x
447
+ else:
448
+ shortcut = self.shortcut(x)
449
+
450
+ return shortcut + output
451
+
452
+
453
+ class ResidualBlock(nn.Module):
454
+ def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
455
+ normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
456
+ super().__init__()
457
+ self.non_linearity = act
458
+ self.input_dim = input_dim
459
+ self.output_dim = output_dim
460
+ self.resample = resample
461
+ self.normalization = normalization
462
+ if resample == 'down':
463
+ if dilation > 1:
464
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
465
+ self.normalize2 = normalization(input_dim)
466
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
467
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
468
+ else:
469
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
470
+ self.normalize2 = normalization(input_dim)
471
+ self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
472
+ conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
473
+
474
+ elif resample is None:
475
+ if dilation > 1:
476
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
477
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
478
+ self.normalize2 = normalization(output_dim)
479
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
480
+ else:
481
+ # conv_shortcut = nn.Conv2d ### Something wierd here.
482
+ conv_shortcut = partial(ncsn_conv1x1)
483
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
484
+ self.normalize2 = normalization(output_dim)
485
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
486
+ else:
487
+ raise Exception('invalid resample value')
488
+
489
+ if output_dim != input_dim or resample is not None:
490
+ self.shortcut = conv_shortcut(input_dim, output_dim)
491
+
492
+ self.normalize1 = normalization(input_dim)
493
+
494
+ def forward(self, x):
495
+ output = self.normalize1(x)
496
+ output = self.non_linearity(output)
497
+ output = self.conv1(output)
498
+ output = self.normalize2(output)
499
+ output = self.non_linearity(output)
500
+ output = self.conv2(output)
501
+
502
+ if self.output_dim == self.input_dim and self.resample is None:
503
+ shortcut = x
504
+ else:
505
+ shortcut = self.shortcut(x)
506
+
507
+ return shortcut + output
508
+
509
+
510
+ ###########################################################################
511
+ # Functions below are ported over from the DDPM codebase:
512
+ # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
513
+ ###########################################################################
514
+
515
+ def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
516
+ assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
517
+ half_dim = embedding_dim // 2
518
+ # magic number 10000 is from transformers
519
+ emb = math.log(max_positions) / (half_dim - 1)
520
+ # emb = math.log(2.) / (half_dim - 1)
521
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
522
+ # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
523
+ # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
524
+ emb = timesteps.float()[:, None] * emb[None, :]
525
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
526
+ if embedding_dim % 2 == 1: # zero pad
527
+ emb = F.pad(emb, (0, 1), mode='constant')
528
+ assert emb.shape == (timesteps.shape[0], embedding_dim)
529
+ return emb
530
+
531
+
532
+ def _einsum(a, b, c, x, y):
533
+ einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
534
+ return torch.einsum(einsum_str, x, y)
535
+
536
+
537
+ def contract_inner(x, y):
538
+ """tensordot(x, y, 1)."""
539
+ x_chars = list(string.ascii_lowercase[:len(x.shape)])
540
+ y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
541
+ y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
542
+ out_chars = x_chars[:-1] + y_chars[1:]
543
+ return _einsum(x_chars, y_chars, out_chars, x, y)
544
+
545
+
546
+ class NIN(nn.Module):
547
+ def __init__(self, in_dim, num_units, init_scale=0.1):
548
+ super().__init__()
549
+ self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
550
+ self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
551
+
552
+ def forward(self, x):
553
+ x = x.permute(0, 2, 3, 1)
554
+ y = contract_inner(x, self.W) + self.b
555
+ return y.permute(0, 3, 1, 2)
556
+
557
+
558
+ class AttnBlock(nn.Module):
559
+ """Channel-wise self-attention block."""
560
+ def __init__(self, channels):
561
+ super().__init__()
562
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
563
+ self.NIN_0 = NIN(channels, channels)
564
+ self.NIN_1 = NIN(channels, channels)
565
+ self.NIN_2 = NIN(channels, channels)
566
+ self.NIN_3 = NIN(channels, channels, init_scale=0.)
567
+
568
+ def forward(self, x):
569
+ B, C, H, W = x.shape
570
+ h = self.GroupNorm_0(x)
571
+ q = self.NIN_0(h)
572
+ k = self.NIN_1(h)
573
+ v = self.NIN_2(h)
574
+
575
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
576
+ w = torch.reshape(w, (B, H, W, H * W))
577
+ w = F.softmax(w, dim=-1)
578
+ w = torch.reshape(w, (B, H, W, H, W))
579
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
580
+ h = self.NIN_3(h)
581
+ return x + h
582
+
583
+
584
+ class Upsample(nn.Module):
585
+ def __init__(self, channels, with_conv=False):
586
+ super().__init__()
587
+ if with_conv:
588
+ self.Conv_0 = ddpm_conv3x3(channels, channels)
589
+ self.with_conv = with_conv
590
+
591
+ def forward(self, x):
592
+ B, C, H, W = x.shape
593
+ h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
594
+ if self.with_conv:
595
+ h = self.Conv_0(h)
596
+ return h
597
+
598
+
599
+ class Downsample(nn.Module):
600
+ def __init__(self, channels, with_conv=False):
601
+ super().__init__()
602
+ if with_conv:
603
+ self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
604
+ self.with_conv = with_conv
605
+
606
+ def forward(self, x):
607
+ B, C, H, W = x.shape
608
+ # Emulate 'SAME' padding
609
+ if self.with_conv:
610
+ x = F.pad(x, (0, 1, 0, 1))
611
+ x = self.Conv_0(x)
612
+ else:
613
+ x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
614
+
615
+ assert x.shape == (B, C, H // 2, W // 2)
616
+ return x
617
+
618
+
619
+ class ResnetBlockDDPM(nn.Module):
620
+ """The ResNet Blocks used in DDPM."""
621
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
622
+ super().__init__()
623
+ if out_ch is None:
624
+ out_ch = in_ch
625
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
626
+ self.act = act
627
+ self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
628
+ if temb_dim is not None:
629
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
630
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
631
+ nn.init.zeros_(self.Dense_0.bias)
632
+
633
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
634
+ self.Dropout_0 = nn.Dropout(dropout)
635
+ self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
636
+ if in_ch != out_ch:
637
+ if conv_shortcut:
638
+ self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
639
+ else:
640
+ self.NIN_0 = NIN(in_ch, out_ch)
641
+ self.out_ch = out_ch
642
+ self.in_ch = in_ch
643
+ self.conv_shortcut = conv_shortcut
644
+
645
+ def forward(self, x, temb=None):
646
+ B, C, H, W = x.shape
647
+ assert C == self.in_ch
648
+ out_ch = self.out_ch if self.out_ch else self.in_ch
649
+ h = self.act(self.GroupNorm_0(x))
650
+ h = self.Conv_0(h)
651
+ # Add bias to each feature map conditioned on the time embedding
652
+ if temb is not None:
653
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
654
+ h = self.act(self.GroupNorm_1(h))
655
+ h = self.Dropout_0(h)
656
+ h = self.Conv_1(h)
657
+ if C != out_ch:
658
+ if self.conv_shortcut:
659
+ x = self.Conv_2(x)
660
+ else:
661
+ x = self.NIN_0(x)
662
+ return x + h
sgmse/backbones/ncsnpp_utils/layerspp.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ """Layers for defining NCSN++.
18
+ """
19
+ from . import layers
20
+ from . import up_or_down_sampling
21
+ import torch.nn as nn
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import numpy as np
25
+
26
+ conv1x1 = layers.ddpm_conv1x1
27
+ conv3x3 = layers.ddpm_conv3x3
28
+ NIN = layers.NIN
29
+ default_init = layers.default_init
30
+
31
+
32
+ class GaussianFourierProjection(nn.Module):
33
+ """Gaussian Fourier embeddings for noise levels."""
34
+
35
+ def __init__(self, embedding_size=256, scale=1.0):
36
+ super().__init__()
37
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
38
+
39
+ def forward(self, x):
40
+ x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
41
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
42
+
43
+
44
+ class Combine(nn.Module):
45
+ """Combine information from skip connections."""
46
+
47
+ def __init__(self, dim1, dim2, method='cat'):
48
+ super().__init__()
49
+ self.Conv_0 = conv1x1(dim1, dim2)
50
+ self.method = method
51
+
52
+ def forward(self, x, y):
53
+ h = self.Conv_0(x)
54
+ if self.method == 'cat':
55
+ return torch.cat([h, y], dim=1)
56
+ elif self.method == 'sum':
57
+ return h + y
58
+ else:
59
+ raise ValueError(f'Method {self.method} not recognized.')
60
+
61
+
62
+ class AttnBlockpp(nn.Module):
63
+ """Channel-wise self-attention block. Modified from DDPM."""
64
+
65
+ def __init__(self, channels, skip_rescale=False, init_scale=0.):
66
+ super().__init__()
67
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
68
+ eps=1e-6)
69
+ self.NIN_0 = NIN(channels, channels)
70
+ self.NIN_1 = NIN(channels, channels)
71
+ self.NIN_2 = NIN(channels, channels)
72
+ self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
73
+ self.skip_rescale = skip_rescale
74
+
75
+ def forward(self, x):
76
+ B, C, H, W = x.shape
77
+ h = self.GroupNorm_0(x)
78
+ q = self.NIN_0(h)
79
+ k = self.NIN_1(h)
80
+ v = self.NIN_2(h)
81
+
82
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
83
+ w = torch.reshape(w, (B, H, W, H * W))
84
+ w = F.softmax(w, dim=-1)
85
+ w = torch.reshape(w, (B, H, W, H, W))
86
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
87
+ h = self.NIN_3(h)
88
+ if not self.skip_rescale:
89
+ return x + h
90
+ else:
91
+ return (x + h) / np.sqrt(2.)
92
+
93
+
94
+ class Upsample(nn.Module):
95
+ def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
96
+ fir_kernel=(1, 3, 3, 1)):
97
+ super().__init__()
98
+ out_ch = out_ch if out_ch else in_ch
99
+ if not fir:
100
+ if with_conv:
101
+ self.Conv_0 = conv3x3(in_ch, out_ch)
102
+ else:
103
+ if with_conv:
104
+ self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
105
+ kernel=3, up=True,
106
+ resample_kernel=fir_kernel,
107
+ use_bias=True,
108
+ kernel_init=default_init())
109
+ self.fir = fir
110
+ self.with_conv = with_conv
111
+ self.fir_kernel = fir_kernel
112
+ self.out_ch = out_ch
113
+
114
+ def forward(self, x):
115
+ B, C, H, W = x.shape
116
+ if not self.fir:
117
+ h = F.interpolate(x, (H * 2, W * 2), 'nearest')
118
+ if self.with_conv:
119
+ h = self.Conv_0(h)
120
+ else:
121
+ if not self.with_conv:
122
+ h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
123
+ else:
124
+ h = self.Conv2d_0(x)
125
+
126
+ return h
127
+
128
+
129
+ class Downsample(nn.Module):
130
+ def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
131
+ fir_kernel=(1, 3, 3, 1)):
132
+ super().__init__()
133
+ out_ch = out_ch if out_ch else in_ch
134
+ if not fir:
135
+ if with_conv:
136
+ self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
137
+ else:
138
+ if with_conv:
139
+ self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
140
+ kernel=3, down=True,
141
+ resample_kernel=fir_kernel,
142
+ use_bias=True,
143
+ kernel_init=default_init())
144
+ self.fir = fir
145
+ self.fir_kernel = fir_kernel
146
+ self.with_conv = with_conv
147
+ self.out_ch = out_ch
148
+
149
+ def forward(self, x):
150
+ B, C, H, W = x.shape
151
+ if not self.fir:
152
+ if self.with_conv:
153
+ x = F.pad(x, (0, 1, 0, 1))
154
+ x = self.Conv_0(x)
155
+ else:
156
+ x = F.avg_pool2d(x, 2, stride=2)
157
+ else:
158
+ if not self.with_conv:
159
+ x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
160
+ else:
161
+ x = self.Conv2d_0(x)
162
+
163
+ return x
164
+
165
+
166
+ class ResnetBlockDDPMpp(nn.Module):
167
+ """ResBlock adapted from DDPM."""
168
+
169
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
170
+ dropout=0.1, skip_rescale=False, init_scale=0.):
171
+ super().__init__()
172
+ out_ch = out_ch if out_ch else in_ch
173
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
174
+ self.Conv_0 = conv3x3(in_ch, out_ch)
175
+ if temb_dim is not None:
176
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
177
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
178
+ nn.init.zeros_(self.Dense_0.bias)
179
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
180
+ self.Dropout_0 = nn.Dropout(dropout)
181
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
182
+ if in_ch != out_ch:
183
+ if conv_shortcut:
184
+ self.Conv_2 = conv3x3(in_ch, out_ch)
185
+ else:
186
+ self.NIN_0 = NIN(in_ch, out_ch)
187
+
188
+ self.skip_rescale = skip_rescale
189
+ self.act = act
190
+ self.out_ch = out_ch
191
+ self.conv_shortcut = conv_shortcut
192
+
193
+ def forward(self, x, temb=None):
194
+ h = self.act(self.GroupNorm_0(x))
195
+ h = self.Conv_0(h)
196
+ if temb is not None:
197
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
198
+ h = self.act(self.GroupNorm_1(h))
199
+ h = self.Dropout_0(h)
200
+ h = self.Conv_1(h)
201
+ if x.shape[1] != self.out_ch:
202
+ if self.conv_shortcut:
203
+ x = self.Conv_2(x)
204
+ else:
205
+ x = self.NIN_0(x)
206
+ if not self.skip_rescale:
207
+ return x + h
208
+ else:
209
+ return (x + h) / np.sqrt(2.)
210
+
211
+
212
+ class ResnetBlockBigGANpp(nn.Module):
213
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
214
+ dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
215
+ skip_rescale=True, init_scale=0.):
216
+ super().__init__()
217
+
218
+ out_ch = out_ch if out_ch else in_ch
219
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
220
+ self.up = up
221
+ self.down = down
222
+ self.fir = fir
223
+ self.fir_kernel = fir_kernel
224
+
225
+ self.Conv_0 = conv3x3(in_ch, out_ch)
226
+ if temb_dim is not None:
227
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
228
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
229
+ nn.init.zeros_(self.Dense_0.bias)
230
+
231
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
232
+ self.Dropout_0 = nn.Dropout(dropout)
233
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
234
+ if in_ch != out_ch or up or down:
235
+ self.Conv_2 = conv1x1(in_ch, out_ch)
236
+
237
+ self.skip_rescale = skip_rescale
238
+ self.act = act
239
+ self.in_ch = in_ch
240
+ self.out_ch = out_ch
241
+
242
+ def forward(self, x, temb=None):
243
+ h = self.act(self.GroupNorm_0(x))
244
+
245
+ if self.up:
246
+ if self.fir:
247
+ h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
248
+ x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
249
+ else:
250
+ h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
251
+ x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
252
+ elif self.down:
253
+ if self.fir:
254
+ h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
255
+ x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
256
+ else:
257
+ h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
258
+ x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
259
+
260
+ h = self.Conv_0(h)
261
+ # Add bias to each feature map conditioned on the time embedding
262
+ if temb is not None:
263
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
264
+ h = self.act(self.GroupNorm_1(h))
265
+ h = self.Dropout_0(h)
266
+ h = self.Conv_1(h)
267
+
268
+ if self.in_ch != self.out_ch or self.up or self.down:
269
+ x = self.Conv_2(x)
270
+
271
+ if not self.skip_rescale:
272
+ return x + h
273
+ else:
274
+ return (x + h) / np.sqrt(2.)
sgmse/backbones/ncsnpp_utils/normalization.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Normalization layers."""
17
+ import torch.nn as nn
18
+ import torch
19
+ import functools
20
+
21
+
22
+ def get_normalization(config, conditional=False):
23
+ """Obtain normalization modules from the config file."""
24
+ norm = config.model.normalization
25
+ if conditional:
26
+ if norm == 'InstanceNorm++':
27
+ return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
28
+ else:
29
+ raise NotImplementedError(f'{norm} not implemented yet.')
30
+ else:
31
+ if norm == 'InstanceNorm':
32
+ return nn.InstanceNorm2d
33
+ elif norm == 'InstanceNorm++':
34
+ return InstanceNorm2dPlus
35
+ elif norm == 'VarianceNorm':
36
+ return VarianceNorm2d
37
+ elif norm == 'GroupNorm':
38
+ return nn.GroupNorm
39
+ else:
40
+ raise ValueError('Unknown normalization: %s' % norm)
41
+
42
+
43
+ class ConditionalBatchNorm2d(nn.Module):
44
+ def __init__(self, num_features, num_classes, bias=True):
45
+ super().__init__()
46
+ self.num_features = num_features
47
+ self.bias = bias
48
+ self.bn = nn.BatchNorm2d(num_features, affine=False)
49
+ if self.bias:
50
+ self.embed = nn.Embedding(num_classes, num_features * 2)
51
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
52
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
53
+ else:
54
+ self.embed = nn.Embedding(num_classes, num_features)
55
+ self.embed.weight.data.uniform_()
56
+
57
+ def forward(self, x, y):
58
+ out = self.bn(x)
59
+ if self.bias:
60
+ gamma, beta = self.embed(y).chunk(2, dim=1)
61
+ out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
62
+ else:
63
+ gamma = self.embed(y)
64
+ out = gamma.view(-1, self.num_features, 1, 1) * out
65
+ return out
66
+
67
+
68
+ class ConditionalInstanceNorm2d(nn.Module):
69
+ def __init__(self, num_features, num_classes, bias=True):
70
+ super().__init__()
71
+ self.num_features = num_features
72
+ self.bias = bias
73
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
74
+ if bias:
75
+ self.embed = nn.Embedding(num_classes, num_features * 2)
76
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
77
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
78
+ else:
79
+ self.embed = nn.Embedding(num_classes, num_features)
80
+ self.embed.weight.data.uniform_()
81
+
82
+ def forward(self, x, y):
83
+ h = self.instance_norm(x)
84
+ if self.bias:
85
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
86
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
87
+ else:
88
+ gamma = self.embed(y)
89
+ out = gamma.view(-1, self.num_features, 1, 1) * h
90
+ return out
91
+
92
+
93
+ class ConditionalVarianceNorm2d(nn.Module):
94
+ def __init__(self, num_features, num_classes, bias=False):
95
+ super().__init__()
96
+ self.num_features = num_features
97
+ self.bias = bias
98
+ self.embed = nn.Embedding(num_classes, num_features)
99
+ self.embed.weight.data.normal_(1, 0.02)
100
+
101
+ def forward(self, x, y):
102
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
103
+ h = x / torch.sqrt(vars + 1e-5)
104
+
105
+ gamma = self.embed(y)
106
+ out = gamma.view(-1, self.num_features, 1, 1) * h
107
+ return out
108
+
109
+
110
+ class VarianceNorm2d(nn.Module):
111
+ def __init__(self, num_features, bias=False):
112
+ super().__init__()
113
+ self.num_features = num_features
114
+ self.bias = bias
115
+ self.alpha = nn.Parameter(torch.zeros(num_features))
116
+ self.alpha.data.normal_(1, 0.02)
117
+
118
+ def forward(self, x):
119
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
120
+ h = x / torch.sqrt(vars + 1e-5)
121
+
122
+ out = self.alpha.view(-1, self.num_features, 1, 1) * h
123
+ return out
124
+
125
+
126
+ class ConditionalNoneNorm2d(nn.Module):
127
+ def __init__(self, num_features, num_classes, bias=True):
128
+ super().__init__()
129
+ self.num_features = num_features
130
+ self.bias = bias
131
+ if bias:
132
+ self.embed = nn.Embedding(num_classes, num_features * 2)
133
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
134
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
135
+ else:
136
+ self.embed = nn.Embedding(num_classes, num_features)
137
+ self.embed.weight.data.uniform_()
138
+
139
+ def forward(self, x, y):
140
+ if self.bias:
141
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
142
+ out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
143
+ else:
144
+ gamma = self.embed(y)
145
+ out = gamma.view(-1, self.num_features, 1, 1) * x
146
+ return out
147
+
148
+
149
+ class NoneNorm2d(nn.Module):
150
+ def __init__(self, num_features, bias=True):
151
+ super().__init__()
152
+
153
+ def forward(self, x):
154
+ return x
155
+
156
+
157
+ class InstanceNorm2dPlus(nn.Module):
158
+ def __init__(self, num_features, bias=True):
159
+ super().__init__()
160
+ self.num_features = num_features
161
+ self.bias = bias
162
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
163
+ self.alpha = nn.Parameter(torch.zeros(num_features))
164
+ self.gamma = nn.Parameter(torch.zeros(num_features))
165
+ self.alpha.data.normal_(1, 0.02)
166
+ self.gamma.data.normal_(1, 0.02)
167
+ if bias:
168
+ self.beta = nn.Parameter(torch.zeros(num_features))
169
+
170
+ def forward(self, x):
171
+ means = torch.mean(x, dim=(2, 3))
172
+ m = torch.mean(means, dim=-1, keepdim=True)
173
+ v = torch.var(means, dim=-1, keepdim=True)
174
+ means = (means - m) / (torch.sqrt(v + 1e-5))
175
+ h = self.instance_norm(x)
176
+
177
+ if self.bias:
178
+ h = h + means[..., None, None] * self.alpha[..., None, None]
179
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
180
+ else:
181
+ h = h + means[..., None, None] * self.alpha[..., None, None]
182
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h
183
+ return out
184
+
185
+
186
+ class ConditionalInstanceNorm2dPlus(nn.Module):
187
+ def __init__(self, num_features, num_classes, bias=True):
188
+ super().__init__()
189
+ self.num_features = num_features
190
+ self.bias = bias
191
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
192
+ if bias:
193
+ self.embed = nn.Embedding(num_classes, num_features * 3)
194
+ self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
195
+ self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
196
+ else:
197
+ self.embed = nn.Embedding(num_classes, 2 * num_features)
198
+ self.embed.weight.data.normal_(1, 0.02)
199
+
200
+ def forward(self, x, y):
201
+ means = torch.mean(x, dim=(2, 3))
202
+ m = torch.mean(means, dim=-1, keepdim=True)
203
+ v = torch.var(means, dim=-1, keepdim=True)
204
+ means = (means - m) / (torch.sqrt(v + 1e-5))
205
+ h = self.instance_norm(x)
206
+
207
+ if self.bias:
208
+ gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
209
+ h = h + means[..., None, None] * alpha[..., None, None]
210
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
211
+ else:
212
+ gamma, alpha = self.embed(y).chunk(2, dim=-1)
213
+ h = h + means[..., None, None] * alpha[..., None, None]
214
+ out = gamma.view(-1, self.num_features, 1, 1) * h
215
+ return out
sgmse/backbones/ncsnpp_utils/op/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .upfirdn2d import upfirdn2d
sgmse/backbones/ncsnpp_utils/op/fused_act.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ fused = load(
12
+ "fused",
13
+ sources=[
14
+ os.path.join(module_path, "fused_bias_act.cpp"),
15
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output, empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ grad_bias = grad_input.sum(dim).detach()
39
+
40
+ return grad_input, grad_bias
41
+
42
+ @staticmethod
43
+ def backward(ctx, gradgrad_input, gradgrad_bias):
44
+ out, = ctx.saved_tensors
45
+ gradgrad_out = fused.fused_bias_act(
46
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
47
+ )
48
+
49
+ return gradgrad_out, None, None, None
50
+
51
+
52
+ class FusedLeakyReLUFunction(Function):
53
+ @staticmethod
54
+ def forward(ctx, input, bias, negative_slope, scale):
55
+ empty = input.new_empty(0)
56
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
57
+ ctx.save_for_backward(out)
58
+ ctx.negative_slope = negative_slope
59
+ ctx.scale = scale
60
+
61
+ return out
62
+
63
+ @staticmethod
64
+ def backward(ctx, grad_output):
65
+ out, = ctx.saved_tensors
66
+
67
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
68
+ grad_output, out, ctx.negative_slope, ctx.scale
69
+ )
70
+
71
+ return grad_input, grad_bias, None, None
72
+
73
+
74
+ class FusedLeakyReLU(nn.Module):
75
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
76
+ super().__init__()
77
+
78
+ self.bias = nn.Parameter(torch.zeros(channel))
79
+ self.negative_slope = negative_slope
80
+ self.scale = scale
81
+
82
+ def forward(self, input):
83
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
84
+
85
+
86
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
87
+ if input.device.type == "cpu":
88
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
89
+ return (
90
+ F.leaky_relu(
91
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
92
+ )
93
+ * scale
94
+ )
95
+
96
+ else:
97
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
sgmse/backbones/ncsnpp_utils/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
sgmse/backbones/ncsnpp_utils/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
sgmse/backbones/ncsnpp_utils/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
+ }
sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+
11
+ if torch.cuda.is_available():
12
+ upfirdn2d_op = load(
13
+ "upfirdn2d",
14
+ sources=[
15
+ os.path.join(module_path, "upfirdn2d.cpp"),
16
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
17
+ ],
18
+ )
19
+ else:
20
+ upfirdn2d_op = None
21
+
22
+ class UpFirDn2dBackward(Function):
23
+ @staticmethod
24
+ def forward(
25
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
26
+ ):
27
+
28
+ up_x, up_y = up
29
+ down_x, down_y = down
30
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
31
+
32
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
33
+
34
+ grad_input = upfirdn2d_op.upfirdn2d(
35
+ grad_output,
36
+ grad_kernel,
37
+ down_x,
38
+ down_y,
39
+ up_x,
40
+ up_y,
41
+ g_pad_x0,
42
+ g_pad_x1,
43
+ g_pad_y0,
44
+ g_pad_y1,
45
+ )
46
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
47
+
48
+ ctx.save_for_backward(kernel)
49
+
50
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
51
+
52
+ ctx.up_x = up_x
53
+ ctx.up_y = up_y
54
+ ctx.down_x = down_x
55
+ ctx.down_y = down_y
56
+ ctx.pad_x0 = pad_x0
57
+ ctx.pad_x1 = pad_x1
58
+ ctx.pad_y0 = pad_y0
59
+ ctx.pad_y1 = pad_y1
60
+ ctx.in_size = in_size
61
+ ctx.out_size = out_size
62
+
63
+ return grad_input
64
+
65
+ @staticmethod
66
+ def backward(ctx, gradgrad_input):
67
+ kernel, = ctx.saved_tensors
68
+
69
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
70
+
71
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
72
+ gradgrad_input,
73
+ kernel,
74
+ ctx.up_x,
75
+ ctx.up_y,
76
+ ctx.down_x,
77
+ ctx.down_y,
78
+ ctx.pad_x0,
79
+ ctx.pad_x1,
80
+ ctx.pad_y0,
81
+ ctx.pad_y1,
82
+ )
83
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
84
+ gradgrad_out = gradgrad_out.view(
85
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
86
+ )
87
+
88
+ return gradgrad_out, None, None, None, None, None, None, None, None
89
+
90
+
91
+ class UpFirDn2d(Function):
92
+ @staticmethod
93
+ def forward(ctx, input, kernel, up, down, pad):
94
+ up_x, up_y = up
95
+ down_x, down_y = down
96
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
97
+
98
+ kernel_h, kernel_w = kernel.shape
99
+ batch, channel, in_h, in_w = input.shape
100
+ ctx.in_size = input.shape
101
+
102
+ input = input.reshape(-1, in_h, in_w, 1)
103
+
104
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
105
+
106
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
107
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
108
+ ctx.out_size = (out_h, out_w)
109
+
110
+ ctx.up = (up_x, up_y)
111
+ ctx.down = (down_x, down_y)
112
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
113
+
114
+ g_pad_x0 = kernel_w - pad_x0 - 1
115
+ g_pad_y0 = kernel_h - pad_y0 - 1
116
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
117
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
118
+
119
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
120
+
121
+ out = upfirdn2d_op.upfirdn2d(
122
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
123
+ )
124
+ # out = out.view(major, out_h, out_w, minor)
125
+ out = out.view(-1, channel, out_h, out_w)
126
+
127
+ return out
128
+
129
+ @staticmethod
130
+ def backward(ctx, grad_output):
131
+ kernel, grad_kernel = ctx.saved_tensors
132
+
133
+ grad_input = UpFirDn2dBackward.apply(
134
+ grad_output,
135
+ kernel,
136
+ grad_kernel,
137
+ ctx.up,
138
+ ctx.down,
139
+ ctx.pad,
140
+ ctx.g_pad,
141
+ ctx.in_size,
142
+ ctx.out_size,
143
+ )
144
+
145
+ return grad_input, None, None, None, None
146
+
147
+
148
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
149
+ if input.device.type == "cpu":
150
+ out = upfirdn2d_native(
151
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
152
+ )
153
+
154
+ else:
155
+ out = UpFirDn2d.apply(
156
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
157
+ )
158
+
159
+ return out
160
+
161
+
162
+ def upfirdn2d_native(
163
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
164
+ ):
165
+ _, channel, in_h, in_w = input.shape
166
+ input = input.reshape(-1, in_h, in_w, 1)
167
+
168
+ _, in_h, in_w, minor = input.shape
169
+ kernel_h, kernel_w = kernel.shape
170
+
171
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
172
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
173
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
174
+
175
+ out = F.pad(
176
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
177
+ )
178
+ out = out[
179
+ :,
180
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
181
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
182
+ :,
183
+ ]
184
+
185
+ out = out.permute(0, 3, 1, 2)
186
+ out = out.reshape(
187
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
188
+ )
189
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
190
+ out = F.conv2d(out, w)
191
+ out = out.reshape(
192
+ -1,
193
+ minor,
194
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
195
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
196
+ )
197
+ out = out.permute(0, 2, 3, 1)
198
+ out = out[:, ::down_y, ::down_x, :]
199
+
200
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
201
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
202
+
203
+ return out.view(-1, channel, out_h, out_w)
sgmse/backbones/ncsnpp_utils/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
+ int c = a / b;
19
+
20
+ if (c * b > a) {
21
+ c--;
22
+ }
23
+
24
+ return c;
25
+ }
26
+
27
+ struct UpFirDn2DKernelParams {
28
+ int up_x;
29
+ int up_y;
30
+ int down_x;
31
+ int down_y;
32
+ int pad_x0;
33
+ int pad_x1;
34
+ int pad_y0;
35
+ int pad_y1;
36
+
37
+ int major_dim;
38
+ int in_h;
39
+ int in_w;
40
+ int minor_dim;
41
+ int kernel_h;
42
+ int kernel_w;
43
+ int out_h;
44
+ int out_w;
45
+ int loop_major;
46
+ int loop_x;
47
+ };
48
+
49
+ template <typename scalar_t>
50
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
+ const scalar_t *kernel,
52
+ const UpFirDn2DKernelParams p) {
53
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
+ int out_y = minor_idx / p.minor_dim;
55
+ minor_idx -= out_y * p.minor_dim;
56
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
+ int major_idx_base = blockIdx.z * p.loop_major;
58
+
59
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
+ major_idx_base >= p.major_dim) {
61
+ return;
62
+ }
63
+
64
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
+
69
+ for (int loop_major = 0, major_idx = major_idx_base;
70
+ loop_major < p.loop_major && major_idx < p.major_dim;
71
+ loop_major++, major_idx++) {
72
+ for (int loop_x = 0, out_x = out_x_base;
73
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
+
79
+ const scalar_t *x_p =
80
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
+ minor_idx];
82
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
+ int x_px = p.minor_dim;
84
+ int k_px = -p.up_x;
85
+ int x_py = p.in_w * p.minor_dim;
86
+ int k_py = -p.up_y * p.kernel_w;
87
+
88
+ scalar_t v = 0.0f;
89
+
90
+ for (int y = 0; y < h; y++) {
91
+ for (int x = 0; x < w; x++) {
92
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
+ x_p += x_px;
94
+ k_p += k_px;
95
+ }
96
+
97
+ x_p += x_py - w * x_px;
98
+ k_p += k_py - w * k_px;
99
+ }
100
+
101
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
+ minor_idx] = v;
103
+ }
104
+ }
105
+ }
106
+
107
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
+ const scalar_t *kernel,
111
+ const UpFirDn2DKernelParams p) {
112
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
+
115
+ __shared__ volatile float sk[kernel_h][kernel_w];
116
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
117
+
118
+ int minor_idx = blockIdx.x;
119
+ int tile_out_y = minor_idx / p.minor_dim;
120
+ minor_idx -= tile_out_y * p.minor_dim;
121
+ tile_out_y *= tile_out_h;
122
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
+ int major_idx_base = blockIdx.z * p.loop_major;
124
+
125
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
+ major_idx_base >= p.major_dim) {
127
+ return;
128
+ }
129
+
130
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
+ tap_idx += blockDim.x) {
132
+ int ky = tap_idx / kernel_w;
133
+ int kx = tap_idx - ky * kernel_w;
134
+ scalar_t v = 0.0;
135
+
136
+ if (kx < p.kernel_w & ky < p.kernel_h) {
137
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
+ }
139
+
140
+ sk[ky][kx] = v;
141
+ }
142
+
143
+ for (int loop_major = 0, major_idx = major_idx_base;
144
+ loop_major < p.loop_major & major_idx < p.major_dim;
145
+ loop_major++, major_idx++) {
146
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
+ loop_x < p.loop_x & tile_out_x < p.out_w;
148
+ loop_x++, tile_out_x += tile_out_w) {
149
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
+ int tile_in_x = floor_div(tile_mid_x, up_x);
152
+ int tile_in_y = floor_div(tile_mid_y, up_y);
153
+
154
+ __syncthreads();
155
+
156
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
+ in_idx += blockDim.x) {
158
+ int rel_in_y = in_idx / tile_in_w;
159
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
+ int in_x = rel_in_x + tile_in_x;
161
+ int in_y = rel_in_y + tile_in_y;
162
+
163
+ scalar_t v = 0.0;
164
+
165
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
+ p.minor_dim +
168
+ minor_idx];
169
+ }
170
+
171
+ sx[rel_in_y][rel_in_x] = v;
172
+ }
173
+
174
+ __syncthreads();
175
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
+ out_idx += blockDim.x) {
177
+ int rel_out_y = out_idx / tile_out_w;
178
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
+ int out_x = rel_out_x + tile_out_x;
180
+ int out_y = rel_out_y + tile_out_y;
181
+
182
+ int mid_x = tile_mid_x + rel_out_x * down_x;
183
+ int mid_y = tile_mid_y + rel_out_y * down_y;
184
+ int in_x = floor_div(mid_x, up_x);
185
+ int in_y = floor_div(mid_y, up_y);
186
+ int rel_in_x = in_x - tile_in_x;
187
+ int rel_in_y = in_y - tile_in_y;
188
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
+
191
+ scalar_t v = 0.0;
192
+
193
+ #pragma unroll
194
+ for (int y = 0; y < kernel_h / up_y; y++)
195
+ #pragma unroll
196
+ for (int x = 0; x < kernel_w / up_x; x++)
197
+ v += sx[rel_in_y + y][rel_in_x + x] *
198
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
+
200
+ if (out_x < p.out_w & out_y < p.out_h) {
201
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
+ minor_idx] = v;
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
+ const torch::Tensor &kernel, int up_x, int up_y,
211
+ int down_x, int down_y, int pad_x0, int pad_x1,
212
+ int pad_y0, int pad_y1) {
213
+ int curDevice = -1;
214
+ cudaGetDevice(&curDevice);
215
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
216
+
217
+ UpFirDn2DKernelParams p;
218
+
219
+ auto x = input.contiguous();
220
+ auto k = kernel.contiguous();
221
+
222
+ p.major_dim = x.size(0);
223
+ p.in_h = x.size(1);
224
+ p.in_w = x.size(2);
225
+ p.minor_dim = x.size(3);
226
+ p.kernel_h = k.size(0);
227
+ p.kernel_w = k.size(1);
228
+ p.up_x = up_x;
229
+ p.up_y = up_y;
230
+ p.down_x = down_x;
231
+ p.down_y = down_y;
232
+ p.pad_x0 = pad_x0;
233
+ p.pad_x1 = pad_x1;
234
+ p.pad_y0 = pad_y0;
235
+ p.pad_y1 = pad_y1;
236
+
237
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
+ p.down_y;
239
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
+ p.down_x;
241
+
242
+ auto out =
243
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
+
245
+ int mode = -1;
246
+
247
+ int tile_out_h = -1;
248
+ int tile_out_w = -1;
249
+
250
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
252
+ mode = 1;
253
+ tile_out_h = 16;
254
+ tile_out_w = 64;
255
+ }
256
+
257
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
259
+ mode = 2;
260
+ tile_out_h = 16;
261
+ tile_out_w = 64;
262
+ }
263
+
264
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
266
+ mode = 3;
267
+ tile_out_h = 16;
268
+ tile_out_w = 64;
269
+ }
270
+
271
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
273
+ mode = 4;
274
+ tile_out_h = 16;
275
+ tile_out_w = 64;
276
+ }
277
+
278
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
280
+ mode = 5;
281
+ tile_out_h = 8;
282
+ tile_out_w = 32;
283
+ }
284
+
285
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
287
+ mode = 6;
288
+ tile_out_h = 8;
289
+ tile_out_w = 32;
290
+ }
291
+
292
+ dim3 block_size;
293
+ dim3 grid_size;
294
+
295
+ if (tile_out_h > 0 && tile_out_w > 0) {
296
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
+ p.loop_x = 1;
298
+ block_size = dim3(32 * 8, 1, 1);
299
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
+ (p.major_dim - 1) / p.loop_major + 1);
302
+ } else {
303
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
+ p.loop_x = 4;
305
+ block_size = dim3(4, 32, 1);
306
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
+ (p.major_dim - 1) / p.loop_major + 1);
309
+ }
310
+
311
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
+ switch (mode) {
313
+ case 1:
314
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
+ x.data_ptr<scalar_t>(),
317
+ k.data_ptr<scalar_t>(), p);
318
+
319
+ break;
320
+
321
+ case 2:
322
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
+ x.data_ptr<scalar_t>(),
325
+ k.data_ptr<scalar_t>(), p);
326
+
327
+ break;
328
+
329
+ case 3:
330
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
+ x.data_ptr<scalar_t>(),
333
+ k.data_ptr<scalar_t>(), p);
334
+
335
+ break;
336
+
337
+ case 4:
338
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
+ x.data_ptr<scalar_t>(),
341
+ k.data_ptr<scalar_t>(), p);
342
+
343
+ break;
344
+
345
+ case 5:
346
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
+ x.data_ptr<scalar_t>(),
349
+ k.data_ptr<scalar_t>(), p);
350
+
351
+ break;
352
+
353
+ case 6:
354
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
+ x.data_ptr<scalar_t>(),
357
+ k.data_ptr<scalar_t>(), p);
358
+
359
+ break;
360
+
361
+ default:
362
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
+ k.data_ptr<scalar_t>(), p);
365
+ }
366
+ });
367
+
368
+ return out;
369
+ }
sgmse/backbones/ncsnpp_utils/up_or_down_sampling.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Layers used for up-sampling or down-sampling images.
2
+
3
+ Many functions are ported from https://github.com/NVlabs/stylegan2.
4
+ """
5
+
6
+ import torch.nn as nn
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from .op import upfirdn2d
11
+
12
+
13
+ # Function ported from StyleGAN2
14
+ def get_weight(module,
15
+ shape,
16
+ weight_var='weight',
17
+ kernel_init=None):
18
+ """Get/create weight tensor for a convolution or fully-connected layer."""
19
+
20
+ return module.param(weight_var, kernel_init, shape)
21
+
22
+
23
+ class Conv2d(nn.Module):
24
+ """Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
25
+
26
+ def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
27
+ resample_kernel=(1, 3, 3, 1),
28
+ use_bias=True,
29
+ kernel_init=None):
30
+ super().__init__()
31
+ assert not (up and down)
32
+ assert kernel >= 1 and kernel % 2 == 1
33
+ self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
34
+ if kernel_init is not None:
35
+ self.weight.data = kernel_init(self.weight.data.shape)
36
+ if use_bias:
37
+ self.bias = nn.Parameter(torch.zeros(out_ch))
38
+
39
+ self.up = up
40
+ self.down = down
41
+ self.resample_kernel = resample_kernel
42
+ self.kernel = kernel
43
+ self.use_bias = use_bias
44
+
45
+ def forward(self, x):
46
+ if self.up:
47
+ x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
48
+ elif self.down:
49
+ x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
50
+ else:
51
+ x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
52
+
53
+ if self.use_bias:
54
+ x = x + self.bias.reshape(1, -1, 1, 1)
55
+
56
+ return x
57
+
58
+
59
+ def naive_upsample_2d(x, factor=2):
60
+ _N, C, H, W = x.shape
61
+ x = torch.reshape(x, (-1, C, H, 1, W, 1))
62
+ x = x.repeat(1, 1, 1, factor, 1, factor)
63
+ return torch.reshape(x, (-1, C, H * factor, W * factor))
64
+
65
+
66
+ def naive_downsample_2d(x, factor=2):
67
+ _N, C, H, W = x.shape
68
+ x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
69
+ return torch.mean(x, dim=(3, 5))
70
+
71
+
72
+ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
73
+ """Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
74
+
75
+ Padding is performed only once at the beginning, not between the
76
+ operations.
77
+ The fused op is considerably more efficient than performing the same
78
+ calculation
79
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
80
+ Args:
81
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
82
+ C]`.
83
+ w: Weight tensor of the shape `[filterH, filterW, inChannels,
84
+ outChannels]`. Grouped convolution can be performed by `inChannels =
85
+ x.shape[0] // numGroups`.
86
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
87
+ (separable). The default is `[1] * factor`, which corresponds to
88
+ nearest-neighbor upsampling.
89
+ factor: Integer upsampling factor (default: 2).
90
+ gain: Scaling factor for signal magnitude (default: 1.0).
91
+
92
+ Returns:
93
+ Tensor of the shape `[N, C, H * factor, W * factor]` or
94
+ `[N, H * factor, W * factor, C]`, and same datatype as `x`.
95
+ """
96
+
97
+ assert isinstance(factor, int) and factor >= 1
98
+
99
+ # Check weight shape.
100
+ assert len(w.shape) == 4
101
+ convH = w.shape[2]
102
+ convW = w.shape[3]
103
+ inC = w.shape[1]
104
+ outC = w.shape[0]
105
+
106
+ assert convW == convH
107
+
108
+ # Setup filter kernel.
109
+ if k is None:
110
+ k = [1] * factor
111
+ k = _setup_kernel(k) * (gain * (factor ** 2))
112
+ p = (k.shape[0] - factor) - (convW - 1)
113
+
114
+ stride = (factor, factor)
115
+
116
+ # Determine data dimensions.
117
+ stride = [1, 1, factor, factor]
118
+ output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
119
+ output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
120
+ output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
121
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
122
+ num_groups = _shape(x, 1) // inC
123
+
124
+ # Transpose weights.
125
+ w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
126
+ w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
127
+ w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
128
+
129
+ x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
130
+ ## Original TF code.
131
+ # x = tf.nn.conv2d_transpose(
132
+ # x,
133
+ # w,
134
+ # output_shape=output_shape,
135
+ # strides=stride,
136
+ # padding='VALID',
137
+ # data_format=data_format)
138
+ ## JAX equivalent
139
+
140
+ return upfirdn2d(x, torch.tensor(k, device=x.device),
141
+ pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
142
+
143
+
144
+ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
145
+ """Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
146
+
147
+ Padding is performed only once at the beginning, not between the operations.
148
+ The fused op is considerably more efficient than performing the same
149
+ calculation
150
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
151
+ Args:
152
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
153
+ C]`.
154
+ w: Weight tensor of the shape `[filterH, filterW, inChannels,
155
+ outChannels]`. Grouped convolution can be performed by `inChannels =
156
+ x.shape[0] // numGroups`.
157
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
158
+ (separable). The default is `[1] * factor`, which corresponds to
159
+ average pooling.
160
+ factor: Integer downsampling factor (default: 2).
161
+ gain: Scaling factor for signal magnitude (default: 1.0).
162
+
163
+ Returns:
164
+ Tensor of the shape `[N, C, H // factor, W // factor]` or
165
+ `[N, H // factor, W // factor, C]`, and same datatype as `x`.
166
+ """
167
+
168
+ assert isinstance(factor, int) and factor >= 1
169
+ _outC, _inC, convH, convW = w.shape
170
+ assert convW == convH
171
+ if k is None:
172
+ k = [1] * factor
173
+ k = _setup_kernel(k) * gain
174
+ p = (k.shape[0] - factor) + (convW - 1)
175
+ s = [factor, factor]
176
+ x = upfirdn2d(x, torch.tensor(k, device=x.device),
177
+ pad=((p + 1) // 2, p // 2))
178
+ return F.conv2d(x, w, stride=s, padding=0)
179
+
180
+
181
+ def _setup_kernel(k):
182
+ k = np.asarray(k, dtype=np.float32)
183
+ if k.ndim == 1:
184
+ k = np.outer(k, k)
185
+ k /= np.sum(k)
186
+ assert k.ndim == 2
187
+ assert k.shape[0] == k.shape[1]
188
+ return k
189
+
190
+
191
+ def _shape(x, dim):
192
+ return x.shape[dim]
193
+
194
+
195
+ def upsample_2d(x, k=None, factor=2, gain=1):
196
+ r"""Upsample a batch of 2D images with the given filter.
197
+
198
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
199
+ and upsamples each image with the given filter. The filter is normalized so
200
+ that
201
+ if the input pixels are constant, they will be scaled by the specified
202
+ `gain`.
203
+ Pixels outside the image are assumed to be zero, and the filter is padded
204
+ with
205
+ zeros so that its shape is a multiple of the upsampling factor.
206
+ Args:
207
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
208
+ C]`.
209
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
210
+ (separable). The default is `[1] * factor`, which corresponds to
211
+ nearest-neighbor upsampling.
212
+ factor: Integer upsampling factor (default: 2).
213
+ gain: Scaling factor for signal magnitude (default: 1.0).
214
+
215
+ Returns:
216
+ Tensor of the shape `[N, C, H * factor, W * factor]`
217
+ """
218
+ assert isinstance(factor, int) and factor >= 1
219
+ if k is None:
220
+ k = [1] * factor
221
+ k = _setup_kernel(k) * (gain * (factor ** 2))
222
+ p = k.shape[0] - factor
223
+ return upfirdn2d(x, torch.tensor(k, device=x.device),
224
+ up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
225
+
226
+
227
+ def downsample_2d(x, k=None, factor=2, gain=1):
228
+ r"""Downsample a batch of 2D images with the given filter.
229
+
230
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
231
+ and downsamples each image with the given filter. The filter is normalized
232
+ so that
233
+ if the input pixels are constant, they will be scaled by the specified
234
+ `gain`.
235
+ Pixels outside the image are assumed to be zero, and the filter is padded
236
+ with
237
+ zeros so that its shape is a multiple of the downsampling factor.
238
+ Args:
239
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
240
+ C]`.
241
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
242
+ (separable). The default is `[1] * factor`, which corresponds to
243
+ average pooling.
244
+ factor: Integer downsampling factor (default: 2).
245
+ gain: Scaling factor for signal magnitude (default: 1.0).
246
+
247
+ Returns:
248
+ Tensor of the shape `[N, C, H // factor, W // factor]`
249
+ """
250
+
251
+ assert isinstance(factor, int) and factor >= 1
252
+ if k is None:
253
+ k = [1] * factor
254
+ k = _setup_kernel(k) * gain
255
+ p = k.shape[0] - factor
256
+ return upfirdn2d(x, torch.tensor(k, device=x.device),
257
+ down=factor, pad=((p + 1) // 2, p // 2))
sgmse/backbones/ncsnpp_utils/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """All functions and modules related to model definition.
17
+ """
18
+
19
+ import torch
20
+
21
+ import numpy as np
22
+ from ...sdes import OUVESDE, OUVPSDE
23
+
24
+
25
+ _MODELS = {}
26
+
27
+
28
+ def register_model(cls=None, *, name=None):
29
+ """A decorator for registering model classes."""
30
+
31
+ def _register(cls):
32
+ if name is None:
33
+ local_name = cls.__name__
34
+ else:
35
+ local_name = name
36
+ if local_name in _MODELS:
37
+ raise ValueError(f'Already registered model with name: {local_name}')
38
+ _MODELS[local_name] = cls
39
+ return cls
40
+
41
+ if cls is None:
42
+ return _register
43
+ else:
44
+ return _register(cls)
45
+
46
+
47
+ def get_model(name):
48
+ return _MODELS[name]
49
+
50
+
51
+ def get_sigmas(sigma_min, sigma_max, num_scales):
52
+ """Get sigmas --- the set of noise levels for SMLD from config files.
53
+ Args:
54
+ config: A ConfigDict object parsed from the config file
55
+ Returns:
56
+ sigmas: a jax numpy arrary of noise levels
57
+ """
58
+ sigmas = np.exp(
59
+ np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
60
+
61
+ return sigmas
62
+
63
+
64
+ def get_ddpm_params(config):
65
+ """Get betas and alphas --- parameters used in the original DDPM paper."""
66
+ num_diffusion_timesteps = 1000
67
+ # parameters need to be adapted if number of time steps differs from 1000
68
+ beta_start = config.model.beta_min / config.model.num_scales
69
+ beta_end = config.model.beta_max / config.model.num_scales
70
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
71
+
72
+ alphas = 1. - betas
73
+ alphas_cumprod = np.cumprod(alphas, axis=0)
74
+ sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
75
+ sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
76
+
77
+ return {
78
+ 'betas': betas,
79
+ 'alphas': alphas,
80
+ 'alphas_cumprod': alphas_cumprod,
81
+ 'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
82
+ 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
83
+ 'beta_min': beta_start * (num_diffusion_timesteps - 1),
84
+ 'beta_max': beta_end * (num_diffusion_timesteps - 1),
85
+ 'num_diffusion_timesteps': num_diffusion_timesteps
86
+ }
87
+
88
+
89
+ def create_model(config):
90
+ """Create the score model."""
91
+ model_name = config.model.name
92
+ score_model = get_model(model_name)(config)
93
+ score_model = score_model.to(config.device)
94
+ score_model = torch.nn.DataParallel(score_model)
95
+ return score_model
96
+
97
+
98
+ def get_model_fn(model, train=False):
99
+ """Create a function to give the output of the score-based model.
100
+
101
+ Args:
102
+ model: The score model.
103
+ train: `True` for training and `False` for evaluation.
104
+
105
+ Returns:
106
+ A model function.
107
+ """
108
+
109
+ def model_fn(x, labels):
110
+ """Compute the output of the score-based model.
111
+
112
+ Args:
113
+ x: A mini-batch of input data.
114
+ labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
115
+ for different models.
116
+
117
+ Returns:
118
+ A tuple of (model output, new mutable states)
119
+ """
120
+ if not train:
121
+ model.eval()
122
+ return model(x, labels)
123
+ else:
124
+ model.train()
125
+ return model(x, labels)
126
+
127
+ return model_fn
128
+
129
+
130
+ def get_score_fn(sde, model, train=False, continuous=False):
131
+ """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
132
+
133
+ Args:
134
+ sde: An `sde_lib.SDE` object that represents the forward SDE.
135
+ model: A score model.
136
+ train: `True` for training and `False` for evaluation.
137
+ continuous: If `True`, the score-based model is expected to directly take continuous time steps.
138
+
139
+ Returns:
140
+ A score function.
141
+ """
142
+ model_fn = get_model_fn(model, train=train)
143
+
144
+ if isinstance(sde, OUVPSDE):
145
+ def score_fn(x, t):
146
+ # Scale neural network output by standard deviation and flip sign
147
+ if continuous:
148
+ # For VP-trained models, t=0 corresponds to the lowest noise level
149
+ # The maximum value of time embedding is assumed to 999 for
150
+ # continuously-trained models.
151
+ labels = t * 999
152
+ score = model_fn(x, labels)
153
+ std = sde.marginal_prob(torch.zeros_like(x), t)[1]
154
+ else:
155
+ # For VP-trained models, t=0 corresponds to the lowest noise level
156
+ labels = t * (sde.N - 1)
157
+ score = model_fn(x, labels)
158
+ std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
159
+
160
+ score = -score / std[:, None, None, None]
161
+ return score
162
+
163
+ elif isinstance(sde, OUVESDE):
164
+ def score_fn(x, t):
165
+ if continuous:
166
+ labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
167
+ else:
168
+ # For VE-trained models, t=0 corresponds to the highest noise level
169
+ labels = sde.T - t
170
+ labels *= sde.N - 1
171
+ labels = torch.round(labels).long()
172
+
173
+ score = model_fn(x, labels)
174
+ return score
175
+
176
+ else:
177
+ raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
178
+
179
+ return score_fn
180
+
181
+
182
+ def to_flattened_numpy(x):
183
+ """Flatten a torch tensor `x` and convert it to numpy."""
184
+ return x.detach().cpu().numpy().reshape((-1,))
185
+
186
+
187
+ def from_flattened_numpy(x, shape):
188
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
189
+ return torch.from_numpy(x.reshape(shape))
sgmse/backbones/shared.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from sgmse.util.registry import Registry
8
+
9
+
10
+ BackboneRegistry = Registry("Backbone")
11
+
12
+
13
+ class GaussianFourierProjection(nn.Module):
14
+ """Gaussian random features for encoding time steps."""
15
+
16
+ def __init__(self, embed_dim, scale=16, complex_valued=False):
17
+ super().__init__()
18
+ self.complex_valued = complex_valued
19
+ if not complex_valued:
20
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
21
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
22
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
23
+ # and this halving is not necessary.
24
+ embed_dim = embed_dim // 2
25
+ # Randomly sample weights during initialization. These weights are fixed
26
+ # during optimization and are not trainable.
27
+ self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
28
+
29
+ def forward(self, t):
30
+ t_proj = t[:, None] * self.W[None, :] * 2*np.pi
31
+ if self.complex_valued:
32
+ return torch.exp(1j * t_proj)
33
+ else:
34
+ return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
35
+
36
+
37
+ class DiffusionStepEmbedding(nn.Module):
38
+ """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
39
+
40
+ def __init__(self, embed_dim, complex_valued=False):
41
+ super().__init__()
42
+ self.complex_valued = complex_valued
43
+ if not complex_valued:
44
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
45
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
46
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
47
+ # and this halving is not necessary.
48
+ embed_dim = embed_dim // 2
49
+ self.embed_dim = embed_dim
50
+
51
+ def forward(self, t):
52
+ fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
53
+ inner = t[:, None] * fac[None, :]
54
+ if self.complex_valued:
55
+ return torch.exp(1j * inner)
56
+ else:
57
+ return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
58
+
59
+
60
+ class ComplexLinear(nn.Module):
61
+ """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
62
+ def __init__(self, input_dim, output_dim, complex_valued):
63
+ super().__init__()
64
+ self.complex_valued = complex_valued
65
+ if self.complex_valued:
66
+ self.re = nn.Linear(input_dim, output_dim)
67
+ self.im = nn.Linear(input_dim, output_dim)
68
+ else:
69
+ self.lin = nn.Linear(input_dim, output_dim)
70
+
71
+ def forward(self, x):
72
+ if self.complex_valued:
73
+ return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
74
+ else:
75
+ return self.lin(x)
76
+
77
+
78
+ class FeatureMapDense(nn.Module):
79
+ """A fully connected layer that reshapes outputs to feature maps."""
80
+
81
+ def __init__(self, input_dim, output_dim, complex_valued=False):
82
+ super().__init__()
83
+ self.complex_valued = complex_valued
84
+ self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
85
+
86
+ def forward(self, x):
87
+ return self.dense(x)[..., None, None]
88
+
89
+
90
+ def torch_complex_from_reim(re, im):
91
+ return torch.view_as_complex(torch.stack([re, im], dim=-1))
92
+
93
+
94
+ class ArgsComplexMultiplicationWrapper(nn.Module):
95
+ """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
96
+
97
+ Make a complex-valued module `F` from a real-valued module `f` by applying
98
+ complex multiplication rules:
99
+
100
+ F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
101
+
102
+ where `f1`, `f2` are instances of `f` that do *not* share weights.
103
+
104
+ Args:
105
+ module_cls (callable): A class or function that returns a Torch module/functional.
106
+ Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
107
+ to construct the real and imaginary component modules.
108
+ """
109
+
110
+ def __init__(self, module_cls, *args, **kwargs):
111
+ super().__init__()
112
+ self.re_module = module_cls(*args, **kwargs)
113
+ self.im_module = module_cls(*args, **kwargs)
114
+
115
+ def forward(self, x, *args, **kwargs):
116
+ return torch_complex_from_reim(
117
+ self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
118
+ self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
119
+ )
120
+
121
+
122
+ ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
123
+ ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
sgmse/data_module.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from os.path import join
3
+ import torch
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import Dataset
6
+ from torch.utils.data import DataLoader
7
+ from glob import glob
8
+ from torchaudio import load
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def get_window(window_type, window_length):
14
+ if window_type == 'sqrthann':
15
+ return torch.sqrt(torch.hann_window(window_length, periodic=True))
16
+ elif window_type == 'hann':
17
+ return torch.hann_window(window_length, periodic=True)
18
+ else:
19
+ raise NotImplementedError(f"Window type {window_type} not implemented!")
20
+
21
+
22
+ class Specs(Dataset):
23
+ def __init__(self, data_dir, subset, dummy, shuffle_spec, num_frames,
24
+ format='default', normalize="noisy", spec_transform=None,
25
+ stft_kwargs=None, **ignored_kwargs):
26
+
27
+ # Read file paths according to file naming format.
28
+ if format == "default":
29
+ self.clean_files = []
30
+ self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav")))
31
+ self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav")))
32
+ self.noisy_files = []
33
+ self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav")))
34
+ self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav")))
35
+ elif format == "reverb":
36
+ self.clean_files = []
37
+ self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav")))
38
+ self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav")))
39
+ self.noisy_files = []
40
+ self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav")))
41
+ self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav")))
42
+ else:
43
+ # Feel free to add your own directory format
44
+ raise NotImplementedError(f"Directory format {format} unknown!")
45
+
46
+ self.dummy = dummy
47
+ self.num_frames = num_frames
48
+ self.shuffle_spec = shuffle_spec
49
+ self.normalize = normalize
50
+ self.spec_transform = spec_transform
51
+
52
+ assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs"
53
+ self.stft_kwargs = stft_kwargs
54
+ self.hop_length = self.stft_kwargs["hop_length"]
55
+ assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation"
56
+
57
+ def __getitem__(self, i):
58
+ x, _ = load(self.clean_files[i])
59
+ y, _ = load(self.noisy_files[i])
60
+
61
+ # formula applies for center=True
62
+ target_len = (self.num_frames - 1) * self.hop_length
63
+ current_len = x.size(-1)
64
+ pad = max(target_len - current_len, 0)
65
+ if pad == 0:
66
+ # extract random part of the audio file
67
+ if self.shuffle_spec:
68
+ start = int(np.random.uniform(0, current_len-target_len))
69
+ else:
70
+ start = int((current_len-target_len)/2)
71
+ x = x[..., start:start+target_len]
72
+ y = y[..., start:start+target_len]
73
+ else:
74
+ # pad audio if the length T is smaller than num_frames
75
+ x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant')
76
+ y = F.pad(y, (pad//2, pad//2+(pad%2)), mode='constant')
77
+
78
+ # normalize w.r.t to the noisy or the clean signal or not at all
79
+ # to ensure same clean signal power in x and y.
80
+ if self.normalize == "noisy":
81
+ normfac = y.abs().max()
82
+ elif self.normalize == "clean":
83
+ normfac = x.abs().max()
84
+ elif self.normalize == "not":
85
+ normfac = 1.0
86
+ x = x / normfac
87
+ y = y / normfac
88
+
89
+ X = torch.stft(x, **self.stft_kwargs)
90
+ Y = torch.stft(y, **self.stft_kwargs)
91
+
92
+ X, Y = self.spec_transform(X), self.spec_transform(Y)
93
+ return X, Y
94
+
95
+ def __len__(self):
96
+ if self.dummy:
97
+ # for debugging shrink the data set size
98
+ return int(len(self.clean_files)/200)
99
+ else:
100
+ return len(self.clean_files)
101
+
102
+
103
+ class SpecsDataModule(pl.LightningDataModule):
104
+ @staticmethod
105
+ def add_argparse_args(parser):
106
+ parser.add_argument("--base_dir", type=str, required=True, help="The base directory of the dataset. Should contain `train`, `valid` and `test` subdirectories, each of which contain `clean` and `noisy` subdirectories.")
107
+ parser.add_argument("--format", type=str, choices=("default", "reverb"), default="default", help="Read file paths according to file naming format.")
108
+ parser.add_argument("--batch_size", type=int, default=8, help="The batch size. 8 by default.")
109
+ parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 256 freq bins
110
+ parser.add_argument("--hop_length", type=int, default=128, help="Window hop length. 128 by default.")
111
+ parser.add_argument("--num_frames", type=int, default=256, help="Number of frames for the dataset. 256 by default.")
112
+ parser.add_argument("--window", type=str, choices=("sqrthann", "hann"), default="hann", help="The window function to use for the STFT. 'hann' by default.")
113
+ parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to use for DataLoaders. 4 by default.")
114
+ parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.")
115
+ parser.add_argument("--spec_factor", type=float, default=0.15, help="Factor to multiply complex STFT coefficients by. 0.15 by default.")
116
+ parser.add_argument("--spec_abs_exponent", type=float, default=0.5, help="Exponent e for the transformation abs(z)**e * exp(1j*angle(z)). 0.5 by default.")
117
+ parser.add_argument("--normalize", type=str, choices=("clean", "noisy", "not"), default="noisy", help="Normalize the input waveforms by the clean signal, the noisy signal, or not at all.")
118
+ parser.add_argument("--transform_type", type=str, choices=("exponent", "log", "none"), default="exponent", help="Spectogram transformation for input representation.")
119
+ return parser
120
+
121
+ def __init__(
122
+ self, base_dir, format='default', batch_size=8,
123
+ n_fft=510, hop_length=128, num_frames=256, window='hann',
124
+ num_workers=4, dummy=False, spec_factor=0.15, spec_abs_exponent=0.5,
125
+ gpu=True, normalize='noisy', transform_type="exponent", **kwargs
126
+ ):
127
+ super().__init__()
128
+ self.base_dir = base_dir
129
+ self.format = format
130
+ self.batch_size = batch_size
131
+ self.n_fft = n_fft
132
+ self.hop_length = hop_length
133
+ self.num_frames = num_frames
134
+ self.window = get_window(window, self.n_fft)
135
+ self.windows = {}
136
+ self.num_workers = num_workers
137
+ self.dummy = dummy
138
+ self.spec_factor = spec_factor
139
+ self.spec_abs_exponent = spec_abs_exponent
140
+ self.gpu = gpu
141
+ self.normalize = normalize
142
+ self.transform_type = transform_type
143
+ self.kwargs = kwargs
144
+
145
+ def setup(self, stage=None):
146
+ specs_kwargs = dict(
147
+ stft_kwargs=self.stft_kwargs, num_frames=self.num_frames,
148
+ spec_transform=self.spec_fwd, **self.kwargs
149
+ )
150
+ if stage == 'fit' or stage is None:
151
+ self.train_set = Specs(data_dir=self.base_dir, subset='train',
152
+ dummy=self.dummy, shuffle_spec=True, format=self.format,
153
+ normalize=self.normalize, **specs_kwargs)
154
+ self.valid_set = Specs(data_dir=self.base_dir, subset='valid',
155
+ dummy=self.dummy, shuffle_spec=False, format=self.format,
156
+ normalize=self.normalize, **specs_kwargs)
157
+ if stage == 'test' or stage is None:
158
+ self.test_set = Specs(data_dir=self.base_dir, subset='test',
159
+ dummy=self.dummy, shuffle_spec=False, format=self.format,
160
+ normalize=self.normalize, **specs_kwargs)
161
+
162
+ def spec_fwd(self, spec):
163
+ if self.transform_type == "exponent":
164
+ if self.spec_abs_exponent != 1:
165
+ # only do this calculation if spec_exponent != 1, otherwise it's quite a bit of wasted computation
166
+ # and introduced numerical error
167
+ e = self.spec_abs_exponent
168
+ spec = spec.abs()**e * torch.exp(1j * spec.angle())
169
+ spec = spec * self.spec_factor
170
+ elif self.transform_type == "log":
171
+ spec = torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle())
172
+ spec = spec * self.spec_factor
173
+ elif self.transform_type == "none":
174
+ spec = spec
175
+ return spec
176
+
177
+ def spec_back(self, spec):
178
+ if self.transform_type == "exponent":
179
+ spec = spec / self.spec_factor
180
+ if self.spec_abs_exponent != 1:
181
+ e = self.spec_abs_exponent
182
+ spec = spec.abs()**(1/e) * torch.exp(1j * spec.angle())
183
+ elif self.transform_type == "log":
184
+ spec = spec / self.spec_factor
185
+ spec = (torch.exp(spec.abs()) - 1) * torch.exp(1j * spec.angle())
186
+ elif self.transform_type == "none":
187
+ spec = spec
188
+ return spec
189
+
190
+ @property
191
+ def stft_kwargs(self):
192
+ return {**self.istft_kwargs, "return_complex": True}
193
+
194
+ @property
195
+ def istft_kwargs(self):
196
+ return dict(
197
+ n_fft=self.n_fft, hop_length=self.hop_length,
198
+ window=self.window, center=True
199
+ )
200
+
201
+ def _get_window(self, x):
202
+ """
203
+ Retrieve an appropriate window for the given tensor x, matching the device.
204
+ Caches the retrieved windows so that only one window tensor will be allocated per device.
205
+ """
206
+ window = self.windows.get(x.device, None)
207
+ if window is None:
208
+ window = self.window.to(x.device)
209
+ self.windows[x.device] = window
210
+ return window
211
+
212
+ def stft(self, sig):
213
+ window = self._get_window(sig)
214
+ return torch.stft(sig, **{**self.stft_kwargs, "window": window})
215
+
216
+ def istft(self, spec, length=None):
217
+ window = self._get_window(spec)
218
+ return torch.istft(spec, **{**self.istft_kwargs, "window": window, "length": length})
219
+
220
+ def train_dataloader(self):
221
+ return DataLoader(
222
+ self.train_set, batch_size=self.batch_size,
223
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=True
224
+ )
225
+
226
+ def val_dataloader(self):
227
+ return DataLoader(
228
+ self.valid_set, batch_size=self.batch_size,
229
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
230
+ )
231
+
232
+ def test_dataloader(self):
233
+ return DataLoader(
234
+ self.test_set, batch_size=self.batch_size,
235
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
236
+ )
sgmse/model.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from math import ceil
3
+ import warnings
4
+
5
+ import torch
6
+ import pytorch_lightning as pl
7
+ from torch_ema import ExponentialMovingAverage
8
+
9
+ from sgmse import sampling
10
+ from sgmse.sdes import SDERegistry
11
+ from sgmse.backbones import BackboneRegistry
12
+ from sgmse.util.inference import evaluate_model
13
+ from sgmse.util.other import pad_spec
14
+
15
+
16
+ class ScoreModel(pl.LightningModule):
17
+ @staticmethod
18
+ def add_argparse_args(parser):
19
+ parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
20
+ parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
21
+ parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
22
+ parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
23
+ parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.")
24
+ return parser
25
+
26
+ def __init__(
27
+ self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03,
28
+ num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
29
+ ):
30
+ """
31
+ Create a new ScoreModel.
32
+
33
+ Args:
34
+ backbone: Backbone DNN that serves as a score-based model.
35
+ sde: The SDE that defines the diffusion process.
36
+ lr: The learning rate of the optimizer. (1e-4 by default).
37
+ ema_decay: The decay constant of the parameter EMA (0.999 by default).
38
+ t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
39
+ loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
40
+ """
41
+ super().__init__()
42
+ # Initialize Backbone DNN
43
+ self.backbone = backbone
44
+ dnn_cls = BackboneRegistry.get_by_name(backbone)
45
+ self.dnn = dnn_cls(**kwargs)
46
+ # Initialize SDE
47
+ sde_cls = SDERegistry.get_by_name(sde)
48
+ self.sde = sde_cls(**kwargs)
49
+ # Store hyperparams and save them
50
+ self.lr = lr
51
+ self.ema_decay = ema_decay
52
+ self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
53
+ self._error_loading_ema = False
54
+ self.t_eps = t_eps
55
+ self.loss_type = loss_type
56
+ self.num_eval_files = num_eval_files
57
+
58
+ self.save_hyperparameters(ignore=['no_wandb'])
59
+ self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
60
+
61
+ def configure_optimizers(self):
62
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
63
+ return optimizer
64
+
65
+ def optimizer_step(self, *args, **kwargs):
66
+ # Method overridden so that the EMA params are updated after each optimizer step
67
+ super().optimizer_step(*args, **kwargs)
68
+ self.ema.update(self.parameters())
69
+
70
+ # on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
71
+ def on_load_checkpoint(self, checkpoint):
72
+ ema = checkpoint.get('ema', None)
73
+ if ema is not None:
74
+ self.ema.load_state_dict(checkpoint['ema'])
75
+ else:
76
+ self._error_loading_ema = True
77
+ warnings.warn("EMA state_dict not found in checkpoint!")
78
+
79
+ def on_save_checkpoint(self, checkpoint):
80
+ checkpoint['ema'] = self.ema.state_dict()
81
+
82
+ def train(self, mode, no_ema=False):
83
+ res = super().train(mode) # call the standard `train` method with the given mode
84
+ if not self._error_loading_ema:
85
+ if mode == False and not no_ema:
86
+ # eval
87
+ self.ema.store(self.parameters()) # store current params in EMA
88
+ self.ema.copy_to(self.parameters()) # copy EMA parameters over current params for evaluation
89
+ else:
90
+ # train
91
+ if self.ema.collected_params is not None:
92
+ self.ema.restore(self.parameters()) # restore the EMA weights (if stored)
93
+ return res
94
+
95
+ def eval(self, no_ema=False):
96
+ return self.train(False, no_ema=no_ema)
97
+
98
+ def _loss(self, err):
99
+ if self.loss_type == 'mse':
100
+ losses = torch.square(err.abs())
101
+ elif self.loss_type == 'mae':
102
+ losses = err.abs()
103
+ # taken from reduce_op function: sum over channels and position and mean over batch dim
104
+ # presumably only important for absolute loss number, not for gradients
105
+ loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
106
+ return loss
107
+
108
+ def _step(self, batch, batch_idx):
109
+ x, y = batch
110
+ t = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps
111
+ mean, std = self.sde.marginal_prob(x, t, y)
112
+ z = torch.randn_like(x) # i.i.d. normal distributed with var=0.5
113
+ sigmas = std[:, None, None, None]
114
+ perturbed_data = mean + sigmas * z
115
+ score = self(perturbed_data, t, y)
116
+ err = score * sigmas + z
117
+ loss = self._loss(err)
118
+ return loss
119
+
120
+ def training_step(self, batch, batch_idx):
121
+ loss = self._step(batch, batch_idx)
122
+ self.log('train_loss', loss, on_step=True, on_epoch=True)
123
+ return loss
124
+
125
+ def validation_step(self, batch, batch_idx):
126
+ loss = self._step(batch, batch_idx)
127
+ self.log('valid_loss', loss, on_step=False, on_epoch=True)
128
+
129
+ # Evaluate speech enhancement performance
130
+ if batch_idx == 0 and self.num_eval_files != 0:
131
+ pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files)
132
+ self.log('pesq', pesq, on_step=False, on_epoch=True)
133
+ self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
134
+ self.log('estoi', estoi, on_step=False, on_epoch=True)
135
+
136
+ return loss
137
+
138
+ def forward(self, x, t, y):
139
+ # Concatenate y as an extra channel
140
+ dnn_input = torch.cat([x, y], dim=1)
141
+
142
+ # the minus is most likely unimportant here - taken from Song's repo
143
+ score = -self.dnn(dnn_input, t)
144
+ return score
145
+
146
+ def to(self, *args, **kwargs):
147
+ """Override PyTorch .to() to also transfer the EMA of the model weights"""
148
+ self.ema.to(*args, **kwargs)
149
+ return super().to(*args, **kwargs)
150
+
151
+ def get_pc_sampler(self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs):
152
+ N = self.sde.N if N is None else N
153
+ sde = self.sde.copy()
154
+ sde.N = N
155
+
156
+ kwargs = {"eps": self.t_eps, **kwargs}
157
+ if minibatch is None:
158
+ return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y, **kwargs)
159
+ else:
160
+ M = y.shape[0]
161
+ def batched_sampling_fn():
162
+ samples, ns = [], []
163
+ for i in range(int(ceil(M / minibatch))):
164
+ y_mini = y[i*minibatch:(i+1)*minibatch]
165
+ sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y_mini, **kwargs)
166
+ sample, n = sampler()
167
+ samples.append(sample)
168
+ ns.append(n)
169
+ samples = torch.cat(samples, dim=0)
170
+ return samples, ns
171
+ return batched_sampling_fn
172
+
173
+ def get_ode_sampler(self, y, N=None, minibatch=None, **kwargs):
174
+ N = self.sde.N if N is None else N
175
+ sde = self.sde.copy()
176
+ sde.N = N
177
+
178
+ kwargs = {"eps": self.t_eps, **kwargs}
179
+ if minibatch is None:
180
+ return sampling.get_ode_sampler(sde, self, y=y, **kwargs)
181
+ else:
182
+ M = y.shape[0]
183
+ def batched_sampling_fn():
184
+ samples, ns = [], []
185
+ for i in range(int(ceil(M / minibatch))):
186
+ y_mini = y[i*minibatch:(i+1)*minibatch]
187
+ sampler = sampling.get_ode_sampler(sde, self, y=y_mini, **kwargs)
188
+ sample, n = sampler()
189
+ samples.append(sample)
190
+ ns.append(n)
191
+ samples = torch.cat(samples, dim=0)
192
+ return sample, ns
193
+ return batched_sampling_fn
194
+
195
+ def train_dataloader(self):
196
+ return self.data_module.train_dataloader()
197
+
198
+ def val_dataloader(self):
199
+ return self.data_module.val_dataloader()
200
+
201
+ def test_dataloader(self):
202
+ return self.data_module.test_dataloader()
203
+
204
+ def setup(self, stage=None):
205
+ return self.data_module.setup(stage=stage)
206
+
207
+ def to_audio(self, spec, length=None):
208
+ return self._istft(self._backward_transform(spec), length)
209
+
210
+ def _forward_transform(self, spec):
211
+ return self.data_module.spec_fwd(spec)
212
+
213
+ def _backward_transform(self, spec):
214
+ return self.data_module.spec_back(spec)
215
+
216
+ def _stft(self, sig):
217
+ return self.data_module.stft(sig)
218
+
219
+ def _istft(self, spec, length=None):
220
+ return self.data_module.istft(spec, length)
221
+
222
+ def enhance(self, y, sampler_type="pc", predictor="reverse_diffusion",
223
+ corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False,
224
+ **kwargs
225
+ ):
226
+ """
227
+ One-call speech enhancement of noisy speech `y`, for convenience.
228
+ """
229
+ sr=16000
230
+ start = time.time()
231
+ T_orig = y.size(1)
232
+ norm_factor = y.abs().max().item()
233
+ y = y / norm_factor
234
+ Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0)
235
+ Y = pad_spec(Y)
236
+ if sampler_type == "pc":
237
+ sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), N=N,
238
+ corrector_steps=corrector_steps, snr=snr, intermediate=False,
239
+ **kwargs)
240
+ elif sampler_type == "ode":
241
+ sampler = self.get_ode_sampler(Y.cuda(), N=N, **kwargs)
242
+ else:
243
+ print("{} is not a valid sampler type!".format(sampler_type))
244
+ sample, nfe = sampler()
245
+ x_hat = self.to_audio(sample.squeeze(), T_orig)
246
+ x_hat = x_hat * norm_factor
247
+ x_hat = x_hat.squeeze().cpu().numpy()
248
+ end = time.time()
249
+ if timeit:
250
+ rtf = (end-start)/(len(x_hat)/sr)
251
+ return x_hat, nfe, rtf
252
+ else:
253
+ return x_hat
sgmse/sampling/__init__.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py
2
+ """Various sampling methods."""
3
+ from scipy import integrate
4
+ import torch
5
+
6
+ from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
7
+ from .correctors import Corrector, CorrectorRegistry
8
+
9
+
10
+ __all__ = [
11
+ 'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
12
+ 'get_sampler'
13
+ ]
14
+
15
+
16
+ def to_flattened_numpy(x):
17
+ """Flatten a torch tensor `x` and convert it to numpy."""
18
+ return x.detach().cpu().numpy().reshape((-1,))
19
+
20
+
21
+ def from_flattened_numpy(x, shape):
22
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
23
+ return torch.from_numpy(x.reshape(shape))
24
+
25
+
26
+ def get_pc_sampler(
27
+ predictor_name, corrector_name, sde, score_fn, y,
28
+ denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
29
+ intermediate=False, **kwargs
30
+ ):
31
+ """Create a Predictor-Corrector (PC) sampler.
32
+
33
+ Args:
34
+ predictor_name: The name of a registered `sampling.Predictor`.
35
+ corrector_name: The name of a registered `sampling.Corrector`.
36
+ sde: An `sdes.SDE` object representing the forward SDE.
37
+ score_fn: A function (typically learned model) that predicts the score.
38
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
39
+ denoise: If `True`, add one-step denoising to the final samples.
40
+ eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
41
+ snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
42
+ N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
43
+
44
+ Returns:
45
+ A sampling function that returns samples and the number of function evaluations during sampling.
46
+ """
47
+ predictor_cls = PredictorRegistry.get_by_name(predictor_name)
48
+ corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
49
+ predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
50
+ corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
51
+
52
+ def pc_sampler():
53
+ """The PC sampler function."""
54
+ with torch.no_grad():
55
+ xt = sde.prior_sampling(y.shape, y).to(y.device)
56
+ timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
57
+ for i in range(sde.N):
58
+ t = timesteps[i]
59
+ if i != len(timesteps) - 1:
60
+ stepsize = t - timesteps[i+1]
61
+ else:
62
+ stepsize = timesteps[-1] # from eps to 0
63
+ vec_t = torch.ones(y.shape[0], device=y.device) * t
64
+ xt, xt_mean = corrector.update_fn(xt, vec_t, y)
65
+ xt, xt_mean = predictor.update_fn(xt, vec_t, y, stepsize)
66
+ x_result = xt_mean if denoise else xt
67
+ ns = sde.N * (corrector.n_steps + 1)
68
+ return x_result, ns
69
+
70
+ return pc_sampler
71
+
72
+
73
+ def get_ode_sampler(
74
+ sde, score_fn, y, inverse_scaler=None,
75
+ denoise=True, rtol=1e-5, atol=1e-5,
76
+ method='RK45', eps=3e-2, device='cuda', **kwargs
77
+ ):
78
+ """Probability flow ODE sampler with the black-box ODE solver.
79
+
80
+ Args:
81
+ sde: An `sdes.SDE` object representing the forward SDE.
82
+ score_fn: A function (typically learned model) that predicts the score.
83
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
84
+ inverse_scaler: The inverse data normalizer.
85
+ denoise: If `True`, add one-step denoising to final samples.
86
+ rtol: A `float` number. The relative tolerance level of the ODE solver.
87
+ atol: A `float` number. The absolute tolerance level of the ODE solver.
88
+ method: A `str`. The algorithm used for the black-box ODE solver.
89
+ See the documentation of `scipy.integrate.solve_ivp`.
90
+ eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
91
+ device: PyTorch device.
92
+
93
+ Returns:
94
+ A sampling function that returns samples and the number of function evaluations during sampling.
95
+ """
96
+ predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
97
+ rsde = sde.reverse(score_fn, probability_flow=True)
98
+
99
+ def denoise_update_fn(x):
100
+ vec_eps = torch.ones(x.shape[0], device=x.device) * eps
101
+ _, x = predictor.update_fn(x, vec_eps, y)
102
+ return x
103
+
104
+ def drift_fn(x, t, y):
105
+ """Get the drift function of the reverse-time SDE."""
106
+ return rsde.sde(x, t, y)[0]
107
+
108
+ def ode_sampler(z=None, **kwargs):
109
+ """The probability flow ODE sampler with black-box ODE solver.
110
+
111
+ Args:
112
+ model: A score model.
113
+ z: If present, generate samples from latent code `z`.
114
+ Returns:
115
+ samples, number of function evaluations.
116
+ """
117
+ with torch.no_grad():
118
+ # If not represent, sample the latent code from the prior distibution of the SDE.
119
+ x = sde.prior_sampling(y.shape, y).to(device)
120
+
121
+ def ode_func(t, x):
122
+ x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64)
123
+ vec_t = torch.ones(y.shape[0], device=x.device) * t
124
+ drift = drift_fn(x, vec_t, y)
125
+ return to_flattened_numpy(drift)
126
+
127
+ # Black-box ODE solver for the probability flow ODE
128
+ solution = integrate.solve_ivp(
129
+ ode_func, (sde.T, eps), to_flattened_numpy(x),
130
+ rtol=rtol, atol=atol, method=method, **kwargs
131
+ )
132
+ nfe = solution.nfev
133
+ x = torch.tensor(solution.y[:, -1]).reshape(y.shape).to(device).type(torch.complex64)
134
+
135
+ # Denoising is equivalent to running one predictor step without adding noise
136
+ if denoise:
137
+ x = denoise_update_fn(x)
138
+
139
+ if inverse_scaler is not None:
140
+ x = inverse_scaler(x)
141
+ return x, nfe
142
+
143
+ return ode_sampler
sgmse/sampling/correctors.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import torch
3
+
4
+ from sgmse import sdes
5
+ from sgmse.util.registry import Registry
6
+
7
+
8
+ CorrectorRegistry = Registry("Corrector")
9
+
10
+
11
+ class Corrector(abc.ABC):
12
+ """The abstract class for a corrector algorithm."""
13
+
14
+ def __init__(self, sde, score_fn, snr, n_steps):
15
+ super().__init__()
16
+ self.rsde = sde.reverse(score_fn)
17
+ self.score_fn = score_fn
18
+ self.snr = snr
19
+ self.n_steps = n_steps
20
+
21
+ @abc.abstractmethod
22
+ def update_fn(self, x, t, *args):
23
+ """One update of the corrector.
24
+
25
+ Args:
26
+ x: A PyTorch tensor representing the current state
27
+ t: A PyTorch tensor representing the current time step.
28
+ *args: Possibly additional arguments, in particular `y` for OU processes
29
+
30
+ Returns:
31
+ x: A PyTorch tensor of the next state.
32
+ x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
33
+ """
34
+ pass
35
+
36
+
37
+ @CorrectorRegistry.register(name='langevin')
38
+ class LangevinCorrector(Corrector):
39
+ def __init__(self, sde, score_fn, snr, n_steps):
40
+ super().__init__(sde, score_fn, snr, n_steps)
41
+ self.score_fn = score_fn
42
+ self.n_steps = n_steps
43
+ self.snr = snr
44
+
45
+ def update_fn(self, x, t, *args):
46
+ target_snr = self.snr
47
+ for _ in range(self.n_steps):
48
+ grad = self.score_fn(x, t, *args)
49
+ noise = torch.randn_like(x)
50
+ grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
51
+ noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
52
+ step_size = ((target_snr * noise_norm / grad_norm) ** 2 * 2).unsqueeze(0)
53
+ x_mean = x + step_size[:, None, None, None] * grad
54
+ x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
55
+
56
+ return x, x_mean
57
+
58
+
59
+ @CorrectorRegistry.register(name='ald')
60
+ class AnnealedLangevinDynamics(Corrector):
61
+ """The original annealed Langevin dynamics predictor in NCSN/NCSNv2."""
62
+ def __init__(self, sde, score_fn, snr, n_steps):
63
+ super().__init__(sde, score_fn, snr, n_steps)
64
+ if not isinstance(sde, (sdes.OUVESDE,)):
65
+ raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
66
+ self.sde = sde
67
+ self.score_fn = score_fn
68
+ self.snr = snr
69
+ self.n_steps = n_steps
70
+
71
+ def update_fn(self, x, t, *args):
72
+ n_steps = self.n_steps
73
+ target_snr = self.snr
74
+ std = self.sde.marginal_prob(x, t, *args)[1]
75
+
76
+ for _ in range(n_steps):
77
+ grad = self.score_fn(x, t, *args)
78
+ noise = torch.randn_like(x)
79
+ step_size = (target_snr * std) ** 2 * 2
80
+ x_mean = x + step_size[:, None, None, None] * grad
81
+ x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
82
+
83
+ return x, x_mean
84
+
85
+
86
+ @CorrectorRegistry.register(name='none')
87
+ class NoneCorrector(Corrector):
88
+ """An empty corrector that does nothing."""
89
+
90
+ def __init__(self, *args, **kwargs):
91
+ self.snr = 0
92
+ self.n_steps = 0
93
+ pass
94
+
95
+ def update_fn(self, x, t, *args):
96
+ return x, x
sgmse/sampling/predictors.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from sgmse.util.registry import Registry
7
+
8
+
9
+ PredictorRegistry = Registry("Predictor")
10
+
11
+
12
+ class Predictor(abc.ABC):
13
+ """The abstract class for a predictor algorithm."""
14
+
15
+ def __init__(self, sde, score_fn, probability_flow=False):
16
+ super().__init__()
17
+ self.sde = sde
18
+ self.rsde = sde.reverse(score_fn)
19
+ self.score_fn = score_fn
20
+ self.probability_flow = probability_flow
21
+
22
+ @abc.abstractmethod
23
+ def update_fn(self, x, t, *args):
24
+ """One update of the predictor.
25
+
26
+ Args:
27
+ x: A PyTorch tensor representing the current state
28
+ t: A Pytorch tensor representing the current time step.
29
+ *args: Possibly additional arguments, in particular `y` for OU processes
30
+
31
+ Returns:
32
+ x: A PyTorch tensor of the next state.
33
+ x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
34
+ """
35
+ pass
36
+
37
+ def debug_update_fn(self, x, t, *args):
38
+ raise NotImplementedError(f"Debug update function not implemented for predictor {self}.")
39
+
40
+
41
+ @PredictorRegistry.register('euler_maruyama')
42
+ class EulerMaruyamaPredictor(Predictor):
43
+ def __init__(self, sde, score_fn, probability_flow=False):
44
+ super().__init__(sde, score_fn, probability_flow=probability_flow)
45
+
46
+ def update_fn(self, x, t, *args):
47
+ dt = -1. / self.rsde.N
48
+ z = torch.randn_like(x)
49
+ f, g = self.rsde.sde(x, t, *args)
50
+ x_mean = x + f * dt
51
+ x = x_mean + g[:, None, None, None] * np.sqrt(-dt) * z
52
+ return x, x_mean
53
+
54
+
55
+ @PredictorRegistry.register('reverse_diffusion')
56
+ class ReverseDiffusionPredictor(Predictor):
57
+ def __init__(self, sde, score_fn, probability_flow=False):
58
+ super().__init__(sde, score_fn, probability_flow=probability_flow)
59
+
60
+ def update_fn(self, x, t, y, stepsize):
61
+ f, g = self.rsde.discretize(x, t, y, stepsize)
62
+ z = torch.randn_like(x)
63
+ x_mean = x - f
64
+ x = x_mean + g[:, None, None, None] * z
65
+ return x, x_mean
66
+
67
+
68
+ @PredictorRegistry.register('none')
69
+ class NonePredictor(Predictor):
70
+ """An empty predictor that does nothing."""
71
+
72
+ def __init__(self, *args, **kwargs):
73
+ pass
74
+
75
+ def update_fn(self, x, t, *args):
76
+ return x, x
sgmse/sdes.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
3
+
4
+ Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
5
+ """
6
+ import abc
7
+ import warnings
8
+
9
+ import numpy as np
10
+ from sgmse.util.tensors import batch_broadcast
11
+ import torch
12
+
13
+ from sgmse.util.registry import Registry
14
+
15
+
16
+ SDERegistry = Registry("SDE")
17
+
18
+
19
+ class SDE(abc.ABC):
20
+ """SDE abstract class. Functions are designed for a mini-batch of inputs."""
21
+
22
+ def __init__(self, N):
23
+ """Construct an SDE.
24
+
25
+ Args:
26
+ N: number of discretization time steps.
27
+ """
28
+ super().__init__()
29
+ self.N = N
30
+
31
+ @property
32
+ @abc.abstractmethod
33
+ def T(self):
34
+ """End time of the SDE."""
35
+ pass
36
+
37
+ @abc.abstractmethod
38
+ def sde(self, x, t, *args):
39
+ pass
40
+
41
+ @abc.abstractmethod
42
+ def marginal_prob(self, x, t, *args):
43
+ """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
44
+ pass
45
+
46
+ @abc.abstractmethod
47
+ def prior_sampling(self, shape, *args):
48
+ """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
49
+ pass
50
+
51
+ @abc.abstractmethod
52
+ def prior_logp(self, z):
53
+ """Compute log-density of the prior distribution.
54
+
55
+ Useful for computing the log-likelihood via probability flow ODE.
56
+
57
+ Args:
58
+ z: latent code
59
+ Returns:
60
+ log probability density
61
+ """
62
+ pass
63
+
64
+ @staticmethod
65
+ @abc.abstractmethod
66
+ def add_argparse_args(parent_parser):
67
+ """
68
+ Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
69
+ """
70
+ pass
71
+
72
+ def discretize(self, x, t, y, stepsize):
73
+ """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
74
+
75
+ Useful for reverse diffusion sampling and probabiliy flow sampling.
76
+ Defaults to Euler-Maruyama discretization.
77
+
78
+ Args:
79
+ x: a torch tensor
80
+ t: a torch float representing the time step (from 0 to `self.T`)
81
+
82
+ Returns:
83
+ f, G
84
+ """
85
+ dt = stepsize
86
+ drift, diffusion = self.sde(x, t, y)
87
+ f = drift * dt
88
+ G = diffusion * torch.sqrt(dt)
89
+ return f, G
90
+
91
+ def reverse(oself, score_model, probability_flow=False):
92
+ """Create the reverse-time SDE/ODE.
93
+
94
+ Args:
95
+ score_model: A function that takes x, t and y and returns the score.
96
+ probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
97
+ """
98
+ N = oself.N
99
+ T = oself.T
100
+ sde_fn = oself.sde
101
+ discretize_fn = oself.discretize
102
+
103
+ # Build the class for reverse-time SDE.
104
+ class RSDE(oself.__class__):
105
+ def __init__(self):
106
+ self.N = N
107
+ self.probability_flow = probability_flow
108
+
109
+ @property
110
+ def T(self):
111
+ return T
112
+
113
+ def sde(self, x, t, *args):
114
+ """Create the drift and diffusion functions for the reverse SDE/ODE."""
115
+ rsde_parts = self.rsde_parts(x, t, *args)
116
+ total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
117
+ return total_drift, diffusion
118
+
119
+ def rsde_parts(self, x, t, *args):
120
+ sde_drift, sde_diffusion = sde_fn(x, t, *args)
121
+ score = score_model(x, t, *args)
122
+ score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.)
123
+ diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion
124
+ total_drift = sde_drift + score_drift
125
+ return {
126
+ 'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift,
127
+ 'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
128
+ }
129
+
130
+ def discretize(self, x, t, y, stepsize):
131
+ """Create discretized iteration rules for the reverse diffusion sampler."""
132
+ f, G = discretize_fn(x, t, y, stepsize)
133
+ rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.)
134
+ rev_G = torch.zeros_like(G) if self.probability_flow else G
135
+ return rev_f, rev_G
136
+
137
+ return RSDE()
138
+
139
+ @abc.abstractmethod
140
+ def copy(self):
141
+ pass
142
+
143
+
144
+ @SDERegistry.register("ouve")
145
+ class OUVESDE(SDE):
146
+ @staticmethod
147
+ def add_argparse_args(parser):
148
+ parser.add_argument("--sde-n", type=int, default=1000, help="The number of timesteps in the SDE discretization. 30 by default")
149
+ parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.")
150
+ parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.")
151
+ parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.")
152
+ return parser
153
+
154
+ def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs):
155
+ """Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
156
+
157
+ Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
158
+ to the methods which require it (e.g., `sde` or `marginal_prob`).
159
+
160
+ dx = -theta (y-x) dt + sigma(t) dw
161
+
162
+ with
163
+
164
+ sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
165
+
166
+ Args:
167
+ theta: stiffness parameter.
168
+ sigma_min: smallest sigma.
169
+ sigma_max: largest sigma.
170
+ N: number of discretization steps
171
+ """
172
+ super().__init__(N)
173
+ self.theta = theta
174
+ self.sigma_min = sigma_min
175
+ self.sigma_max = sigma_max
176
+ self.logsig = np.log(self.sigma_max / self.sigma_min)
177
+ self.N = N
178
+
179
+ def copy(self):
180
+ return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N)
181
+
182
+ @property
183
+ def T(self):
184
+ return 1
185
+
186
+ def sde(self, x, t, y):
187
+ drift = self.theta * (y - x)
188
+ # the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
189
+ # standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
190
+ # with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
191
+ # unless this sqrt(2*logsig) factor is included.
192
+ sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
193
+ diffusion = sigma * np.sqrt(2 * self.logsig)
194
+ return drift, diffusion
195
+
196
+ def _mean(self, x0, t, y):
197
+ theta = self.theta
198
+ exp_interp = torch.exp(-theta * t)[:, None, None, None]
199
+ return exp_interp * x0 + (1 - exp_interp) * y
200
+
201
+ def alpha(self, t):
202
+ return torch.exp(-self.theta * t)
203
+
204
+ def _std(self, t):
205
+ # This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
206
+ sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
207
+ # could maybe replace the two torch.exp(... * t) terms here by cached values **t
208
+ return torch.sqrt(
209
+ (
210
+ sigma_min**2
211
+ * torch.exp(-2 * theta * t)
212
+ * (torch.exp(2 * (theta + logsig) * t) - 1)
213
+ * logsig
214
+ )
215
+ /
216
+ (theta + logsig)
217
+ )
218
+
219
+ def marginal_prob(self, x0, t, y):
220
+ return self._mean(x0, t, y), self._std(t)
221
+
222
+ def prior_sampling(self, shape, y):
223
+ if shape != y.shape:
224
+ warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
225
+ std = self._std(torch.ones((y.shape[0],), device=y.device))
226
+ x_T = y + torch.randn_like(y) * std[:, None, None, None]
227
+ return x_T
228
+
229
+ def prior_logp(self, z):
230
+ raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
231
+
232
+
233
+ @SDERegistry.register("ouvp")
234
+ class OUVPSDE(SDE):
235
+ # !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
236
+ @staticmethod
237
+ def add_argparse_args(parser):
238
+ parser.add_argument("--sde-n", type=int, default=1000,
239
+ help="The number of timesteps in the SDE discretization. 1000 by default")
240
+ parser.add_argument("--beta-min", type=float, required=True,
241
+ help="The minimum beta to use.")
242
+ parser.add_argument("--beta-max", type=float, required=True,
243
+ help="The maximum beta to use.")
244
+ parser.add_argument("--stiffness", type=float, default=1,
245
+ help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.")
246
+ return parser
247
+
248
+ def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs):
249
+ """
250
+ !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
251
+
252
+ Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
253
+
254
+ dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
255
+
256
+ with
257
+
258
+ beta(t) = beta_min + t(beta_max - beta_min)
259
+
260
+ Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
261
+ to the methods which require it (e.g., `sde` or `marginal_prob`).
262
+
263
+ Args:
264
+ beta_min: smallest sigma.
265
+ beta_max: largest sigma.
266
+ stiffness: stiffness factor of the drift. 1 by default.
267
+ N: number of discretization steps
268
+ """
269
+ super().__init__(N)
270
+ self.beta_min = beta_min
271
+ self.beta_max = beta_max
272
+ self.stiffness = stiffness
273
+ self.N = N
274
+
275
+ def copy(self):
276
+ return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)
277
+
278
+ @property
279
+ def T(self):
280
+ return 1
281
+
282
+ def _beta(self, t):
283
+ return self.beta_min + t * (self.beta_max - self.beta_min)
284
+
285
+ def sde(self, x, t, y):
286
+ drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x)
287
+ diffusion = torch.sqrt(self._beta(t))
288
+ return drift, diffusion
289
+
290
+ def _mean(self, x0, t, y):
291
+ b0, b1, s = self.beta_min, self.beta_max, self.stiffness
292
+ x0y_fac = torch.exp(-0.25 * s * t * (t * (b1-b0) + 2 * b0))[:, None, None, None]
293
+ return y + x0y_fac * (x0 - y)
294
+
295
+ def _std(self, t):
296
+ b0, b1, s = self.beta_min, self.beta_max, self.stiffness
297
+ return (1 - torch.exp(-0.5 * s * t * (t * (b1-b0) + 2 * b0))) / s
298
+
299
+ def marginal_prob(self, x0, t, y):
300
+ return self._mean(x0, t, y), self._std(t)
301
+
302
+ def prior_sampling(self, shape, y):
303
+ if shape != y.shape:
304
+ warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
305
+ std = self._std(torch.ones((y.shape[0],), device=y.device))
306
+ x_T = y + torch.randn_like(y) * std[:, None, None, None]
307
+ return x_T
308
+
309
+ def prior_logp(self, z):
310
+ raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
sgmse/util/inference.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchaudio import load
3
+
4
+ from pesq import pesq
5
+ from pystoi import stoi
6
+
7
+ from .other import si_sdr, pad_spec
8
+
9
+ # Settings
10
+ sr = 16000
11
+ snr = 0.5
12
+ N = 30
13
+ corrector_steps = 1
14
+
15
+
16
+ def evaluate_model(model, num_eval_files):
17
+
18
+ clean_files = model.data_module.valid_set.clean_files
19
+ noisy_files = model.data_module.valid_set.noisy_files
20
+
21
+ # Select test files uniformly accros validation files
22
+ total_num_files = len(clean_files)
23
+ indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
24
+ clean_files = list(clean_files[i] for i in indices)
25
+ noisy_files = list(noisy_files[i] for i in indices)
26
+
27
+ _pesq = 0
28
+ _si_sdr = 0
29
+ _estoi = 0
30
+ # iterate over files
31
+ for (clean_file, noisy_file) in zip(clean_files, noisy_files):
32
+ # Load wavs
33
+ x, _ = load(clean_file)
34
+ y, _ = load(noisy_file)
35
+ T_orig = x.size(1)
36
+
37
+ # Normalize per utterance
38
+ norm_factor = y.abs().max()
39
+ y = y / norm_factor
40
+
41
+ # Prepare DNN input
42
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
43
+ Y = pad_spec(Y)
44
+ y = y * norm_factor
45
+
46
+ # Reverse sampling
47
+ sampler = model.get_pc_sampler(
48
+ 'reverse_diffusion', 'ald', Y.cuda(), N=N,
49
+ corrector_steps=corrector_steps, snr=snr)
50
+ sample, _ = sampler()
51
+
52
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
53
+ x_hat = x_hat * norm_factor
54
+
55
+ x_hat = x_hat.squeeze().cpu().numpy()
56
+ x = x.squeeze().cpu().numpy()
57
+ y = y.squeeze().cpu().numpy()
58
+
59
+ _si_sdr += si_sdr(x, x_hat)
60
+ _pesq += pesq(sr, x, x_hat, 'wb')
61
+ _estoi += stoi(x, x_hat, sr, extended=True)
62
+
63
+ return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
64
+
sgmse/util/other.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import scipy.stats
5
+ from scipy.signal import butter, sosfilt
6
+
7
+ from pesq import pesq
8
+ from pystoi import stoi
9
+
10
+
11
+ def si_sdr_components(s_hat, s, n):
12
+ # s_target
13
+ alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
14
+ s_target = alpha_s * s
15
+
16
+ # e_noise
17
+ alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
18
+ e_noise = alpha_n * n
19
+
20
+ # e_art
21
+ e_art = s_hat - s_target - e_noise
22
+
23
+ return s_target, e_noise, e_art
24
+
25
+ def energy_ratios(s_hat, s, n):
26
+ s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
27
+
28
+ si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
29
+ si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
30
+ si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
31
+
32
+ return si_sdr, si_sir, si_sar
33
+
34
+ def mean_conf_int(data, confidence=0.95):
35
+ a = 1.0 * np.array(data)
36
+ n = len(a)
37
+ m, se = np.mean(a), scipy.stats.sem(a)
38
+ h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
39
+ return m, h
40
+
41
+ class Method():
42
+ def __init__(self, name, base_dir, metrics):
43
+ self.name = name
44
+ self.base_dir = base_dir
45
+ self.metrics = {}
46
+
47
+ for i in range(len(metrics)):
48
+ metric = metrics[i]
49
+ value = []
50
+ self.metrics[metric] = value
51
+
52
+ def append(self, matric, value):
53
+ self.metrics[matric].append(value)
54
+
55
+ def get_mean_ci(self, metric):
56
+ return mean_conf_int(np.array(self.metrics[metric]))
57
+
58
+ def hp_filter(signal, cut_off=80, order=10, sr=16000):
59
+ factor = cut_off /sr * 2
60
+ sos = butter(order, factor, 'hp', output='sos')
61
+ filtered = sosfilt(sos, signal)
62
+ return filtered
63
+
64
+ def si_sdr(s, s_hat):
65
+ alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
66
+ sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
67
+ alpha*s - s_hat)**2)
68
+ return sdr
69
+
70
+ def snr_dB(s,n):
71
+ s_power = 1/len(s)*np.sum(s**2)
72
+ n_power = 1/len(n)*np.sum(n**2)
73
+ snr_dB = 10*np.log10(s_power/n_power)
74
+ return snr_dB
75
+
76
+ def pad_spec(Y, mode="zero_pad"):
77
+ T = Y.size(3)
78
+ if T%64 !=0:
79
+ num_pad = 64-T%64
80
+ else:
81
+ num_pad = 0
82
+ if mode == "zero_pad":
83
+ pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
84
+ elif mode == "reflection":
85
+ pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
86
+ elif mode == "replication":
87
+ pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
88
+ else:
89
+ raise NotImplementedError("This function hasn't been implemented yet.")
90
+ return pad2d(Y)
91
+
92
+ def ensure_dir(file_path):
93
+ directory = file_path
94
+ if not os.path.exists(directory):
95
+ os.makedirs(directory)
96
+
97
+
98
+ def print_metrics(x, y, x_hat_list, labels, sr=16000):
99
+ _si_sdr_mix = si_sdr(x, y)
100
+ _pesq_mix = pesq(sr, x, y, 'wb')
101
+ _estoi_mix = stoi(x, y, sr, extended=True)
102
+ print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
103
+ for i, x_hat in enumerate(x_hat_list):
104
+ _si_sdr = si_sdr(x, x_hat)
105
+ _pesq = pesq(sr, x, x_hat, 'wb')
106
+ _estoi = stoi(x, x_hat, sr, extended=True)
107
+ print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
108
+
109
+ def mean_std(data):
110
+ data = data[~np.isnan(data)]
111
+ mean = np.mean(data)
112
+ std = np.std(data)
113
+ return mean, std
114
+
115
+ def print_mean_std(data, decimal=2):
116
+ data = np.array(data)
117
+ data = data[~np.isnan(data)]
118
+ mean = np.mean(data)
119
+ std = np.std(data)
120
+ if decimal == 2:
121
+ string = f'{mean:.2f} ± {std:.2f}'
122
+ elif decimal == 1:
123
+ string = f'{mean:.1f} ± {std:.1f}'
124
+ return string
125
+
126
+ def set_torch_cuda_arch_list():
127
+ if not torch.cuda.is_available():
128
+ print("CUDA is not available. No GPUs found.")
129
+ return
130
+
131
+ num_gpus = torch.cuda.device_count()
132
+ compute_capabilities = []
133
+
134
+ for i in range(num_gpus):
135
+ cc_major, cc_minor = torch.cuda.get_device_capability(i)
136
+ cc = f"{cc_major}.{cc_minor}"
137
+ compute_capabilities.append(cc)
138
+
139
+ cc_string = ";".join(compute_capabilities)
140
+ os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
141
+ print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}")
sgmse/util/registry.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Callable
3
+
4
+
5
+ class Registry:
6
+ def __init__(self, managed_thing: str):
7
+ """
8
+ Create a new registry.
9
+
10
+ Args:
11
+ managed_thing: A string describing what type of thing is managed by this registry. Will be used for
12
+ warnings and errors, so it's a good idea to keep this string globally unique and easily understood.
13
+ """
14
+ self.managed_thing = managed_thing
15
+ self._registry = {}
16
+
17
+ def register(self, name: str) -> Callable:
18
+ def inner_wrapper(wrapped_class) -> Callable:
19
+ if name in self._registry:
20
+ warnings.warn(f"{self.managed_thing} with name '{name}' doubly registered, old class will be replaced.")
21
+ self._registry[name] = wrapped_class
22
+ return wrapped_class
23
+ return inner_wrapper
24
+
25
+ def get_by_name(self, name: str):
26
+ """Get a managed thing by name."""
27
+ if name in self._registry:
28
+ return self._registry[name]
29
+ else:
30
+ raise ValueError(f"{self.managed_thing} with name '{name}' unknown.")
31
+
32
+ def get_all_names(self):
33
+ """Get the list of things' names registered to this registry."""
34
+ return list(self._registry.keys())
sgmse/util/tensors.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def batch_broadcast(a, x):
2
+ """Broadcasts a over all dimensions of x, except the batch dimension, which must match."""
3
+
4
+ if len(a.shape) != 1:
5
+ a = a.squeeze()
6
+ if len(a.shape) != 1:
7
+ raise ValueError(
8
+ f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})"
9
+ )
10
+
11
+ if a.shape[0] != x.shape[0] and a.shape[0] != 1:
12
+ raise ValueError(
13
+ f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching")
14
+
15
+ out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1))))
16
+ return out