jadechoghari commited on
Commit
5aecd47
1 Parent(s): 813b71e

Create vae.py

Browse files
Files changed (1) hide show
  1. vae.py +490 -0
vae.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from LDM's KL-VAE: https://github.com/CompVis/latent-diffusion
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ import numpy as np
6
+
7
+
8
+ def nonlinearity(x):
9
+ # swish
10
+ return x * torch.sigmoid(x)
11
+
12
+
13
+ def Normalize(in_channels, num_groups=32):
14
+ return torch.nn.GroupNorm(
15
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
16
+ )
17
+
18
+
19
+ class Upsample(nn.Module):
20
+ def __init__(self, in_channels, with_conv):
21
+ super().__init__()
22
+ self.with_conv = with_conv
23
+ if self.with_conv:
24
+ self.conv = torch.nn.Conv2d(
25
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
26
+ )
27
+
28
+ def forward(self, x):
29
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
30
+ if self.with_conv:
31
+ x = self.conv(x)
32
+ return x
33
+
34
+
35
+ class Downsample(nn.Module):
36
+ def __init__(self, in_channels, with_conv):
37
+ super().__init__()
38
+ self.with_conv = with_conv
39
+ if self.with_conv:
40
+ # no asymmetric padding in torch conv, must do it ourselves
41
+ self.conv = torch.nn.Conv2d(
42
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
43
+ )
44
+
45
+ def forward(self, x):
46
+ if self.with_conv:
47
+ pad = (0, 1, 0, 1)
48
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
49
+ x = self.conv(x)
50
+ else:
51
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
52
+ return x
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(
57
+ self,
58
+ *,
59
+ in_channels,
60
+ out_channels=None,
61
+ conv_shortcut=False,
62
+ dropout,
63
+ temb_channels=512,
64
+ ):
65
+ super().__init__()
66
+ self.in_channels = in_channels
67
+ out_channels = in_channels if out_channels is None else out_channels
68
+ self.out_channels = out_channels
69
+ self.use_conv_shortcut = conv_shortcut
70
+
71
+ self.norm1 = Normalize(in_channels)
72
+ self.conv1 = torch.nn.Conv2d(
73
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
74
+ )
75
+ if temb_channels > 0:
76
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
77
+ self.norm2 = Normalize(out_channels)
78
+ self.dropout = torch.nn.Dropout(dropout)
79
+ self.conv2 = torch.nn.Conv2d(
80
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
81
+ )
82
+ if self.in_channels != self.out_channels:
83
+ if self.use_conv_shortcut:
84
+ self.conv_shortcut = torch.nn.Conv2d(
85
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
86
+ )
87
+ else:
88
+ self.nin_shortcut = torch.nn.Conv2d(
89
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
90
+ )
91
+
92
+ def forward(self, x, temb):
93
+ h = x
94
+ h = self.norm1(h)
95
+ h = nonlinearity(h)
96
+ h = self.conv1(h)
97
+
98
+ if temb is not None:
99
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
100
+
101
+ h = self.norm2(h)
102
+ h = nonlinearity(h)
103
+ h = self.dropout(h)
104
+ h = self.conv2(h)
105
+
106
+ if self.in_channels != self.out_channels:
107
+ if self.use_conv_shortcut:
108
+ x = self.conv_shortcut(x)
109
+ else:
110
+ x = self.nin_shortcut(x)
111
+
112
+ return x + h
113
+
114
+
115
+ class AttnBlock(nn.Module):
116
+ def __init__(self, in_channels):
117
+ super().__init__()
118
+ self.in_channels = in_channels
119
+
120
+ self.norm = Normalize(in_channels)
121
+ self.q = torch.nn.Conv2d(
122
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
123
+ )
124
+ self.k = torch.nn.Conv2d(
125
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
126
+ )
127
+ self.v = torch.nn.Conv2d(
128
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
129
+ )
130
+ self.proj_out = torch.nn.Conv2d(
131
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
132
+ )
133
+
134
+ def forward(self, x):
135
+ h_ = x
136
+ h_ = self.norm(h_)
137
+ q = self.q(h_)
138
+ k = self.k(h_)
139
+ v = self.v(h_)
140
+
141
+ # compute attention
142
+ b, c, h, w = q.shape
143
+ q = q.reshape(b, c, h * w)
144
+ q = q.permute(0, 2, 1) # b,hw,c
145
+ k = k.reshape(b, c, h * w) # b,c,hw
146
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
147
+ w_ = w_ * (int(c) ** (-0.5))
148
+ w_ = torch.nn.functional.softmax(w_, dim=2)
149
+
150
+ # attend to values
151
+ v = v.reshape(b, c, h * w)
152
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
153
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
154
+ h_ = h_.reshape(b, c, h, w)
155
+
156
+ h_ = self.proj_out(h_)
157
+
158
+ return x + h_
159
+
160
+
161
+ class Encoder(nn.Module):
162
+ def __init__(
163
+ self,
164
+ *,
165
+ ch=128,
166
+ out_ch=3,
167
+ ch_mult=(1, 1, 2, 2, 4),
168
+ num_res_blocks=2,
169
+ attn_resolutions=(16,),
170
+ dropout=0.0,
171
+ resamp_with_conv=True,
172
+ in_channels=3,
173
+ resolution=256,
174
+ z_channels=16,
175
+ double_z=True,
176
+ **ignore_kwargs,
177
+ ):
178
+ super().__init__()
179
+ self.ch = ch
180
+ self.temb_ch = 0
181
+ self.num_resolutions = len(ch_mult)
182
+ self.num_res_blocks = num_res_blocks
183
+ self.resolution = resolution
184
+ self.in_channels = in_channels
185
+
186
+ # downsampling
187
+ self.conv_in = torch.nn.Conv2d(
188
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
189
+ )
190
+
191
+ curr_res = resolution
192
+ in_ch_mult = (1,) + tuple(ch_mult)
193
+ self.down = nn.ModuleList()
194
+ for i_level in range(self.num_resolutions):
195
+ block = nn.ModuleList()
196
+ attn = nn.ModuleList()
197
+ block_in = ch * in_ch_mult[i_level]
198
+ block_out = ch * ch_mult[i_level]
199
+ for i_block in range(self.num_res_blocks):
200
+ block.append(
201
+ ResnetBlock(
202
+ in_channels=block_in,
203
+ out_channels=block_out,
204
+ temb_channels=self.temb_ch,
205
+ dropout=dropout,
206
+ )
207
+ )
208
+ block_in = block_out
209
+ if curr_res in attn_resolutions:
210
+ attn.append(AttnBlock(block_in))
211
+ down = nn.Module()
212
+ down.block = block
213
+ down.attn = attn
214
+ if i_level != self.num_resolutions - 1:
215
+ down.downsample = Downsample(block_in, resamp_with_conv)
216
+ curr_res = curr_res // 2
217
+ self.down.append(down)
218
+
219
+ # middle
220
+ self.mid = nn.Module()
221
+ self.mid.block_1 = ResnetBlock(
222
+ in_channels=block_in,
223
+ out_channels=block_in,
224
+ temb_channels=self.temb_ch,
225
+ dropout=dropout,
226
+ )
227
+ self.mid.attn_1 = AttnBlock(block_in)
228
+ self.mid.block_2 = ResnetBlock(
229
+ in_channels=block_in,
230
+ out_channels=block_in,
231
+ temb_channels=self.temb_ch,
232
+ dropout=dropout,
233
+ )
234
+
235
+ # end
236
+ self.norm_out = Normalize(block_in)
237
+ self.conv_out = torch.nn.Conv2d(
238
+ block_in,
239
+ 2 * z_channels if double_z else z_channels,
240
+ kernel_size=3,
241
+ stride=1,
242
+ padding=1,
243
+ )
244
+
245
+ def forward(self, x):
246
+ # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
247
+
248
+ # timestep embedding
249
+ temb = None
250
+
251
+ # downsampling
252
+ hs = [self.conv_in(x)]
253
+ for i_level in range(self.num_resolutions):
254
+ for i_block in range(self.num_res_blocks):
255
+ h = self.down[i_level].block[i_block](hs[-1], temb)
256
+ if len(self.down[i_level].attn) > 0:
257
+ h = self.down[i_level].attn[i_block](h)
258
+ hs.append(h)
259
+ if i_level != self.num_resolutions - 1:
260
+ hs.append(self.down[i_level].downsample(hs[-1]))
261
+
262
+ # middle
263
+ h = hs[-1]
264
+ h = self.mid.block_1(h, temb)
265
+ h = self.mid.attn_1(h)
266
+ h = self.mid.block_2(h, temb)
267
+
268
+ # end
269
+ h = self.norm_out(h)
270
+ h = nonlinearity(h)
271
+ h = self.conv_out(h)
272
+ return h
273
+
274
+
275
+ class Decoder(nn.Module):
276
+ def __init__(
277
+ self,
278
+ *,
279
+ ch=128,
280
+ out_ch=3,
281
+ ch_mult=(1, 1, 2, 2, 4),
282
+ num_res_blocks=2,
283
+ attn_resolutions=(),
284
+ dropout=0.0,
285
+ resamp_with_conv=True,
286
+ in_channels=3,
287
+ resolution=256,
288
+ z_channels=16,
289
+ give_pre_end=False,
290
+ **ignore_kwargs,
291
+ ):
292
+ super().__init__()
293
+ self.ch = ch
294
+ self.temb_ch = 0
295
+ self.num_resolutions = len(ch_mult)
296
+ self.num_res_blocks = num_res_blocks
297
+ self.resolution = resolution
298
+ self.in_channels = in_channels
299
+ self.give_pre_end = give_pre_end
300
+
301
+ # compute in_ch_mult, block_in and curr_res at lowest res
302
+ in_ch_mult = (1,) + tuple(ch_mult)
303
+ block_in = ch * ch_mult[self.num_resolutions - 1]
304
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
305
+ self.z_shape = (1, z_channels, curr_res, curr_res)
306
+ print(
307
+ "Working with z of shape {} = {} dimensions.".format(
308
+ self.z_shape, np.prod(self.z_shape)
309
+ )
310
+ )
311
+
312
+ # z to block_in
313
+ self.conv_in = torch.nn.Conv2d(
314
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
315
+ )
316
+
317
+ # middle
318
+ self.mid = nn.Module()
319
+ self.mid.block_1 = ResnetBlock(
320
+ in_channels=block_in,
321
+ out_channels=block_in,
322
+ temb_channels=self.temb_ch,
323
+ dropout=dropout,
324
+ )
325
+ self.mid.attn_1 = AttnBlock(block_in)
326
+ self.mid.block_2 = ResnetBlock(
327
+ in_channels=block_in,
328
+ out_channels=block_in,
329
+ temb_channels=self.temb_ch,
330
+ dropout=dropout,
331
+ )
332
+
333
+ # upsampling
334
+ self.up = nn.ModuleList()
335
+ for i_level in reversed(range(self.num_resolutions)):
336
+ block = nn.ModuleList()
337
+ attn = nn.ModuleList()
338
+ block_out = ch * ch_mult[i_level]
339
+ for i_block in range(self.num_res_blocks + 1):
340
+ block.append(
341
+ ResnetBlock(
342
+ in_channels=block_in,
343
+ out_channels=block_out,
344
+ temb_channels=self.temb_ch,
345
+ dropout=dropout,
346
+ )
347
+ )
348
+ block_in = block_out
349
+ if curr_res in attn_resolutions:
350
+ attn.append(AttnBlock(block_in))
351
+ up = nn.Module()
352
+ up.block = block
353
+ up.attn = attn
354
+ if i_level != 0:
355
+ up.upsample = Upsample(block_in, resamp_with_conv)
356
+ curr_res = curr_res * 2
357
+ self.up.insert(0, up) # prepend to get consistent order
358
+
359
+ # end
360
+ self.norm_out = Normalize(block_in)
361
+ self.conv_out = torch.nn.Conv2d(
362
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
363
+ )
364
+
365
+ def forward(self, z):
366
+ # assert z.shape[1:] == self.z_shape[1:]
367
+ self.last_z_shape = z.shape
368
+
369
+ # timestep embedding
370
+ temb = None
371
+
372
+ # z to block_in
373
+ h = self.conv_in(z)
374
+
375
+ # middle
376
+ h = self.mid.block_1(h, temb)
377
+ h = self.mid.attn_1(h)
378
+ h = self.mid.block_2(h, temb)
379
+
380
+ # upsampling
381
+ for i_level in reversed(range(self.num_resolutions)):
382
+ for i_block in range(self.num_res_blocks + 1):
383
+ h = self.up[i_level].block[i_block](h, temb)
384
+ if len(self.up[i_level].attn) > 0:
385
+ h = self.up[i_level].attn[i_block](h)
386
+ if i_level != 0:
387
+ h = self.up[i_level].upsample(h)
388
+
389
+ # end
390
+ if self.give_pre_end:
391
+ return h
392
+
393
+ h = self.norm_out(h)
394
+ h = nonlinearity(h)
395
+ h = self.conv_out(h)
396
+ return h
397
+
398
+
399
+ class DiagonalGaussianDistribution(object):
400
+ def __init__(self, parameters, deterministic=False):
401
+ self.parameters = parameters
402
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
403
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
404
+ self.deterministic = deterministic
405
+ self.std = torch.exp(0.5 * self.logvar)
406
+ self.var = torch.exp(self.logvar)
407
+ if self.deterministic:
408
+ self.var = self.std = torch.zeros_like(self.mean).to(
409
+ device=self.parameters.device
410
+ )
411
+
412
+ def sample(self):
413
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
414
+ device=self.parameters.device
415
+ )
416
+ return x
417
+
418
+ def kl(self, other=None):
419
+ if self.deterministic:
420
+ return torch.Tensor([0.0])
421
+ else:
422
+ if other is None:
423
+ return 0.5 * torch.sum(
424
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
425
+ dim=[1, 2, 3],
426
+ )
427
+ else:
428
+ return 0.5 * torch.sum(
429
+ torch.pow(self.mean - other.mean, 2) / other.var
430
+ + self.var / other.var
431
+ - 1.0
432
+ - self.logvar
433
+ + other.logvar,
434
+ dim=[1, 2, 3],
435
+ )
436
+
437
+ def nll(self, sample, dims=[1, 2, 3]):
438
+ if self.deterministic:
439
+ return torch.Tensor([0.0])
440
+ logtwopi = np.log(2.0 * np.pi)
441
+ return 0.5 * torch.sum(
442
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
443
+ dim=dims,
444
+ )
445
+
446
+ def mode(self):
447
+ return self.mean
448
+
449
+
450
+ class AutoencoderKL(nn.Module):
451
+ def __init__(self, embed_dim, ch_mult, use_variational=True, ckpt_path=None):
452
+ super().__init__()
453
+ self.encoder = Encoder(ch_mult=ch_mult, z_channels=embed_dim)
454
+ self.decoder = Decoder(ch_mult=ch_mult, z_channels=embed_dim)
455
+ self.use_variational = use_variational
456
+ mult = 2 if self.use_variational else 1
457
+ self.quant_conv = torch.nn.Conv2d(2 * embed_dim, mult * embed_dim, 1)
458
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, embed_dim, 1)
459
+ self.embed_dim = embed_dim
460
+ if ckpt_path is not None:
461
+ self.init_from_ckpt(ckpt_path)
462
+
463
+ def init_from_ckpt(self, path):
464
+ sd = torch.load(path, map_location="cpu")["model"]
465
+ msg = self.load_state_dict(sd, strict=False)
466
+ print("Loading pre-trained KL-VAE")
467
+ print("Missing keys:")
468
+ print(msg.missing_keys)
469
+ print("Unexpected keys:")
470
+ print(msg.unexpected_keys)
471
+ print(f"Restored from {path}")
472
+
473
+ def encode(self, x):
474
+ h = self.encoder(x)
475
+ moments = self.quant_conv(h)
476
+ if not self.use_variational:
477
+ moments = torch.cat((moments, torch.ones_like(moments)), 1)
478
+ posterior = DiagonalGaussianDistribution(moments)
479
+ return posterior
480
+
481
+ def decode(self, z):
482
+ z = self.post_quant_conv(z)
483
+ dec = self.decoder(z)
484
+ return dec
485
+
486
+ def forward(self, inputs, disable=True, train=True, optimizer_idx=0):
487
+ if train:
488
+ return self.training_step(inputs, disable, optimizer_idx)
489
+ else:
490
+ return self.validation_step(inputs, disable)