hysts HF staff commited on
Commit
93fd2ea
1 Parent(s): 522d80d

Use diffusers implementation

Browse files
Files changed (6) hide show
  1. .gitmodules +0 -3
  2. .vscode/settings.json +18 -0
  3. Attend-and-Excite +0 -1
  4. app.py +6 -11
  5. model.py +45 -55
  6. requirements.txt +5 -8
.gitmodules CHANGED
@@ -1,3 +0,0 @@
1
- [submodule "Attend-and-Excite"]
2
- path = Attend-and-Excite
3
- url = https://github.com/AttendAndExcite/Attend-and-Excite
 
 
 
 
.vscode/settings.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "python.linting.enabled": true,
3
+ "python.linting.flake8Enabled": true,
4
+ "python.linting.pylintEnabled": false,
5
+ "python.linting.lintOnSave": true,
6
+ "python.formatting.provider": "yapf",
7
+ "python.formatting.yapfArgs": [
8
+ "--style={based_on_style: pep8, indent_width: 4, blank_line_before_nested_class_or_def: false, spaces_before_comment: 2, split_before_logical_operator: true}"
9
+ ],
10
+ "[python]": {
11
+ "editor.formatOnType": true,
12
+ "editor.codeActionsOnSave": {
13
+ "source.organizeImports": true
14
+ }
15
+ },
16
+ "editor.formatOnSave": true,
17
+ "files.insertFinalNewline": true
18
+ }
Attend-and-Excite DELETED
@@ -1 +0,0 @@
1
- Subproject commit 41620338367f980b9d73752360ffd2557d8ddf5d
 
 
app.py CHANGED
@@ -24,12 +24,11 @@ def process_example(
24
  seed: int,
25
  apply_attend_and_excite: bool,
26
  ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
27
- model_id = 'CompVis/stable-diffusion-v1-4'
28
  num_steps = 50
29
  guidance_scale = 7.5
30
 
31
- token_table = model.get_token_table(model_id, prompt)
32
- result = model.run(model_id, prompt, indices_to_alter_str, seed,
33
  apply_attend_and_excite, num_steps, guidance_scale)
34
  return token_table, result
35
 
@@ -39,9 +38,6 @@ with gr.Blocks(css='style.css') as demo:
39
 
40
  with gr.Row():
41
  with gr.Column():
42
- model_id = gr.Text(label='Model ID',
43
- value='CompVis/stable-diffusion-v1-4',
44
- visible=False)
45
  prompt = gr.Text(
46
  label='Prompt',
47
  max_lines=1,
@@ -171,13 +167,12 @@ with gr.Blocks(css='style.css') as demo:
171
 
172
  show_token_indices_button.click(
173
  fn=model.get_token_table,
174
- inputs=[model_id, prompt],
175
  outputs=token_indices_table,
176
  queue=False,
177
  )
178
 
179
  inputs = [
180
- model_id,
181
  prompt,
182
  token_indices_str,
183
  seed,
@@ -187,7 +182,7 @@ with gr.Blocks(css='style.css') as demo:
187
  ]
188
  prompt.submit(
189
  fn=model.get_token_table,
190
- inputs=[model_id, prompt],
191
  outputs=token_indices_table,
192
  queue=False,
193
  ).then(
@@ -197,7 +192,7 @@ with gr.Blocks(css='style.css') as demo:
197
  )
198
  token_indices_str.submit(
199
  fn=model.get_token_table,
200
- inputs=[model_id, prompt],
201
  outputs=token_indices_table,
202
  queue=False,
203
  ).then(
@@ -207,7 +202,7 @@ with gr.Blocks(css='style.css') as demo:
207
  )
208
  run_button.click(
209
  fn=model.get_token_table,
210
- inputs=[model_id, prompt],
211
  outputs=token_indices_table,
212
  queue=False,
213
  ).then(
 
24
  seed: int,
25
  apply_attend_and_excite: bool,
26
  ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
 
27
  num_steps = 50
28
  guidance_scale = 7.5
29
 
30
+ token_table = model.get_token_table(prompt)
31
+ result = model.run(prompt, indices_to_alter_str, seed,
32
  apply_attend_and_excite, num_steps, guidance_scale)
33
  return token_table, result
34
 
 
38
 
39
  with gr.Row():
40
  with gr.Column():
 
 
 
41
  prompt = gr.Text(
42
  label='Prompt',
43
  max_lines=1,
 
167
 
168
  show_token_indices_button.click(
169
  fn=model.get_token_table,
170
+ inputs=prompt,
171
  outputs=token_indices_table,
172
  queue=False,
173
  )
174
 
175
  inputs = [
 
176
  prompt,
177
  token_indices_str,
178
  seed,
 
182
  ]
183
  prompt.submit(
184
  fn=model.get_token_table,
185
+ inputs=prompt,
186
  outputs=token_indices_table,
187
  queue=False,
188
  ).then(
 
192
  )
193
  token_indices_str.submit(
194
  fn=model.get_token_table,
195
+ inputs=prompt,
196
  outputs=token_indices_table,
197
  queue=False,
198
  ).then(
 
202
  )
203
  run_button.click(
204
  fn=model.get_token_table,
205
+ inputs=prompt,
206
  outputs=token_indices_table,
207
  queue=False,
208
  ).then(
model.py CHANGED
@@ -1,83 +1,73 @@
1
  from __future__ import annotations
2
 
3
- import sys
4
-
5
- import gradio as gr
6
  import PIL.Image
7
  import torch
8
-
9
- sys.path.append('Attend-and-Excite')
10
-
11
- from config import RunConfig
12
- from pipeline_attend_and_excite import AttendAndExcitePipeline
13
- from run import run_on_prompt
14
- from utils.ptp_utils import AttentionStore
15
 
16
 
17
  class Model:
18
  def __init__(self):
19
  self.device = torch.device(
20
  'cuda:0' if torch.cuda.is_available() else 'cpu')
21
- self.model_id = ''
22
- self.model = None
23
- self.tokenizer = None
24
-
25
- self.load_model('CompVis/stable-diffusion-v1-4')
 
 
 
 
 
 
 
26
 
27
- def load_model(self, model_id: str) -> None:
28
- if model_id == self.model_id:
29
- return
30
- self.model = AttendAndExcitePipeline.from_pretrained(model_id).to(
31
- self.device)
32
- self.tokenizer = self.model.tokenizer
33
- self.model_id = model_id
34
-
35
- def get_token_table(self, model_id: str, prompt: str):
36
- self.load_model(model_id)
37
  tokens = [
38
- self.tokenizer.decode(t)
39
- for t in self.tokenizer(prompt)['input_ids']
40
  ]
41
  tokens = tokens[1:-1]
42
  return list(enumerate(tokens, start=1))
43
 
44
  def run(
45
  self,
46
- model_id: str,
47
  prompt: str,
48
  indices_to_alter_str: str,
49
- seed: int,
50
- apply_attend_and_excite: bool,
51
- num_steps: int,
52
- guidance_scale: float,
53
  scale_factor: int = 20,
54
  thresholds: dict[int, float] = {
55
  10: 0.5,
56
- 20: 0.8
57
  },
58
  max_iter_to_alter: int = 25,
59
  ) -> PIL.Image.Image:
60
  generator = torch.Generator(device=self.device).manual_seed(seed)
61
- try:
62
- indices_to_alter = list(map(int, indices_to_alter_str.split(',')))
63
- except:
64
- raise gr.Error('Invalid token indices.')
65
-
66
- self.load_model(model_id)
67
-
68
- controller = AttentionStore()
69
- config = RunConfig(prompt=prompt,
70
- n_inference_steps=num_steps,
71
- guidance_scale=guidance_scale,
72
- run_standard_sd=not apply_attend_and_excite,
73
- scale_factor=scale_factor,
74
- thresholds=thresholds,
75
- max_iter_to_alter=max_iter_to_alter)
76
- image = run_on_prompt(model=self.model,
77
- prompt=[prompt],
78
- controller=controller,
79
- token_indices=indices_to_alter,
80
- seed=generator,
81
- config=config)
82
 
83
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
 
 
 
3
  import PIL.Image
4
  import torch
5
+ from diffusers import (StableDiffusionAttendAndExcitePipeline,
6
+ StableDiffusionPipeline)
 
 
 
 
 
7
 
8
 
9
  class Model:
10
  def __init__(self):
11
  self.device = torch.device(
12
  'cuda:0' if torch.cuda.is_available() else 'cpu')
13
+ model_id = 'CompVis/stable-diffusion-v1-4'
14
+ if self.device.type == 'cuda':
15
+ self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
16
+ model_id, torch_dtype=torch.float16)
17
+ self.ax_pipe.to(self.device)
18
+ self.sd_pipe = StableDiffusionPipeline.from_pretrained(
19
+ model_id, torch_dtype=torch.float16)
20
+ self.sd_pipe.to(self.device)
21
+ else:
22
+ self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
23
+ model_id)
24
+ self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
25
 
26
+ def get_token_table(self, prompt: str):
 
 
 
 
 
 
 
 
 
27
  tokens = [
28
+ self.ax_pipe.tokenizer.decode(t)
29
+ for t in self.ax_pipe.tokenizer(prompt)['input_ids']
30
  ]
31
  tokens = tokens[1:-1]
32
  return list(enumerate(tokens, start=1))
33
 
34
  def run(
35
  self,
 
36
  prompt: str,
37
  indices_to_alter_str: str,
38
+ seed: int = 0,
39
+ apply_attend_and_excite: bool = True,
40
+ num_steps: int = 50,
41
+ guidance_scale: float = 7.5,
42
  scale_factor: int = 20,
43
  thresholds: dict[int, float] = {
44
  10: 0.5,
45
+ 20: 0.8,
46
  },
47
  max_iter_to_alter: int = 25,
48
  ) -> PIL.Image.Image:
49
  generator = torch.Generator(device=self.device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ if apply_attend_and_excite:
52
+ try:
53
+ token_indices = list(map(int, indices_to_alter_str.split(',')))
54
+ except Exception:
55
+ raise ValueError('Invalid token indices.')
56
+ out = self.ax_pipe(
57
+ prompt=prompt,
58
+ token_indices=token_indices,
59
+ guidance_scale=guidance_scale,
60
+ generator=generator,
61
+ num_inference_steps=num_steps,
62
+ max_iter_to_alter=max_iter_to_alter,
63
+ thresholds=thresholds,
64
+ scale_factor=scale_factor,
65
+ )
66
+ else:
67
+ out = self.sd_pipe(
68
+ prompt=prompt,
69
+ guidance_scale=guidance_scale,
70
+ generator=generator,
71
+ num_inference_steps=num_steps,
72
+ )
73
+ return out.images[0]
requirements.txt CHANGED
@@ -1,8 +1,5 @@
1
- accelerate==0.19.0
2
- diffusers==0.12.1
3
- ftfy==6.1.1
4
- jupyter
5
- opencv-python-headless==4.7.0.68
6
- pyrallis==0.3.1
7
- torch==1.13.1
8
- transformers==4.29.2
 
1
+ accelerate==0.20.3
2
+ diffusers==0.17.0
3
+ Pillow==9.5.0
4
+ torch==2.0.1
5
+ transformers==4.30.1