|
import os |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision.transforms import ToPILImage |
|
|
|
from diffusers.models import Transformer2DModel |
|
from diffusers.models.unets import UNet2DConditionModel |
|
from diffusers.models.transformers import SD3Transformer2DModel, FluxTransformer2DModel |
|
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock |
|
from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock |
|
from diffusers import FluxPipeline |
|
from diffusers.models.attention_processor import ( |
|
AttnProcessor, |
|
AttnProcessor2_0, |
|
LoRAAttnProcessor, |
|
LoRAAttnProcessor2_0, |
|
JointAttnProcessor2_0, |
|
FluxAttnProcessor2_0 |
|
) |
|
|
|
from modules import * |
|
|
|
def cross_attn_init(): |
|
AttnProcessor.__call__ = attn_call |
|
AttnProcessor2_0.__call__ = attn_call2_0 |
|
LoRAAttnProcessor.__call__ = lora_attn_call |
|
LoRAAttnProcessor2_0.__call__ = lora_attn_call2_0 |
|
JointAttnProcessor2_0.__call__ = joint_attn_call2_0 |
|
FluxAttnProcessor2_0.__call__ = flux_attn_call2_0 |
|
|
|
|
|
def hook_function(name, detach=True): |
|
def forward_hook(module, input, output): |
|
if hasattr(module.processor, "attn_map"): |
|
|
|
timestep = module.processor.timestep |
|
|
|
attn_maps[timestep] = attn_maps.get(timestep, dict()) |
|
attn_maps[timestep][name] = module.processor.attn_map.cpu() if detach \ |
|
else module.processor.attn_map |
|
|
|
del module.processor.attn_map |
|
|
|
return forward_hook |
|
|
|
|
|
def register_cross_attention_hook(model, hook_function, target_name): |
|
for name, module in model.named_modules(): |
|
if not name.endswith(target_name): |
|
continue |
|
|
|
if isinstance(module.processor, AttnProcessor): |
|
module.processor.store_attn_map = True |
|
elif isinstance(module.processor, AttnProcessor2_0): |
|
module.processor.store_attn_map = True |
|
elif isinstance(module.processor, LoRAAttnProcessor): |
|
module.processor.store_attn_map = True |
|
elif isinstance(module.processor, LoRAAttnProcessor2_0): |
|
module.processor.store_attn_map = True |
|
elif isinstance(module.processor, JointAttnProcessor2_0): |
|
module.processor.store_attn_map = True |
|
elif isinstance(module.processor, FluxAttnProcessor2_0): |
|
module.processor.store_attn_map = True |
|
|
|
hook = module.register_forward_hook(hook_function(name)) |
|
|
|
return model |
|
|
|
|
|
def replace_call_method_for_unet(model): |
|
if model.__class__.__name__ == 'UNet2DConditionModel': |
|
model.forward = UNet2DConditionModelForward.__get__(model, UNet2DConditionModel) |
|
|
|
for name, layer in model.named_children(): |
|
|
|
if layer.__class__.__name__ == 'Transformer2DModel': |
|
layer.forward = Transformer2DModelForward.__get__(layer, Transformer2DModel) |
|
|
|
elif layer.__class__.__name__ == 'BasicTransformerBlock': |
|
layer.forward = BasicTransformerBlockForward.__get__(layer, BasicTransformerBlock) |
|
|
|
replace_call_method_for_unet(layer) |
|
|
|
return model |
|
|
|
|
|
def replace_call_method_for_sd3(model): |
|
if model.__class__.__name__ == 'SD3Transformer2DModel': |
|
model.forward = SD3Transformer2DModelForward.__get__(model, SD3Transformer2DModel) |
|
|
|
for name, layer in model.named_children(): |
|
|
|
if layer.__class__.__name__ == 'JointTransformerBlock': |
|
layer.forward = JointTransformerBlockForward.__get__(layer, JointTransformerBlock) |
|
|
|
replace_call_method_for_sd3(layer) |
|
|
|
return model |
|
|
|
|
|
def replace_call_method_for_flux(model): |
|
if model.__class__.__name__ == 'FluxTransformer2DModel': |
|
model.forward = FluxTransformer2DModelForward.__get__(model, FluxTransformer2DModel) |
|
|
|
for name, layer in model.named_children(): |
|
|
|
if layer.__class__.__name__ == 'FluxTransformerBlock': |
|
layer.forward = FluxTransformerBlockForward.__get__(layer, FluxTransformerBlock) |
|
|
|
replace_call_method_for_flux(layer) |
|
|
|
return model |
|
|
|
|
|
def init_pipeline(pipeline): |
|
if 'transformer' in vars(pipeline).keys(): |
|
if pipeline.transformer.__class__.__name__ == 'SD3Transformer2DModel': |
|
pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn') |
|
pipeline.transformer = replace_call_method_for_sd3(pipeline.transformer) |
|
|
|
elif pipeline.transformer.__class__.__name__ == 'FluxTransformer2DModel': |
|
FluxPipeline.__call__ = FluxPipeline_call |
|
pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn') |
|
pipeline.transformer = replace_call_method_for_flux(pipeline.transformer) |
|
|
|
else: |
|
if pipeline.unet.__class__.__name__ == 'UNet2DConditionModel': |
|
pipeline.unet = register_cross_attention_hook(pipeline.unet, hook_function, 'attn2') |
|
pipeline.unet = replace_call_method_for_unet(pipeline.unet) |
|
|
|
|
|
return pipeline |
|
|
|
|
|
def save_attention_maps(attn_maps, tokenizer, prompts, base_dir='attn_maps', unconditional=True): |
|
to_pil = ToPILImage() |
|
|
|
token_ids = tokenizer(prompts)['input_ids'] |
|
total_tokens = [] |
|
for token_id in token_ids: |
|
total_tokens.append(tokenizer.convert_ids_to_tokens(token_id)) |
|
|
|
if not os.path.exists(base_dir): |
|
os.mkdir(base_dir) |
|
|
|
total_attn_map = list(list(attn_maps.values())[0].values())[0].sum(1) |
|
if unconditional: |
|
total_attn_map = total_attn_map.chunk(2)[1] |
|
total_attn_map = total_attn_map.permute(0, 3, 1, 2) |
|
total_attn_map = torch.zeros_like(total_attn_map) |
|
total_attn_map_shape = total_attn_map.shape[-2:] |
|
total_attn_map_number = 0 |
|
|
|
for timestep, layers in attn_maps.items(): |
|
timestep_dir = os.path.join(base_dir, f'{timestep}') |
|
if not os.path.exists(timestep_dir): |
|
os.mkdir(timestep_dir) |
|
|
|
for layer, attn_map in layers.items(): |
|
layer_dir = os.path.join(timestep_dir, f'{layer}') |
|
if not os.path.exists(layer_dir): |
|
os.mkdir(layer_dir) |
|
|
|
attn_map = attn_map.sum(1).squeeze(1) |
|
attn_map = attn_map.permute(0, 3, 1, 2) |
|
|
|
if unconditional: |
|
attn_map = attn_map.chunk(2)[1] |
|
|
|
resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False) |
|
total_attn_map += resized_attn_map |
|
total_attn_map_number += 1 |
|
|
|
|
|
total_attn_map /= total_attn_map_number |
|
final_attn_map = [] |
|
for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)): |
|
batch_dir = os.path.join(base_dir, f'batch-{batch}') |
|
if not os.path.exists(batch_dir): |
|
os.mkdir(batch_dir) |
|
|
|
startofword = True |
|
print('tokens',tokens) |
|
for i, (token, a) in enumerate(zip(tokens, attn_map[:len(tokens)])): |
|
if '</w>' in token: |
|
token = token.replace('</w>', '') |
|
if startofword: |
|
token = '<' + token + '>' |
|
else: |
|
token = '-' + token + '>' |
|
startofword = True |
|
|
|
elif token != '<|startoftext|>' and token != '<|endoftext|>': |
|
if startofword: |
|
token = '<' + token + '-' |
|
startofword = False |
|
else: |
|
token = '-' + token + '-' |
|
|
|
|
|
final_attn_map.append((to_pil(a.to(torch.float32)), f'{i}-{token}')) |
|
|
|
return final_attn_map |