import gradio as gr import os from diffusers import DiffusionPipeline, FluxControlPipeline, FluxTransformer2DModel import torch from transformers import T5EncoderModel from controlnet_aux import CannyDetector from diffusers.utils import load_image from huggingface_hub import login hf_token = os.environ.get("HF_TOKEN") login(token=hf_token) def load_pipeline(four_bit=False): orig_pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) if four_bit: print("Using four bit.") transformer = FluxTransformer2DModel.from_pretrained( "sayakpaul/FLUX.1-Canny-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16 ) text_encoder_2 = T5EncoderModel.from_pretrained( "sayakpaul/FLUX.1-Canny-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16 ) pipeline = FluxControlPipeline.from_pipe( orig_pipeline, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16 ) else: transformer = FluxTransformer2DModel.from_pretrained( "black-forest-labs/FLUX.1-Canny-dev", subfolder="transformer", revision="refs/pr/1", torch_dtype=torch.bfloat16, ) pipeline = FluxControlPipeline.from_pipe(orig_pipeline, transformer=transformer, torch_dtype=torch.bfloat16) pipeline.enable_model_cpu_offload() return pipeline def get_canny(control_image): processor = CannyDetector() control_image = processor( control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 ) return control_image def main(ref_filepath, prompt, use_nf4, progress=gr.Progress(track_tqdm=True)): pipe = load_pipeline(use_nf4) control_image = load_image(ref_filepath) control_image = get_canny(control_image) image = pipe( prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=50, guidance_scale=30.0, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0), ).images[0] filename = "output_" filename += "_4bit" if use_nf4 else "" image.save(f"{filename}.png") return f"{filename}.png", control_image with gr.Blocks() as demo: with gr.Column(): gr.Markdown("# FLUX.1 Canny Dev") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Reference Image", type="filepath") prompt = gr.Textbox(label="Prompt") use_nf4 = gr.Checkbox(label="Use NF4 checkpoints", value=True) submit_btn = gr.Button("Submit") with gr.Column(): results= gr.Gallery(label="Results") submit_btn.click( fn = main, inputs = [image_input, prompt, use_nf4], outputs = [results] ) demo.launch(show_api=False, show_error=True)