KingNish commited on
Commit
f3f26af
1 Parent(s): ef404b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -63
app.py CHANGED
@@ -10,32 +10,26 @@ import spaces
10
  import torch
11
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
12
 
13
- # Use environment variables for flexibility
14
- MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
 
 
 
15
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
16
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
17
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
18
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
19
 
20
- # Determine device and load model outside of function for efficiency
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
- pipe = StableDiffusionXLPipeline.from_pretrained(
23
- MODEL_ID,
24
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
- use_safetensors=True,
26
- add_watermarker=False,
27
- ).to(device)
28
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
29
-
30
- # Torch compile for potential speedup (experimental)
31
- if USE_TORCH_COMPILE:
32
- pipe.compile()
33
-
34
- # CPU offloading for larger RAM capacity (experimental)
35
- if ENABLE_CPU_OFFLOAD:
36
- pipe.enable_model_cpu_offload()
37
 
38
- MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
39
 
40
  def save_image(img):
41
  unique_name = str(uuid.uuid4()) + ".png"
@@ -47,58 +41,50 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
47
  seed = random.randint(0, MAX_SEED)
48
  return seed
49
 
50
- @spaces.GPU(duration=35, enable_queue=True)
51
  def generate(
52
  prompt: str,
53
  negative_prompt: str = "",
54
  use_negative_prompt: bool = False,
55
- seed: int = 1,
56
  width: int = 1024,
57
  height: int = 1024,
58
  guidance_scale: float = 3,
59
- num_inference_steps: int = 30,
60
  randomize_seed: bool = False,
61
- use_resolution_binning: bool = True,
62
- num_images: int = 1, # Number of images to generate
63
  progress=gr.Progress(track_tqdm=True),
64
  ):
 
65
  seed = int(randomize_seed_fn(seed, randomize_seed))
66
- generator = torch.Generator(device=device).manual_seed(seed)
67
 
68
- # Improved options handling
69
  options = {
70
- "prompt": [prompt] * num_images,
71
- "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
72
- "width": width,
73
- "height": height,
74
- "guidance_scale": guidance_scale,
75
- "num_inference_steps": num_inference_steps,
76
- "generator": generator,
77
- "output_type": "pil",
78
- }
79
-
80
- # Use resolution binning for faster generation with less VRAM usage
81
- if use_resolution_binning:
82
- options["use_resolution_binning"] = True
83
 
84
- # Generate images potentially in batches
85
- images = []
86
- for i in range(0, num_images, BATCH_SIZE):
87
- batch_options = options.copy()
88
- batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
89
- if "negative_prompt" in batch_options:
90
- batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
91
- images.extend(pipe(**batch_options).images)
92
 
93
  image_paths = [save_image(img) for img in images]
94
  return image_paths, seed
95
 
 
96
  examples = [
97
  "a cat eating a piece of cheese",
98
- "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
99
- "Ironman VS Hulk, ultrarealistic",
 
100
  "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
101
- "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
102
  "Kids going to school, Anime style"
103
  ]
104
 
@@ -109,9 +95,9 @@ footer {
109
  visibility: hidden
110
  }
