wooyeolbaek's picture
Update utils.py
c993844 verified
raw
history blame
7.7 kB
import os
import torch
import torch.nn.functional as F
from torchvision.transforms import ToPILImage
from diffusers.models import Transformer2DModel
from diffusers.models.unets import UNet2DConditionModel
from diffusers.models.transformers import SD3Transformer2DModel, FluxTransformer2DModel
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock
from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock
from diffusers import FluxPipeline
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
JointAttnProcessor2_0,
FluxAttnProcessor2_0
)
from modules import *
def cross_attn_init():
AttnProcessor.__call__ = attn_call
AttnProcessor2_0.__call__ = attn_call2_0
LoRAAttnProcessor.__call__ = lora_attn_call
LoRAAttnProcessor2_0.__call__ = lora_attn_call2_0
JointAttnProcessor2_0.__call__ = joint_attn_call2_0
FluxAttnProcessor2_0.__call__ = flux_attn_call2_0
def hook_function(name, detach=True):
def forward_hook(module, input, output):
if hasattr(module.processor, "attn_map"):
timestep = module.processor.timestep
attn_maps[timestep] = attn_maps.get(timestep, dict())
attn_maps[timestep][name] = module.processor.attn_map.cpu() if detach \
else module.processor.attn_map
del module.processor.attn_map
return forward_hook
def register_cross_attention_hook(model, hook_function, target_name):
for name, module in model.named_modules():
if not name.endswith(target_name):
continue
if isinstance(module.processor, AttnProcessor):
module.processor.store_attn_map = True
elif isinstance(module.processor, AttnProcessor2_0):
module.processor.store_attn_map = True
elif isinstance(module.processor, LoRAAttnProcessor):
module.processor.store_attn_map = True
elif isinstance(module.processor, LoRAAttnProcessor2_0):
module.processor.store_attn_map = True
elif isinstance(module.processor, JointAttnProcessor2_0):
module.processor.store_attn_map = True
elif isinstance(module.processor, FluxAttnProcessor2_0):
module.processor.store_attn_map = True
hook = module.register_forward_hook(hook_function(name))
return model
def replace_call_method_for_unet(model):
if model.__class__.__name__ == 'UNet2DConditionModel':
model.forward = UNet2DConditionModelForward.__get__(model, UNet2DConditionModel)
for name, layer in model.named_children():
if layer.__class__.__name__ == 'Transformer2DModel':
layer.forward = Transformer2DModelForward.__get__(layer, Transformer2DModel)
elif layer.__class__.__name__ == 'BasicTransformerBlock':
layer.forward = BasicTransformerBlockForward.__get__(layer, BasicTransformerBlock)
replace_call_method_for_unet(layer)
return model
def replace_call_method_for_sd3(model):
if model.__class__.__name__ == 'SD3Transformer2DModel':
model.forward = SD3Transformer2DModelForward.__get__(model, SD3Transformer2DModel)
for name, layer in model.named_children():
if layer.__class__.__name__ == 'JointTransformerBlock':
layer.forward = JointTransformerBlockForward.__get__(layer, JointTransformerBlock)
replace_call_method_for_sd3(layer)
return model
def replace_call_method_for_flux(model):
if model.__class__.__name__ == 'FluxTransformer2DModel':
model.forward = FluxTransformer2DModelForward.__get__(model, FluxTransformer2DModel)
for name, layer in model.named_children():
if layer.__class__.__name__ == 'FluxTransformerBlock':
layer.forward = FluxTransformerBlockForward.__get__(layer, FluxTransformerBlock)
replace_call_method_for_flux(layer)
return model
def init_pipeline(pipeline):
if 'transformer' in vars(pipeline).keys():
if pipeline.transformer.__class__.__name__ == 'SD3Transformer2DModel':
pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn')
pipeline.transformer = replace_call_method_for_sd3(pipeline.transformer)
elif pipeline.transformer.__class__.__name__ == 'FluxTransformer2DModel':
FluxPipeline.__call__ = FluxPipeline_call
pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn')
pipeline.transformer = replace_call_method_for_flux(pipeline.transformer)
else:
if pipeline.unet.__class__.__name__ == 'UNet2DConditionModel':
pipeline.unet = register_cross_attention_hook(pipeline.unet, hook_function, 'attn2')
pipeline.unet = replace_call_method_for_unet(pipeline.unet)
return pipeline
def save_attention_maps(attn_maps, tokenizer, prompts, base_dir='attn_maps', unconditional=True):
to_pil = ToPILImage()
token_ids = tokenizer(prompts)['input_ids']
total_tokens = []
for token_id in token_ids:
total_tokens.append(tokenizer.convert_ids_to_tokens(token_id))
if not os.path.exists(base_dir):
os.mkdir(base_dir)
total_attn_map = list(list(attn_maps.values())[0].values())[0].sum(1)
if unconditional:
total_attn_map = total_attn_map.chunk(2)[1] # (batch, height, width, attn_dim)
total_attn_map = total_attn_map.permute(0, 3, 1, 2)
total_attn_map = torch.zeros_like(total_attn_map)
total_attn_map_shape = total_attn_map.shape[-2:]
total_attn_map_number = 0
for timestep, layers in attn_maps.items():
timestep_dir = os.path.join(base_dir, f'{timestep}')
if not os.path.exists(timestep_dir):
os.mkdir(timestep_dir)
for layer, attn_map in layers.items():
layer_dir = os.path.join(timestep_dir, f'{layer}')
if not os.path.exists(layer_dir):
os.mkdir(layer_dir)
attn_map = attn_map.sum(1).squeeze(1)
attn_map = attn_map.permute(0, 3, 1, 2)
if unconditional:
attn_map = attn_map.chunk(2)[1]
resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False)
total_attn_map += resized_attn_map
total_attn_map_number += 1
total_attn_map /= total_attn_map_number
final_attn_map = []
for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)):
batch_dir = os.path.join(base_dir, f'batch-{batch}')
if not os.path.exists(batch_dir):
os.mkdir(batch_dir)
startofword = True
print('tokens',tokens)
for i, (token, a) in enumerate(zip(tokens, attn_map[:len(tokens)])):
if '</w>' in token:
token = token.replace('</w>', '')
if startofword:
token = '<' + token + '>'
else:
token = '-' + token + '>'
startofword = True
elif token != '<|startoftext|>' and token != '<|endoftext|>':
if startofword:
token = '<' + token + '-'
startofword = False
else:
token = '-' + token + '-'
final_attn_map.append((to_pil(a.to(torch.float32)), f'{i}-{token}'))
return final_attn_map