Safety Checker

#1
by TheOneHong - opened

I tried to disable the checker by passing parameters, but it seems to check no matter what, as I am sure my code works.

--- original
+++ latest
@@ -5,11 +5,25 @@
 import traceback
 import os
 
-def generate_image(api_key, prompt, image_size='landscape_4_3', num_images=1):
+def generate_image(api_key, prompt, image_size='landscape_4_3', num_images=1, enable_safety_checker=True, safety_tolerance=2):
     try:
         # Set the API key as an environment variable
         os.environ['FAL_KEY'] = api_key
 
+        arguments = {
+            "prompt": prompt,
+            "image_size": image_size,
+            "num_images": num_images,
+        }
+        
+        if enable_safety_checker:
+            arguments["enable_safety_checker"] = True
+            arguments["safety_tolerance"] = safety_tolerance
+        else:
+            arguments["enable_safety_checker"] = False
+
         handler = fal_client.submit(
             "fal-ai/flux-pro/v1.1",
-            arguments={
-                "prompt": prompt,
-                "image_size": image_size,
-                "num_images": num_images,
-            },
+            arguments=arguments,
         )
         result = handler.get()
         images = []
@@ -24,11 +38,15 @@
         print(error_msg)
         return [gr.update(visible=False), gr.update(value=error_msg, visible=True)]
 
+def update_safety_tolerance_visibility(enable_safety):
+    return gr.update(visible=enable_safety)
+
 with gr.Blocks() as demo:
     gr.Markdown("# FLUX1.1 [pro] Text-to-Image Generator")
-    gr.Markdown("get your api key at https://fal.ai/dashboard/keys")
+    gr.Markdown("Get your API key at https://fal.ai/dashboard/keys")
 
     with gr.Row():
         api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API key here")
     with gr.Row():
         prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here")
     with gr.Row():
@@ -38,15 +56,30 @@
             value="landscape_4_3"
         )
         num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1)
+    with gr.Row():
+        enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True)
+        safety_tolerance = gr.Slider(label="Safety Tolerance", minimum=0, maximum=5, step=1, value=2, visible=True)
+    gr.Markdown("**Note:** The effectiveness of the safety checker and tolerance settings may vary depending on the API's implementation. For the most accurate information, please consult the fal.ai documentation or contact their support.")
+    
     generate_btn = gr.Button("Generate Image")
     output_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2)
     error_output = gr.Textbox(label="Error Message", visible=False)
 
+    enable_safety_checker.change(
+        fn=update_safety_tolerance_visibility,
+        inputs=[enable_safety_checker],
+        outputs=[safety_tolerance]
+    )
+
     generate_btn.click(
         fn=generate_image,
-        inputs=[api_key, prompt, image_size, num_images],
+        inputs=[api_key, prompt, image_size, num_images, enable_safety_checker, safety_tolerance],
         outputs=[output_gallery, error_output]
     )
 
 if __name__ == "__main__":
     demo.launch()

Sign up or log in to comment