KingNish commited on
Commit
b8975f7
1 Parent(s): 4c0ea3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -117
app.py CHANGED
@@ -1,70 +1,119 @@
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
- from diffusers import DiffusionPipeline
5
- from diffusers import StableDiffusionXLPipeline, DPMSolverSinglestepScheduler
6
- import torch
7
  import spaces
 
 
8
 
9
- device = "cuda"
10
-
11
- torch.cuda.max_memory_allocated(device=device)
12
- pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash")
13
- pipe = pipe.to(device)
14
- pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
15
 
16
  MAX_SEED = np.iinfo(np.int32).max
17
- MAX_IMAGE_SIZE = 4096
18
-
19
- @spaces.GPU(duration=120, queue=False)
20
- def infer(
21
- prompt: str,
22
- negative_prompt: str = "",
23
- seed: int = 24,
24
- randomize_seed: bool = False,
25
- width: int = 1024,
26
- height: int = 1024,
27
- guidance_scale = 3,
28
- num_inference_steps: int = 9,
29
- progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  if randomize_seed:
31
- seed = random.randint(0, MAX_SEED)
32
- generator = torch.Generator().manual_seed(seed)
33
- image = pipe(
34
- prompt = prompt,
35
- negative_prompt = negative_prompt,
36
- guidance_scale = guidance_scale,
37
- num_inference_steps = num_inference_steps,
38
- width = width,
39
- height = height,
40
- generator = generator
41
- ).images[0]
42
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  examples = [
45
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
46
- "An astronaut riding a green horse",
47
- "A delicious ceviche cheesecake slice",
48
- "An alien grasping a sign board contain word 'Flash'",
49
- "Kids going to school, Anime style"
 
50
  ]
51
 
52
  css = '''
53
  .gradio-container{max-width: 560px !important}
54
  h1{text-align:center}
55
- footer {
56
- visibility: hidden
57
- }
58
  '''
59
-
60
- with gr.Blocks(title="SDXL Flash", css=css) as demo:
61
- with gr.Column():
62
- gr.Markdown("""# SDXL Flash
63
- ### Super fast text to Image Generator.
64
- ### <span style='color: red;'>You may change the steps from 5 to 8 or 10, if you didn't get satisfied results.
65
- ### First Image processing takes time then images generate faster.""")
66
  with gr.Row():
67
-
68
  prompt = gr.Text(
69
  label="Prompt",
70
  show_label=False,
@@ -72,79 +121,93 @@ with gr.Blocks(title="SDXL Flash", css=css) as demo:
72
  placeholder="Enter your prompt",
73
  container=False,
74
  )
75
-
76
  run_button = gr.Button("Run", scale=0)
77
-
78
- result = gr.Image(label="Result")
79
-
80
- with gr.Accordion("Advanced Settings", open=False):
81
-
82
  negative_prompt = gr.Text(
83
  label="Negative prompt",
84
- max_lines=5,
85
- lines=4,
86
  placeholder="Enter a negative prompt",
87
- value = "(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  )
89
-
90
- seed = gr.Slider(
91
- label="Seed",
92
- minimum=0,
93
- maximum=MAX_SEED,
 
 
 
 
 
 
 
94
  step=1,
95
- value=0,
96
  )
97
-
98
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
99
-
100
- with gr.Row():
101
-
102
- width = gr.Slider(
103
- label="Width",
104
- minimum=256,
105
- maximum=MAX_IMAGE_SIZE,
106
- step=8,
107
- value=1024,
108
- )
109
-
110
- height = gr.Slider(
111
- label="Height",
112
- minimum=256,
113
- maximum=MAX_IMAGE_SIZE,
114
- step=8,
115
- value=1024,
116
- )
117
-
118
- with gr.Row():
119
-
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=1.0,
123
- maximum=6.0,
124
- step=0.1,
125
- value=3.0,
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=15,
132
- step=1,
133
- value=5,
134
- )
135
-
136
- gr.Examples(
137
- examples = examples,
138
- inputs = prompt,
139
- outputs = result,
140
- fn=infer,
141
- cache_examples=True
142
- )
143
 
144
- run_button.click(
145
- fn = infer,
146
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
147
- outputs = [result]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  )
149
 
150
- demo.queue(max_size=20).launch()
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+
6
  import gradio as gr
7
  import numpy as np
8
+ from PIL import Image
 
 
 
9
  import spaces
10
+ import torch
11
+ from diffusers import DiffusionPipeline
12
 
13
+ DESCRIPTION = """# SDXL Flash"""
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
 
 
 
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
19
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
20
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
21
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
22
+
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+
25
+ NUM_IMAGES_PER_PROMPT = 1
26
+
27
+ if torch.cuda.is_available():
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ "sd-community/sdxl-flash",
30
+ torch_dtype=torch.float16,
31
+ use_safetensors=True,
32
+ add_watermarker=False,
33
+ variant="fp16"
34
+ )
35
+ if ENABLE_CPU_OFFLOAD:
36
+ pipe.enable_model_cpu_offload()
37
+ else:
38
+ pipe.to(device)
39
+ print("Loaded on Device!")
40
+
41
+ if USE_TORCH_COMPILE:
42
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
43
+ pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
44
+ print("Model Compiled!")
45
+
46
+
47
+ def save_image(img):
48
+ unique_name = str(uuid.uuid4()) + ".png"
49
+ img.save(unique_name)
50
+ return unique_name
51
+
52
+
53
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
54
  if randomize_seed:
