Idor980 commited on
Commit
1fe3117
1 Parent(s): 9b257f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -577
app.py CHANGED
@@ -1,339 +1,37 @@
1
- # import gradio as gr
2
- # import numpy as np
3
- # import random
4
- # from diffusers import DiffusionPipeline
5
- # import torch
6
- # import transformers
7
-
8
- # # Perform cache migration
9
- # transformers.utils.move_cache()
10
-
11
- # device = "cuda" if torch.cuda.is_available() else "cpu"
12
-
13
- # if torch.cuda.is_available():
14
- # torch.cuda.max_memory_allocated(device=device)
15
- # pipe = DiffusionPipeline.from_pretrained(
16
- # "stabilityai/sdxl-turbo",
17
- # torch_dtype=torch.float16,
18
- # variant="fp16",
19
- # use_safetensors=True,
20
- # )
21
- # pipe.enable_xformers_memory_efficient_attention()
22
- # pipe = pipe.to(device)
23
- # else:
24
- # pipe = DiffusionPipeline.from_pretrained(
25
- # "stabilityai/sdxl-turbo", use_safetensors=True
26
- # )
27
- # pipe = pipe.to(device)
28
-
29
- # # Quantize the model
30
- # pipe.unet = torch.quantization.convert(pipe.unet, inplace=True)
31
-
32
- # MAX_SEED = np.iinfo(np.int32).max
33
- # MAX_IMAGE_SIZE = 512
34
-
35
-
36
- # def generate_image(
37
- # seed, prompt, negative_prompt, guidance_scale, num_inference_steps, width, height
38
- # ):
39
- # try:
40
- # generator = torch.Generator().manual_seed(seed)
41
- # image = pipe(
42
- # prompt=prompt,
43
- # negative_prompt=negative_prompt,
44
- # guidance_scale=guidance_scale,
45
- # num_inference_steps=num_inference_steps,
46
- # width=width,
47
- # height=height,
48
- # generator=generator,
49
- # ).images[0]
50
- # return image
51
- # except Exception as e:
52
- # print(f"Error generating image with seed {seed}: {e}")
53
- # return None
54
-
55
-
56
- # def infer(
57
- # prompt,
58
- # negative_prompt,
59
- # seed,
60
- # randomize_seed,
61
- # width,
62
- # height,
63
- # guidance_scale,
64
- # num_inference_steps,
65
- # ):
66
-
67
- # if randomize_seed:
68
- # seeds = [random.randint(0, MAX_SEED) for _ in range(2)]
69
- # else:
70
- # seeds = [seed, seed + 1]
71
-
72
- # images = []
73
- # for seed in seeds:
74
- # image = generate_image(
75
- # seed,
76
- # prompt,
77
- # negative_prompt,
78
- # guidance_scale,
79
- # num_inference_steps,
80
- # width,
81
- # height,
82
- # )
83
- # images.append(image)
84
-
85
- # return images
86
-
87
-
88
- # examples = [
89
- # "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
90
- # "An astronaut riding a green horse",
91
- # "A delicious ceviche cheesecake slice",
92
- # ]
93
-
94
- # css = """
95
- # #col-container {
96
- # margin: 0 auto;
97
- # max-width: 520px;
98
- # }
99
- # """
100
-
101
- # if torch.cuda.is_available():
102
- # power_device = "GPU"
103
- # else:
104
- # power_device = "CPU"
105
-
106
- # with gr.Blocks(css=css) as demo:
107
-
108
- # with gr.Column(elem_id="col-container"):
109
- # gr.Markdown(
110
- # f"""
111
- # # Text-to-Image Gradio Template
112
- # Currently running on {power_device}.
113
- # """
114
- # )
115
-
116
- # with gr.Row():
117
-
118
- # prompt = gr.Text(
119
- # label="Prompt",
120
- # show_label=False,
121
- # max_lines=1,
122
- # placeholder="Enter your prompt",
123
- # container=False,
124
- # )
125
-
126
- # run_button = gr.Button("Run", scale=0)
127
-
128
- # result1 = gr.Image(label="Result 1", show_label=False)
129
- # result2 = gr.Image(label="Result 2", show_label=False)
130
-
131
- # with gr.Accordion("Advanced Settings", open=False):
132
-
133
- # negative_prompt = gr.Text(
134
- # label="Negative prompt",
135
- # max_lines=1,
136
- # placeholder="Enter a negative prompt",
137
- # visible=False,
138
- # )
139
-
140
- # seed = gr.Slider(
141
- # label="Seed",
142
- # minimum=0,
143
- # maximum=MAX_SEED,
144
- # step=1,
145
- # value=0,
146
- # )
147
-
148
- # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
149
-
150
- # with gr.Row():
151
-
152
- # width = gr.Slider(
153
- # label="Width",
154
- # minimum=256,
155
- # maximum=MAX_IMAGE_SIZE,
156
- # step=32,
157
- # value=512,
158
- # )
159
-
160
- # height = gr.Slider(
161
- # label="Height",
162
- # minimum=256,
163
- # maximum=MAX_IMAGE_SIZE,
164
- # step=32,
165
- # value=512,
166
- # )
167
-
168
- # with gr.Row():
169
-
170
- # guidance_scale = gr.Slider(
171
- # label="Guidance scale",
172
- # minimum=0.0,
173
- # maximum=10.0,
174
- # step=0.1,
175
- # value=0.0,
176
- # )
177
-
178
- # num_inference_steps = gr.Slider(
179
- # label="Number of inference steps",
180
- # minimum=1,
181
- # maximum=50, # Ensure the number of steps is reasonable
182
- # step=1,
183
- # value=2,
184
- # )
185
-
186
- # gr.Examples(examples=examples, inputs=[prompt])
187
-
188
- # run_button.click(
189
- # fn=infer,
190
- # inputs=[
191
- # prompt,
192
- # negative_prompt,
193
- # seed,
194
- # randomize_seed,
195
- # width,
196
- # height,
197
- # guidance_scale,
198
- # num_inference_steps,
199
- # ],
200
- # outputs=[result1, result2],
201
- # )
202
-
203
- # demo.queue().launch()
204
-
205
  import gradio as gr
