File size: 1,405 Bytes
0c1540a
 
138ee92
0c1540a
 
9fc47f2
 
 
0c1540a
64548be
01fed6f
64548be
0c1540a
 
9fc47f2
138ee92
 
 
0c1540a
9fc47f2
 
 
eeb734b
5f4f504
0c1540a
 
 
a6279fd
602d244
a6279fd
 
c47b6ff
602d244
0c1540a
9fc47f2
0c1540a
 
 
 
7971f77
 
 
 
9fc47f2
0c1540a
 
 
 
 
7971f77
 
 
0c1540a
 
 
bc4149f
0c1540a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import gradio as gr
from diffusers import DiffusionPipeline
from utils import (
    attn_maps,
    cross_attn_init,
    init_pipeline,
    save_attention_maps
)
# from transformers.utils.hub import move_cache

# move_cache()

cross_attn_init()

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
)

pipe = init_pipeline(pipe)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe = pipe.to(device)


def inference(prompt):
    image = pipe(
        [prompt],
        num_inference_steps=15,
    ).images[0]

    total_attn_maps = save_attention_maps(attn_maps, pipe.tokenizer, [prompt])
    
    return image, total_attn_maps


with gr.Blocks() as demo:
    gr.Markdown(
        """
        # 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡
        """
    )
    prompt = gr.Textbox(value="A capybara holding a sign that reads Hello World.", label="Prompt", lines=2)
    btn = gr.Button("Generate images", scale=0)
    
    with gr.Row():
        image = gr.Image(height=512,width=512,type="pil")
        gallery = gr.Gallery(
            value=None, label="Generated images", show_label=False,
            elem_id="gallery", object_fit="contain", height="auto"
        )
    
    btn.click(inference, prompt, [image, gallery])


if __name__ == "__main__":
    demo.launch(share=True)