gokaygokay's picture
Update app.py
bd1238e verified
import gradio as gr
import requests
from PIL import Image
import io
import os
from fal_client import submit
def set_fal_key(api_key):
os.environ["FAL_KEY"] = api_key
return "FAL API key set successfully!"
def generate_image(api_key, model, prompt, image_size, num_inference_steps, guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed):
set_fal_key(api_key)
if model == "Flux Ultra":
arguments = {
"prompt": prompt,
"num_images": num_images,
"enable_safety_checker": enable_safety_checker,
"safety_tolerance": safety_tolerance,
"aspect_ratio": image_size # For Ultra, we pass the aspect ratio directly
}
fal_model = "fal-ai/flux-pro/v1.1-ultra"
else:
# Original logic for other models
arguments = {
"prompt": prompt,
"image_size": image_size,
"num_inference_steps": num_inference_steps,
"num_images": num_images,
}
if model == "Flux Pro":
arguments["guidance_scale"] = guidance_scale
arguments["safety_tolerance"] = safety_tolerance
fal_model = "fal-ai/flux-pro"
elif model == "Flux Dev":
arguments["guidance_scale"] = guidance_scale
arguments["enable_safety_checker"] = enable_safety_checker
fal_model = "fal-ai/flux/dev"
else: # Flux Schnell
arguments["enable_safety_checker"] = enable_safety_checker
fal_model = "fal-ai/flux/schnell"
if seed != -1:
arguments["seed"] = seed
try:
handler = submit(fal_model, arguments=arguments)
result = handler.get()
images = []
for img_info in result["images"]:
img_url = img_info["url"]
img_response = requests.get(img_url)
img = Image.open(io.BytesIO(img_response.content))
images.append(img)
return images
except Exception as e:
print(f"Error: {str(e)}")
return [Image.new('RGB', (512, 512), color='black')]
def update_visible_components(model):
if model == "Flux Ultra":
return [
gr.update(visible=False), # num_inference_steps not used in Ultra
gr.update(visible=False), # guidance_scale not used in Ultra
gr.update(visible=True, value="2"), # safety_tolerance
gr.update(visible=True, value=True) # enable_safety_checker
]
elif model == "Flux Pro":
return [
gr.update(visible=True, value=28),
gr.update(visible=True, value=3.5),
gr.update(visible=True, value="2"),
gr.update(visible=False)
]
elif model == "Flux Dev":
return [
gr.update(visible=True, value=28),
gr.update(visible=True, value=3.5),
gr.update(visible=False),
gr.update(visible=True, value=True)
]
else: # Flux Schnell
return [
gr.update(visible=True, value=4),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True, value=True)
]
with gr.Blocks(theme='bethecloud/storj_theme') as demo:
gr.HTML("""
<h1 align="center">FLUX 1.1 Ultra Image Generation</h1>
<p align="center">
<a href="https://blackforestlabs.ai/" target="_blank">[Black Forest Labs]</a>
<a href="https://blackforestlabs.ai/announcing-black-forest-labs/" target="_blank">[Blog]</a>
<a href="https://fal.ai/models/fal-ai/flux-pro/v1.1-ultra" target="_blank">[FLUX 1.1 Ultra Model FAL]</a>
<a href="https://fal.ai/dashboard/keys" target="_blank">[GET YOUR API KEY HERE]</a>
</p>
""")
with gr.Row():
with gr.Column(scale=1):
api_key = gr.Textbox(type="password", label="FAL API Key")
model = gr.Dropdown(
label="Model",
choices=["Flux Ultra", "Flux Pro", "Flux Dev", "Flux Schnell"],
value="Flux Ultra"
)
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Add your prompt here")
# Different aspect ratio options based on model
ultra_sizes = ["21:9", "16:9", "4:3", "1:1", "3:4", "9:16", "9:21"]
other_sizes = ["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"]
image_size = gr.Dropdown(
choices=ultra_sizes,
label="Aspect Ratio",
value="16:9"
)
with gr.Accordion("Advanced settings", open=False):
num_inference_steps = gr.Slider(1, 100, 28, step=1, label="Number of Inference Steps", visible=False)
guidance_scale = gr.Slider(0, 20, 3.5, step=0.1, label="Guidance Scale", visible=False)
num_images = gr.Slider(1, 10, 1, step=1, label="Number of Images")
safety_tolerance = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6"], label="Safety Tolerance", value="2")
enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True)
seed = gr.Number(label="Seed", value=-1)
generate_btn = gr.Button("Generate Image")
with gr.Column(scale=1):
output_gallery = gr.Gallery(label="Generated Images", elem_id="gallery", show_label=False)
def update_model_options(model):
if model == "Flux Ultra":
return [
gr.update(choices=ultra_sizes, value="16:9", label="Aspect Ratio"),
*update_visible_components(model)
]
else:
return [
gr.update(choices=other_sizes, value="landscape_16_9", label="Image Size"),
*update_visible_components(model)
]
model.change(
update_model_options,
inputs=[model],
outputs=[image_size, num_inference_steps, guidance_scale, safety_tolerance, enable_safety_checker]
)
generate_btn.click(
fn=generate_image,
inputs=[
api_key, model, prompt, image_size, num_inference_steps,
guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed
],
outputs=[output_gallery]
)
demo.launch()