206
  import numpy as np
207
- from PIL import Image
208
- import requests
209
- from io import BytesIO
210
  import random
211
  from diffusers import DiffusionPipeline
212
  import torch
213
  import transformers
214
- from tqdm import tqdm
215
 
216
  # Perform cache migration
217
  transformers.utils.move_cache()
218
 
219
  device = "cuda" if torch.cuda.is_available() else "cpu"
220
 
221
- def load_pipeline():
222
- if device == "cuda":
223
- pipe = DiffusionPipeline.from_pretrained(
224
- "stabilityai/sdxl-turbo",
225
- torch_dtype=torch.float16,
226
- variant="fp16",
227
- use_safetensors=True,
228
- )
229
- pipe.enable_xformers_memory_efficient_attention()
230
- else:
231
- pipe = DiffusionPipeline.from_pretrained(
232
- "stabilityai/sdxl-turbo", use_safetensors=True
233
- )
234
- # pipe.unet = torch.quantization.convert(pipe.unet, inplace=True)
235
- return pipe.to(device)
236
-
237
- pipe = load_pipeline()
238
-
239
- #################### our model ####################
240
- import warnings
241
- import torch.utils
242
- import torch.utils.checkpoint
243
- from transformers import BitsAndBytesConfig, InstructBlipProcessor, InstructBlipForConditionalGeneration
244
-
245
-
246
- # Filter out specific warnings by message
247
- warnings.filterwarnings("ignore", message="Repo card metadata block was not found. Setting CardData to empty.")
248
- warnings.filterwarnings(
249
- "ignore",
250
- message="torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly.",
251
- category=UserWarning
252
- )
253
-
254
- model_checkpoint = "Salesforce/instructblip-vicuna-7b"
255
- lora_weights_repo_id = "instructblip-vicuna-7b-peft-lora"
256
- # peft_model_id = "NoyHanan/instructblip-vicuna-7b-peft-lora-6400"
257
- peft_model_id = "NoyHanan/instructblip-vicuna-7b-peft-lora-1600"
258
-
259
- prompt_format = """###USER:ֿ\nHere is an image. Please analyze the image and enhance the base prompt by integrating detailed observations, including colors, textures, lighting, and key visual elements, while staying true to the original description. The goal is to produce a more vibrant, detailed, and visually appealing image. The base prompt is: "{base_prompt}"\n###ASSISTANT:\n"""
260
-
261
- text = prompt_format.format(base_prompt="enter prompt here")
262
-
263
-
264
- bnb_config = BitsAndBytesConfig(
265
- load_in_4bit=True,
266
- bnb_4bit_use_double_quant=True,
267
- bnb_4bit_quant_type="nf4",
268
- bnb_4bit_compute_dtype=torch.bfloat16,
269
- )
270
-
271
- def load_prompt_optimization_model_and_processor():
272
- prompt_optimizer_processor = InstructBlipProcessor.from_pretrained(model_checkpoint, legacy=False, quantization_config=bnb_config, device_map="auto")
273
- prompt_optimizer_model = InstructBlipForConditionalGeneration.from_pretrained(model_checkpoint, quantization_config=bnb_config, device_map="auto")
274
- prompt_optimizer_model.load_adapter(peft_model_id)
275
- prompt_optimizer_model.tie_weights()
276
-
277
- return prompt_optimizer_processor, prompt_optimizer_model
278
-
279
- processor, model = load_prompt_optimization_model_and_processor()
280
-
281
- def get_enhanced_prompts_from_our_model(prompt, selected_image):
282
- print("Start generating prompts using our model")
283
- model.eval()
284
-
285
- enhanced_prompts_list = []
286
-
287
- text = prompt_format.format(base_prompt=prompt)
288
- inputs = processor(images=selected_image, text=text, return_tensors="pt").to(device)
289
-
290
-
291
-
292
- # # outputs = trainer.model.generate(
293
- # outputs = model.generate(
294
- # **inputs,
295
- # do_sample=True,
296
- # # num_beams=5,
297
- # num_beams = 3,
298
- # # max_length=516,
299
- # max_length=256,
300
- # min_length=1,
301
- # top_p=0.9,
302
- # repetition_penalty=1.5,
303
- # # length_penalty=1.0,
304
- # length_penalty = 0.5,
305
- # temperature=0.1, # give 3 of the same prompt
306
- # )
307
-
308
- for i in tqdm(range(3)):
309
- # outputs = trainer.model.generate(
310
- outputs = model.generate(
311
- **inputs,
312
- do_sample=True,
313
- # num_beams=5,
314
- num_beams = 3,
315
- # max_length=516,
316
- max_length=256,
317
- min_length=1,
318
- top_p=0.9,
319
- repetition_penalty=1.5,
320
- # length_penalty=1.0,
321
- length_penalty = 0.5,
322
- # temperature=0.1, # give 3 of the same prompt
323
- temperature=0.8
324
- )
325
- res = processor.batch_decode(outputs, skip_special_tokens=True)
326
- generated_text = res[0].strip()
327
- enhanced_prompts_list.append(generated_text)
328
- print(generated_text)
329
- torch.cuda.empty_cache()
330
-
331
- print("Finish generating prompts using our model")
332
- return enhanced_prompts_list
333
- ###################################################
334
 
