Spaces:
Runtime error
Runtime error
File size: 4,479 Bytes
fd1c028 5767af2 46d9ce2 ce70a4b 5767af2 fd1c028 5767af2 c3f2272 5767af2 fd1c028 dba288d fd1c028 c3f2272 c2b4feb 5767af2 66ab16f c3f2272 5767af2 c3f2272 5767af2 c3f2272 5767af2 66ab16f 5767af2 c3f2272 5767af2 c3f2272 1a530c6 c3f2272 fd1c028 5767af2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import torch
import sa_handler
# init models
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
scheduler=scheduler
).to("cuda")
# Configure the pipeline for CPU offloading and VAE slicing#pipeline.enable_sequential_cpu_offload()
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
# Initialize the style-aligned handler
handler = sa_handler.Handler(pipeline)
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
share_layer_norm=False,
share_attention=True,
adain_queries=True,
adain_keys=True,
adain_values=False,
)
handler.register(sa_args, )
# Define the function to generate style-aligned images
def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt):
try:
# Combine the style prompt with each initial prompt
sets_of_prompts = [ prompt + ". " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5,]]
# Generate images using the pipeline
images = pipeline(sets_of_prompts,).images
return images
except Exception as e:
raise gr.Error(f"Error in generating images: {e}")
with gr.Blocks() as demo:
gr.HTML('<h1 style="text-align: center;">Style-aligned SDXL</h1>')
with gr.Group():
with gr.Column():
with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
with gr.Row(variant='panel'):
# Textboxes for initial prompts
initial_prompt1 = gr.Textbox(label='Initial prompt 1', value='', show_label=False, container=False, placeholder='a toy train')
initial_prompt2 = gr.Textbox(label='Initial prompt 2', value='', show_label=False, container=False, placeholder='a toy airplane')
initial_prompt3 = gr.Textbox(label='Initial prompt 3', value='', show_label=False, container=False, placeholder='a toy bicycle')
initial_prompt4 = gr.Textbox(label='Initial prompt 4', value='', show_label=False, container=False, placeholder='a toy car')
initial_prompt5 = gr.Textbox(label='Initial prompt 5', value='', show_label=False, container=False, placeholder='a toy boat')
with gr.Row():
# Textbox for the style prompt
style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset')
# Button to generate images
btn = gr.Button("Generate a set of Style-aligned SDXL images",)
# Display the generated images
output = gr.Gallery(label="Style-Aligned SDXL Images", elem_id="gallery",columns=5, rows=1, object_fit="contain", height="auto",)
# Button click event
btn.click(fn=style_aligned_sdxl,
inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
outputs=output,
api_name="style_aligned_sdxl")
# Providing Example inputs for the demo
gr.Examples(examples=[
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
["a cat", "a dog", "a bear", "a man on a bicycle", "a girl working on laptop", "minimal origami."],
["a firewoman", "a Gardner", "a scientist", "a policewoman", "a saxophone player", "made of claymation, stop motion animation."],
["a firewoman", "a Gardner", "a scientist", "a policewoman", "a saxophone player", "sketch, character sheet."],
],
inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
outputs=[output],
fn=style_aligned_sdxl)
# Launch the Gradio demo
demo.launch() |