Support SDXL-Lightning and fix some errors for baseline.

#1
Files changed (1) hide show
  1. app.py +68 -46
app.py CHANGED
@@ -2,25 +2,35 @@
2
 
3
 
4
  import os
 
5
  os.system("pip install -U peft")
6
  import random
7
 
8
  import gradio as gr
9
  import numpy as np
10
  import PIL.Image
 
11
  import spaces
12
  import torch
13
- from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler
 
 
 
 
14
  from huggingface_hub import hf_hub_download
15
- from diffusers.models.attention_processor import AttnProcessor2_0
16
 
17
  DESCRIPTION = """
18
  # Res-Adapter :Domain Consistent Resolution Adapter for Diffusion Models
19
  **Demo by [ameer azam] - [Twitter](https://twitter.com/Ameerazam18) - [GitHub](https://github.com/AMEERAZAM08)) - [Hugging Face](https://huggingface.co/ameerazam08)**
20
- This is a demo of https://huggingface.co/jiaxiangc/res-adapter LORAs by ByteDance
 
 
21
  """
22
  if not torch.cuda.is_available():
23
- DESCRIPTION += "\n<h1>Running on CPU πŸ₯Ά This demo does not work on CPU.</a> instead</h1>"
 
 
24
 
25
  MAX_SEED = np.iinfo(np.int32).max
26
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
@@ -29,21 +39,26 @@ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
29
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
- pipe = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',use_safetensors=True)# torch_dtype=torch.float16, variant="safetensors")
33
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
34
 
 
 
 
35
 
 
 
 
 
 
36
 
 
37
  pipe.load_lora_weights(
38
  hf_hub_download(
39
- repo_id="jiaxiangc/res-adapter",
40
- subfolder="sdxl-i",
41
  filename="resolution_lora.safetensors",
42
  ),
43
  adapter_name="res_adapter",
44
  )
45
- pipe.set_adapters(["res_adapter"], adapter_weights=[1.0])
46
- pipe = pipe.to(device)
47
 
48
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
49
  if randomize_seed:
@@ -63,11 +78,11 @@ def generate(
63
  seed: int = 0,
64
  width: int = 1024,
65
  height: int = 1024,
66
- guidance_scale_base: float = 5.0,
67
- num_inference_steps_base: int = 20,
68
  progress=gr.Progress(track_tqdm=True),
69
  ) -> PIL.Image.Image:
70
- print(f"** Generating image for: \"{prompt}\" **")
71
  generator = torch.Generator().manual_seed(seed)
72
 
73
  if not use_negative_prompt:
