File size: 4,479 Bytes
fd1c028
 
 
 
 
 
 
 
 
 
 
 
5767af2
46d9ce2
ce70a4b
5767af2
fd1c028
 
 
 
 
 
 
 
 
 
 
5767af2
c3f2272
5767af2
 
 
 
 
 
 
 
fd1c028
 
dba288d
fd1c028
c3f2272
 
c2b4feb
5767af2
66ab16f
 
 
 
 
c3f2272
5767af2
c3f2272
5767af2
c3f2272
5767af2
66ab16f
5767af2
 
c3f2272
 
 
 
 
5767af2
c3f2272
 
 
 
1a530c6
 
c3f2272
 
 
 
fd1c028
5767af2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import torch
import sa_handler

# init models
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
                              set_alpha_to_one=False)
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
    scheduler=scheduler
).to("cuda")
# Configure the pipeline for CPU offloading and VAE slicing#pipeline.enable_sequential_cpu_offload()
pipeline.enable_model_cpu_offload() 
pipeline.enable_vae_slicing()
# Initialize the style-aligned handler
handler = sa_handler.Handler(pipeline)
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
                                      share_layer_norm=False,
                                      share_attention=True,
                                      adain_queries=True,
                                      adain_keys=True,
                                      adain_values=False,
                                     )

handler.register(sa_args, )

# Define the function to generate style-aligned images
def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt): 
    try:
        # Combine the style prompt with each initial prompt
        sets_of_prompts = [ prompt + ". " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5,]]
        # Generate images using the pipeline
        images = pipeline(sets_of_prompts,).images
        return images
    except Exception as e:
        raise gr.Error(f"Error in generating images: {e}")

with gr.Blocks() as demo:
    gr.HTML('<h1 style="text-align: center;">Style-aligned SDXL</h1>')
    with gr.Group():
      with gr.Column():
        with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
          with gr.Row(variant='panel'):
            # Textboxes for initial prompts
            initial_prompt1 = gr.Textbox(label='Initial prompt 1', value='', show_label=False, container=False, placeholder='a toy train')
            initial_prompt2 = gr.Textbox(label='Initial prompt 2', value='', show_label=False, container=False, placeholder='a toy airplane')
            initial_prompt3 = gr.Textbox(label='Initial prompt 3', value='', show_label=False, container=False, placeholder='a toy bicycle')
            initial_prompt4 = gr.Textbox(label='Initial prompt 4', value='', show_label=False, container=False, placeholder='a toy car')
            initial_prompt5 = gr.Textbox(label='Initial prompt 5', value='', show_label=False, container=False, placeholder='a toy boat')
        with gr.Row():
          # Textbox for the style prompt
          style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset')
        # Button to generate images
        btn = gr.Button("Generate a set of Style-aligned SDXL images",)
    # Display the generated images
    output = gr.Gallery(label="Style-Aligned SDXL Images", elem_id="gallery",columns=5, rows=1, object_fit="contain", height="auto",)

    # Button click event
    btn.click(fn=style_aligned_sdxl, 
              inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt], 
              outputs=output, 
              api_name="style_aligned_sdxl")

    # Providing Example inputs for the demo
    gr.Examples(examples=[
                    ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
                    ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
                    ["a cat", "a dog", "a bear", "a man on a bicycle", "a girl working on laptop", "minimal origami."],
                    ["a firewoman", "a Gardner", "a scientist", "a policewoman", "a saxophone player", "made of claymation, stop motion animation."],
                    ["a firewoman", "a Gardner", "a scientist", "a policewoman", "a saxophone player", "sketch, character sheet."],
                    ],
            inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
            outputs=[output],
            fn=style_aligned_sdxl)

# Launch the Gradio demo
demo.launch()