55
+ seed = random.randint(0, MAX_SEED)
56
+ return seed
57
+
58
+ @spaces.GPU(enable_queue=False)
59
+ def generate(
60
+ prompt: str,
61
+ negative_prompt: str = "",
62
+ use_negative_prompt: bool = False,
63
+ seed: int = 0,
64
+ width: int = 1024,
65
+ height: int = 1024,
66
+ guidance_scale: float = 3,
67
+ num_inference_steps: int = 9,
68
+ randomize_seed: bool = False,
69
+ use_resolution_binning: bool = True,
70
+ progress=gr.Progress(track_tqdm=True),
71
+ ):
72
+ pipe.to(device)
73
+ seed = int(randomize_seed_fn(seed, randomize_seed))
74
+ generator = torch.Generator().manual_seed(seed)
75
+
76
+ if not use_negative_prompt:
77
+ negative_prompt = "" # type: ignore
78
+ negative_prompt += default_negative
79
+
80
+ options = {
81
+ "prompt":prompt,
82
+ "negative_prompt":negative_prompt,
83
+ "width":width,
84
+ "height":height,
85
+ "guidance_scale":guidance_scale,
86
+ "num_inference_steps":num_inference_steps,
87
+ "generator":generator,
88
+ "num_images_per_prompt":NUM_IMAGES_PER_PROMPT,
89
+ "use_resolution_binning":use_resolution_binning,
90
+ "output_type":"pil",
91
+
92
+ }
93
+
94
+ images = pipe(**options).images
95
+
96
+ image_paths = [save_image(img) for img in images]
97
+ return image_paths, seed
98
+
99
 
100
  examples = [
101
+ "neon holography crystal cat",
102
+ "a cat eating a piece of cheese",
103
+ "an astronaut riding a horse in space",
104
+ "a cartoon of a boy playing with a tiger",
105
+ "a cute robot artist painting on an easel, concept art",
106
+ #"a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone"
107
  ]
108
 
109
  css = '''
110
  .gradio-container{max-width: 560px !important}
111
  h1{text-align:center}
 
 
 
112
  '''
113
+ with gr.Blocks(css=css) as demo:
114
+ gr.Markdown(DESCRIPTION)
115
+ with gr.Group():
 
 
 
 
116
  with gr.Row():
 
117
  prompt = gr.Text(
118
  label="Prompt",
119
  show_label=False,
 
121
  placeholder="Enter your prompt",
122
  container=False,
123
  )
 
124
  run_button = gr.Button("Run", scale=0)
125
+ result = gr.Gallery(label="Result", columns=NUM_IMAGES_PER_PROMPT, show_label=False)
126
+ with gr.Accordion("Advanced options", open=False):
127
+ with gr.Row():
128
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
 
129
  negative_prompt = gr.Text(
130
  label="Negative prompt",
131
+ max_lines=1,
 
132
  placeholder="Enter a negative prompt",
133
+ value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
134
+ visible=True,
135
+ )
136
+ seed = gr.Slider(
137
+ label="Seed",
138
+ minimum=0,
139
+ maximum=MAX_SEED,
140
+ step=1,
141
+ value=0,
142
+ )
143
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
144
+ with gr.Row(visible=True):
145
+ width = gr.Slider(
146
+ label="Width",
147
+ minimum=256,
148
+ maximum=MAX_IMAGE_SIZE,
149
+ step=32,
150
+ value=1024,
151
+ )
152
+ height = gr.Slider(
153
+ label="Height",
154
+ minimum=256,
155
+ maximum=MAX_IMAGE_SIZE,
156
+ step=32,
157
+ value=1024,
158
  )
159
+ with gr.Row():
160
+ guidance_scale = gr.Slider(
161
+ label="Guidance Scale",
162
+ minimum=0.1,
163
+ maximum=6,
164
+ step=0.1,
165
+ value=3.0,
166
+ )
167
+ num_inference_steps = gr.Slider(
168
+ label="Number of inference steps",
169
+ minimum=1,
170
+ maximum=15,
171
  step=1,
172
+ value=5,
173
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ gr.Examples(
176
+ examples=examples,
177
+ inputs=prompt,
178
+ outputs=[result, seed],
179
+ fn=generate,
180
+ cache_examples=CACHE_EXAMPLES,
181
+ )
182
+
183
+ use_negative_prompt.change(
184
+ fn=lambda x: gr.update(visible=x),
185
+ inputs=use_negative_prompt,
186
+ outputs=negative_prompt,
187
+ api_name=False,
188
+ )
189
+
190
+ gr.on(
191
+ triggers=[
192
+ prompt.submit,
193
+ negative_prompt.submit,
194
+ run_button.click,
195
+ ],
196
+ fn=generate,
197
+ inputs=[
198
+ prompt,
199
+ negative_prompt,
200
+ use_negative_prompt,
201
+ seed,
202
+ width,
203
+ height,
204
+ guidance_scale,
205
+ num_inference_steps,
206
+ randomize_seed,
207
+ ],
208
+ outputs=[result, seed],
209
+ api_name="run",
210
  )
211
 
212
+ if __name__ == "__main__":
213
+ demo.queue(max_size=20).launch()