import spaces import gradio as gr import numpy as np import random import torch from diffusers import AuraFlowPipeline device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize the AuraFlow v0.3 pipeline pipe = AuraFlowPipeline.from_pretrained( "fal/AuraFlow-v0.3", torch_dtype=torch.float16 ).to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 @spaces.GPU def infer(prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(seed) image = pipe( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator ).images[0] return image, seed css = """ footer { visibility: hidden; } """ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo: with gr.Row(): with gr.Column(scale=1): prompt = gr.Text(label="Prompt", placeholder="Enter your prompt") negative_prompt = gr.Text(label="Negative prompt", placeholder="Enter a negative prompt") seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0) num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=28) run_button = gr.Button("Generate") with gr.Column(scale=1): result = gr.Image(label="Generated Image") seed_output = gr.Number(label="Seed used") run_button.click( fn=infer, inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], outputs=[result, seed_output] ) gr.Examples( examples=[ "A photo of a lavender cat", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", ], inputs=prompt, ) demo.queue().launch(server_name="0.0.0.0", share=False)