|
import torch |
|
import gradio as gr |
|
from diffusers import StableDiffusionXLPipeline |
|
from utils import ( |
|
cross_attn_init, |
|
register_cross_attention_hook, |
|
attn_maps, |
|
get_net_attn_map, |
|
resize_net_attn_map, |
|
return_net_attn_map, |
|
) |
|
|
|
|
|
|
|
|
|
cross_attn_init() |
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
|
|
torch_dtype=torch.float16, |
|
) |
|
pipe.unet = register_cross_attention_hook(pipe.unet) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
pipe = pipe.to(device) |
|
|
|
|
|
def inference(prompt): |
|
image = pipe( |
|
prompt, |
|
num_inference_steps=15, |
|
).images[0] |
|
net_attn_maps = get_net_attn_map(image.size) |
|
net_attn_maps = resize_net_attn_map(net_attn_maps, image.size) |
|
net_attn_maps = return_net_attn_map(net_attn_maps, pipe.tokenizer, prompt) |
|
|
|
|
|
net_attn_maps = [attn_map for attn_map in net_attn_maps if attn_map[1].split('_')[-1] != "<<|startoftext|>>"] |
|
net_attn_maps = [attn_map for attn_map in net_attn_maps if attn_map[1].split('_')[-1] != "<<|endoftext|>>"] |
|
|
|
return image, net_attn_maps |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡ |
|
""" |
|
) |
|
|
|
prompt = gr.Textbox(value="A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grass in front of the Sydney Opera House holding a sign on the chest that says 'SDXL'!.", label="Prompt", lines=2) |
|
btn = gr.Button("Generate images", scale=0) |
|
|
|
with gr.Row(): |
|
image = gr.Image(height=512,width=512,type="pil") |
|
gallery = gr.Gallery( |
|
value=None, label="Generated images", show_label=False, |
|
elem_id="gallery", object_fit="contain", height="auto" |
|
) |
|
|
|
btn.click(inference, prompt, [image, gallery]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|
|
|