Spaces:
Running
on
Zero
Running
on
Zero
import abc | |
import torch | |
from typing import Tuple, List | |
from einops import rearrange | |
class AttentionControl(abc.ABC): | |
def step_callback(self, x_t): | |
return x_t | |
def between_steps(self): | |
return | |
def num_uncond_att_layers(self): | |
return 0 | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
raise NotImplementedError | |
def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
if self.cur_att_layer >= self.num_uncond_att_layers: | |
self.forward(attn, is_cross, place_in_unet) | |
self.cur_att_layer += 1 | |
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: | |
self.cur_att_layer = 0 | |
self.cur_step += 1 | |
self.between_steps() | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
def __init__(self): | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
class AttentionStore(AttentionControl): | |
def get_empty_store(): | |
return {"down_cross": [], "mid_cross": [], "up_cross": [], | |
"down_self": [], "mid_self": [], "up_self": []} | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
#if attn.shape[1] <= 32 ** 2: # avoid memory overhead | |
self.step_store[key].append(attn) | |
return attn | |
def between_steps(self): | |
self.attention_store = self.step_store | |
if self.save_global_store: | |
with torch.no_grad(): | |
if len(self.global_store) == 0: | |
self.global_store = self.step_store | |
else: | |
for key in self.global_store: | |
for i in range(len(self.global_store[key])): | |
self.global_store[key][i] += self.step_store[key][i].detach() | |
self.step_store = self.get_empty_store() | |
self.step_store = self.get_empty_store() | |
def get_average_attention(self): | |
average_attention = self.attention_store | |
return average_attention | |
def get_average_global_attention(self): | |
average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in | |
self.attention_store} | |
return average_attention | |
def reset(self): | |
super(AttentionStore, self).reset() | |
self.step_store = self.get_empty_store() | |
self.attention_store = {} | |
self.global_store = {} | |
def __init__(self, save_global_store=False): | |
''' | |
Initialize an empty AttentionStore | |
:param step_index: used to visualize only a specific step in the diffusion process | |
''' | |
super(AttentionStore, self).__init__() | |
self.save_global_store = save_global_store | |
self.step_store = self.get_empty_store() | |
self.attention_store = {} | |
self.global_store = {} | |
self.curr_step_index = 0 | |
class AttentionStoreProcessor: | |
def __init__(self, attnstore, place_in_unet): | |
super().__init__() | |
self.attnstore = attnstore | |
self.place_in_unet = place_in_unet | |
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class AttentionFlipCtrlProcessor: | |
def __init__(self, attnstore, attnstore_ref, place_in_unet): | |
super().__init__() | |
self.attnstore = attnstore | |
self.attnrstore_ref = attnstore_ref | |
self.place_in_unet = place_in_unet | |
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
if self.place_in_unet == 'mid': | |
cur_att_layer = self.attnstore.cur_att_layer-len(self.attnrstore_ref.attention_store["down_self"]) | |
elif self.place_in_unet == 'up': | |
cur_att_layer = self.attnstore.cur_att_layer-(len(self.attnrstore_ref.attention_store["down_self"])+len(self.attnrstore_ref.attention_store["mid_self"])) | |
else: | |
cur_att_layer = self.attnstore.cur_att_layer | |
attention_probs_ref = self.attnrstore_ref.attention_store[f"{self.place_in_unet}_{'self'}"][cur_att_layer] | |
attention_probs_ref = rearrange(attention_probs_ref, 'b h i j -> (b h) i j') | |
attention_probs = 0.0 * attention_probs + 1.0 * torch.flip(attention_probs_ref, dims=(-2, -1)) | |
self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
def register_temporal_self_attention_control(unet, controller): | |
attn_procs = {} | |
temporal_self_att_count = 0 | |
for name in unet.attn_processors.keys(): | |
if name.endswith("temporal_transformer_blocks.0.attn1.processor"): | |
if name.startswith("mid_block"): | |
place_in_unet = "mid" | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
place_in_unet = "up" | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
place_in_unet = "down" | |
else: | |
continue | |
temporal_self_att_count += 1 | |
attn_procs[name] = AttentionStoreProcessor( | |
attnstore=controller, place_in_unet=place_in_unet | |
) | |
else: | |
attn_procs[name] = unet.attn_processors[name] | |
unet.set_attn_processor(attn_procs) | |
controller.num_att_layers = temporal_self_att_count | |
def register_temporal_self_attention_flip_control(unet, controller, controller_ref): | |
attn_procs = {} | |
temporal_self_att_count = 0 | |
for name in unet.attn_processors.keys(): | |
if name.endswith("temporal_transformer_blocks.0.attn1.processor"): | |
if name.startswith("mid_block"): | |
place_in_unet = "mid" | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
place_in_unet = "up" | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
place_in_unet = "down" | |
else: | |
continue | |
temporal_self_att_count += 1 | |
attn_procs[name] = AttentionFlipCtrlProcessor( | |
attnstore=controller, attnstore_ref=controller_ref, place_in_unet=place_in_unet | |
) | |
else: | |
attn_procs[name] = unet.attn_processors[name] | |
unet.set_attn_processor(attn_procs) | |
controller.num_att_layers = temporal_self_att_count | |