@@ -76,46 +91,51 @@ def generate(
76
  prompt_2 = None # type: ignore
77
  if not use_negative_prompt_2:
78
  negative_prompt_2 = None # type: ignore
79
- res_adapt=pipe(
 
 
80
  prompt=prompt,
81
  negative_prompt=negative_prompt,
82
  prompt_2=prompt_2,
83
  negative_prompt_2=negative_prompt_2,
84
  width=width,
85
  height=height,
86
- guidance_scale=guidance_scale_base,
87
- num_inference_steps=num_inference_steps_base,
88
- generator=generator,
89
  output_type="pil",
 
90
  ).images[0]
91
 
92
- pipe.unet.set_attn_processor(AttnProcessor2_0())
93
- base_image = pipe(
 
94
  prompt=prompt,
95
  negative_prompt=negative_prompt,
96
  prompt_2=prompt_2,
97
  negative_prompt_2=negative_prompt_2,
98
  width=width,
99
  height=height,
100
- guidance_scale=guidance_scale_base,
101
- num_inference_steps=num_inference_steps_base,
 
102
  generator=generator,
103
- output_type="pil").images[0]
104
 
105
-
106
-
107
-
108
- return [res_adapt,base_image]
109
 
110
 
111
  examples = [
112
- "A realistic photograph of an astronaut in a jungle, cold color palette, detailed, 8k",
113
- "An astronaut riding a green horse",
114
- "cinematic film still, photo of a girl, cyberpunk, neonpunk, headset, city at night, sony fe 12-24mm f/2.8 gm, close up, 32k uhd, wallpaper, analog film grain, SONY headset"
115
  ]
116
 
117
  theme = gr.themes.Base(
118
- font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
 
 
 
 
 
119
  )
120
  with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
121
  gr.Markdown(DESCRIPTION)
@@ -136,13 +156,15 @@ with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
136
  # result = gr.Gallery(label="Right is Res-Adapt-LORA and Left is Base"),
137
  with gr.Accordion("Advanced options", open=False):
138
  with gr.Row():
139
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
140
  use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
141
- use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
 
 
142
  negative_prompt = gr.Text(
143
  label="Negative prompt",
144
  max_lines=1,
145
- placeholder="ugly, deformed, noisy, blurry, nsfw, low contrast, text, BadDream, 3d, cgi, render, fake, anime, open mouth, big forehead, long neck",
146
  visible=True,
147
  )
148
  prompt_2 = gr.Text(
@@ -182,19 +204,19 @@ with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
182
  value=512,
183
  )
184
  with gr.Row():
185
- guidance_scale_base = gr.Slider(
186
- label="Guidance scale for base",
187
- minimum=1,
188
  maximum=20,
189
  step=0.1,
190
- value=9.5,
191
  )
192
- num_inference_steps_base = gr.Slider(
193
- label="Number of inference steps for base",
194
- minimum=10,
195
- maximum=100,
196
  step=1,
197
- value=25,
198
  )
199
  gr.Examples(
200
  examples=examples,
@@ -251,12 +273,12 @@ with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
251
  seed,
252
  width,
253
  height,
254
- guidance_scale_base,
255
- num_inference_steps_base,
256
  ],
257
- outputs=gr.Gallery(label="Left is Res-Adapt-LORA and Right is Base"),
258
  api_name="run",
259
  )
260
 
261
  if __name__ == "__main__":
262
- demo.queue(max_size=20, api_open=False).launch(show_api=False)
 
2
 
3
 
4
  import os
5
+
6
  os.system("pip install -U peft")
7
  import random
8
 
9
  import gradio as gr
10
  import numpy as np
11
  import PIL.Image
12
+
13
  import spaces
14
  import torch
15
+ from diffusers import (
16
+ StableDiffusionXLPipeline,
17
+ UNet2DConditionModel,
18
+ EulerDiscreteScheduler,
19
+ )
20
  from huggingface_hub import hf_hub_download
21
+ from safetensors.torch import load_file
22
 
23
  DESCRIPTION = """
24
  # Res-Adapter :Domain Consistent Resolution Adapter for Diffusion Models
25
  **Demo by [ameer azam] - [Twitter](https://twitter.com/Ameerazam18) - [GitHub](https://github.com/AMEERAZAM08)) - [Hugging Face](https://huggingface.co/ameerazam08)**
26
+ This is a demo of https://huggingface.co/jiaxiangc/res-adapter ResAdapter by ByteDance.
27
+
28
+ ByteDance provide a demo of [ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) with [SDXL-Lightning-Step4](https://huggingface.co/ByteDance/SDXL-Lightning) to expand resolution range from 1024-only to 256~1024.
29
  """
30
  if not torch.cuda.is_available():
31
+ DESCRIPTION += (
32
+ "\n<h1>Running on CPU πŸ₯Ά This demo does not work on CPU.</a> instead</h1>"
33
+ )
34
 
35
  MAX_SEED = np.iinfo(np.int32).max
36
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
 
39
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
40
 
41
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
42
 
43
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
44
+ repo = "ByteDance/SDXL-Lightning"
45
+ ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
46
 
47
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
48
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
49
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16")
50
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
51
+ pipe = pipe.to(device)
52
 
