|
import math |
|
import inspect |
|
import numpy as np |
|
from typing import Any, Dict, Optional, Tuple, Union, List, Callable |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
|
|
from diffusers.models.attention import _chunked_feed_forward |
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput |
|
from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput |
|
from diffusers.pipelines.flux.pipeline_flux import ( |
|
retrieve_timesteps, |
|
replace_example_docstring, |
|
EXAMPLE_DOC_STRING, |
|
calculate_shift, |
|
XLA_AVAILABLE, |
|
FluxPipelineOutput |
|
) |
|
|
|
from diffusers.utils import ( |
|
deprecate, |
|
BaseOutput, |
|
is_torch_version, |
|
logging, |
|
USE_PEFT_BACKEND, |
|
scale_lora_layers, |
|
unscale_lora_layers, |
|
) |
|
from diffusers.models.attention_processor import ( |
|
Attention, |
|
AttnProcessor, |
|
AttnProcessor2_0, |
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
attn_maps = {} |
|
|
|
|
|
@torch.no_grad() |
|
@replace_example_docstring(EXAMPLE_DOC_STRING) |
|
def FluxPipeline_call( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
prompt_2: Optional[Union[str, List[str]]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 28, |
|
timesteps: List[int] = None, |
|
guidance_scale: float = 3.5, |
|
num_images_per_prompt: Optional[int] = 1, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
max_sequence_length: int = 512, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
|
instead. |
|
prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
|
will be used instead |
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The height in pixels of the generated image. This is set to 1024 by default for the best results. |
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The width in pixels of the generated image. This is set to 1024 by default for the best results. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument |
|
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is |
|
passed will be used. Must be in descending order. |
|
guidance_scale (`float`, *optional*, defaults to 7.0): |
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen |
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
|
usually at the expense of lower image quality. |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
|
to make generation deterministic. |
|
latents (`torch.FloatTensor`, *optional*): |
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor will ge generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
|
If not provided, pooled text embeddings will be generated from `prompt` input argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. |
|
joint_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
callback_on_step_end (`Callable`, *optional*): |
|
A function that calls at the end of each denoising steps during the inference. The function is called |
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, |
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by |
|
`callback_on_step_end_tensor_inputs`. |
|
callback_on_step_end_tensor_inputs (`List`, *optional*): |
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list |
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the |
|
`._callback_tensor_inputs` attribute of your pipeline class. |
|
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. |
|
|
|
Examples: |
|
|
|
Returns: |
|
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` |
|
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated |
|
images. |
|
""" |
|
|
|
height = height or self.default_sample_size * self.vae_scale_factor |
|
width = width or self.default_sample_size * self.vae_scale_factor |
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
prompt_2, |
|
height, |
|
width, |
|
prompt_embeds=prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
|
max_sequence_length=max_sequence_length, |
|
) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._joint_attention_kwargs = joint_attention_kwargs |
|
self._interrupt = False |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
lora_scale = ( |
|
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None |
|
) |
|
( |
|
prompt_embeds, |
|
pooled_prompt_embeds, |
|
text_ids, |
|
) = self.encode_prompt( |
|
prompt=prompt, |
|
prompt_2=prompt_2, |
|
prompt_embeds=prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
device=device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
lora_scale=lora_scale, |
|
) |
|
|
|
|
|
num_channels_latents = self.transformer.config.in_channels // 4 |
|
latents, latent_image_ids = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) |
|
image_seq_len = latents.shape[1] |
|
mu = calculate_shift( |
|
image_seq_len, |
|
self.scheduler.config.base_image_seq_len, |
|
self.scheduler.config.max_image_seq_len, |
|
self.scheduler.config.base_shift, |
|
self.scheduler.config.max_shift, |
|
) |
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
self.scheduler, |
|
num_inference_steps, |
|
device, |
|
timesteps, |
|
sigmas, |
|
mu=mu, |
|
) |
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
self._num_timesteps = len(timesteps) |
|
|
|
|
|
if self.transformer.config.guidance_embeds: |
|
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) |
|
guidance = guidance.expand(latents.shape[0]) |
|
else: |
|
guidance = None |
|
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
if self.interrupt: |
|
continue |
|
|
|
|
|
timestep = t.expand(latents.shape[0]).to(latents.dtype) |
|
|
|
noise_pred = self.transformer( |
|
hidden_states=latents, |
|
timestep=timestep / 1000, |
|
guidance=guidance, |
|
pooled_projections=pooled_prompt_embeds, |
|
encoder_hidden_states=prompt_embeds, |
|
txt_ids=text_ids, |
|
img_ids=latent_image_ids, |
|
joint_attention_kwargs=self.joint_attention_kwargs, |
|
return_dict=False, |
|
|
|
height=height, |
|
|
|
)[0] |
|
|
|
|
|
latents_dtype = latents.dtype |
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
|
|
|
if latents.dtype != latents_dtype: |
|
if torch.backends.mps.is_available(): |
|
|
|
latents = latents.to(latents_dtype) |
|
|
|
if callback_on_step_end is not None: |
|
callback_kwargs = {} |
|
for k in callback_on_step_end_tensor_inputs: |
|
callback_kwargs[k] = locals()[k] |
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
|
latents = callback_outputs.pop("latents", latents) |
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
|
|
if XLA_AVAILABLE: |
|
xm.mark_step() |
|
|
|
if output_type == "latent": |
|
image = latents |
|
|
|
else: |
|
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) |
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
|
image = self.vae.decode(latents, return_dict=False)[0] |
|
image = self.image_processor.postprocess(image, output_type=output_type) |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return FluxPipelineOutput(images=image) |
|
|
|
|
|
def UNet2DConditionModelForward( |
|
self, |
|
sample: torch.Tensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
class_labels: Optional[torch.Tensor] = None, |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
|
mid_block_additional_residual: Optional[torch.Tensor] = None, |
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[UNet2DConditionOutput, Tuple]: |
|
r""" |
|
The [`UNet2DConditionModel`] forward method. |
|
|
|
Args: |
|
sample (`torch.Tensor`): |
|
The noisy input tensor with the following shape `(batch, channel, height, width)`. |
|
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. |
|
encoder_hidden_states (`torch.Tensor`): |
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. |
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`): |
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. |
|
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): |
|
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed |
|
through the `self.time_embedding` layer to obtain the timestep embeddings. |
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): |
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask |
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large |
|
negative values to the attention scores corresponding to "discard" tokens. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
added_cond_kwargs: (`dict`, *optional*): |
|
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that |
|
are passed along to the UNet blocks. |
|
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): |
|
A tuple of tensors that if specified are added to the residuals of down unet blocks. |
|
mid_block_additional_residual: (`torch.Tensor`, *optional*): |
|
A tensor that if specified is added to the residual of the middle unet block. |
|
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): |
|
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) |
|
encoder_attention_mask (`torch.Tensor`): |
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If |
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, |
|
which adds large negative values to the attention scores corresponding to "discard" tokens. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: |
|
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, |
|
otherwise a `tuple` is returned where the first element is the sample tensor. |
|
""" |
|
|
|
|
|
|
|
|
|
default_overall_up_factor = 2**self.num_upsamplers |
|
|
|
|
|
forward_upsample_size = False |
|
upsample_size = None |
|
|
|
for dim in sample.shape[-2:]: |
|
if dim % default_overall_up_factor != 0: |
|
|
|
forward_upsample_size = True |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
if encoder_attention_mask is not None: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
|
|
|
if self.config.center_input_sample: |
|
sample = 2 * sample - 1.0 |
|
|
|
|
|
t_emb = self.get_time_embed(sample=sample, timestep=timestep) |
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
aug_emb = None |
|
|
|
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) |
|
if class_emb is not None: |
|
if self.config.class_embeddings_concat: |
|
emb = torch.cat([emb, class_emb], dim=-1) |
|
else: |
|
emb = emb + class_emb |
|
|
|
aug_emb = self.get_aug_embed( |
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
if self.config.addition_embed_type == "image_hint": |
|
aug_emb, hint = aug_emb |
|
sample = torch.cat([sample, hint], dim=1) |
|
|
|
emb = emb + aug_emb if aug_emb is not None else emb |
|
|
|
if self.time_embed_act is not None: |
|
emb = self.time_embed_act(emb) |
|
|
|
encoder_hidden_states = self.process_encoder_hidden_states( |
|
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
|
|
|
|
sample = self.conv_in(sample) |
|
|
|
|
|
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: |
|
cross_attention_kwargs = cross_attention_kwargs.copy() |
|
gligen_args = cross_attention_kwargs.pop("gligen") |
|
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} |
|
|
|
|
|
|
|
|
|
|
|
if cross_attention_kwargs is None: |
|
cross_attention_kwargs = {'timestep' : timestep} |
|
else: |
|
cross_attention_kwargs['timestep'] = timestep |
|
|
|
|
|
|
|
if cross_attention_kwargs is not None: |
|
cross_attention_kwargs = cross_attention_kwargs.copy() |
|
lora_scale = cross_attention_kwargs.pop("scale", 1.0) |
|
else: |
|
lora_scale = 1.0 |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
scale_lora_layers(self, lora_scale) |
|
|
|
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None |
|
|
|
is_adapter = down_intrablock_additional_residuals is not None |
|
|
|
|
|
|
|
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: |
|
deprecate( |
|
"T2I should not use down_block_additional_residuals", |
|
"1.3.0", |
|
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ |
|
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ |
|
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", |
|
standard_warn=False, |
|
) |
|
down_intrablock_additional_residuals = down_block_additional_residuals |
|
is_adapter = True |
|
|
|
down_block_res_samples = (sample,) |
|
for downsample_block in self.down_blocks: |
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: |
|
|
|
additional_residuals = {} |
|
if is_adapter and len(down_intrablock_additional_residuals) > 0: |
|
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) |
|
|
|
sample, res_samples = downsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
encoder_attention_mask=encoder_attention_mask, |
|
**additional_residuals, |
|
) |
|
else: |
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
if is_adapter and len(down_intrablock_additional_residuals) > 0: |
|
sample += down_intrablock_additional_residuals.pop(0) |
|
|
|
down_block_res_samples += res_samples |
|
|
|
if is_controlnet: |
|
new_down_block_res_samples = () |
|
|
|
for down_block_res_sample, down_block_additional_residual in zip( |
|
down_block_res_samples, down_block_additional_residuals |
|
): |
|
down_block_res_sample = down_block_res_sample + down_block_additional_residual |
|
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) |
|
|
|
down_block_res_samples = new_down_block_res_samples |
|
|
|
|
|
if self.mid_block is not None: |
|
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: |
|
sample = self.mid_block( |
|
sample, |
|
emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
else: |
|
sample = self.mid_block(sample, emb) |
|
|
|
|
|
if ( |
|
is_adapter |
|
and len(down_intrablock_additional_residuals) > 0 |
|
and sample.shape == down_intrablock_additional_residuals[0].shape |
|
): |
|
sample += down_intrablock_additional_residuals.pop(0) |
|
|
|
if is_controlnet: |
|
sample = sample + mid_block_additional_residual |
|
|
|
|
|
for i, upsample_block in enumerate(self.up_blocks): |
|
is_final_block = i == len(self.up_blocks) - 1 |
|
|
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
|
|
|
|
|
if not is_final_block and forward_upsample_size: |
|
upsample_size = down_block_res_samples[-1].shape[2:] |
|
|
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_tuple=res_samples, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
upsample_size=upsample_size, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
else: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_tuple=res_samples, |
|
upsample_size=upsample_size, |
|
) |
|
|
|
|
|
if self.conv_norm_out: |
|
sample = self.conv_norm_out(sample) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self, lora_scale) |
|
|
|
if not return_dict: |
|
return (sample,) |
|
|
|
return UNet2DConditionOutput(sample=sample) |
|
|
|
|
|
def SD3Transformer2DModelForward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor = None, |
|
pooled_projections: torch.FloatTensor = None, |
|
timestep: torch.LongTensor = None, |
|
block_controlnet_hidden_states: List = None, |
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
return_dict: bool = True, |
|
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: |
|
""" |
|
The [`SD3Transformer2DModel`] forward method. |
|
|
|
Args: |
|
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): |
|
Input `hidden_states`. |
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): |
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. |
|
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected |
|
from the embeddings of input conditions. |
|
timestep ( `torch.LongTensor`): |
|
Used to indicate denoising step. |
|
block_controlnet_hidden_states: (`list` of `torch.Tensor`): |
|
A list of tensors that if specified are added to the residuals of transformer blocks. |
|
joint_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a |
|
`tuple` where the first element is the sample tensor. |
|
""" |
|
if joint_attention_kwargs is not None: |
|
joint_attention_kwargs = joint_attention_kwargs.copy() |
|
lora_scale = joint_attention_kwargs.pop("scale", 1.0) |
|
else: |
|
lora_scale = 1.0 |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
scale_lora_layers(self, lora_scale) |
|
else: |
|
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: |
|
logger.warning( |
|
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." |
|
) |
|
|
|
height, width = hidden_states.shape[-2:] |
|
|
|
hidden_states = self.pos_embed(hidden_states) |
|
temb = self.time_text_embed(timestep, pooled_projections) |
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states) |
|
|
|
for index_block, block in enumerate(self.transformer_blocks): |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
encoder_hidden_states, |
|
temb, |
|
**ckpt_kwargs, |
|
) |
|
|
|
else: |
|
encoder_hidden_states, hidden_states = block( |
|
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, |
|
|
|
timestep=timestep, height=height // self.config.patch_size, |
|
|
|
) |
|
|
|
|
|
if block_controlnet_hidden_states is not None and block.context_pre_only is False: |
|
interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) |
|
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] |
|
|
|
hidden_states = self.norm_out(hidden_states, temb) |
|
hidden_states = self.proj_out(hidden_states) |
|
|
|
|
|
patch_size = self.config.patch_size |
|
height = height // patch_size |
|
width = width // patch_size |
|
|
|
hidden_states = hidden_states.reshape( |
|
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) |
|
) |
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) |
|
output = hidden_states.reshape( |
|
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) |
|
) |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self, lora_scale) |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return Transformer2DModelOutput(sample=output) |
|
|
|
|
|
def FluxTransformer2DModelForward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor = None, |
|
pooled_projections: torch.Tensor = None, |
|
timestep: torch.LongTensor = None, |
|
img_ids: torch.Tensor = None, |
|
txt_ids: torch.Tensor = None, |
|
guidance: torch.Tensor = None, |
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
controlnet_block_samples=None, |
|
controlnet_single_block_samples=None, |
|
return_dict: bool = True, |
|
controlnet_blocks_repeat: bool = False, |
|
|
|
height: int = None, |
|
|
|
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: |
|
""" |
|
The [`FluxTransformer2DModel`] forward method. |
|
|
|
Args: |
|
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): |
|
Input `hidden_states`. |
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): |
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. |
|
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected |
|
from the embeddings of input conditions. |
|
timestep ( `torch.LongTensor`): |
|
Used to indicate denoising step. |
|
block_controlnet_hidden_states: (`list` of `torch.Tensor`): |
|
A list of tensors that if specified are added to the residuals of transformer blocks. |
|
joint_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a |
|
`tuple` where the first element is the sample tensor. |
|
""" |
|
if joint_attention_kwargs is not None: |
|
joint_attention_kwargs = joint_attention_kwargs.copy() |
|
lora_scale = joint_attention_kwargs.pop("scale", 1.0) |
|
else: |
|
lora_scale = 1.0 |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
scale_lora_layers(self, lora_scale) |
|
else: |
|
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: |
|
logger.warning( |
|
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." |
|
) |
|
hidden_states = self.x_embedder(hidden_states) |
|
|
|
timestep = timestep.to(hidden_states.dtype) * 1000 |
|
if guidance is not None: |
|
guidance = guidance.to(hidden_states.dtype) * 1000 |
|
else: |
|
guidance = None |
|
temb = ( |
|
self.time_text_embed(timestep, pooled_projections) |
|
if guidance is None |
|
else self.time_text_embed(timestep, guidance, pooled_projections) |
|
) |
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states) |
|
|
|
if txt_ids.ndim == 3: |
|
logger.warning( |
|
"Passing `txt_ids` 3d torch.Tensor is deprecated." |
|
"Please remove the batch dimension and pass it as a 2d torch Tensor" |
|
) |
|
txt_ids = txt_ids[0] |
|
if img_ids.ndim == 3: |
|
logger.warning( |
|
"Passing `img_ids` 3d torch.Tensor is deprecated." |
|
"Please remove the batch dimension and pass it as a 2d torch Tensor" |
|
) |
|
img_ids = img_ids[0] |
|
|
|
ids = torch.cat((txt_ids, img_ids), dim=0) |
|
image_rotary_emb = self.pos_embed(ids) |
|
|
|
for index_block, block in enumerate(self.transformer_blocks): |
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
encoder_hidden_states, |
|
temb, |
|
image_rotary_emb, |
|
**ckpt_kwargs, |
|
) |
|
|
|
else: |
|
encoder_hidden_states, hidden_states = block( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
temb=temb, |
|
image_rotary_emb=image_rotary_emb, |
|
joint_attention_kwargs=joint_attention_kwargs, |
|
|
|
timestep=timestep, height=height // self.config.patch_size, |
|
|
|
) |
|
|
|
|
|
if controlnet_block_samples is not None: |
|
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) |
|
interval_control = int(np.ceil(interval_control)) |
|
|
|
if controlnet_blocks_repeat: |
|
hidden_states = ( |
|
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] |
|
) |
|
else: |
|
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] |
|
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
|
|
for index_block, block in enumerate(self.single_transformer_blocks): |
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
temb, |
|
image_rotary_emb, |
|
**ckpt_kwargs, |
|
) |
|
|
|
else: |
|
hidden_states = block( |
|
hidden_states=hidden_states, |
|
temb=temb, |
|
image_rotary_emb=image_rotary_emb, |
|
joint_attention_kwargs=joint_attention_kwargs, |
|
) |
|
|
|
|
|
if controlnet_single_block_samples is not None: |
|
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) |
|
interval_control = int(np.ceil(interval_control)) |
|
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( |
|
hidden_states[:, encoder_hidden_states.shape[1] :, ...] |
|
+ controlnet_single_block_samples[index_block // interval_control] |
|
) |
|
|
|
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] |
|
|
|
hidden_states = self.norm_out(hidden_states, temb) |
|
output = self.proj_out(hidden_states) |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self, lora_scale) |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return Transformer2DModelOutput(sample=output) |
|
|
|
|
|
def Transformer2DModelForward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
added_cond_kwargs: Dict[str, torch.Tensor] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
): |
|
""" |
|
The [`Transformer2DModel`] forward method. |
|
|
|
Args: |
|
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): |
|
Input `hidden_states`. |
|
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): |
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to |
|
self-attention. |
|
timestep ( `torch.LongTensor`, *optional*): |
|
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. |
|
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): |
|
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in |
|
`AdaLayerZeroNorm`. |
|
cross_attention_kwargs ( `Dict[str, Any]`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
attention_mask ( `torch.Tensor`, *optional*): |
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask |
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large |
|
negative values to the attention scores corresponding to "discard" tokens. |
|
encoder_attention_mask ( `torch.Tensor`, *optional*): |
|
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: |
|
|
|
* Mask `(batch, sequence_length)` True = keep, False = discard. |
|
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. |
|
|
|
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format |
|
above. This bias will be added to the cross-attention scores. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned, |
|
otherwise a `tuple` where the first element is the sample tensor. |
|
""" |
|
if cross_attention_kwargs is not None: |
|
if cross_attention_kwargs.get("scale", None) is not None: |
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.ndim == 2: |
|
|
|
|
|
|
|
|
|
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
|
|
|
if self.is_input_continuous: |
|
batch_size, _, height, width = hidden_states.shape |
|
residual = hidden_states |
|
hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) |
|
elif self.is_input_vectorized: |
|
hidden_states = self.latent_image_embedding(hidden_states) |
|
elif self.is_input_patches: |
|
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size |
|
hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( |
|
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs |
|
) |
|
|
|
|
|
cross_attention_kwargs['height'] = height |
|
cross_attention_kwargs['width'] = width |
|
|
|
|
|
|
|
for block in self.transformer_blocks: |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
timestep, |
|
cross_attention_kwargs, |
|
class_labels, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states = block( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
timestep=timestep, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
class_labels=class_labels, |
|
) |
|
|
|
|
|
if self.is_input_continuous: |
|
output = self._get_output_for_continuous_inputs( |
|
hidden_states=hidden_states, |
|
residual=residual, |
|
batch_size=batch_size, |
|
height=height, |
|
width=width, |
|
inner_dim=inner_dim, |
|
) |
|
elif self.is_input_vectorized: |
|
output = self._get_output_for_vectorized_inputs(hidden_states) |
|
elif self.is_input_patches: |
|
output = self._get_output_for_patched_inputs( |
|
hidden_states=hidden_states, |
|
timestep=timestep, |
|
class_labels=class_labels, |
|
embedded_timestep=embedded_timestep, |
|
height=height, |
|
width=width, |
|
) |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return Transformer2DModelOutput(sample=output) |
|
|
|
|
|
def BasicTransformerBlockForward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
if cross_attention_kwargs is not None: |
|
if cross_attention_kwargs.get("scale", None) is not None: |
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") |
|
|
|
|
|
|
|
batch_size = hidden_states.shape[0] |
|
|
|
if self.norm_type == "ada_norm": |
|
norm_hidden_states = self.norm1(hidden_states, timestep) |
|
elif self.norm_type == "ada_norm_zero": |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( |
|
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype |
|
) |
|
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: |
|
norm_hidden_states = self.norm1(hidden_states) |
|
elif self.norm_type == "ada_norm_continuous": |
|
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) |
|
elif self.norm_type == "ada_norm_single": |
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
|
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) |
|
).chunk(6, dim=1) |
|
norm_hidden_states = self.norm1(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
|
norm_hidden_states = norm_hidden_states.squeeze(1) |
|
else: |
|
raise ValueError("Incorrect norm used") |
|
|
|
if self.pos_embed is not None: |
|
norm_hidden_states = self.pos_embed(norm_hidden_states) |
|
|
|
|
|
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} |
|
gligen_kwargs = cross_attention_kwargs.pop("gligen", None) |
|
|
|
|
|
attn_parameters = set(inspect.signature(self.attn1.processor.__call__).parameters.keys()) |
|
|
|
|
|
attn_output = self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, |
|
attention_mask=attention_mask, |
|
|
|
**{k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}, |
|
|
|
) |
|
if self.norm_type == "ada_norm_zero": |
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
elif self.norm_type == "ada_norm_single": |
|
attn_output = gate_msa * attn_output |
|
|
|
hidden_states = attn_output + hidden_states |
|
if hidden_states.ndim == 4: |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
|
|
if gligen_kwargs is not None: |
|
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) |
|
|
|
|
|
if self.attn2 is not None: |
|
if self.norm_type == "ada_norm": |
|
norm_hidden_states = self.norm2(hidden_states, timestep) |
|
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: |
|
norm_hidden_states = self.norm2(hidden_states) |
|
elif self.norm_type == "ada_norm_single": |
|
|
|
|
|
norm_hidden_states = hidden_states |
|
elif self.norm_type == "ada_norm_continuous": |
|
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) |
|
else: |
|
raise ValueError("Incorrect norm") |
|
|
|
if self.pos_embed is not None and self.norm_type != "ada_norm_single": |
|
norm_hidden_states = self.pos_embed(norm_hidden_states) |
|
|
|
attn_output = self.attn2( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
hidden_states = attn_output + hidden_states |
|
|
|
|
|
|
|
if self.norm_type == "ada_norm_continuous": |
|
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) |
|
elif not self.norm_type == "ada_norm_single": |
|
norm_hidden_states = self.norm3(hidden_states) |
|
|
|
if self.norm_type == "ada_norm_zero": |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
|
if self.norm_type == "ada_norm_single": |
|
norm_hidden_states = self.norm2(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp |
|
|
|
if self._chunk_size is not None: |
|
|
|
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) |
|
else: |
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
if self.norm_type == "ada_norm_zero": |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
elif self.norm_type == "ada_norm_single": |
|
ff_output = gate_mlp * ff_output |
|
|
|
hidden_states = ff_output + hidden_states |
|
if hidden_states.ndim == 4: |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
return hidden_states |
|
|
|
|
|
def JointTransformerBlockForward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor, |
|
temb: torch.FloatTensor, |
|
|
|
height: int = None, |
|
timestep: Optional[torch.Tensor] = None, |
|
|
|
): |
|
if self.use_dual_attention: |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( |
|
hidden_states, emb=temb |
|
) |
|
else: |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) |
|
|
|
if self.context_pre_only: |
|
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) |
|
else: |
|
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( |
|
encoder_hidden_states, emb=temb |
|
) |
|
|
|
|
|
attn_output, context_attn_output = self.attn( |
|
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, |
|
|
|
timestep=timestep, height=height, |
|
|
|
) |
|
|
|
|
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
hidden_states = hidden_states + attn_output |
|
|
|
if self.use_dual_attention: |
|
attn_output2 = self.attn2(hidden_states=norm_hidden_states2) |
|
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 |
|
hidden_states = hidden_states + attn_output2 |
|
|
|
norm_hidden_states = self.norm2(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
if self._chunk_size is not None: |
|
|
|
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) |
|
else: |
|
ff_output = self.ff(norm_hidden_states) |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
|
|
hidden_states = hidden_states + ff_output |
|
|
|
|
|
if self.context_pre_only: |
|
encoder_hidden_states = None |
|
else: |
|
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output |
|
encoder_hidden_states = encoder_hidden_states + context_attn_output |
|
|
|
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) |
|
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] |
|
if self._chunk_size is not None: |
|
|
|
context_ff_output = _chunked_feed_forward( |
|
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size |
|
) |
|
else: |
|
context_ff_output = self.ff_context(norm_encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output |
|
|
|
return encoder_hidden_states, hidden_states |
|
|
|
|
|
def FluxTransformerBlockForward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor, |
|
temb: torch.FloatTensor, |
|
image_rotary_emb=None, |
|
joint_attention_kwargs=None, |
|
|
|
height: int = None, |
|
timestep: Optional[torch.Tensor] = None, |
|
|
|
): |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) |
|
|
|
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( |
|
encoder_hidden_states, emb=temb |
|
) |
|
joint_attention_kwargs = joint_attention_kwargs or {} |
|
|
|
attn_output, context_attn_output = self.attn( |
|
hidden_states=norm_hidden_states, |
|
encoder_hidden_states=norm_encoder_hidden_states, |
|
image_rotary_emb=image_rotary_emb, |
|
|
|
timestep=timestep, height=height, |
|
|
|
**joint_attention_kwargs, |
|
) |
|
|
|
|
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
hidden_states = hidden_states + attn_output |
|
|
|
norm_hidden_states = self.norm2(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
|
|
hidden_states = hidden_states + ff_output |
|
|
|
|
|
|
|
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output |
|
encoder_hidden_states = encoder_hidden_states + context_attn_output |
|
|
|
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) |
|
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] |
|
|
|
context_ff_output = self.ff_context(norm_encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output |
|
if encoder_hidden_states.dtype == torch.float16: |
|
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) |
|
|
|
return encoder_hidden_states, hidden_states |
|
|
|
|
|
def attn_call( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
temb: Optional[torch.Tensor] = None, |
|
height: int = None, |
|
width: int = None, |
|
timestep: Optional[torch.Tensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
if len(args) > 0 or kwargs.get("scale", None) is not None: |
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
|
deprecate("scale", "1.0.0", deprecation_message) |
|
|
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
query = attn.head_to_batch_dim(query) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
|
|
if hasattr(self, "store_attn_map"): |
|
self.attn_map = rearrange(attention_probs, 'b (h w) d -> b d h w', h=height) |
|
self.timestep = int(timestep.item()) |
|
|
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: |
|
|
|
L, S = query.size(-2), key.size(-2) |
|
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale |
|
attn_bias = torch.zeros(L, S, dtype=query.dtype) |
|
if is_causal: |
|
assert attn_mask is None |
|
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) |
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
|
attn_bias.to(query.dtype) |
|
|
|
if attn_mask is not None: |
|
if attn_mask.dtype == torch.bool: |
|
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) |
|
else: |
|
attn_bias += attn_mask |
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor |
|
attn_weight += attn_bias.to(attn_weight.device) |
|
attn_weight = torch.softmax(attn_weight, dim=-1) |
|
|
|
return torch.dropout(attn_weight, dropout_p, train=True) @ value, attn_weight |
|
|
|
|
|
def attn_call2_0( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
temb: Optional[torch.Tensor] = None, |
|
height: int = None, |
|
width: int = None, |
|
timestep: Optional[torch.Tensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
if len(args) > 0 or kwargs.get("scale", None) is not None: |
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
|
deprecate("scale", "1.0.0", deprecation_message) |
|
|
|
residual = hidden_states |
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
if hasattr(self, "store_attn_map"): |
|
hidden_states, attention_probs = scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
self.attn_map = rearrange(attention_probs, 'batch attn_head (h w) attn_dim -> batch attn_head h w attn_dim ', h=height) |
|
self.timestep = int(timestep.item()) |
|
else: |
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
def lora_attn_call(self, attn: Attention, hidden_states, height, width, *args, **kwargs): |
|
self_cls_name = self.__class__.__name__ |
|
deprecate( |
|
self_cls_name, |
|
"0.26.0", |
|
( |
|
f"Make sure use {self_cls_name[4:]} instead by setting" |
|
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using" |
|
" `LoraLoaderMixin.load_lora_weights`" |
|
), |
|
) |
|
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device) |
|
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device) |
|
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device) |
|
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device) |
|
|
|
attn._modules.pop("processor") |
|
attn.processor = AttnProcessor() |
|
|
|
attn.processor.__call__ = attn_call.__get__(attn.processor, AttnProcessor) |
|
|
|
|
|
if hasattr(self, "store_attn_map"): |
|
attn.processor.store_attn_map = True |
|
|
|
return attn.processor(attn, hidden_states, height, width, *args, **kwargs) |
|
|
|
|
|
def lora_attn_call2_0(self, attn: Attention, hidden_states, height, width, *args, **kwargs): |
|
self_cls_name = self.__class__.__name__ |
|
deprecate( |
|
self_cls_name, |
|
"0.26.0", |
|
( |
|
f"Make sure use {self_cls_name[4:]} instead by setting" |
|
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using" |
|
" `LoraLoaderMixin.load_lora_weights`" |
|
), |
|
) |
|
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device) |
|
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device) |
|
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device) |
|
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device) |
|
|
|
attn._modules.pop("processor") |
|
attn.processor = AttnProcessor2_0() |
|
|
|
attn.processor.__call__ = attn_call.__get__(attn.processor, AttnProcessor2_0) |
|
|
|
|
|
if hasattr(self, "store_attn_map"): |
|
attn.processor.store_attn_map = True |
|
|
|
return attn.processor(attn, hidden_states, height, width, *args, **kwargs) |
|
|
|
|
|
def joint_attn_call2_0( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
|
height: int = None, |
|
timestep: Optional[torch.Tensor] = None, |
|
|
|
*args, |
|
**kwargs, |
|
) -> torch.FloatTensor: |
|
residual = hidden_states |
|
|
|
batch_size = hidden_states.shape[0] |
|
|
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
if encoder_hidden_states is not None: |
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) |
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
|
|
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( |
|
batch_size, -1, attn.heads, head_dim |
|
).transpose(1, 2) |
|
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( |
|
batch_size, -1, attn.heads, head_dim |
|
).transpose(1, 2) |
|
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( |
|
batch_size, -1, attn.heads, head_dim |
|
).transpose(1, 2) |
|
|
|
if attn.norm_added_q is not None: |
|
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) |
|
if attn.norm_added_k is not None: |
|
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) |
|
|
|
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) |
|
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) |
|
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) |
|
|
|
|
|
if hasattr(self, "store_attn_map"): |
|
hidden_states, attention_probs = scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
image_length = query.shape[2] - encoder_hidden_states_query_proj.shape[2] |
|
|
|
|
|
attention_probs = attention_probs[:,:,:image_length,image_length:].cpu() |
|
|
|
self.attn_map = rearrange( |
|
attention_probs, |
|
'batch attn_head (height width) attn_dim -> batch attn_head height width attn_dim', |
|
height = height |
|
) |
|
self.timestep = timestep[0].cpu().item() |
|
else: |
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
if encoder_hidden_states is not None: |
|
|
|
hidden_states, encoder_hidden_states = ( |
|
hidden_states[:, : residual.shape[1]], |
|
hidden_states[:, residual.shape[1] :], |
|
) |
|
if not attn.context_pre_only: |
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if encoder_hidden_states is not None: |
|
return hidden_states, encoder_hidden_states |
|
else: |
|
return hidden_states |
|
|
|
|
|
|
|
def flux_attn_call2_0( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
|
|
height: int = None, |
|
timestep: Optional[torch.Tensor] = None, |
|
|
|
) -> torch.FloatTensor: |
|
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
|
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
if encoder_hidden_states is not None: |
|
|
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) |
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
|
|
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( |
|
batch_size, -1, attn.heads, head_dim |
|
).transpose(1, 2) |
|
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( |
|
batch_size, -1, attn.heads, head_dim |
|
).transpose(1, 2) |
|
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( |
|
batch_size, -1, attn.heads, head_dim |
|
).transpose(1, 2) |
|
|
|
if attn.norm_added_q is not None: |
|
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) |
|
if attn.norm_added_k is not None: |
|
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) |
|
|
|
|
|
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) |
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) |
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) |
|
|
|
if image_rotary_emb is not None: |
|
from diffusers.models.embeddings import apply_rotary_emb |
|
|
|
|
|
query = apply_rotary_emb(query, image_rotary_emb) |
|
key = apply_rotary_emb(key, image_rotary_emb) |
|
|
|
|
|
if hasattr(self, "store_attn_map"): |
|
hidden_states, attention_probs = scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
image_length = query.shape[2] - encoder_hidden_states_query_proj.shape[2] |
|
|
|
|
|
attention_probs = attention_probs[:,:,:image_length,image_length:].cpu() |
|
|
|
self.attn_map = rearrange( |
|
attention_probs, |
|
'batch attn_head (height width) attn_dim -> batch attn_head height width attn_dim', |
|
height = height |
|
) |
|
self.timestep = timestep[0].cpu().item() |
|
else: |
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
if encoder_hidden_states is not None: |
|
encoder_hidden_states, hidden_states = ( |
|
hidden_states[:, : encoder_hidden_states.shape[1]], |
|
hidden_states[:, encoder_hidden_states.shape[1] :], |
|
) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
|
|
|
return hidden_states, encoder_hidden_states |
|
else: |
|
return hidden_states |