335
  MAX_SEED = np.iinfo(np.int32).max
336
- MAX_IMAGE_SIZE = 1024
 
337
 
338
  def generate_image(
339
  seed, prompt, negative_prompt, guidance_scale, num_inference_steps, width, height
@@ -343,10 +41,10 @@ def generate_image(
343
  image = pipe(
344
  prompt=prompt,
345
  negative_prompt=negative_prompt,
346
- guidance_scale=float(guidance_scale),
347
- num_inference_steps=int(num_inference_steps),
348
- width=int(width),
349
- height=int(height),
350
  generator=generator,
351
  ).images[0]
352
  return image
@@ -355,30 +53,23 @@ def generate_image(
355
  return None
356
 
357
 
358
- # def get_enhanced_prompts_from_our_model(prompt, selected_image):
359
- # # This prompts will be generated by our model
360
- # return ["cat", "frog", "giraffe"]
361
-
362
- first_4_images = []
363
- seeds = []
364
- prompt_entered_by_the_user = ""
365
- improved_prompts = []
366
-
367
-
368
- def create_4_images_from_original_prompt(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
369
-
370
- global first_4_images
371
- global seeds
372
- global prompt_entered_by_the_user
373
-
374
- prompt_entered_by_the_user = prompt
375
 
376
  if randomize_seed:
377
- seeds = [random.randint(0, MAX_SEED) for _ in range(4)]
378
  else:
379
- seeds = [seed, seed + 1, seed + 2, seed + 3]
380
 
381
- first_4_images = []
382
  for seed in seeds:
383
  image = generate_image(
384
  seed,
@@ -389,73 +80,12 @@ def create_4_images_from_original_prompt(prompt, negative_prompt, seed, randomiz
389
  width,
390
  height,
391
  )
392
- first_4_images.append(image)
393
-
394
- grid = create_image_grid(first_4_images, rows=2, cols=2)
395
- return grid
396
-
397
- def enhance_prompt_and_create_new_images(prompt, selected_index, width, height, guidance_scale, num_inference_steps, negative_prompt):
398
-
399
- global improved_prompts
400
-
401
- selected_image = first_4_images[selected_index]
402
-
403
- improved_prompts = get_enhanced_prompts_from_our_model(prompt, selected_image)
404
-
405
- seed = seeds[selected_index]
406
- final_4_images = [None] * 4
407
- images_from_improved_prompt = []
408
-
409
- for prompt in improved_prompts:
410
- image = generate_image(
411
- seed,
412
- prompt,
413
- negative_prompt,
414
- guidance_scale,
415
- num_inference_steps,
416
- width,
417
- height,
418
- )
419
- if image is not None:
420
- images_from_improved_prompt.append(image)
421
-
422
- for i in range(4):
423
- if i == selected_index:
424
- final_4_images[i] = selected_image
425
- elif images_from_improved_prompt:
426
- final_4_images[i] = images_from_improved_prompt.pop(0)
427
- else:
428
- final_4_images[i] = selected_image
429
- grid = create_image_grid(final_4_images, rows=2, cols=2)
430
- return grid
431
-
432
 
 
433
 
434
 
435
- def create_image_grid(images, rows, cols):
436
- assert len(images) == rows * cols
437
- w, h = images[0].size
438
- grid = Image.new("RGB", size=(cols * w, rows * h))
439
- for i, img in enumerate(images):
440
- grid.paste(img, box=(i % cols * w, i // cols * h))
441
- return grid
442
-
443
- def get_final_prompt(prompt, selected_index_left, selected_index_right):
444
- global improved_prompts
445
-
446
- prompt = f"Prompt did not improved.\n\nOriginal Prompt = {prompt}"
447
-
448
- if selected_index_left == selected_index_right:
449
- return prompt
450
- else:
451
- improved_prompts.insert(selected_index_left, prompt)
452
- new_prompt = improved_prompts[selected_index_right]
453
- improved_prompts.remove(prompt)
454
- return f"Prompt improved succesfuly.\n\nEnhanced Prompt = {new_prompt}"
455
-
456
  examples = [
457
- "Light emerald green wine jar with deer head lid, gold lines, small cracks",
458
- "Textured tempura painting of a friendly waitress serving coffee",
459
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
460
  "An astronaut riding a green horse",
461
  "A delicious ceviche cheesecake slice",
@@ -463,183 +93,102 @@ examples = [
463
 
464
  css = """
465
  #col-container {
466
- margin: 0;
467
- max-width: 20px; # Smaller width
468
- }
469
- body {
470
- display: flex;
471
- justify-content: space-between; # Space between left and right columns
472
- }
473
- .spacer {
474
- height: 130px; # Large spacer
475
  }
476
  """
477
 
478
- power_device = device
 
 
 
479
 
480
  with gr.Blocks(css=css) as demo:
481
 
482
- with gr.Row():
 
 
 
 
 
 
 
 
483
 
484
- #################### First 4 Images Component ####################
485
- with gr.Column(elem_id="col-container"):
486
- gr.Markdown(
487
- f"""
488
- # Enter Prompt & Select The Best Image
489
- Currently running on {power_device}.
490
- """
491
  )
492
 
493
- with gr.Row():
494
- original_prompt = gr.Text(
495
- label="Prompt",
496
- show_label=False,
497
- max_lines=1,
498
- placeholder="Enter your prompt",
499
- container=False,
500
- )
501
 
502
- enter_prompt_buttom = gr.Button("Run", scale=0)
 
503
 
504
- result_left = gr.Image(label="Result", show_label=False)
505
 
506
- selected_index_left = gr.Radio(
507
- label="Select Image",
508
- choices=[
509
- "0 - Upper Left",
510
- "1 - Upper Right",
511
- "2 - Lower Left",
512
- "3 - Lower Right",
513
- ],
514
- type="index",
515
  )
516
 
517
- selected_text_left = gr.Text(label="Selected Image Index", visible=False)
518
-
519
- def update_selected_index(index):
520
- return str(index)
521
-
522
- selected_index_left.change(
523
- fn=update_selected_index,
524
- inputs=selected_index_left,
525
- outputs=selected_text_left,
526
  )
527
 
528
- with gr.Accordion("Advanced Settings", open=False):
529
 
530
- negative_prompt = gr.Text(
531
- label="Negative prompt",
532
- max_lines=1,
533
- placeholder="Enter a negative prompt",
534
- visible=False,
535
- )
536
 
537
- seed = gr.Slider(
538
- label="Seed",
539
- minimum=0,
540
- maximum=MAX_SEED,
541
- step=1,
542
- value=0,
543
  )
544
 
545
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
546
-
547
- with gr.Row():
548
-
549
- width = gr.Slider(
550
- label="Width",
551
- minimum=256,
552
- maximum=MAX_IMAGE_SIZE,
553
- step=32,
554
- value=512,
555
- )
556
-
557
- height = gr.Slider(
558
- label="Height",
559
- minimum=256,
560
- maximum=MAX_IMAGE_SIZE,
561
- step=32,
562
- value=512,
563
- )
564
-
565
- with gr.Row():
566
-
567
- guidance_scale = gr.Slider(
568
- label="Guidance scale",
569
- minimum=0.0,
570
- maximum=10.0,
571
- step=0.1,
572
- value=0.0,
573
- )
574
-
575
- num_inference_steps = gr.Slider(
576
- label="Number of inference steps",
577
- minimum=1,
578
- maximum=12,
579
- step=1,
580
- value=2,
581
- )
582
-
583
- gr.Examples(examples=examples, inputs=[original_prompt])
584
- ##################################################################
585
-
586
- #################### Final 4 Images Component ####################
587
- with gr.Column(elem_id="col-container"):
588
- gr.Markdown(
589
- f"""
590
- # Select The Best Image
591
- you can choose the same one as before if you want
592
- """
593
- )
594
-
595
- # Adding empty rows for spacing
596
- gr.Markdown(" ", elem_classes=["spacer"])
597
-
598
- result_right = gr.Image(label="Result", show_label=False)
599
-
600
- selected_index_right = gr.Radio(
601
- label="Select Image",
602
- choices=[
603
- "0 - Upper Left",
604
- "1 - Upper Right",
605
- "2 - Lower Left",
606
- "3 - Lower Right",
607
- ],
608
- type="index",
609
- )
610
 
611
- selected_text_right = gr.Text(label="Selected Image Index", visible=False)
612
 
613
- selected_index_right.change(
614
- fn=update_selected_index,
615
- inputs=selected_index_right,
616
- outputs=selected_text_right,
617
- )
618
- ##################################################################
619
-
620
- ##################### Final Prompt Component #####################
621
- with gr.Column(elem_id="col-container"):
622
- gr.Markdown(
623
- f"""
624
- # Enhance Prompt
625
- """
626
- )
627
 
628
- enhanced_prompt_output = gr.Textbox(
629
- label="Final Prompt", interactive=False
630
- )
 
 
 
 
631
 
632
- selected_index_right.change(
633
- fn=get_final_prompt,
634
- inputs=[original_prompt, selected_index_left, selected_index_right],
635
- outputs=[enhanced_prompt_output],
636
- )
637
- ##################################################################
638
 
639
- enter_prompt_buttom.click(
640
- fn=create_4_images_from_original_prompt,
641
  inputs=[
642
- original_prompt,
643
  negative_prompt,
644
  seed,
645
  randomize_seed,
@@ -648,22 +197,7 @@ with gr.Blocks(css=css) as demo:
648
  guidance_scale,
649
  num_inference_steps,
650
  ],
651
- outputs=[result_left],
652
  )
653
 
654
-
655
- selected_index_left.change(
656
- fn=enhance_prompt_and_create_new_images,
657
- inputs=[
658
- original_prompt,
659
- selected_index_left,
660
- width,
661
- height,
662
- guidance_scale,
663
- num_inference_steps,
664
- negative_prompt,
665
- ],
666
- outputs=[result_right],
667
- )
668
-
669
- demo.queue().launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
 
 
 
3
  import random
4
  from diffusers import DiffusionPipeline
5
  import torch
6
  import transformers
 
7
 
8
  # Perform cache migration
9
  transformers.utils.move_cache()
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ if torch.cuda.is_available():
14
+ torch.cuda.max_memory_allocated(device=device)
15
+ pipe = DiffusionPipeline.from_pretrained(
16
+ "stabilityai/sdxl-turbo",
17
+ torch_dtype=torch.float16,
18
+ variant="fp16",
19
+ use_safetensors=True,
20
+ )
21
+ pipe.enable_xformers_memory_efficient_attention()
22
+ pipe = pipe.to(device)
23
+ else:
24
+ pipe = DiffusionPipeline.from_pretrained(
25
+ "stabilityai/sdxl-turbo", use_safetensors=True
26
+ )
27
+ pipe = pipe.to(device)
28
+
29
+ # Quantize the model
30
+ pipe.unet = torch.quantization.convert(pipe.unet, inplace=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  MAX_SEED = np.iinfo(np.int32).max
33
+ MAX_IMAGE_SIZE = 512
34
+
35
 
36
  def generate_image(
37
  seed, prompt, negative_prompt, guidance_scale, num_inference_steps, width, height
 
41
  image = pipe(
42
  prompt=prompt,
43
  negative_prompt=negative_prompt,
44
+ guidance_scale=guidance_scale,
45
+ num_inference_steps=num_inference_steps,
46
+ width=width,
47
+ height=height,
48
  generator=generator,
49
  ).images[0]
50
  return image
 
53
  return None
54
 
55
 
56
+ def infer(
57
+ prompt,
58
+ negative_prompt,
59
+ seed,
60
+ randomize_seed,
61
+ width,
62
+ height,
63
+ guidance_scale,
64
+ num_inference_steps,
65
+ ):
 
 
 
 
 
 
 
66
 
67
  if randomize_seed:
68
+ seeds = [random.randint(0, MAX_SEED) for _ in range(2)]
69
  else:
70
+ seeds = [seed, seed + 1]
71
 
72
+ images = []
73
  for seed in seeds:
74
  image = generate_image(
75
  seed,
 
80
  width,
81
  height,
82
  )
83
+ images.append(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ return images
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  examples = [
 
 
89
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
90
  "An astronaut riding a green horse",
91
  "A delicious ceviche cheesecake slice",
 
93
 
94
  css = """
95
  #col-container {
96
+ margin: 0 auto;
97
+ max-width: 520px;
 
 
 
 
 
 
 
98
  }
99
  """
100
 
101
+ if torch.cuda.is_available():
102
+ power_device = "GPU"
103
+ else:
104
+ power_device = "CPU"
105
 
106
  with gr.Blocks(css=css) as demo:
107
 
108
+ with gr.Column(elem_id="col-container"):
109
+ gr.Markdown(
110
+ f"""
111
+ # Text-to-Image Gradio Template
112
+ Currently running on {power_device}.
113
+ """
114
+ )
115
+
116
+ with gr.Row():
117
 
118
+ prompt = gr.Text(
119
+ label="Prompt",
120
+ show_label=False,
121
+ max_lines=1,
122
+ placeholder="Enter your prompt",
123
+ container=False,
 
124
  )
125
 
126
+ run_button = gr.Button("Run", scale=0)
 
 
 
 
 
 
 
127
 
128
+ result1 = gr.Image(label="Result 1", show_label=False)
129
+ result2 = gr.Image(label="Result 2", show_label=False)
130
 
131
+ with gr.Accordion("Advanced Settings", open=False):
132
 
133
+ negative_prompt = gr.Text(
134
+ label="Negative prompt",
135
+ max_lines=1,
136
+ placeholder="Enter a negative prompt",
137
+ visible=False,
 
 
 
 
138
  )
139
 
140
+ seed = gr.Slider(
141
+ label="Seed",
142
+ minimum=0,
143
+ maximum=MAX_SEED,
144
+ step=1,
145
+ value=0,
 
 
 
146
  )
147
 
148
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
149
 
150
+ with gr.Row():
 
 
 
 
 
151
 
152
+ width = gr.Slider(
153
+ label="Width",
154
+ minimum=256,
155
+ maximum=MAX_IMAGE_SIZE,
156
+ step=32,
157
+ value=512,
158
  )
159
 
160
+ height = gr.Slider(
161
+ label="Height",
162
+ minimum=256,
163
+ maximum=MAX_IMAGE_SIZE,
164
+ step=32,
165
+ value=512,
166
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ with gr.Row():
169
 
170
+ guidance_scale = gr.Slider(
171
+ label="Guidance scale",
172
+ minimum=0.0,
173
+ maximum=10.0,
174
+ step=0.1,
175
+ value=0.0,
176
+ )
 
 
 
 
 
 
 
177
 
178
+ num_inference_steps = gr.Slider(
179
+ label="Number of inference steps",
180
+ minimum=1,
181
+ maximum=50, # Ensure the number of steps is reasonable
182
+ step=1,
183
+ value=2,
184
+ )
185
 
186
+ gr.Examples(examples=examples, inputs=[prompt])
 
 
 
 
 
187
 
188
+ run_button.click(
189
+ fn=infer,
190
  inputs=[
191
+ prompt,
192
  negative_prompt,
193
  seed,
194
  randomize_seed,
 
197
  guidance_scale,
198
  num_inference_steps,
199
  ],
200
+ outputs=[result1, result2],
201
  )
202
 
203
+ demo.queue().launch()