Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,339 +1,37 @@
|
|
1 |
-
# import gradio as gr
|
2 |
-
# import numpy as np
|
3 |
-
# import random
|
4 |
-
# from diffusers import DiffusionPipeline
|
5 |
-
# import torch
|
6 |
-
# import transformers
|
7 |
-
|
8 |
-
# # Perform cache migration
|
9 |
-
# transformers.utils.move_cache()
|
10 |
-
|
11 |
-
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
-
|
13 |
-
# if torch.cuda.is_available():
|
14 |
-
# torch.cuda.max_memory_allocated(device=device)
|
15 |
-
# pipe = DiffusionPipeline.from_pretrained(
|
16 |
-
# "stabilityai/sdxl-turbo",
|
17 |
-
# torch_dtype=torch.float16,
|
18 |
-
# variant="fp16",
|
19 |
-
# use_safetensors=True,
|
20 |
-
# )
|
21 |
-
# pipe.enable_xformers_memory_efficient_attention()
|
22 |
-
# pipe = pipe.to(device)
|
23 |
-
# else:
|
24 |
-
# pipe = DiffusionPipeline.from_pretrained(
|
25 |
-
# "stabilityai/sdxl-turbo", use_safetensors=True
|
26 |
-
# )
|
27 |
-
# pipe = pipe.to(device)
|
28 |
-
|
29 |
-
# # Quantize the model
|
30 |
-
# pipe.unet = torch.quantization.convert(pipe.unet, inplace=True)
|
31 |
-
|
32 |
-
# MAX_SEED = np.iinfo(np.int32).max
|
33 |
-
# MAX_IMAGE_SIZE = 512
|
34 |
-
|
35 |
-
|
36 |
-
# def generate_image(
|
37 |
-
# seed, prompt, negative_prompt, guidance_scale, num_inference_steps, width, height
|
38 |
-
# ):
|
39 |
-
# try:
|
40 |
-
# generator = torch.Generator().manual_seed(seed)
|
41 |
-
# image = pipe(
|
42 |
-
# prompt=prompt,
|
43 |
-
# negative_prompt=negative_prompt,
|
44 |
-
# guidance_scale=guidance_scale,
|
45 |
-
# num_inference_steps=num_inference_steps,
|
46 |
-
# width=width,
|
47 |
-
# height=height,
|
48 |
-
# generator=generator,
|
49 |
-
# ).images[0]
|
50 |
-
# return image
|
51 |
-
# except Exception as e:
|
52 |
-
# print(f"Error generating image with seed {seed}: {e}")
|
53 |
-
# return None
|
54 |
-
|
55 |
-
|
56 |
-
# def infer(
|
57 |
-
# prompt,
|
58 |
-
# negative_prompt,
|
59 |
-
# seed,
|
60 |
-
# randomize_seed,
|
61 |
-
# width,
|
62 |
-
# height,
|
63 |
-
# guidance_scale,
|
64 |
-
# num_inference_steps,
|
65 |
-
# ):
|
66 |
-
|
67 |
-
# if randomize_seed:
|
68 |
-
# seeds = [random.randint(0, MAX_SEED) for _ in range(2)]
|
69 |
-
# else:
|
70 |
-
# seeds = [seed, seed + 1]
|
71 |
-
|
72 |
-
# images = []
|
73 |
-
# for seed in seeds:
|
74 |
-
# image = generate_image(
|
75 |
-
# seed,
|
76 |
-
# prompt,
|
77 |
-
# negative_prompt,
|
78 |
-
# guidance_scale,
|
79 |
-
# num_inference_steps,
|
80 |
-
# width,
|
81 |
-
# height,
|
82 |
-
# )
|
83 |
-
# images.append(image)
|
84 |
-
|
85 |
-
# return images
|
86 |
-
|
87 |
-
|
88 |
-
# examples = [
|
89 |
-
# "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
90 |
-
# "An astronaut riding a green horse",
|
91 |
-
# "A delicious ceviche cheesecake slice",
|
92 |
-
# ]
|
93 |
-
|
94 |
-
# css = """
|
95 |
-
# #col-container {
|
96 |
-
# margin: 0 auto;
|
97 |
-
# max-width: 520px;
|
98 |
-
# }
|
99 |
-
# """
|
100 |
-
|
101 |
-
# if torch.cuda.is_available():
|
102 |
-
# power_device = "GPU"
|
103 |
-
# else:
|
104 |
-
# power_device = "CPU"
|
105 |
-
|
106 |
-
# with gr.Blocks(css=css) as demo:
|
107 |
-
|
108 |
-
# with gr.Column(elem_id="col-container"):
|
109 |
-
# gr.Markdown(
|
110 |
-
# f"""
|
111 |
-
# # Text-to-Image Gradio Template
|
112 |
-
# Currently running on {power_device}.
|
113 |
-
# """
|
114 |
-
# )
|
115 |
-
|
116 |
-
# with gr.Row():
|
117 |
-
|
118 |
-
# prompt = gr.Text(
|
119 |
-
# label="Prompt",
|
120 |
-
# show_label=False,
|
121 |
-
# max_lines=1,
|
122 |
-
# placeholder="Enter your prompt",
|
123 |
-
# container=False,
|
124 |
-
# )
|
125 |
-
|
126 |
-
# run_button = gr.Button("Run", scale=0)
|
127 |
-
|
128 |
-
# result1 = gr.Image(label="Result 1", show_label=False)
|
129 |
-
# result2 = gr.Image(label="Result 2", show_label=False)
|
130 |
-
|
131 |
-
# with gr.Accordion("Advanced Settings", open=False):
|
132 |
-
|
133 |
-
# negative_prompt = gr.Text(
|
134 |
-
# label="Negative prompt",
|
135 |
-
# max_lines=1,
|
136 |
-
# placeholder="Enter a negative prompt",
|
137 |
-
# visible=False,
|
138 |
-
# )
|
139 |
-
|
140 |
-
# seed = gr.Slider(
|
141 |
-
# label="Seed",
|
142 |
-
# minimum=0,
|
143 |
-
# maximum=MAX_SEED,
|
144 |
-
# step=1,
|
145 |
-
# value=0,
|
146 |
-
# )
|
147 |
-
|
148 |
-
# randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
149 |
-
|
150 |
-
# with gr.Row():
|
151 |
-
|
152 |
-
# width = gr.Slider(
|
153 |
-
# label="Width",
|
154 |
-
# minimum=256,
|
155 |
-
# maximum=MAX_IMAGE_SIZE,
|
156 |
-
# step=32,
|
157 |
-
# value=512,
|
158 |
-
# )
|
159 |
-
|
160 |
-
# height = gr.Slider(
|
161 |
-
# label="Height",
|
162 |
-
# minimum=256,
|
163 |
-
# maximum=MAX_IMAGE_SIZE,
|
164 |
-
# step=32,
|
165 |
-
# value=512,
|
166 |
-
# )
|
167 |
-
|
168 |
-
# with gr.Row():
|
169 |
-
|
170 |
-
# guidance_scale = gr.Slider(
|
171 |
-
# label="Guidance scale",
|
172 |
-
# minimum=0.0,
|
173 |
-
# maximum=10.0,
|
174 |
-
# step=0.1,
|
175 |
-
# value=0.0,
|
176 |
-
# )
|
177 |
-
|
178 |
-
# num_inference_steps = gr.Slider(
|
179 |
-
# label="Number of inference steps",
|
180 |
-
# minimum=1,
|
181 |
-
# maximum=50, # Ensure the number of steps is reasonable
|
182 |
-
# step=1,
|
183 |
-
# value=2,
|
184 |
-
# )
|
185 |
-
|
186 |
-
# gr.Examples(examples=examples, inputs=[prompt])
|
187 |
-
|
188 |
-
# run_button.click(
|
189 |
-
# fn=infer,
|
190 |
-
# inputs=[
|
191 |
-
# prompt,
|
192 |
-
# negative_prompt,
|
193 |
-
# seed,
|
194 |
-
# randomize_seed,
|
195 |
-
# width,
|
196 |
-
# height,
|
197 |
-
# guidance_scale,
|
198 |
-
# num_inference_steps,
|
199 |
-
# ],
|
200 |
-
# outputs=[result1, result2],
|
201 |
-
# )
|
202 |
-
|
203 |
-
# demo.queue().launch()
|
204 |
-
|
205 |
import gradio as gr
|
206 |
import numpy as np
|
207 |
-
from PIL import Image
|
208 |
-
import requests
|
209 |
-
from io import BytesIO
|
210 |
import random
|
211 |
from diffusers import DiffusionPipeline
|
212 |
import torch
|
213 |
import transformers
|
214 |
-
from tqdm import tqdm
|
215 |
|
216 |
# Perform cache migration
|
217 |
transformers.utils.move_cache()
|
218 |
|
219 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
#################### our model ####################
|
240 |
-
import warnings
|
241 |
-
import torch.utils
|
242 |
-
import torch.utils.checkpoint
|
243 |
-
from transformers import BitsAndBytesConfig, InstructBlipProcessor, InstructBlipForConditionalGeneration
|
244 |
-
|
245 |
-
|
246 |
-
# Filter out specific warnings by message
|
247 |
-
warnings.filterwarnings("ignore", message="Repo card metadata block was not found. Setting CardData to empty.")
|
248 |
-
warnings.filterwarnings(
|
249 |
-
"ignore",
|
250 |
-
message="torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly.",
|
251 |
-
category=UserWarning
|
252 |
-
)
|
253 |
-
|
254 |
-
model_checkpoint = "Salesforce/instructblip-vicuna-7b"
|
255 |
-
lora_weights_repo_id = "instructblip-vicuna-7b-peft-lora"
|
256 |
-
# peft_model_id = "NoyHanan/instructblip-vicuna-7b-peft-lora-6400"
|
257 |
-
peft_model_id = "NoyHanan/instructblip-vicuna-7b-peft-lora-1600"
|
258 |
-
|
259 |
-
prompt_format = """###USER:ֿ\nHere is an image. Please analyze the image and enhance the base prompt by integrating detailed observations, including colors, textures, lighting, and key visual elements, while staying true to the original description. The goal is to produce a more vibrant, detailed, and visually appealing image. The base prompt is: "{base_prompt}"\n###ASSISTANT:\n"""
|
260 |
-
|
261 |
-
text = prompt_format.format(base_prompt="enter prompt here")
|
262 |
-
|
263 |
-
|
264 |
-
bnb_config = BitsAndBytesConfig(
|
265 |
-
load_in_4bit=True,
|
266 |
-
bnb_4bit_use_double_quant=True,
|
267 |
-
bnb_4bit_quant_type="nf4",
|
268 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
269 |
-
)
|
270 |
-
|
271 |
-
def load_prompt_optimization_model_and_processor():
|
272 |
-
prompt_optimizer_processor = InstructBlipProcessor.from_pretrained(model_checkpoint, legacy=False, quantization_config=bnb_config, device_map="auto")
|
273 |
-
prompt_optimizer_model = InstructBlipForConditionalGeneration.from_pretrained(model_checkpoint, quantization_config=bnb_config, device_map="auto")
|
274 |
-
prompt_optimizer_model.load_adapter(peft_model_id)
|
275 |
-
prompt_optimizer_model.tie_weights()
|
276 |
-
|
277 |
-
return prompt_optimizer_processor, prompt_optimizer_model
|
278 |
-
|
279 |
-
processor, model = load_prompt_optimization_model_and_processor()
|
280 |
-
|
281 |
-
def get_enhanced_prompts_from_our_model(prompt, selected_image):
|
282 |
-
print("Start generating prompts using our model")
|
283 |
-
model.eval()
|
284 |
-
|
285 |
-
enhanced_prompts_list = []
|
286 |
-
|
287 |
-
text = prompt_format.format(base_prompt=prompt)
|
288 |
-
inputs = processor(images=selected_image, text=text, return_tensors="pt").to(device)
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
# # outputs = trainer.model.generate(
|
293 |
-
# outputs = model.generate(
|
294 |
-
# **inputs,
|
295 |
-
# do_sample=True,
|
296 |
-
# # num_beams=5,
|
297 |
-
# num_beams = 3,
|
298 |
-
# # max_length=516,
|
299 |
-
# max_length=256,
|
300 |
-
# min_length=1,
|
301 |
-
# top_p=0.9,
|
302 |
-
# repetition_penalty=1.5,
|
303 |
-
# # length_penalty=1.0,
|
304 |
-
# length_penalty = 0.5,
|
305 |
-
# temperature=0.1, # give 3 of the same prompt
|
306 |
-
# )
|
307 |
-
|
308 |
-
for i in tqdm(range(3)):
|
309 |
-
# outputs = trainer.model.generate(
|
310 |
-
outputs = model.generate(
|
311 |
-
**inputs,
|
312 |
-
do_sample=True,
|
313 |
-
# num_beams=5,
|
314 |
-
num_beams = 3,
|
315 |
-
# max_length=516,
|
316 |
-
max_length=256,
|
317 |
-
min_length=1,
|
318 |
-
top_p=0.9,
|
319 |
-
repetition_penalty=1.5,
|
320 |
-
# length_penalty=1.0,
|
321 |
-
length_penalty = 0.5,
|
322 |
-
# temperature=0.1, # give 3 of the same prompt
|
323 |
-
temperature=0.8
|
324 |
-
)
|
325 |
-
res = processor.batch_decode(outputs, skip_special_tokens=True)
|
326 |
-
generated_text = res[0].strip()
|
327 |
-
enhanced_prompts_list.append(generated_text)
|
328 |
-
print(generated_text)
|
329 |
-
torch.cuda.empty_cache()
|
330 |
-
|
331 |
-
print("Finish generating prompts using our model")
|
332 |
-
return enhanced_prompts_list
|
333 |
-
###################################################
|
334 |
|
335 |
MAX_SEED = np.iinfo(np.int32).max
|
336 |
-
MAX_IMAGE_SIZE =
|
|
|
337 |
|
338 |
def generate_image(
|
339 |
seed, prompt, negative_prompt, guidance_scale, num_inference_steps, width, height
|
@@ -343,10 +41,10 @@ def generate_image(
|
|
343 |
image = pipe(
|
344 |
prompt=prompt,
|
345 |
negative_prompt=negative_prompt,
|
346 |
-
guidance_scale=
|
347 |
-
num_inference_steps=
|
348 |
-
width=
|
349 |
-
height=
|
350 |
generator=generator,
|
351 |
).images[0]
|
352 |
return image
|
@@ -355,30 +53,23 @@ def generate_image(
|
|
355 |
return None
|
356 |
|
357 |
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
def create_4_images_from_original_prompt(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
|
369 |
-
|
370 |
-
global first_4_images
|
371 |
-
global seeds
|
372 |
-
global prompt_entered_by_the_user
|
373 |
-
|
374 |
-
prompt_entered_by_the_user = prompt
|
375 |
|
376 |
if randomize_seed:
|
377 |
-
seeds = [random.randint(0, MAX_SEED) for _ in range(
|
378 |
else:
|
379 |
-
seeds = [seed, seed + 1
|
380 |
|
381 |
-
|
382 |
for seed in seeds:
|
383 |
image = generate_image(
|
384 |
seed,
|
@@ -389,73 +80,12 @@ def create_4_images_from_original_prompt(prompt, negative_prompt, seed, randomiz
|
|
389 |
width,
|
390 |
height,
|
391 |
)
|
392 |
-
|
393 |
-
|
394 |
-
grid = create_image_grid(first_4_images, rows=2, cols=2)
|
395 |
-
return grid
|
396 |
-
|
397 |
-
def enhance_prompt_and_create_new_images(prompt, selected_index, width, height, guidance_scale, num_inference_steps, negative_prompt):
|
398 |
-
|
399 |
-
global improved_prompts
|
400 |
-
|
401 |
-
selected_image = first_4_images[selected_index]
|
402 |
-
|
403 |
-
improved_prompts = get_enhanced_prompts_from_our_model(prompt, selected_image)
|
404 |
-
|
405 |
-
seed = seeds[selected_index]
|
406 |
-
final_4_images = [None] * 4
|
407 |
-
images_from_improved_prompt = []
|
408 |
-
|
409 |
-
for prompt in improved_prompts:
|
410 |
-
image = generate_image(
|
411 |
-
seed,
|
412 |
-
prompt,
|
413 |
-
negative_prompt,
|
414 |
-
guidance_scale,
|
415 |
-
num_inference_steps,
|
416 |
-
width,
|
417 |
-
height,
|
418 |
-
)
|
419 |
-
if image is not None:
|
420 |
-
images_from_improved_prompt.append(image)
|
421 |
-
|
422 |
-
for i in range(4):
|
423 |
-
if i == selected_index:
|
424 |
-
final_4_images[i] = selected_image
|
425 |
-
elif images_from_improved_prompt:
|
426 |
-
final_4_images[i] = images_from_improved_prompt.pop(0)
|
427 |
-
else:
|
428 |
-
final_4_images[i] = selected_image
|
429 |
-
grid = create_image_grid(final_4_images, rows=2, cols=2)
|
430 |
-
return grid
|
431 |
-
|
432 |
|
|
|
433 |
|
434 |
|
435 |
-
def create_image_grid(images, rows, cols):
|
436 |
-
assert len(images) == rows * cols
|
437 |
-
w, h = images[0].size
|
438 |
-
grid = Image.new("RGB", size=(cols * w, rows * h))
|
439 |
-
for i, img in enumerate(images):
|
440 |
-
grid.paste(img, box=(i % cols * w, i // cols * h))
|
441 |
-
return grid
|
442 |
-
|
443 |
-
def get_final_prompt(prompt, selected_index_left, selected_index_right):
|
444 |
-
global improved_prompts
|
445 |
-
|
446 |
-
prompt = f"Prompt did not improved.\n\nOriginal Prompt = {prompt}"
|
447 |
-
|
448 |
-
if selected_index_left == selected_index_right:
|
449 |
-
return prompt
|
450 |
-
else:
|
451 |
-
improved_prompts.insert(selected_index_left, prompt)
|
452 |
-
new_prompt = improved_prompts[selected_index_right]
|
453 |
-
improved_prompts.remove(prompt)
|
454 |
-
return f"Prompt improved succesfuly.\n\nEnhanced Prompt = {new_prompt}"
|
455 |
-
|
456 |
examples = [
|
457 |
-
"Light emerald green wine jar with deer head lid, gold lines, small cracks",
|
458 |
-
"Textured tempura painting of a friendly waitress serving coffee",
|
459 |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
460 |
"An astronaut riding a green horse",
|
461 |
"A delicious ceviche cheesecake slice",
|
@@ -463,183 +93,102 @@ examples = [
|
|
463 |
|
464 |
css = """
|
465 |
#col-container {
|
466 |
-
margin: 0;
|
467 |
-
max-width:
|
468 |
-
}
|
469 |
-
body {
|
470 |
-
display: flex;
|
471 |
-
justify-content: space-between; # Space between left and right columns
|
472 |
-
}
|
473 |
-
.spacer {
|
474 |
-
height: 130px; # Large spacer
|
475 |
}
|
476 |
"""
|
477 |
|
478 |
-
|
|
|
|
|
|
|
479 |
|
480 |
with gr.Blocks(css=css) as demo:
|
481 |
|
482 |
-
with gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
"""
|
491 |
)
|
492 |
|
493 |
-
|
494 |
-
original_prompt = gr.Text(
|
495 |
-
label="Prompt",
|
496 |
-
show_label=False,
|
497 |
-
max_lines=1,
|
498 |
-
placeholder="Enter your prompt",
|
499 |
-
container=False,
|
500 |
-
)
|
501 |
|
502 |
-
|
|
|
503 |
|
504 |
-
|
505 |
|
506 |
-
|
507 |
-
label="
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
"2 - Lower Left",
|
512 |
-
"3 - Lower Right",
|
513 |
-
],
|
514 |
-
type="index",
|
515 |
)
|
516 |
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
fn=update_selected_index,
|
524 |
-
inputs=selected_index_left,
|
525 |
-
outputs=selected_text_left,
|
526 |
)
|
527 |
|
528 |
-
|
529 |
|
530 |
-
|
531 |
-
label="Negative prompt",
|
532 |
-
max_lines=1,
|
533 |
-
placeholder="Enter a negative prompt",
|
534 |
-
visible=False,
|
535 |
-
)
|
536 |
|
537 |
-
|
538 |
-
label="
|
539 |
-
minimum=
|
540 |
-
maximum=
|
541 |
-
step=
|
542 |
-
value=
|
543 |
)
|
544 |
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
maximum=MAX_IMAGE_SIZE,
|
553 |
-
step=32,
|
554 |
-
value=512,
|
555 |
-
)
|
556 |
-
|
557 |
-
height = gr.Slider(
|
558 |
-
label="Height",
|
559 |
-
minimum=256,
|
560 |
-
maximum=MAX_IMAGE_SIZE,
|
561 |
-
step=32,
|
562 |
-
value=512,
|
563 |
-
)
|
564 |
-
|
565 |
-
with gr.Row():
|
566 |
-
|
567 |
-
guidance_scale = gr.Slider(
|
568 |
-
label="Guidance scale",
|
569 |
-
minimum=0.0,
|
570 |
-
maximum=10.0,
|
571 |
-
step=0.1,
|
572 |
-
value=0.0,
|
573 |
-
)
|
574 |
-
|
575 |
-
num_inference_steps = gr.Slider(
|
576 |
-
label="Number of inference steps",
|
577 |
-
minimum=1,
|
578 |
-
maximum=12,
|
579 |
-
step=1,
|
580 |
-
value=2,
|
581 |
-
)
|
582 |
-
|
583 |
-
gr.Examples(examples=examples, inputs=[original_prompt])
|
584 |
-
##################################################################
|
585 |
-
|
586 |
-
#################### Final 4 Images Component ####################
|
587 |
-
with gr.Column(elem_id="col-container"):
|
588 |
-
gr.Markdown(
|
589 |
-
f"""
|
590 |
-
# Select The Best Image
|
591 |
-
you can choose the same one as before if you want
|
592 |
-
"""
|
593 |
-
)
|
594 |
-
|
595 |
-
# Adding empty rows for spacing
|
596 |
-
gr.Markdown(" ", elem_classes=["spacer"])
|
597 |
-
|
598 |
-
result_right = gr.Image(label="Result", show_label=False)
|
599 |
-
|
600 |
-
selected_index_right = gr.Radio(
|
601 |
-
label="Select Image",
|
602 |
-
choices=[
|
603 |
-
"0 - Upper Left",
|
604 |
-
"1 - Upper Right",
|
605 |
-
"2 - Lower Left",
|
606 |
-
"3 - Lower Right",
|
607 |
-
],
|
608 |
-
type="index",
|
609 |
-
)
|
610 |
|
611 |
-
|
612 |
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
##################### Final Prompt Component #####################
|
621 |
-
with gr.Column(elem_id="col-container"):
|
622 |
-
gr.Markdown(
|
623 |
-
f"""
|
624 |
-
# Enhance Prompt
|
625 |
-
"""
|
626 |
-
)
|
627 |
|
628 |
-
|
629 |
-
|
630 |
-
|
|
|
|
|
|
|
|
|
631 |
|
632 |
-
|
633 |
-
fn=get_final_prompt,
|
634 |
-
inputs=[original_prompt, selected_index_left, selected_index_right],
|
635 |
-
outputs=[enhanced_prompt_output],
|
636 |
-
)
|
637 |
-
##################################################################
|
638 |
|
639 |
-
|
640 |
-
fn=
|
641 |
inputs=[
|
642 |
-
|
643 |
negative_prompt,
|
644 |
seed,
|
645 |
randomize_seed,
|
@@ -648,22 +197,7 @@ with gr.Blocks(css=css) as demo:
|
|
648 |
guidance_scale,
|
649 |
num_inference_steps,
|
650 |
],
|
651 |
-
outputs=[
|
652 |
)
|
653 |
|
654 |
-
|
655 |
-
selected_index_left.change(
|
656 |
-
fn=enhance_prompt_and_create_new_images,
|
657 |
-
inputs=[
|
658 |
-
original_prompt,
|
659 |
-
selected_index_left,
|
660 |
-
width,
|
661 |
-
height,
|
662 |
-
guidance_scale,
|
663 |
-
num_inference_steps,
|
664 |
-
negative_prompt,
|
665 |
-
],
|
666 |
-
outputs=[result_right],
|
667 |
-
)
|
668 |
-
|
669 |
-
demo.queue().launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
|
|
|
|
|
|
3 |
import random
|
4 |
from diffusers import DiffusionPipeline
|
5 |
import torch
|
6 |
import transformers
|
|
|
7 |
|
8 |
# Perform cache migration
|
9 |
transformers.utils.move_cache()
|
10 |
|
11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
|
13 |
+
if torch.cuda.is_available():
|
14 |
+
torch.cuda.max_memory_allocated(device=device)
|
15 |
+
pipe = DiffusionPipeline.from_pretrained(
|
16 |
+
"stabilityai/sdxl-turbo",
|
17 |
+
torch_dtype=torch.float16,
|
18 |
+
variant="fp16",
|
19 |
+
use_safetensors=True,
|
20 |
+
)
|
21 |
+
pipe.enable_xformers_memory_efficient_attention()
|
22 |
+
pipe = pipe.to(device)
|
23 |
+
else:
|
24 |
+
pipe = DiffusionPipeline.from_pretrained(
|
25 |
+
"stabilityai/sdxl-turbo", use_safetensors=True
|
26 |
+
)
|
27 |
+
pipe = pipe.to(device)
|
28 |
+
|
29 |
+
# Quantize the model
|
30 |
+
pipe.unet = torch.quantization.convert(pipe.unet, inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
MAX_SEED = np.iinfo(np.int32).max
|
33 |
+
MAX_IMAGE_SIZE = 512
|
34 |
+
|
35 |
|
36 |
def generate_image(
|
37 |
seed, prompt, negative_prompt, guidance_scale, num_inference_steps, width, height
|
|
|
41 |
image = pipe(
|
42 |
prompt=prompt,
|
43 |
negative_prompt=negative_prompt,
|
44 |
+
guidance_scale=guidance_scale,
|
45 |
+
num_inference_steps=num_inference_steps,
|
46 |
+
width=width,
|
47 |
+
height=height,
|
48 |
generator=generator,
|
49 |
).images[0]
|
50 |
return image
|
|
|
53 |
return None
|
54 |
|
55 |
|
56 |
+
def infer(
|
57 |
+
prompt,
|
58 |
+
negative_prompt,
|
59 |
+
seed,
|
60 |
+
randomize_seed,
|
61 |
+
width,
|
62 |
+
height,
|
63 |
+
guidance_scale,
|
64 |
+
num_inference_steps,
|
65 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
if randomize_seed:
|
68 |
+
seeds = [random.randint(0, MAX_SEED) for _ in range(2)]
|
69 |
else:
|
70 |
+
seeds = [seed, seed + 1]
|
71 |
|
72 |
+
images = []
|
73 |
for seed in seeds:
|
74 |
image = generate_image(
|
75 |
seed,
|
|
|
80 |
width,
|
81 |
height,
|
82 |
)
|
83 |
+
images.append(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
return images
|
86 |
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
examples = [
|
|
|
|
|
89 |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
90 |
"An astronaut riding a green horse",
|
91 |
"A delicious ceviche cheesecake slice",
|
|
|
93 |
|
94 |
css = """
|
95 |
#col-container {
|
96 |
+
margin: 0 auto;
|
97 |
+
max-width: 520px;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
}
|
99 |
"""
|
100 |
|
101 |
+
if torch.cuda.is_available():
|
102 |
+
power_device = "GPU"
|
103 |
+
else:
|
104 |
+
power_device = "CPU"
|
105 |
|
106 |
with gr.Blocks(css=css) as demo:
|
107 |
|
108 |
+
with gr.Column(elem_id="col-container"):
|
109 |
+
gr.Markdown(
|
110 |
+
f"""
|
111 |
+
# Text-to-Image Gradio Template
|
112 |
+
Currently running on {power_device}.
|
113 |
+
"""
|
114 |
+
)
|
115 |
+
|
116 |
+
with gr.Row():
|
117 |
|
118 |
+
prompt = gr.Text(
|
119 |
+
label="Prompt",
|
120 |
+
show_label=False,
|
121 |
+
max_lines=1,
|
122 |
+
placeholder="Enter your prompt",
|
123 |
+
container=False,
|
|
|
124 |
)
|
125 |
|
126 |
+
run_button = gr.Button("Run", scale=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
+
result1 = gr.Image(label="Result 1", show_label=False)
|
129 |
+
result2 = gr.Image(label="Result 2", show_label=False)
|
130 |
|
131 |
+
with gr.Accordion("Advanced Settings", open=False):
|
132 |
|
133 |
+
negative_prompt = gr.Text(
|
134 |
+
label="Negative prompt",
|
135 |
+
max_lines=1,
|
136 |
+
placeholder="Enter a negative prompt",
|
137 |
+
visible=False,
|
|
|
|
|
|
|
|
|
138 |
)
|
139 |
|
140 |
+
seed = gr.Slider(
|
141 |
+
label="Seed",
|
142 |
+
minimum=0,
|
143 |
+
maximum=MAX_SEED,
|
144 |
+
step=1,
|
145 |
+
value=0,
|
|
|
|
|
|
|
146 |
)
|
147 |
|
148 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
149 |
|
150 |
+
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
+
width = gr.Slider(
|
153 |
+
label="Width",
|
154 |
+
minimum=256,
|
155 |
+
maximum=MAX_IMAGE_SIZE,
|
156 |
+
step=32,
|
157 |
+
value=512,
|
158 |
)
|
159 |
|
160 |
+
height = gr.Slider(
|
161 |
+
label="Height",
|
162 |
+
minimum=256,
|
163 |
+
maximum=MAX_IMAGE_SIZE,
|
164 |
+
step=32,
|
165 |
+
value=512,
|
166 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
with gr.Row():
|
169 |
|
170 |
+
guidance_scale = gr.Slider(
|
171 |
+
label="Guidance scale",
|
172 |
+
minimum=0.0,
|
173 |
+
maximum=10.0,
|
174 |
+
step=0.1,
|
175 |
+
value=0.0,
|
176 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
+
num_inference_steps = gr.Slider(
|
179 |
+
label="Number of inference steps",
|
180 |
+
minimum=1,
|
181 |
+
maximum=50, # Ensure the number of steps is reasonable
|
182 |
+
step=1,
|
183 |
+
value=2,
|
184 |
+
)
|
185 |
|
186 |
+
gr.Examples(examples=examples, inputs=[prompt])
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
+
run_button.click(
|
189 |
+
fn=infer,
|
190 |
inputs=[
|
191 |
+
prompt,
|
192 |
negative_prompt,
|
193 |
seed,
|
194 |
randomize_seed,
|
|
|
197 |
guidance_scale,
|
198 |
num_inference_steps,
|
199 |
],
|
200 |
+
outputs=[result1, result2],
|
201 |
)
|
202 |
|
203 |
+
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|