|
import torch |
|
import gradio as gr |
|
from diffusers import StableDiffusion3Pipeline |
|
from utils import ( |
|
attn_maps, |
|
cross_attn_init, |
|
init_pipeline, |
|
save_attention_maps |
|
) |
|
|
|
|
|
|
|
|
|
cross_attn_init() |
|
|
|
pipe = StableDiffusion3Pipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-3-medium-diffusers", |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
pipe = init_pipeline(pipe) |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
pipe = pipe.to(device) |
|
|
|
|
|
def inference(prompt): |
|
image = pipe( |
|
prompt, |
|
num_inference_steps=15, |
|
).images[0] |
|
|
|
total_attn_maps = save_attention_maps(attn_maps, tokenizer, prompts) |
|
|
|
return image, total_attn_maps |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡ |
|
""" |
|
) |
|
prompt = gr.Textbox(value="A capybara holding a sign that reads Hello World.", label="Prompt", lines=2) |
|
btn = gr.Button("Generate images", scale=0) |
|
|
|
with gr.Row(): |
|
image = gr.Image(height=512,width=512,type="pil") |
|
gallery = gr.Gallery( |
|
value=None, label="Generated images", show_label=False, |
|
elem_id="gallery", object_fit="contain", height="auto" |
|
) |
|
|
|
btn.click(inference, prompt, [image, gallery]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|
|
|