|
import gradio as gr |
|
import fal_client |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
import traceback |
|
import os |
|
|
|
def generate_image(api_key, prompt, image_size='landscape_4_3', num_images=1): |
|
try: |
|
|
|
os.environ['FAL_KEY'] = api_key |
|
|
|
handler = fal_client.submit( |
|
"fal-ai/flux-pro/v1.1", |
|
arguments={ |
|
"prompt": prompt, |
|
"image_size": image_size, |
|
"num_images": num_images, |
|
}, |
|
) |
|
result = handler.get() |
|
images = [] |
|
for img_info in result['images']: |
|
img_url = img_info['url'] |
|
|
|
response = requests.get(img_url) |
|
img = Image.open(BytesIO(response.content)) |
|
images.append(img) |
|
return [gr.update(value=images, visible=True), gr.update(visible=False)] |
|
except Exception as e: |
|
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
print(error_msg) |
|
return [gr.update(visible=False), gr.update(value=error_msg, visible=True)] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# FLUX1.1 [pro] Text-to-Image Generator") |
|
gr.Markdown("get your api key at https://fal.ai/dashboard/keys") |
|
|
|
with gr.Row(): |
|
api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API key here") |
|
with gr.Row(): |
|
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here") |
|
with gr.Row(): |
|
image_size = gr.Dropdown( |
|
label="Image Size", |
|
choices=["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"], |
|
value="landscape_4_3" |
|
) |
|
num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1) |
|
generate_btn = gr.Button("Generate Image") |
|
output_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2) |
|
error_output = gr.Textbox(label="Error Message", visible=False) |
|
|
|
generate_btn.click( |
|
fn=generate_image, |
|
inputs=[api_key, prompt, image_size, num_images], |
|
outputs=[output_gallery, error_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |