Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
|
3 |
import torch
|
4 |
-
import mediapy
|
5 |
import sa_handler
|
6 |
|
7 |
# init models
|
@@ -11,10 +10,10 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
|
|
11 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
|
12 |
scheduler=scheduler
|
13 |
).to("cuda")
|
14 |
-
#pipeline.enable_sequential_cpu_offload()
|
15 |
pipeline.enable_model_cpu_offload()
|
16 |
pipeline.enable_vae_slicing()
|
17 |
-
|
18 |
handler = sa_handler.Handler(pipeline)
|
19 |
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
|
20 |
share_layer_norm=False,
|
@@ -26,43 +25,43 @@ sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
|
|
26 |
|
27 |
handler.register(sa_args, )
|
28 |
|
29 |
-
#
|
30 |
-
sets_of_prompts = [
|
31 |
-
"a toy train. macro photo. 3d game asset",
|
32 |
-
"a toy airplane. macro photo. 3d game asset",
|
33 |
-
"a toy bicycle. macro photo. 3d game asset",
|
34 |
-
"a toy car. macro photo. 3d game asset",
|
35 |
-
"a toy boat. macro photo. 3d game asset",
|
36 |
-
]
|
37 |
-
|
38 |
-
# run StyleAligned
|
39 |
def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt):
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
45 |
|
46 |
with gr.Blocks() as demo:
|
47 |
with gr.Group():
|
48 |
with gr.Column():
|
49 |
with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
|
50 |
with gr.Row(variant='panel'):
|
|
|
51 |
initial_prompt1 = gr.Textbox(label='Initial prompt 1', value='', show_label=False, container=False, placeholder='a toy train')
|
52 |
initial_prompt2 = gr.Textbox(label='Initial prompt 2', value='', show_label=False, container=False, placeholder='a toy airplane')
|
53 |
initial_prompt3 = gr.Textbox(label='Initial prompt 3', value='', show_label=False, container=False, placeholder='a toy bicycle')
|
54 |
initial_prompt4 = gr.Textbox(label='Initial prompt 4', value='', show_label=False, container=False, placeholder='a toy car')
|
55 |
initial_prompt5 = gr.Textbox(label='Initial prompt 5', value='', show_label=False, container=False, placeholder='a toy boat')
|
56 |
with gr.Row():
|
|
|
57 |
style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset')
|
|
|
58 |
btn = gr.Button("Generate a set of Style-aligned SDXL images",)
|
|
|
59 |
output = gr.Gallery(label="Style-Aligned SDXL Images", elem_id="gallery",columns=5, rows=1, object_fit="contain", height="auto",)
|
60 |
-
|
|
|
61 |
btn.click(fn=style_aligned_sdxl,
|
62 |
inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
|
63 |
outputs=output,
|
64 |
api_name="style_aligned_sdxl")
|
65 |
|
|
|
66 |
gr.Examples(examples=[
|
67 |
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
|
68 |
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
|
@@ -74,5 +73,5 @@ with gr.Blocks() as demo:
|
|
74 |
outputs=[output],
|
75 |
fn=style_aligned_sdxl)
|
76 |
|
77 |
-
demo
|
78 |
-
|
|
|
1 |
import gradio as gr
|
2 |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
|
3 |
import torch
|
|
|
4 |
import sa_handler
|
5 |
|
6 |
# init models
|
|
|
10 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
|
11 |
scheduler=scheduler
|
12 |
).to("cuda")
|
13 |
+
# Configure the pipeline for CPU offloading and VAE slicing#pipeline.enable_sequential_cpu_offload()
|
14 |
pipeline.enable_model_cpu_offload()
|
15 |
pipeline.enable_vae_slicing()
|
16 |
+
# Initialize the style-aligned handler
|
17 |
handler = sa_handler.Handler(pipeline)
|
18 |
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
|
19 |
share_layer_norm=False,
|
|
|
25 |
|
26 |
handler.register(sa_args, )
|
27 |
|
28 |
+
# Define the function to generate style-aligned images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt):
|
30 |
+
try:
|
31 |
+
# Combine the style prompt with each initial prompt
|
32 |
+
sets_of_prompts = [ prompt + ". " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5,]]
|
33 |
+
# Generate images using the pipeline
|
34 |
+
images = pipeline(sets_of_prompts,).images
|
35 |
+
return images
|
36 |
+
except Exception as e:
|
37 |
+
raise gr.Error(f"Error in generating images: {e}")
|
38 |
|
39 |
with gr.Blocks() as demo:
|
40 |
with gr.Group():
|
41 |
with gr.Column():
|
42 |
with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
|
43 |
with gr.Row(variant='panel'):
|
44 |
+
# Textboxes for initial prompts
|
45 |
initial_prompt1 = gr.Textbox(label='Initial prompt 1', value='', show_label=False, container=False, placeholder='a toy train')
|
46 |
initial_prompt2 = gr.Textbox(label='Initial prompt 2', value='', show_label=False, container=False, placeholder='a toy airplane')
|
47 |
initial_prompt3 = gr.Textbox(label='Initial prompt 3', value='', show_label=False, container=False, placeholder='a toy bicycle')
|
48 |
initial_prompt4 = gr.Textbox(label='Initial prompt 4', value='', show_label=False, container=False, placeholder='a toy car')
|
49 |
initial_prompt5 = gr.Textbox(label='Initial prompt 5', value='', show_label=False, container=False, placeholder='a toy boat')
|
50 |
with gr.Row():
|
51 |
+
# Textbox for the style prompt
|
52 |
style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset')
|
53 |
+
# Button to generate images
|
54 |
btn = gr.Button("Generate a set of Style-aligned SDXL images",)
|
55 |
+
# Display the generated images
|
56 |
output = gr.Gallery(label="Style-Aligned SDXL Images", elem_id="gallery",columns=5, rows=1, object_fit="contain", height="auto",)
|
57 |
+
|
58 |
+
# Button click event
|
59 |
btn.click(fn=style_aligned_sdxl,
|
60 |
inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
|
61 |
outputs=output,
|
62 |
api_name="style_aligned_sdxl")
|
63 |
|
64 |
+
# Providing Example inputs for the demo
|
65 |
gr.Examples(examples=[
|
66 |
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
|
67 |
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
|
|
|
73 |
outputs=[output],
|
74 |
fn=style_aligned_sdxl)
|
75 |
|
76 |
+
# Launch the Gradio demo
|
77 |
+
demo.launch()
|