wooyeolbaek commited on
Commit
9fc47f2
1 Parent(s): fb5d7d9
Files changed (1) hide show
  1. app.py +14 -20
app.py CHANGED
@@ -1,25 +1,25 @@
1
  import torch
2
  import gradio as gr
3
- from diffusers import StableDiffusionXLPipeline
4
  from utils import (
5
- cross_attn_init,
6
- register_cross_attention_hook,
7
  attn_maps,
8
- get_net_attn_map,
9
- resize_net_attn_map,
10
- return_net_attn_map,
11
  )
12
  # from transformers.utils.hub import move_cache
13
 
14
  # move_cache()
15
 
16
  cross_attn_init()
17
- pipe = StableDiffusionXLPipeline.from_pretrained(
18
- "stabilityai/stable-diffusion-xl-base-1.0",
19
- # "stabilityai/sdxl-turbo",
20
- torch_dtype=torch.float16,
21
  )
22
- pipe.unet = register_cross_attention_hook(pipe.unet)
 
 
23
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
  pipe = pipe.to(device)
25
 
@@ -29,15 +29,10 @@ def inference(prompt):
29
  prompt,
30
  num_inference_steps=15,
31
  ).images[0]
32
- net_attn_maps = get_net_attn_map(image.size)
33
- net_attn_maps = resize_net_attn_map(net_attn_maps, image.size)
34
- net_attn_maps = return_net_attn_map(net_attn_maps, pipe.tokenizer, prompt)
35
 
36
- # remove sos and eos
37
- net_attn_maps = [attn_map for attn_map in net_attn_maps if attn_map[1].split('_')[-1] != "<<|startoftext|>>"]
38
- net_attn_maps = [attn_map for attn_map in net_attn_maps if attn_map[1].split('_')[-1] != "<<|endoftext|>>"]
39
 
40
- return image, net_attn_maps
41
 
42
 
43
  with gr.Blocks() as demo:
@@ -46,8 +41,7 @@ with gr.Blocks() as demo:
46
  # 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡
47
  """
48
  )
49
- # prompt = gr.Textbox(value="A photo of a black puppy, christmas atmosphere", label="Prompt", lines=2)
50
- prompt = gr.Textbox(value="A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grass in front of the Sydney Opera House holding a sign on the chest that says 'SDXL'!.", label="Prompt", lines=2)
51
  btn = gr.Button("Generate images", scale=0)
52
 
53
  with gr.Row():
 
1
  import torch
2
  import gradio as gr
3
+ from diffusers import StableDiffusion3Pipeline
4
  from utils import (
 
 
5
  attn_maps,
6
+ cross_attn_init,
7
+ init_pipeline,
8
+ save_attention_maps
9
  )
10
  # from transformers.utils.hub import move_cache
11
 
12
  # move_cache()
13
 
14
  cross_attn_init()
15
+
16
+ pipe = StableDiffusion3Pipeline.from_pretrained(
17
+ "stabilityai/stable-diffusion-3-medium-diffusers",
18
+ torch_dtype=torch.bfloat16
19
  )
20
+
21
+ pipe = init_pipeline(pipe)
22
+
23
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
  pipe = pipe.to(device)
25
 
 
29
  prompt,
30
  num_inference_steps=15,
31
  ).images[0]
 
 
 
32
 
33
+ total_attn_maps = save_attention_maps(attn_maps, tokenizer, prompts)
 
 
34
 
35
+ return image, total_attn_maps
36
 
37
 
38
  with gr.Blocks() as demo:
 
41
  # 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡
42
  """
43
  )
44
+ prompt = gr.Textbox(value="A capybara holding a sign that reads Hello World.", label="Prompt", lines=2)
 
45
  btn = gr.Button("Generate images", scale=0)
46
 
47
  with gr.Row():