File size: 1,405 Bytes
0c1540a 138ee92 0c1540a 9fc47f2 0c1540a 64548be 01fed6f 64548be 0c1540a 9fc47f2 138ee92 0c1540a 9fc47f2 eeb734b 5f4f504 0c1540a a6279fd 602d244 a6279fd c47b6ff 602d244 0c1540a 9fc47f2 0c1540a 7971f77 9fc47f2 0c1540a 7971f77 0c1540a bc4149f 0c1540a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
import gradio as gr
from diffusers import DiffusionPipeline
from utils import (
attn_maps,
cross_attn_init,
init_pipeline,
save_attention_maps
)
# from transformers.utils.hub import move_cache
# move_cache()
cross_attn_init()
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
)
pipe = init_pipeline(pipe)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe = pipe.to(device)
def inference(prompt):
image = pipe(
[prompt],
num_inference_steps=15,
).images[0]
total_attn_maps = save_attention_maps(attn_maps, pipe.tokenizer, [prompt])
return image, total_attn_maps
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡
"""
)
prompt = gr.Textbox(value="A capybara holding a sign that reads Hello World.", label="Prompt", lines=2)
btn = gr.Button("Generate images", scale=0)
with gr.Row():
image = gr.Image(height=512,width=512,type="pil")
gallery = gr.Gallery(
value=None, label="Generated images", show_label=False,
elem_id="gallery", object_fit="contain", height="auto"
)
btn.click(inference, prompt, [image, gallery])
if __name__ == "__main__":
demo.launch(share=True)
|