|
import gradio as gr |
|
import requests |
|
from PIL import Image |
|
import io |
|
import os |
|
from fal_client import submit |
|
|
|
def set_fal_key(api_key): |
|
os.environ["FAL_KEY"] = api_key |
|
return "FAL API key set successfully!" |
|
|
|
def generate_image(api_key, model, prompt, image_size, num_inference_steps, guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed): |
|
set_fal_key(api_key) |
|
|
|
if model == "Flux Ultra": |
|
arguments = { |
|
"prompt": prompt, |
|
"num_images": num_images, |
|
"enable_safety_checker": enable_safety_checker, |
|
"safety_tolerance": safety_tolerance, |
|
"aspect_ratio": image_size |
|
} |
|
fal_model = "fal-ai/flux-pro/v1.1-ultra" |
|
else: |
|
|
|
arguments = { |
|
"prompt": prompt, |
|
"image_size": image_size, |
|
"num_inference_steps": num_inference_steps, |
|
"num_images": num_images, |
|
} |
|
|
|
if model == "Flux Pro": |
|
arguments["guidance_scale"] = guidance_scale |
|
arguments["safety_tolerance"] = safety_tolerance |
|
fal_model = "fal-ai/flux-pro" |
|
elif model == "Flux Dev": |
|
arguments["guidance_scale"] = guidance_scale |
|
arguments["enable_safety_checker"] = enable_safety_checker |
|
fal_model = "fal-ai/flux/dev" |
|
else: |
|
arguments["enable_safety_checker"] = enable_safety_checker |
|
fal_model = "fal-ai/flux/schnell" |
|
|
|
if seed != -1: |
|
arguments["seed"] = seed |
|
|
|
try: |
|
handler = submit(fal_model, arguments=arguments) |
|
result = handler.get() |
|
images = [] |
|
for img_info in result["images"]: |
|
img_url = img_info["url"] |
|
img_response = requests.get(img_url) |
|
img = Image.open(io.BytesIO(img_response.content)) |
|
images.append(img) |
|
return images |
|
except Exception as e: |
|
print(f"Error: {str(e)}") |
|
return [Image.new('RGB', (512, 512), color='black')] |
|
|
|
def update_visible_components(model): |
|
if model == "Flux Ultra": |
|
return [ |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True, value="2"), |
|
gr.update(visible=True, value=True) |
|
] |
|
elif model == "Flux Pro": |
|
return [ |
|
gr.update(visible=True, value=28), |
|
gr.update(visible=True, value=3.5), |
|
gr.update(visible=True, value="2"), |
|
gr.update(visible=False) |
|
] |
|
elif model == "Flux Dev": |
|
return [ |
|
gr.update(visible=True, value=28), |
|
gr.update(visible=True, value=3.5), |
|
gr.update(visible=False), |
|
gr.update(visible=True, value=True) |
|
] |
|
else: |
|
return [ |
|
gr.update(visible=True, value=4), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True, value=True) |
|
] |
|
|
|
with gr.Blocks(theme='bethecloud/storj_theme') as demo: |
|
gr.HTML(""" |
|
<h1 align="center">FLUX 1.1 Ultra Image Generation</h1> |
|
<p align="center"> |
|
<a href="https://blackforestlabs.ai/" target="_blank">[Black Forest Labs]</a> |
|
<a href="https://blackforestlabs.ai/announcing-black-forest-labs/" target="_blank">[Blog]</a> |
|
<a href="https://fal.ai/models/fal-ai/flux-pro/v1.1-ultra" target="_blank">[FLUX 1.1 Ultra Model FAL]</a> |
|
<a href="https://fal.ai/dashboard/keys" target="_blank">[GET YOUR API KEY HERE]</a> |
|
</p> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
api_key = gr.Textbox(type="password", label="FAL API Key") |
|
model = gr.Dropdown( |
|
label="Model", |
|
choices=["Flux Ultra", "Flux Pro", "Flux Dev", "Flux Schnell"], |
|
value="Flux Ultra" |
|
) |
|
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Add your prompt here") |
|
|
|
|
|
ultra_sizes = ["21:9", "16:9", "4:3", "1:1", "3:4", "9:16", "9:21"] |
|
other_sizes = ["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"] |
|
|
|
image_size = gr.Dropdown( |
|
choices=ultra_sizes, |
|
label="Aspect Ratio", |
|
value="16:9" |
|
) |
|
|
|
with gr.Accordion("Advanced settings", open=False): |
|
num_inference_steps = gr.Slider(1, 100, 28, step=1, label="Number of Inference Steps", visible=False) |
|
guidance_scale = gr.Slider(0, 20, 3.5, step=0.1, label="Guidance Scale", visible=False) |
|
num_images = gr.Slider(1, 10, 1, step=1, label="Number of Images") |
|
safety_tolerance = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6"], label="Safety Tolerance", value="2") |
|
enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True) |
|
seed = gr.Number(label="Seed", value=-1) |
|
|
|
generate_btn = gr.Button("Generate Image") |
|
|
|
with gr.Column(scale=1): |
|
output_gallery = gr.Gallery(label="Generated Images", elem_id="gallery", show_label=False) |
|
|
|
def update_model_options(model): |
|
if model == "Flux Ultra": |
|
return [ |
|
gr.update(choices=ultra_sizes, value="16:9", label="Aspect Ratio"), |
|
*update_visible_components(model) |
|
] |
|
else: |
|
return [ |
|
gr.update(choices=other_sizes, value="landscape_16_9", label="Image Size"), |
|
*update_visible_components(model) |
|
] |
|
|
|
model.change( |
|
update_model_options, |
|
inputs=[model], |
|
outputs=[image_size, num_inference_steps, guidance_scale, safety_tolerance, enable_safety_checker] |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_image, |
|
inputs=[ |
|
api_key, model, prompt, image_size, num_inference_steps, |
|
guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed |
|
], |
|
outputs=[output_gallery] |
|
) |
|
|
|
demo.launch() |