import gradio as gr from src.const import MODEL_CHOICES from src.example import EXAMPLES from src.inference import inference def build_interface(): """Build Gradio Interface""" theme = gr.themes.Default(primary_hue=gr.themes.colors.emerald) with gr.Blocks(theme=theme) as interface: gr.Markdown(f"# Stable Diffusion Demo") with gr.Row(): with gr.Column(): prompt = gr.Text(label="Prompt", placeholder="Enter a prompt here") model_id = gr.Dropdown( label="Model ID", choices=MODEL_CHOICES, value="stabilityai/stable-diffusion-3-medium-diffusers", ) # Additional Input Settings with gr.Accordion("Additional Settings", open=False): negative_prompt = gr.Text(label="Negative Prompt", value="", ) with gr.Row(): width = gr.Number(label="Width", value=512, step=64, minimum=64, maximum=2048) height = gr.Number(label="Height", value=512, step=64, minimum=64, maximum=2048) num_images = gr.Number(label="Num Images", value=4, minimum=1, maximum=10, step=1) seed = gr.Number(label="Seed", value=8888, step=1) guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.5, minimum=0, maximum=10) num_inference_step = gr.Slider( label="Num Inference Steps", value=50, minimum=1, maximum=100, step=2 ) with gr.Row(): use_safety_checker = gr.Checkbox(value=True, label='Use Safety Checker') use_model_offload = gr.Checkbox(value=False, label='Use Model Offload') with gr.Accordion(label='Notes', open=False): # language=HTML notes = gr.HTML( """
If you want to use negative embedding, use the following tokens in the prompt.
""" ) with gr.Column(): output_image = gr.Image(label="Image", type="pil") inputs = [ prompt, model_id, negative_prompt, width, height, guidance_scale, num_inference_step, num_images, use_safety_checker, use_model_offload, seed, ] btn = gr.Button("Generate", variant='primary') btn.click( fn=inference, inputs=inputs, outputs=output_image ) gr.Examples( examples=EXAMPLES, inputs=inputs, outputs=output_image, fn=inference, cache_examples='lazy' ) return interface if __name__ == "__main__": iface = build_interface() iface.queue().launch()