|
import torch |
|
import gradio as gr |
|
from diffusers import DiffusionPipeline |
|
from utils import ( |
|
attn_maps, |
|
cross_attn_init, |
|
init_pipeline, |
|
save_attention_maps |
|
) |
|
|
|
|
|
|
|
|
|
cross_attn_init() |
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
pipe = init_pipeline(pipe) |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
pipe = pipe.to(device) |
|
|
|
|
|
def inference(prompt): |
|
image = pipe( |
|
[prompt], |
|
num_inference_steps=15, |
|
).images[0] |
|
|
|
total_attn_maps = save_attention_maps(attn_maps, pipe.tokenizer, [prompt]) |
|
|
|
return image, total_attn_maps |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡ |
|
""" |
|
) |
|
prompt = gr.Textbox(value="A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grass in front of the Sydney Opera House holding a sign on the chest that says 'SDXL'!.", label="Prompt", lines=2) |
|
btn = gr.Button("Generate images", scale=0) |
|
|
|
with gr.Row(): |
|
image = gr.Image(height=512,width=512,type="pil") |
|
gallery = gr.Gallery( |
|
value=None, label="Generated images", show_label=False, |
|
elem_id="gallery", object_fit="contain", height="auto" |
|
) |
|
|
|
btn.click(inference, prompt, [image, gallery]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|
|
|