SunderAli17 commited on
Commit
7fe60bd
1 Parent(s): 7612a7b

Create vae.py

Browse files
Files changed (1) hide show
  1. module/diffusers_vae/vae.py +978 -0
module/diffusers_vae/vae.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.utils import BaseOutput, is_torch_version
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention_processor import SpatialNorm
25
+ from diffusers.models.unet_2d_blocks import (
26
+ AutoencoderTinyBlock,
27
+ UNetMidBlock2D,
28
+ get_down_block,
29
+ get_up_block,
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class DecoderOutput(BaseOutput):
35
+ r"""
36
+ Output of decoding method.
37
+ Args:
38
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
39
+ The decoded output sample from the last layer of the model.
40
+ """
41
+
42
+ sample: torch.FloatTensor
43
+
44
+
45
+ class Encoder(nn.Module):
46
+ r"""
47
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
48
+ Args:
49
+ in_channels (`int`, *optional*, defaults to 3):
50
+ The number of input channels.
51
+ out_channels (`int`, *optional*, defaults to 3):
52
+ The number of output channels.
53
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
54
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
55
+ options.
56
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
57
+ The number of output channels for each block.
58
+ layers_per_block (`int`, *optional*, defaults to 2):
59
+ The number of layers per block.
60
+ norm_num_groups (`int`, *optional*, defaults to 32):
61
+ The number of groups for normalization.
62
+ act_fn (`str`, *optional*, defaults to `"silu"`):
63
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
64
+ double_z (`bool`, *optional*, defaults to `True`):
65
+ Whether to double the number of output channels for the last block.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ in_channels: int = 3,
71
+ out_channels: int = 3,
72
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
73
+ block_out_channels: Tuple[int, ...] = (64,),
74
+ layers_per_block: int = 2,
75
+ norm_num_groups: int = 32,
76
+ act_fn: str = "silu",
77
+ double_z: bool = True,
78
+ mid_block_add_attention=True,
79
+ ):
80
+ super().__init__()
81
+ self.layers_per_block = layers_per_block
82
+
83
+ self.conv_in = nn.Conv2d(
84
+ in_channels,
85
+ block_out_channels[0],
86
+ kernel_size=3,
87
+ stride=1,
88
+ padding=1,
89
+ )
90
+
91
+ self.mid_block = None
92
+ self.down_blocks = nn.ModuleList([])
93
+
94
+ # down
95
+ output_channel = block_out_channels[0]
96
+ for i, down_block_type in enumerate(down_block_types):
97
+ input_channel = output_channel
98
+ output_channel = block_out_channels[i]
99
+ is_final_block = i == len(block_out_channels) - 1
100
+
101
+ down_block = get_down_block(
102
+ down_block_type,
103
+ num_layers=self.layers_per_block,
104
+ in_channels=input_channel,
105
+ out_channels=output_channel,
106
+ add_downsample=not is_final_block,
107
+ resnet_eps=1e-6,
108
+ downsample_padding=0,
109
+ resnet_act_fn=act_fn,
110
+ resnet_groups=norm_num_groups,
111
+ attention_head_dim=output_channel,
112
+ temb_channels=None,
113
+ )
114
+ self.down_blocks.append(down_block)
115
+
116
+ # mid
117
+ self.mid_block = UNetMidBlock2D(
118
+ in_channels=block_out_channels[-1],
119
+ resnet_eps=1e-6,
120
+ resnet_act_fn=act_fn,
121
+ output_scale_factor=1,
122
+ resnet_time_scale_shift="default",
123
+ attention_head_dim=block_out_channels[-1],
124
+ resnet_groups=norm_num_groups,
125
+ temb_channels=None,
126
+ add_attention=mid_block_add_attention,
127
+ )
128
+
129
+ # out
130
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
131
+ self.conv_act = nn.SiLU()
132
+
133
+ conv_out_channels = 2 * out_channels if double_z else out_channels
134
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
135
+
136
+ self.gradient_checkpointing = False
137
+
138
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
139
+ r"""The forward method of the `Encoder` class."""
140
+
141
+ sample = self.conv_in(sample)
142
+
143
+ if self.training and self.gradient_checkpointing:
144
+
145
+ def create_custom_forward(module):
146
+ def custom_forward(*inputs):
147
+ return module(*inputs)
148
+
149
+ return custom_forward
150
+
151
+ # down
152
+ if is_torch_version(">=", "1.11.0"):
153
+ for down_block in self.down_blocks:
154
+ sample = torch.utils.checkpoint.checkpoint(
155
+ create_custom_forward(down_block), sample, use_reentrant=False
156
+ )
157
+ # middle
158
+ sample = torch.utils.checkpoint.checkpoint(
159
+ create_custom_forward(self.mid_block), sample, use_reentrant=False
160
+ )
161
+ else:
162
+ for down_block in self.down_blocks:
163
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
164
+ # middle
165
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
166
+
167
+ else:
168
+ # down
169
+ for down_block in self.down_blocks:
170
+ sample = down_block(sample)
171
+
172
+ # middle
173
+ sample = self.mid_block(sample)
174
+
175
+ # post-process
176
+ sample = self.conv_norm_out(sample)
177
+ sample = self.conv_act(sample)
178
+ sample = self.conv_out(sample)
179
+
180
+ return sample
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ r"""
185
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
186
+ Args:
187
+ in_channels (`int`, *optional*, defaults to 3):
188
+ The number of input channels.
189
+ out_channels (`int`, *optional*, defaults to 3):
190
+ The number of output channels.
191
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
192
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
193
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
194
+ The number of output channels for each block.
195
+ layers_per_block (`int`, *optional*, defaults to 2):
196
+ The number of layers per block.
197
+ norm_num_groups (`int`, *optional*, defaults to 32):
198
+ The number of groups for normalization.
199
+ act_fn (`str`, *optional*, defaults to `"silu"`):
200
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
201
+ norm_type (`str`, *optional*, defaults to `"group"`):
202
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ in_channels: int = 3,
208
+ out_channels: int = 3,
209
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
210
+ block_out_channels: Tuple[int, ...] = (64,),
211
+ layers_per_block: int = 2,
212
+ norm_num_groups: int = 32,
213
+ act_fn: str = "silu",
214
+ norm_type: str = "group", # group, spatial
215
+ mid_block_add_attention=True,
216
+ ):
217
+ super().__init__()
218
+ self.layers_per_block = layers_per_block
219
+
220
+ self.conv_in = nn.Conv2d(
221
+ in_channels,
222
+ block_out_channels[-1],
223
+ kernel_size=3,
224
+ stride=1,
225
+ padding=1,
226
+ )
227
+
228
+ self.mid_block = None
229
+ self.up_blocks = nn.ModuleList([])
230
+
231
+ temb_channels = in_channels if norm_type == "spatial" else None
232
+
233
+ # mid
234
+ self.mid_block = UNetMidBlock2D(
235
+ in_channels=block_out_channels[-1],
236
+ resnet_eps=1e-6,
237
+ resnet_act_fn=act_fn,
238
+ output_scale_factor=1,
239
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
240
+ attention_head_dim=block_out_channels[-1],
241
+ resnet_groups=norm_num_groups,
242
+ temb_channels=temb_channels,
243
+ add_attention=mid_block_add_attention,
244
+ )
245
+
246
+ # up
247
+ reversed_block_out_channels = list(reversed(block_out_channels))
248
+ output_channel = reversed_block_out_channels[0]
249
+ for i, up_block_type in enumerate(up_block_types):
250
+ prev_output_channel = output_channel
251
+ output_channel = reversed_block_out_channels[i]
252
+
253
+ is_final_block = i == len(block_out_channels) - 1
254
+
255
+ up_block = get_up_block(
256
+ up_block_type,
257
+ num_layers=self.layers_per_block + 1,
258
+ in_channels=prev_output_channel,
259
+ out_channels=output_channel,
260
+ prev_output_channel=None,
261
+ add_upsample=not is_final_block,
262
+ resnet_eps=1e-6,
263
+ resnet_act_fn=act_fn,
264
+ resnet_groups=norm_num_groups,
265
+ attention_head_dim=output_channel,
266
+ temb_channels=temb_channels,
267
+ resnet_time_scale_shift=norm_type,
268
+ )
269
+ self.up_blocks.append(up_block)
270
+ prev_output_channel = output_channel
271
+
272
+ # out
273
+ if norm_type == "spatial":
274
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
275
+ else:
276
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
277
+ self.conv_act = nn.SiLU()
278
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
279
+
280
+ self.gradient_checkpointing = False
281
+
282
+ def forward(
283
+ self,
284
+ sample: torch.FloatTensor,
285
+ latent_embeds: Optional[torch.FloatTensor] = None,
286
+ ) -> torch.FloatTensor:
287
+ r"""The forward method of the `Decoder` class."""
288
+
289
+ sample = self.conv_in(sample)
290
+ sample = sample.to(torch.float32)
291
+
292
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
293
+
294
+ if self.training and self.gradient_checkpointing:
295
+
296
+ def create_custom_forward(module):
297
+ def custom_forward(*inputs):
298
+ return module(*inputs)
299
+
300
+ return custom_forward
301
+
302
+ if is_torch_version(">=", "1.11.0"):
303
+ # middle
304
+ sample = torch.utils.checkpoint.checkpoint(
305
+ create_custom_forward(self.mid_block),
306
+ sample,
307
+ latent_embeds,
308
+ use_reentrant=False,
309
+ )
310
+ sample = sample.to(upscale_dtype)
311
+
312
+ # up
313
+ for up_block in self.up_blocks:
314
+ sample = torch.utils.checkpoint.checkpoint(
315
+ create_custom_forward(up_block),
316
+ sample,
317
+ latent_embeds,
318
+ use_reentrant=False,
319
+ )
320
+ else:
321
+ # middle
322
+ sample = torch.utils.checkpoint.checkpoint(
323
+ create_custom_forward(self.mid_block), sample, latent_embeds
324
+ )
325
+ sample = sample.to(upscale_dtype)
326
+
327
+ # up
328
+ for up_block in self.up_blocks:
329
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
330
+ else:
331
+ # middle
332
+ sample = self.mid_block(sample, latent_embeds)
333
+ sample = sample.to(upscale_dtype)
334
+
335
+ # up
336
+ for up_block in self.up_blocks:
337
+ sample = up_block(sample, latent_embeds)
338
+
339
+ # post-process
340
+ if latent_embeds is None:
341
+ sample = self.conv_norm_out(sample)
342
+ else:
343
+ sample = self.conv_norm_out(sample, latent_embeds)
344
+ sample = self.conv_act(sample)
345
+ sample = self.conv_out(sample)
346
+
347
+ return sample
348
+
349
+
350
+ class UpSample(nn.Module):
351
+ r"""
352
+ The `UpSample` layer of a variational autoencoder that upsamples its input.
353
+ Args:
354
+ in_channels (`int`, *optional*, defaults to 3):
355
+ The number of input channels.
356
+ out_channels (`int`, *optional*, defaults to 3):
357
+ The number of output channels.
358
+ """
359
+
360
+ def __init__(
361
+ self,
362
+ in_channels: int,
363
+ out_channels: int,
364
+ ) -> None:
365
+ super().__init__()
366
+ self.in_channels = in_channels
367
+ self.out_channels = out_channels
368
+ self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
369
+
370
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
371
+ r"""The forward method of the `UpSample` class."""
372
+ x = torch.relu(x)
373
+ x = self.deconv(x)
374
+ return x
375
+
376
+
377
+ class MaskConditionEncoder(nn.Module):
378
+ """
379
+ used in AsymmetricAutoencoderKL
380
+ """
381
+
382
+ def __init__(
383
+ self,
384
+ in_ch: int,
385
+ out_ch: int = 192,
386
+ res_ch: int = 768,
387
+ stride: int = 16,
388
+ ) -> None:
389
+ super().__init__()
390
+
391
+ channels = []
392
+ while stride > 1:
393
+ stride = stride // 2
394
+ in_ch_ = out_ch * 2
395
+ if out_ch > res_ch:
396
+ out_ch = res_ch
397
+ if stride == 1:
398
+ in_ch_ = res_ch
399
+ channels.append((in_ch_, out_ch))
400
+ out_ch *= 2
401
+
402
+ out_channels = []
403
+ for _in_ch, _out_ch in channels:
404
+ out_channels.append(_out_ch)
405
+ out_channels.append(channels[-1][0])
406
+
407
+ layers = []
408
+ in_ch_ = in_ch
409
+ for l in range(len(out_channels)):
410
+ out_ch_ = out_channels[l]
411
+ if l == 0 or l == 1:
412
+ layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
413
+ else:
414
+ layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
415
+ in_ch_ = out_ch_
416
+
417
+ self.layers = nn.Sequential(*layers)
418
+
419
+ def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
420
+ r"""The forward method of the `MaskConditionEncoder` class."""
421
+ out = {}
422
+ for l in range(len(self.layers)):
423
+ layer = self.layers[l]
424
+ x = layer(x)
425
+ out[str(tuple(x.shape))] = x
426
+ x = torch.relu(x)
427
+ return out
428
+
429
+
430
+ class MaskConditionDecoder(nn.Module):
431
+ r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
432
+ decoder with a conditioner on the mask and masked image.
433
+ Args:
434
+ in_channels (`int`, *optional*, defaults to 3):
435
+ The number of input channels.
436
+ out_channels (`int`, *optional*, defaults to 3):
437
+ The number of output channels.
438
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
439
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
440
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
441
+ The number of output channels for each block.
442
+ layers_per_block (`int`, *optional*, defaults to 2):
443
+ The number of layers per block.
444
+ norm_num_groups (`int`, *optional*, defaults to 32):
445
+ The number of groups for normalization.
446
+ act_fn (`str`, *optional*, defaults to `"silu"`):
447
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
448
+ norm_type (`str`, *optional*, defaults to `"group"`):
449
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
450
+ """
451
+
452
+ def __init__(
453
+ self,
454
+ in_channels: int = 3,
455
+ out_channels: int = 3,
456
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
457
+ block_out_channels: Tuple[int, ...] = (64,),
458
+ layers_per_block: int = 2,
459
+ norm_num_groups: int = 32,
460
+ act_fn: str = "silu",
461
+ norm_type: str = "group", # group, spatial
462
+ ):
463
+ super().__init__()
464
+ self.layers_per_block = layers_per_block
465
+
466
+ self.conv_in = nn.Conv2d(
467
+ in_channels,
468
+ block_out_channels[-1],
469
+ kernel_size=3,
470
+ stride=1,
471
+ padding=1,
472
+ )
473
+
474
+ self.mid_block = None
475
+ self.up_blocks = nn.ModuleList([])
476
+
477
+ temb_channels = in_channels if norm_type == "spatial" else None
478
+
479
+ # mid
480
+ self.mid_block = UNetMidBlock2D(
481
+ in_channels=block_out_channels[-1],
482
+ resnet_eps=1e-6,
483
+ resnet_act_fn=act_fn,
484
+ output_scale_factor=1,
485
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
486
+ attention_head_dim=block_out_channels[-1],
487
+ resnet_groups=norm_num_groups,
488
+ temb_channels=temb_channels,
489
+ )
490
+
491
+ # up
492
+ reversed_block_out_channels = list(reversed(block_out_channels))
493
+ output_channel = reversed_block_out_channels[0]
494
+ for i, up_block_type in enumerate(up_block_types):
495
+ prev_output_channel = output_channel
496
+ output_channel = reversed_block_out_channels[i]
497
+
498
+ is_final_block = i == len(block_out_channels) - 1
499
+
500
+ up_block = get_up_block(
501
+ up_block_type,
502
+ num_layers=self.layers_per_block + 1,
503
+ in_channels=prev_output_channel,
504
+ out_channels=output_channel,
505
+ prev_output_channel=None,
506
+ add_upsample=not is_final_block,
507
+ resnet_eps=1e-6,
508
+ resnet_act_fn=act_fn,
509
+ resnet_groups=norm_num_groups,
510
+ attention_head_dim=output_channel,
511
+ temb_channels=temb_channels,
512
+ resnet_time_scale_shift=norm_type,
513
+ )
514
+ self.up_blocks.append(up_block)
515
+ prev_output_channel = output_channel
516
+
517
+ # condition encoder
518
+ self.condition_encoder = MaskConditionEncoder(
519
+ in_ch=out_channels,
520
+ out_ch=block_out_channels[0],
521
+ res_ch=block_out_channels[-1],
522
+ )
523
+
524
+ # out
525
+ if norm_type == "spatial":
526
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
527
+ else:
528
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
529
+ self.conv_act = nn.SiLU()
530
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
531
+
532
+ self.gradient_checkpointing = False
533
+
534
+ def forward(
535
+ self,
536
+ z: torch.FloatTensor,
537
+ image: Optional[torch.FloatTensor] = None,
538
+ mask: Optional[torch.FloatTensor] = None,
539
+ latent_embeds: Optional[torch.FloatTensor] = None,
540
+ ) -> torch.FloatTensor:
541
+ r"""The forward method of the `MaskConditionDecoder` class."""
542
+ sample = z
543
+ sample = self.conv_in(sample)
544
+
545
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
546
+ if self.training and self.gradient_checkpointing:
547
+
548
+ def create_custom_forward(module):
549
+ def custom_forward(*inputs):
550
+ return module(*inputs)
551
+
552
+ return custom_forward
553
+
554
+ if is_torch_version(">=", "1.11.0"):
555
+ # middle
556
+ sample = torch.utils.checkpoint.checkpoint(
557
+ create_custom_forward(self.mid_block),
558
+ sample,
559
+ latent_embeds,
560
+ use_reentrant=False,
561
+ )
562
+ sample = sample.to(upscale_dtype)
563
+
564
+ # condition encoder
565
+ if image is not None and mask is not None:
566
+ masked_image = (1 - mask) * image
567
+ im_x = torch.utils.checkpoint.checkpoint(
568
+ create_custom_forward(self.condition_encoder),
569
+ masked_image,
570
+ mask,
571
+ use_reentrant=False,
572
+ )
573
+
574
+ # up
575
+ for up_block in self.up_blocks:
576
+ if image is not None and mask is not None:
577
+ sample_ = im_x[str(tuple(sample.shape))]
578
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
579
+ sample = sample * mask_ + sample_ * (1 - mask_)
580
+ sample = torch.utils.checkpoint.checkpoint(
581
+ create_custom_forward(up_block),
582
+ sample,
583
+ latent_embeds,
584
+ use_reentrant=False,
585
+ )
586
+ if image is not None and mask is not None:
587
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
588
+ else:
589
+ # middle
590
+ sample = torch.utils.checkpoint.checkpoint(
591
+ create_custom_forward(self.mid_block), sample, latent_embeds
592
+ )
593
+ sample = sample.to(upscale_dtype)
594
+
595
+ # condition encoder
596
+ if image is not None and mask is not None:
597
+ masked_image = (1 - mask) * image
598
+ im_x = torch.utils.checkpoint.checkpoint(
599
+ create_custom_forward(self.condition_encoder),
600
+ masked_image,
601
+ mask,
602
+ )
603
+
604
+ # up
605
+ for up_block in self.up_blocks:
606
+ if image is not None and mask is not None:
607
+ sample_ = im_x[str(tuple(sample.shape))]
608
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
609
+ sample = sample * mask_ + sample_ * (1 - mask_)
610
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
611
+ if image is not None and mask is not None:
612
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
613
+ else:
614
+ # middle
615
+ sample = self.mid_block(sample, latent_embeds)
616
+ sample = sample.to(upscale_dtype)
617
+
618
+ # condition encoder
619
+ if image is not None and mask is not None:
620
+ masked_image = (1 - mask) * image
621
+ im_x = self.condition_encoder(masked_image, mask)
622
+
623
+ # up
624
+ for up_block in self.up_blocks:
625
+ if image is not None and mask is not None:
626
+ sample_ = im_x[str(tuple(sample.shape))]
627
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
628
+ sample = sample * mask_ + sample_ * (1 - mask_)
629
+ sample = up_block(sample, latent_embeds)
630
+ if image is not None and mask is not None:
631
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
632
+
633
+ # post-process
634
+ if latent_embeds is None:
635
+ sample = self.conv_norm_out(sample)
636
+ else:
637
+ sample = self.conv_norm_out(sample, latent_embeds)
638
+ sample = self.conv_act(sample)
639
+ sample = self.conv_out(sample)
640
+
641
+ return sample
642
+
643
+
644
+ class VectorQuantizer(nn.Module):
645
+ """
646
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
647
+ multiplications and allows for post-hoc remapping of indices.
648
+ """
649
+
650
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
651
+ # backwards compatibility we use the buggy version by default, but you can
652
+ # specify legacy=False to fix it.
653
+ def __init__(
654
+ self,
655
+ n_e: int,
656
+ vq_embed_dim: int,
657
+ beta: float,
658
+ remap=None,
659
+ unknown_index: str = "random",
660
+ sane_index_shape: bool = False,
661
+ legacy: bool = True,
662
+ ):
663
+ super().__init__()
664
+ self.n_e = n_e
665
+ self.vq_embed_dim = vq_embed_dim
666
+ self.beta = beta
667
+ self.legacy = legacy
668
+
669
+ self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
670
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
671
+
672
+ self.remap = remap
673
+ if self.remap is not None:
674
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
675
+ self.used: torch.Tensor
676
+ self.re_embed = self.used.shape[0]
677
+ self.unknown_index = unknown_index # "random" or "extra" or integer
678
+ if self.unknown_index == "extra":
679
+ self.unknown_index = self.re_embed
680
+ self.re_embed = self.re_embed + 1
681
+ print(
682
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
683
+ f"Using {self.unknown_index} for unknown indices."
684
+ )
685
+ else:
686
+ self.re_embed = n_e
687
+
688
+ self.sane_index_shape = sane_index_shape
689
+
690
+ def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
691
+ ishape = inds.shape
692
+ assert len(ishape) > 1
693
+ inds = inds.reshape(ishape[0], -1)
694
+ used = self.used.to(inds)
695
+ match = (inds[:, :, None] == used[None, None, ...]).long()
696
+ new = match.argmax(-1)
697
+ unknown = match.sum(2) < 1
698
+ if self.unknown_index == "random":
699
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
700
+ else:
701
+ new[unknown] = self.unknown_index
702
+ return new.reshape(ishape)
703
+
704
+ def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
705
+ ishape = inds.shape
706
+ assert len(ishape) > 1
707
+ inds = inds.reshape(ishape[0], -1)
708
+ used = self.used.to(inds)
709
+ if self.re_embed > self.used.shape[0]: # extra token
710
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
711
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
712
+ return back.reshape(ishape)
713
+
714
+ def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
715
+ # reshape z -> (batch, height, width, channel) and flatten
716
+ z = z.permute(0, 2, 3, 1).contiguous()
717
+ z_flattened = z.view(-1, self.vq_embed_dim)
718
+
719
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
720
+ min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
721
+
722
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
723
+ perplexity = None
724
+ min_encodings = None
725
+
726
+ # compute loss for embedding
727
+ if not self.legacy:
728
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
729
+ else:
730
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
731
+
732
+ # preserve gradients
733
+ z_q: torch.FloatTensor = z + (z_q - z).detach()
734
+
735
+ # reshape back to match original input shape
736
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
737
+
738
+ if self.remap is not None:
739
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
740
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
741
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
742
+
743
+ if self.sane_index_shape:
744
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
745
+
746
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
747
+
748
+ def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
749
+ # shape specifying (batch, height, width, channel)
750
+ if self.remap is not None:
751
+ indices = indices.reshape(shape[0], -1) # add batch axis
752
+ indices = self.unmap_to_all(indices)
753
+ indices = indices.reshape(-1) # flatten again
754
+
755
+ # get quantized latent vectors
756
+ z_q: torch.FloatTensor = self.embedding(indices)
757
+
758
+ if shape is not None:
759
+ z_q = z_q.view(shape)
760
+ # reshape back to match original input shape
761
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
762
+
763
+ return z_q
764
+
765
+
766
+ class DiagonalGaussianDistribution(object):
767
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
768
+ self.parameters = parameters
769
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
770
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
771
+ self.deterministic = deterministic
772
+ self.std = torch.exp(0.5 * self.logvar)
773
+ self.var = torch.exp(self.logvar)
774
+ if self.deterministic:
775
+ self.var = self.std = torch.zeros_like(
776
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
777
+ )
778
+
779
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
780
+ # make sure sample is on the same device as the parameters and has same dtype
781
+ sample = randn_tensor(
782
+ self.mean.shape,
783
+ generator=generator,
784
+ device=self.parameters.device,
785
+ dtype=self.parameters.dtype,
786
+ )
787
+ x = self.mean + self.std * sample
788
+ return x
789
+
790
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
791
+ if self.deterministic:
792
+ return torch.Tensor([0.0])
793
+ else:
794
+ if other is None:
795
+ return 0.5 * torch.sum(
796
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
797
+ dim=[1, 2, 3],
798
+ )
799
+ else:
800
+ return 0.5 * torch.sum(
801
+ torch.pow(self.mean - other.mean, 2) / other.var
802
+ + self.var / other.var
803
+ - 1.0
804
+ - self.logvar
805
+ + other.logvar,
806
+ dim=[1, 2, 3],
807
+ )
808
+
809
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
810
+ if self.deterministic:
811
+ return torch.Tensor([0.0])
812
+ logtwopi = np.log(2.0 * np.pi)
813
+ return 0.5 * torch.sum(
814
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
815
+ dim=dims,
816
+ )
817
+
818
+ def mode(self) -> torch.Tensor:
819
+ return self.mean
820
+
821
+
822
+ class EncoderTiny(nn.Module):
823
+ r"""
824
+ The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
825
+ Args:
826
+ in_channels (`int`):
827
+ The number of input channels.
828
+ out_channels (`int`):
829
+ The number of output channels.
830
+ num_blocks (`Tuple[int, ...]`):
831
+ Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
832
+ use.
833
+ block_out_channels (`Tuple[int, ...]`):
834
+ The number of output channels for each block.
835
+ act_fn (`str`):
836
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
837
+ """
838
+
839
+ def __init__(
840
+ self,
841
+ in_channels: int,
842
+ out_channels: int,
843
+ num_blocks: Tuple[int, ...],
844
+ block_out_channels: Tuple[int, ...],
845
+ act_fn: str,
846
+ ):
847
+ super().__init__()
848
+
849
+ layers = []
850
+ for i, num_block in enumerate(num_blocks):
851
+ num_channels = block_out_channels[i]
852
+
853
+ if i == 0:
854
+ layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
855
+ else:
856
+ layers.append(
857
+ nn.Conv2d(
858
+ num_channels,
859
+ num_channels,
860
+ kernel_size=3,
861
+ padding=1,
862
+ stride=2,
863
+ bias=False,
864
+ )
865
+ )
866
+
867
+ for _ in range(num_block):
868
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
869
+
870
+ layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
871
+
872
+ self.layers = nn.Sequential(*layers)
873
+ self.gradient_checkpointing = False
874
+
875
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
876
+ r"""The forward method of the `EncoderTiny` class."""
877
+ if self.training and self.gradient_checkpointing:
878
+
879
+ def create_custom_forward(module):
880
+ def custom_forward(*inputs):
881
+ return module(*inputs)
882
+
883
+ return custom_forward
884
+
885
+ if is_torch_version(">=", "1.11.0"):
886
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
887
+ else:
888
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
889
+
890
+ else:
891
+ # scale image from [-1, 1] to [0, 1] to match TAESD convention
892
+ x = self.layers(x.add(1).div(2))
893
+
894
+ return x
895
+
896
+
897
+ class DecoderTiny(nn.Module):
898
+ r"""
899
+ The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
900
+ Args:
901
+ in_channels (`int`):
902
+ The number of input channels.
903
+ out_channels (`int`):
904
+ The number of output channels.
905
+ num_blocks (`Tuple[int, ...]`):
906
+ Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
907
+ use.
908
+ block_out_channels (`Tuple[int, ...]`):
909
+ The number of output channels for each block.
910
+ upsampling_scaling_factor (`int`):
911
+ The scaling factor to use for upsampling.
912
+ act_fn (`str`):
913
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
914
+ """
915
+
916
+ def __init__(
917
+ self,
918
+ in_channels: int,
919
+ out_channels: int,
920
+ num_blocks: Tuple[int, ...],
921
+ block_out_channels: Tuple[int, ...],
922
+ upsampling_scaling_factor: int,
923
+ act_fn: str,
924
+ ):
925
+ super().__init__()
926
+
927
+ layers = [
928
+ nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
929
+ get_activation(act_fn),
930
+ ]
931
+
932
+ for i, num_block in enumerate(num_blocks):
933
+ is_final_block = i == (len(num_blocks) - 1)
934
+ num_channels = block_out_channels[i]
935
+
936
+ for _ in range(num_block):
937
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
938
+
939
+ if not is_final_block:
940
+ layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
941
+
942
+ conv_out_channel = num_channels if not is_final_block else out_channels
943
+ layers.append(
944
+ nn.Conv2d(
945
+ num_channels,
946
+ conv_out_channel,
947
+ kernel_size=3,
948
+ padding=1,
949
+ bias=is_final_block,
950
+ )
951
+ )
952
+
953
+ self.layers = nn.Sequential(*layers)
954
+ self.gradient_checkpointing = False
955
+
956
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
957
+ r"""The forward method of the `DecoderTiny` class."""
958
+ # Clamp.
959
+ x = torch.tanh(x / 3) * 3
960
+
961
+ if self.training and self.gradient_checkpointing:
962
+
963
+ def create_custom_forward(module):
964
+ def custom_forward(*inputs):
965
+ return module(*inputs)
966
+
967
+ return custom_forward
968
+
969
+ if is_torch_version(">=", "1.11.0"):
970
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
971
+ else:
972
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
973
+
974
+ else:
975
+ x = self.layers(x)
976
+
977
+ # scale image from [0, 1] to [-1, 1] to match diffusers convention
978
+ return x.mul(2).sub(1)