111
  '''
112
-
113
  with gr.Blocks(css=css) as demo:
114
- gr.Markdown("""# SDXL Flash""")
 
115
  with gr.Group():
116
  with gr.Row():
117
  prompt = gr.Text(
@@ -122,15 +108,8 @@ with gr.Blocks(css=css) as demo:
122
  container=False,
123
  )
124
  run_button = gr.Button("Run", scale=0)
125
- result = gr.Gallery(label="Result", columns=1, show_label=False)
126
  with gr.Accordion("Advanced options", open=False):
127
- num_images = gr.Slider(
128
- label="Number of Images",
129
- minimum=1,
130
- maximum=4,
131
- step=1,
132
- value=1,
133
- )
134
  with gr.Row():
135
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
136
  negative_prompt = gr.Text(
@@ -183,7 +162,9 @@ with gr.Blocks(css=css) as demo:
183
  gr.Examples(
184
  examples=examples,
185
  inputs=prompt,
186
- cache_examples=False
 
 
187
  )
188
 
189
  use_negative_prompt.change(
@@ -210,7 +191,6 @@ with gr.Blocks(css=css) as demo:
210
  guidance_scale,
211
  num_inference_steps,
212
  randomize_seed,
213
- num_images
214
  ],
215
  outputs=[result, seed],
216
  api_name="run",
 
10
  import torch
11
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
12
 
13
+ if not torch.cuda.is_available():
14
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
15
+
16
+ MAX_SEED = np.iinfo(np.int32).max
17
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
18
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
19
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
20
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
 
21
 
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ if torch.cuda.is_available():
25
+ pipe = StableDiffusionXLPipeline.from_pretrained(
26
+ "sd-community/sdxl-flash",
27
+ torch_dtype=torch.float16,
28
+ use_safetensors=True,
29
+ add_watermarker=False
30
+ )
31
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
32
+ pipe.to("cuda")
33
 
34
  def save_image(img):
35
  unique_name = str(uuid.uuid4()) + ".png"
 
41
  seed = random.randint(0, MAX_SEED)
42
  return seed
43
 
44
+ @spaces.GPU(duration=30, queue=False)
45
  def generate(
46
  prompt: str,
47
  negative_prompt: str = "",
48
  use_negative_prompt: bool = False,
49
+ seed: int = 0,
50
  width: int = 1024,
51
  height: int = 1024,
52
  guidance_scale: float = 3,
53
+ num_inference_steps: int = 25,
54
  randomize_seed: bool = False,
55
+ use_resolution_binning: bool = True,
 
56
  progress=gr.Progress(track_tqdm=True),
57
  ):
58
+ pipe.to(device)
59
  seed = int(randomize_seed_fn(seed, randomize_seed))
60
+ generator = torch.Generator().manual_seed(seed)
61
 
 
62
  options = {
63
+ "prompt":prompt,
64
+ "negative_prompt":negative_prompt,
65
+ "width":width,
66
+ "height":height,
67
+ "guidance_scale":guidance_scale,
68
+ "num_inference_steps":num_inference_steps,
69
+ "generator":generator,
70
+ "use_resolution_binning":use_resolution_binning,
71
+ "output_type":"pil",
 
 
 
 
72
 
73
+ }
74
+
75
+ images = pipe(**options).images
 
 
 
 
 
76
 
77
  image_paths = [save_image(img) for img in images]
78
  return image_paths, seed
79
 
80
+
81
  examples = [
82
  "a cat eating a piece of cheese",
83
+ "a ROBOT riding a BLUE horse on Mars, photorealistic",
84
+ "a cartoon of a IRONMAN fighting with HULK, wall painting",
85
+ "a cute robot artist painting on an easel, concept art",
86
  "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
87
+ "An alien grasping a sign board contain word 'Flash', futuristic, neonpunk, detailed",
88
  "Kids going to school, Anime style"
89
  ]
90
 
 
95
  visibility: hidden
96
  }
97
  '''
 
98
  with gr.Blocks(css=css) as demo:
99
+ gr.Markdown("""# SDXL Flash
100
+ ### First Image processing takes time then images generate faster.""")
101
  with gr.Group():
102
  with gr.Row():
103
  prompt = gr.Text(
 
108
  container=False,
109
  )
110
  run_button = gr.Button("Run", scale=0)
111
+ result = gr.Gallery(label="Result", columns=1)
112
  with gr.Accordion("Advanced options", open=False):
 
 
 
 
 
 
 
113
  with gr.Row():
114
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
115
  negative_prompt = gr.Text(
 
162
  gr.Examples(
163
  examples=examples,
164
  inputs=prompt,
165
+ outputs=[result, seed],
166
+ fn=generate,
167
+ cache_examples=CACHE_EXAMPLES,
168
  )
169
 
170
  use_negative_prompt.change(
 
191
  guidance_scale,
192
  num_inference_steps,
193
  randomize_seed,
 
194
  ],
195
  outputs=[result, seed],
196
  api_name="run",