53
+ # Load resadapter
54
  pipe.load_lora_weights(
55
  hf_hub_download(
56
+ repo_id="jiaxiangc/res-adapter",
57
+ subfolder="sdxl-i",
58
  filename="resolution_lora.safetensors",
59
  ),
60
  adapter_name="res_adapter",
61
  )
 
 
62
 
63
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
64
  if randomize_seed:
 
78
  seed: int = 0,
79
  width: int = 1024,
80
  height: int = 1024,
81
+ guidance_scale: float = 0,
82
+ num_inference_steps: int = 4,
83
  progress=gr.Progress(track_tqdm=True),
84
  ) -> PIL.Image.Image:
85
+ print(f'** Generating image for: "{prompt}" **')
86
  generator = torch.Generator().manual_seed(seed)
87
 
88
  if not use_negative_prompt:
 
91
  prompt_2 = None # type: ignore
92
  if not use_negative_prompt_2:
93
  negative_prompt_2 = None # type: ignore
94
+
95
+ pipe.set_adapters(["res_adapter"], adapter_weights=[0.0])
96
+ base_image = pipe(
97
  prompt=prompt,
98
  negative_prompt=negative_prompt,
99
  prompt_2=prompt_2,
100
  negative_prompt_2=negative_prompt_2,
101
  width=width,
102
  height=height,
103
+ num_inference_steps=num_inference_steps,
104
+ guidance_scale=guidance_scale,
 
105
  output_type="pil",
106
+ generator=generator,
107
  ).images[0]
108
 
109
+
110
+ pipe.set_adapters(["res_adapter"], adapter_weights=[1.0])
111
+ res_adapt = pipe(
112
  prompt=prompt,
113
  negative_prompt=negative_prompt,
114
  prompt_2=prompt_2,
115
  negative_prompt_2=negative_prompt_2,
116
  width=width,
117
  height=height,
118
+ num_inference_steps=num_inference_steps,
119
+ guidance_scale=guidance_scale,
120
+ output_type="pil",
121
  generator=generator,
122
+ ).images[0]
123
 
124
+ return [res_adapt, base_image]
 
 
 
125
 
126
 
127
  examples = [
128
+ "A girl smiling",
129
+ "A boy smiling",
 
130
  ]
131
 
132
  theme = gr.themes.Base(
133
+ font=[
134
+ gr.themes.GoogleFont("Libre Franklin"),
135
+ gr.themes.GoogleFont("Public Sans"),
136
+ "system-ui",
137
+ "sans-serif",
138
+ ],
139
  )
140
  with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
141
  gr.Markdown(DESCRIPTION)
 
156
  # result = gr.Gallery(label="Right is Res-Adapt-LORA and Left is Base"),
157
  with gr.Accordion("Advanced options", open=False):
158
  with gr.Row():
159
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
160
  use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
161
+ use_negative_prompt_2 = gr.Checkbox(
162
+ label="Use negative prompt 2", value=False
163
+ )
164
  negative_prompt = gr.Text(
165
  label="Negative prompt",
166
  max_lines=1,
167
+ placeholder="Enter your prompt",
168
  visible=True,
169
  )
170
  prompt_2 = gr.Text(
 
204
  value=512,
205
  )
206
  with gr.Row():
207
+ guidance_scale = gr.Slider(
208
+ label="Guidance scale",
209
+ minimum=0,
210
  maximum=20,
211
  step=0.1,
212
+ value=0,
213
  )
214
+ num_inference_steps = gr.Slider(
215
+ label="Number of inference steps",
216
+ minimum=1,
217
+ maximum=50,
218
  step=1,
219
+ value=4,
220
  )
221
  gr.Examples(
222
  examples=examples,
 
273
  seed,
274
  width,
275
  height,
276
+ guidance_scale,
277
+ num_inference_steps,
278
  ],
279
+ outputs=gr.Gallery(label="Left is ResAdapter and Right is Base"),
280
  api_name="run",
281
  )
282
 
283
  if __name__ == "__main__":
284
+ demo.queue(max_size=20, api_open=False).launch(show_api=False)