File size: 1,900 Bytes
0c1540a 64548be 01fed6f 64548be 0c1540a 1ef8448 0c1540a eeb734b 5f4f504 0c1540a a6279fd 0c1540a c47b6ff 0c1540a 7971f77 0c1540a 7971f77 0c1540a bc4149f 0c1540a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
import gradio as gr
from diffusers import StableDiffusionXLPipeline
from utils import (
cross_attn_init,
register_cross_attention_hook,
attn_maps,
get_net_attn_map,
resize_net_attn_map,
return_net_attn_map,
)
# from transformers.utils.hub import move_cache
# move_cache()
cross_attn_init()
pipe = StableDiffusionXLPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16,
)
pipe.unet = register_cross_attention_hook(pipe.unet)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe = pipe.to(device)
def inference(prompt):
image = pipe(
prompt,
num_inference_steps=15,
).images[0]
net_attn_maps = get_net_attn_map(image.size)
net_attn_maps = resize_net_attn_map(net_attn_maps, image.size)
net_attn_maps = return_net_attn_map(net_attn_maps, pipe.tokenizer, prompt)
# remove sos and eos
net_attn_maps = [attn_map for attn_map in net_attn_maps if attn_map[1].split('_')[-1] != "<<|startoftext|>>"]
net_attn_maps = [attn_map for attn_map in net_attn_maps if attn_map[1].split('_')[-1] != "<<|endoftext|>>"]
return image, net_attn_maps
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡
"""
)
prompt = gr.Textbox(value="A photo of a black puppy, christmas atmosphere", label="Prompt", lines=2)
btn = gr.Button("Generate images", scale=0)
with gr.Row():
image = gr.Image(height=512,width=512,type="pil")
gallery = gr.Gallery(
value=None, label="Generated images", show_label=False,
elem_id="gallery", object_fit="contain", height="auto"
)
btn.click(inference, prompt, [image, gallery])
if __name__ == "__main__":
demo.launch(share=True)
|