File size: 1,535 Bytes
0c1540a 138ee92 0c1540a 9fc47f2 0c1540a 64548be 01fed6f 64548be 0c1540a 9fc47f2 138ee92 0c1540a 9fc47f2 eeb734b 5f4f504 0c1540a a6279fd 602d244 a6279fd c47b6ff 602d244 0c1540a 9fc47f2 0c1540a 7971f77 321bbea 0c1540a 7971f77 0c1540a bc4149f 0c1540a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
import gradio as gr
from diffusers import DiffusionPipeline
from utils import (
attn_maps,
cross_attn_init,
init_pipeline,
save_attention_maps
)
# from transformers.utils.hub import move_cache
# move_cache()
cross_attn_init()
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
)
pipe = init_pipeline(pipe)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe = pipe.to(device)
def inference(prompt):
image = pipe(
[prompt],
num_inference_steps=15,
).images[0]
total_attn_maps = save_attention_maps(attn_maps, pipe.tokenizer, [prompt])
return image, total_attn_maps
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡
"""
)
prompt = gr.Textbox(value="A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grass in front of the Sydney Opera House holding a sign on the chest that says 'SDXL'!.", label="Prompt", lines=2)
btn = gr.Button("Generate images", scale=0)
with gr.Row():
image = gr.Image(height=512,width=512,type="pil")
gallery = gr.Gallery(
value=None, label="Generated images", show_label=False,
elem_id="gallery", object_fit="contain", height="auto"
)
btn.click(inference, prompt, [image, gallery])
if __name__ == "__main__":
demo.launch(share=True)
|