AuroFlow-v3 / app.py
fantaxy's picture
Update app.py
0133b45 verified
raw
history blame contribute delete
No virus
2.8 kB
import spaces
import gradio as gr
import numpy as np
import random
import torch
from diffusers import AuraFlowPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the AuraFlow v0.3 pipeline
pipe = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3",
torch_dtype=torch.float16
).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU
def infer(prompt,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=5.0,
num_inference_steps=28,
progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
).images[0]
return image, seed
css = """
footer {
visibility: hidden;
}
"""
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt")
negative_prompt = gr.Text(label="Negative prompt", placeholder="Enter a negative prompt")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=28)
run_button = gr.Button("Generate")
with gr.Column(scale=1):
result = gr.Image(label="Generated Image")
seed_output = gr.Number(label="Seed used")
run_button.click(
fn=infer,
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed_output]
)
gr.Examples(
examples=[
"A photo of a lavender cat",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
],
inputs=prompt,
)
demo.queue().launch(server_name="0.0.0.0", share=False)