import abc import torch from typing import Tuple, List from einops import rearrange class AttentionControl(abc.ABC): def step_callback(self, x_t): return x_t def between_steps(self): return @property def num_uncond_att_layers(self): return 0 @abc.abstractmethod def forward(self, attn, is_cross: bool, place_in_unet: str): raise NotImplementedError def __call__(self, attn, is_cross: bool, place_in_unet: str): if self.cur_att_layer >= self.num_uncond_att_layers: self.forward(attn, is_cross, place_in_unet) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: self.cur_att_layer = 0 self.cur_step += 1 self.between_steps() def reset(self): self.cur_step = 0 self.cur_att_layer = 0 def __init__(self): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 class AttentionStore(AttentionControl): @staticmethod def get_empty_store(): return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" #if attn.shape[1] <= 32 ** 2: # avoid memory overhead self.step_store[key].append(attn) return attn def between_steps(self): self.attention_store = self.step_store if self.save_global_store: with torch.no_grad(): if len(self.global_store) == 0: self.global_store = self.step_store else: for key in self.global_store: for i in range(len(self.global_store[key])): self.global_store[key][i] += self.step_store[key][i].detach() self.step_store = self.get_empty_store() self.step_store = self.get_empty_store() def get_average_attention(self): average_attention = self.attention_store return average_attention def get_average_global_attention(self): average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in self.attention_store} return average_attention def reset(self): super(AttentionStore, self).reset() self.step_store = self.get_empty_store() self.attention_store = {} self.global_store = {} def __init__(self, save_global_store=False): ''' Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion process ''' super(AttentionStore, self).__init__() self.save_global_store = save_global_store self.step_store = self.get_empty_store() self.attention_store = {} self.global_store = {} self.curr_step_index = 0 class AttentionStoreProcessor: def __init__(self, attnstore, place_in_unet): super().__init__() self.attnstore = attnstore self.place_in_unet = place_in_unet def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttentionFlipCtrlProcessor: def __init__(self, attnstore, attnstore_ref, place_in_unet): super().__init__() self.attnstore = attnstore self.attnrstore_ref = attnstore_ref self.place_in_unet = place_in_unet def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) if self.place_in_unet == 'mid': cur_att_layer = self.attnstore.cur_att_layer-len(self.attnrstore_ref.attention_store["down_self"]) elif self.place_in_unet == 'up': cur_att_layer = self.attnstore.cur_att_layer-(len(self.attnrstore_ref.attention_store["down_self"])+len(self.attnrstore_ref.attention_store["mid_self"])) else: cur_att_layer = self.attnstore.cur_att_layer attention_probs_ref = self.attnrstore_ref.attention_store[f"{self.place_in_unet}_{'self'}"][cur_att_layer] attention_probs_ref = rearrange(attention_probs_ref, 'b h i j -> (b h) i j') attention_probs = 0.0 * attention_probs + 1.0 * torch.flip(attention_probs_ref, dims=(-2, -1)) self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def register_temporal_self_attention_control(unet, controller): attn_procs = {} temporal_self_att_count = 0 for name in unet.attn_processors.keys(): if name.endswith("temporal_transformer_blocks.0.attn1.processor"): if name.startswith("mid_block"): place_in_unet = "mid" elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) place_in_unet = "up" elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) place_in_unet = "down" else: continue temporal_self_att_count += 1 attn_procs[name] = AttentionStoreProcessor( attnstore=controller, place_in_unet=place_in_unet ) else: attn_procs[name] = unet.attn_processors[name] unet.set_attn_processor(attn_procs) controller.num_att_layers = temporal_self_att_count def register_temporal_self_attention_flip_control(unet, controller, controller_ref): attn_procs = {} temporal_self_att_count = 0 for name in unet.attn_processors.keys(): if name.endswith("temporal_transformer_blocks.0.attn1.processor"): if name.startswith("mid_block"): place_in_unet = "mid" elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) place_in_unet = "up" elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) place_in_unet = "down" else: continue temporal_self_att_count += 1 attn_procs[name] = AttentionFlipCtrlProcessor( attnstore=controller, attnstore_ref=controller_ref, place_in_unet=place_in_unet ) else: attn_procs[name] = unet.attn_processors[name] unet.set_attn_processor(attn_procs) controller.num_att_layers = temporal_self_att_count