fffiloni commited on
Commit
a0976c0
1 Parent(s): 97d96ce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from diffusers import DiffusionPipeline, FluxControlPipeline, FluxTransformer2DModel
4
+ import torch
5
+ from transformers import T5EncoderModel
6
+ from controlnet_aux import CannyDetector
7
+ from diffusers.utils import load_image
8
+
9
+ from huggingface_hub import login
10
+ hf_token = os.environ.get["HF_TOKEN"]
11
+ login(hf_token)
12
+
13
+ def load_pipeline(four_bit=False):
14
+ orig_pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
15
+ if four_bit:
16
+ print("Using four bit.")
17
+ transformer = FluxTransformer2DModel.from_pretrained(
18
+ "sayakpaul/FLUX.1-Canny-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
19
+ )
20
+ text_encoder_2 = T5EncoderModel.from_pretrained(
21
+ "sayakpaul/FLUX.1-Canny-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
22
+ )
23
+ pipeline = FluxControlPipeline.from_pipe(
24
+ orig_pipeline, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16
25
+ )
26
+ else:
27
+ transformer = FluxTransformer2DModel.from_pretrained(
28
+ "black-forest-labs/FLUX.1-Canny-dev",
29
+ subfolder="transformer",
30
+ revision="refs/pr/1",
31
+ torch_dtype=torch.bfloat16,
32
+ )
33
+ pipeline = FluxControlPipeline.from_pipe(orig_pipeline, transformer=transformer, torch_dtype=torch.bfloat16)
34
+
35
+ pipeline.enable_model_cpu_offload()
36
+ return pipeline
37
+
38
+ def get_canny(control_image):
39
+ processor = CannyDetector()
40
+ control_image = processor(
41
+ control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
42
+ )
43
+ return control_image
44
+
45
+ def main(ref_filepath, prompt, use_nf4):
46
+ pipe = load_pipeline(use_nf4)
47
+ control_image = load_image(ref_filepath)
48
+ control_image = get_canny(control_image)
49
+ image = pipe(
50
+ prompt=prompt,
51
+ control_image=control_image,
52
+ height=1024,
53
+ width=1024,
54
+ num_inference_steps=50,
55
+ guidance_scale=30.0,
56
+ max_sequence_length=512,
57
+ generator=torch.Generator("cpu").manual_seed(0),
58
+ ).images[0]
59
+ filename = "output_"
60
+ filename += "_4bit" if four_bit else ""
61
+ image.save(f"{filename}.png")
62
+ return f"{filename}.png", control_image
63
+
64
+ with gr.Blocks() as demo:
65
+ with gr.Column():
66
+ gr.Markdown("# FLUX.1 Canny Dev")
67
+ with gr.Row():
68
+ with gr.Column():
69
+ image_input = gr.Image(label="Reference Image", type="filepath")
70
+ prompt = gr.Textbox(label="Prompt")
71
+ use_nf4 = gr.Checkbox(label="Use NF4 checkpoints", value=True)
72
+ submit_btn = gr.Button("Submit")
73
+ with gr.Column():
74
+ results= gr.Gallery(label="Results")
75
+
76
+ submit_btn.click(
77
+ fn = main,
78
+ inputs = [image_input, prompt, use_nf4],
79
+ outputs = [results]
80
+ )
81
+
82
+ demo.launch(show_api=False, show_error=True)