DeciDiffusion-v2-0 / pipeline.py
Borys Tymchenko
Initial commit
ae2e28c
raw
history blame
44.6 kB
import itertools
from functools import partial
from typing import Any, Dict, Tuple, Callable
from typing import Union, Optional, List
import numpy as np
import torch
from diffusers import DPMSolverMultistepScheduler
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers import Transformer2DModel, ModelMixin, ConfigMixin
from diffusers import UNet2DConditionModel
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModelOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import replace_example_docstring
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
def custom_sort_order(obj):
"""
Key function for sorting order of execution in forward methods
"""
return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)
def squeeze_to_len_n_starting_from_index_i(n, i, timestep_spacing):
"""
:param timestep_spacing: the timestep_spacing array we want to squeeze
:param n: the size of the squeezed array
:param i: the index we start squeezing from
:return: squeezed timestep_spacing
Example:
timesteps = np.array([967, 907, 846, 786, 725, 665, 604, 544, 484, 423, 363, 302, 242, 181, 121, 60]) (len=16)
n = 10, i = 6
Expected:
[967, 907, 846, 786, 725, 665, 4k, 3k, 2k, k], and if we define 665=5k => k = 133
"""
assert i < n
squeezed = np.flip(np.arange(n)) + 1 # [n, n-1, ..., 2, 1]
squeezed[:i] = timestep_spacing[:i]
k = squeezed[i - 1] // (n - i + 1)
squeezed[i:] *= k
return squeezed
PREDEFINED_TIMESTEP_SQUEEZERS = {
# Tested with DPM 16-steps (reduced 16 -> 10 or 11 steps)
"10,6": partial(squeeze_to_len_n_starting_from_index_i, 10, 6),
"11,7": partial(squeeze_to_len_n_starting_from_index_i, 11, 7),
}
FlexibleUnetConfigurations = {
# General parameters for all blocks
"sample_size": 64,
"temb_dim": 320 * 4,
"resnet_eps": 1e-5,
"resnet_act_fn": "silu",
"num_attention_heads": 8,
"cross_attention_dim": 768,
# Controls modules execute order in unet's forward
"mix_block_in_forward": True,
# Down blocks parameters
"down_blocks_in_channels": [320, 320, 640],
"down_blocks_out_channels": [320, 640, 1280],
"down_blocks_num_attentions": [0, 1, 3],
"down_blocks_num_resnets": [2, 2, 1],
"add_downsample": [True, True, False],
# Middle block parameters
"add_upsample_mid_block": None,
"mid_num_resnets": 0,
"mid_num_attentions": 0,
# Up block parameters
"prev_output_channels": [1280, 1280, 640],
"up_blocks_num_attentions": [5, 3, 0],
"up_blocks_num_resnets": [2, 3, 3],
"add_upsample": [True, True, False],
}
class SqueezedDPMSolverMultistepScheduler(DPMSolverMultistepScheduler):
"""
This is a copy-paste from Diffuser's `DPMSolverMultistepScheduler`, with minor differences:
* Defaults are modified to accommodate DeciDiffusion
* It supports a squeezer to squeeze the number of inference steps to a smaller number
//!\\ IMPORTANT: the actual number of inference steps is deduced by the squeezer, and not the pipeline!
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "squaredcos_cap_v2", # NOTE THIS DEFAULT VALUE
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "v_prediction", # NOTE THIS DEFAULT VALUE
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "heun", # NOTE THIS DEFAULT VALUE
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -3.0, # NOTE THIS DEFAULT VALUE
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 1,
squeeze_mode: Optional[str] = None, # NOTE THIS ADDITION. Supports keys from `PREDEFINED_TIMESTEP_SQUEEZERS` defined above
):
self._squeezer = PREDEFINED_TIMESTEP_SQUEEZERS.get(squeeze_mode)
if use_karras_sigmas:
raise NotImplementedError("Squeezing isn't tested with `use_karras_sigmas`. Please provide `use_karras_sigmas=False`")
super().__init__(
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
trained_betas=trained_betas,
solver_order=solver_order,
prediction_type=prediction_type,
thresholding=thresholding,
dynamic_thresholding_ratio=dynamic_thresholding_ratio,
sample_max_value=sample_max_value,
algorithm_type=algorithm_type,
solver_type=solver_type,
lower_order_final=lower_order_final,
use_karras_sigmas=False,
lambda_min_clipped=lambda_min_clipped,
variance_type=variance_type,
timestep_spacing=timestep_spacing,
steps_offset=steps_offset,
)
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
super().set_timesteps(num_inference_steps=num_inference_steps, device=device)
if self._squeezer is not None:
timesteps = self._squeezer(self.timesteps.cpu())
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
self.num_inference_steps = len(timesteps)
class FlexibleIdentityBlock(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
return hidden_states
class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
configurations = FlexibleUnetConfigurations
@register_to_config
def __init__(self):
super().__init__(
sample_size=self.configurations.get("sample_size", FlexibleUnetConfigurations["sample_size"]),
cross_attention_dim=self.configurations.get("cross_attention_dim", FlexibleUnetConfigurations["cross_attention_dim"]),
)
num_attention_heads = self.configurations.get("num_attention_heads")
cross_attention_dim = self.configurations.get("cross_attention_dim")
mix_block_in_forward = self.configurations.get("mix_block_in_forward")
resnet_act_fn = self.configurations.get("resnet_act_fn")
resnet_eps = self.configurations.get("resnet_eps")
temb_dim = self.configurations.get("temb_dim")
###############
# Down blocks #
###############
down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions")
down_blocks_out_channels = self.configurations.get("down_blocks_out_channels")
down_blocks_in_channels = self.configurations.get("down_blocks_in_channels")
down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets")
add_downsample = self.configurations.get("add_downsample")
self.down_blocks = nn.ModuleList()
for i, (in_c, out_c, n_res, n_att, add_down) in enumerate(
zip(down_blocks_in_channels, down_blocks_out_channels, down_blocks_num_resnets, down_blocks_num_attentions, add_downsample)
):
last_block = i == len(down_blocks_in_channels) - 1
self.down_blocks.append(
FlexibleCrossAttnDownBlock2D(
in_channels=in_c,
out_channels=out_c,
temb_channels=temb_dim,
num_resnets=n_res,
num_attentions=n_att,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
num_attention_heads=num_attention_heads,
cross_attention_dim=cross_attention_dim,
add_downsample=add_down,
last_block=last_block,
mix_block_in_forward=mix_block_in_forward,
)
)
###############
# Mid blocks #
###############
mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
mid_num_attentions = self.configurations.get("mid_num_attentions")
mid_num_resnets = self.configurations.get("mid_num_resnets")
if mid_num_resnets == mid_num_attentions == 0:
self.mid_block = FlexibleIdentityBlock()
else:
self.mid_block = FlexibleUNetMidBlock2DCrossAttn(
in_channels=down_blocks_out_channels[-1],
temb_channels=temb_dim,
resnet_act_fn=resnet_act_fn,
resnet_eps=resnet_eps,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
num_resnets=mid_num_resnets,
num_attentions=mid_num_attentions,
mix_block_in_forward=mix_block_in_forward,
add_upsample=mid_block_add_upsample,
)
###############
# Up blocks #
###############
up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions")
up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets")
prev_output_channels = self.configurations.get("prev_output_channels")
up_upsample = self.configurations.get("add_upsample")
self.up_blocks = nn.ModuleList()
for in_c, out_c, prev_out, n_res, n_att, add_up in zip(
reversed(down_blocks_in_channels),
reversed(down_blocks_out_channels),
prev_output_channels,
up_blocks_num_resnets,
up_blocks_num_attentions,
up_upsample,
):
self.up_blocks.append(
FlexibleCrossAttnUpBlock2D(
in_channels=in_c,
out_channels=out_c,
prev_output_channel=prev_out,
temb_channels=temb_dim,
num_resnets=n_res,
num_attentions=n_att,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
num_attention_heads=num_attention_heads,
cross_attention_dim=cross_attention_dim,
add_upsample=add_up,
mix_block_in_forward=mix_block_in_forward,
)
)
class FlexibleCrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_resnets: int = 1,
num_attentions: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
add_downsample: bool = True,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
last_block: bool = False,
mix_block_in_forward: bool = True,
):
super().__init__()
self.last_block = last_block
self.mix_block_in_forward = mix_block_in_forward
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
modules = []
add_resnets = [True] * num_resnets
add_cross_attentions = [True] * num_attentions
for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
in_channels = in_channels if i == 0 else out_channels
if add_resnet:
modules.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if add_cross_attention:
modules.append(
FlexibleTransformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
if not mix_block_in_forward:
modules = sorted(modules, key=custom_sort_order)
self.modules_list = nn.ModuleList(modules)
if add_downsample:
self.downsamplers = nn.ModuleList([Downsample2D(out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op")])
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
output_states = ()
for module in self.modules_list:
if isinstance(module, ResnetBlock2D):
hidden_states = module(hidden_states, temb)
elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
hidden_states = module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
raise ValueError(f"Got an unexpected module in modules list! {type(module)}")
if isinstance(module, ResnetBlock2D):
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
if not self.last_block:
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class FlexibleCrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_resnets: int = 1,
num_attentions: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
mix_block_in_forward: bool = True,
):
super().__init__()
modules = []
# WARNING: This parameter is filled with number of resnets and used within StableDiffusionPipeline
self.resnets = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
add_resnets = [True] * num_resnets
add_cross_attentions = [True] * num_attentions
for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
if add_resnet:
self.resnets += [True]
modules.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if add_cross_attention:
modules.append(
FlexibleTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
if not mix_block_in_forward:
modules = sorted(modules, key=custom_sort_order)
self.modules_list = nn.ModuleList(modules)
self.upsamplers = None
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
for module in self.modules_list:
if isinstance(module, ResnetBlock2D):
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = module(hidden_states, temb)
if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
hidden_states = module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class FlexibleUNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_resnets: int = 1,
num_attentions: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
use_linear_projection: bool = False,
upcast_attention: bool = False,
mix_block_in_forward: bool = True,
add_upsample: bool = True,
):
super().__init__()
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# There is always at least one resnet
modules = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
add_resnets = [True] * num_resnets
add_cross_attentions = [True] * num_attentions
for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
if add_cross_attention:
modules.append(
FlexibleTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
)
if add_resnet:
modules.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not mix_block_in_forward:
modules = sorted(modules, key=custom_sort_order)
self.modules_list = nn.ModuleList(modules)
self.upsamplers = nn.ModuleList([nn.Identity()])
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states = self.modules_list[0](hidden_states, temb)
for module in self.modules_list:
if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
hidden_states = module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
elif isinstance(module, ResnetBlock2D):
hidden_states = module(hidden_states, temb)
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class FlexibleTransformer2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
only_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.in_channels = in_channels
inner_dim = num_attention_heads * attention_head_dim
# Define input layers
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.use_linear_projection = use_linear_projection
if self.use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
# Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
)
for _ in range(num_layers)
]
)
# Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = False,
):
# 1. Input
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
if return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
class DeciDiffusionPipeline(StableDiffusionPipeline):
deci_default_squeeze_mode = "10,6"
deci_default_number_of_iterations = 16
deci_default_guidance_rescale = 0.7
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
# Replace UNet with Deci`s unet
del unet
unet = FlexibleUNet2DConditionModel()
# Replace with custom scheduler
del scheduler
scheduler = SqueezedDPMSolverMultistepScheduler(squeeze_mode=self.deci_default_squeeze_mode)
super().__init__(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 16,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.7,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=len(timesteps)) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)