svd_keyframe_interpolation / attn_ctrl /attention_control.py
fffiloni's picture
Upload 33 files
fcb4edd verified
raw
history blame
10.7 kB
import abc
import torch
from typing import Tuple, List
from einops import rearrange
class AttentionControl(abc.ABC):
def step_callback(self, x_t):
return x_t
def between_steps(self):
return
@property
def num_uncond_att_layers(self):
return 0
@abc.abstractmethod
def forward(self, attn, is_cross: bool, place_in_unet: str):
raise NotImplementedError
def __call__(self, attn, is_cross: bool, place_in_unet: str):
if self.cur_att_layer >= self.num_uncond_att_layers:
self.forward(attn, is_cross, place_in_unet)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
self.between_steps()
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
class AttentionStore(AttentionControl):
@staticmethod
def get_empty_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [],
"down_self": [], "mid_self": [], "up_self": []}
def forward(self, attn, is_cross: bool, place_in_unet: str):
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
#if attn.shape[1] <= 32 ** 2: # avoid memory overhead
self.step_store[key].append(attn)
return attn
def between_steps(self):
self.attention_store = self.step_store
if self.save_global_store:
with torch.no_grad():
if len(self.global_store) == 0:
self.global_store = self.step_store
else:
for key in self.global_store:
for i in range(len(self.global_store[key])):
self.global_store[key][i] += self.step_store[key][i].detach()
self.step_store = self.get_empty_store()
self.step_store = self.get_empty_store()
def get_average_attention(self):
average_attention = self.attention_store
return average_attention
def get_average_global_attention(self):
average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
self.attention_store}
return average_attention
def reset(self):
super(AttentionStore, self).reset()
self.step_store = self.get_empty_store()
self.attention_store = {}
self.global_store = {}
def __init__(self, save_global_store=False):
'''
Initialize an empty AttentionStore
:param step_index: used to visualize only a specific step in the diffusion process
'''
super(AttentionStore, self).__init__()
self.save_global_store = save_global_store
self.step_store = self.get_empty_store()
self.attention_store = {}
self.global_store = {}
self.curr_step_index = 0
class AttentionStoreProcessor:
def __init__(self, attnstore, place_in_unet):
super().__init__()
self.attnstore = attnstore
self.place_in_unet = place_in_unet
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttentionFlipCtrlProcessor:
def __init__(self, attnstore, attnstore_ref, place_in_unet):
super().__init__()
self.attnstore = attnstore
self.attnrstore_ref = attnstore_ref
self.place_in_unet = place_in_unet
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
if self.place_in_unet == 'mid':
cur_att_layer = self.attnstore.cur_att_layer-len(self.attnrstore_ref.attention_store["down_self"])
elif self.place_in_unet == 'up':
cur_att_layer = self.attnstore.cur_att_layer-(len(self.attnrstore_ref.attention_store["down_self"])+len(self.attnrstore_ref.attention_store["mid_self"]))
else:
cur_att_layer = self.attnstore.cur_att_layer
attention_probs_ref = self.attnrstore_ref.attention_store[f"{self.place_in_unet}_{'self'}"][cur_att_layer]
attention_probs_ref = rearrange(attention_probs_ref, 'b h i j -> (b h) i j')
attention_probs = 0.0 * attention_probs + 1.0 * torch.flip(attention_probs_ref, dims=(-2, -1))
self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def register_temporal_self_attention_control(unet, controller):
attn_procs = {}
temporal_self_att_count = 0
for name in unet.attn_processors.keys():
if name.endswith("temporal_transformer_blocks.0.attn1.processor"):
if name.startswith("mid_block"):
place_in_unet = "mid"
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
place_in_unet = "up"
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
place_in_unet = "down"
else:
continue
temporal_self_att_count += 1
attn_procs[name] = AttentionStoreProcessor(
attnstore=controller, place_in_unet=place_in_unet
)
else:
attn_procs[name] = unet.attn_processors[name]
unet.set_attn_processor(attn_procs)
controller.num_att_layers = temporal_self_att_count
def register_temporal_self_attention_flip_control(unet, controller, controller_ref):
attn_procs = {}
temporal_self_att_count = 0
for name in unet.attn_processors.keys():
if name.endswith("temporal_transformer_blocks.0.attn1.processor"):
if name.startswith("mid_block"):
place_in_unet = "mid"
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
place_in_unet = "up"
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
place_in_unet = "down"
else:
continue
temporal_self_att_count += 1
attn_procs[name] = AttentionFlipCtrlProcessor(
attnstore=controller, attnstore_ref=controller_ref, place_in_unet=place_in_unet
)
else:
attn_procs[name] = unet.attn_processors[name]
unet.set_attn_processor(attn_procs)
controller.num_att_layers = temporal_self_att_count