wooyeolbaek's picture
Update app.py
321bbea verified
raw
history blame
1.54 kB
import torch
import gradio as gr
from diffusers import DiffusionPipeline
from utils import (
attn_maps,
cross_attn_init,
init_pipeline,
save_attention_maps
)
# from transformers.utils.hub import move_cache
# move_cache()
cross_attn_init()
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
)
pipe = init_pipeline(pipe)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe = pipe.to(device)
def inference(prompt):
image = pipe(
[prompt],
num_inference_steps=15,
).images[0]
total_attn_maps = save_attention_maps(attn_maps, pipe.tokenizer, [prompt])
return image, total_attn_maps
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡
"""
)
prompt = gr.Textbox(value="A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grass in front of the Sydney Opera House holding a sign on the chest that says 'SDXL'!.", label="Prompt", lines=2)
btn = gr.Button("Generate images", scale=0)
with gr.Row():
image = gr.Image(height=512,width=512,type="pil")
gallery = gr.Gallery(
value=None, label="Generated images", show_label=False,
elem_id="gallery", object_fit="contain", height="auto"
)
btn.click(inference, prompt, [image, gallery])
if __name__ == "__main__":
demo.launch(share=True)