File size: 9,966 Bytes
cd1e8dc
628e6c3
cd1e8dc
 
 
 
 
 
 
 
 
 
628e6c3
35f97ba
628e6c3
ba29a7c
 
 
 
 
 
 
a30c911
 
 
 
 
 
ba29a7c
 
 
 
 
 
cd1e8dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bff
cd1e8dc
 
 
 
 
a5e9129
 
 
cd1e8dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35f97ba
 
be5cb04
f2819ff
ce09356
 
ccf1a03
cd1e8dc
 
 
 
 
 
 
 
4598830
cd1e8dc
 
 
4598830
 
 
cd1e8dc
 
 
e57f9cc
cd1e8dc
 
9702a1f
cd1e8dc
 
a5e9129
 
cd1e8dc
 
9702a1f
a5e9129
cd1e8dc
 
 
 
7b66f42
cd1e8dc
 
 
 
 
 
 
 
 
788a013
dc72f49
788a013
 
 
 
cd1e8dc
788a013
 
 
 
cd1e8dc
 
68696f0
71f4cfe
 
5461399
cd1e8dc
 
 
 
 
 
 
71f4cfe
 
cd1e8dc
 
 
71f4cfe
 
 
783c45d
cd1e8dc
 
 
783c45d
cd1e8dc
 
 
a5e9129
cd1e8dc
628e6c3
 
 
 
 
 
 
 
 
cd1e8dc
 
628e6c3
 
 
 
 
cd1e8dc
628e6c3
 
 
 
 
 
 
 
 
 
 
 
cd1e8dc
 
 
 
 
 
628e6c3
 
 
 
 
 
 
 
3f2581f
628e6c3
7b6aea1
4598830
 
2d0240c
628e6c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba29a7c
 
 
a30c911
5e4a2d1
a30c911
 
ba29a7c
 
 
 
 
 
628e6c3
 
 
cd1e8dc
628e6c3
cd1e8dc
 
 
628e6c3
 
4598830
628e6c3
cd1e8dc
628e6c3
 
 
 
 
 
88bd8ea
 
628e6c3
 
a5e9129
 
 
628e6c3
a5e9129
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py
# The original license file is LICENSE.ControlNet in this repo.
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
from flax.training.common_utils import shard
from flax.jax_utils import replicate    
from diffusers.utils import load_image
import jax.numpy as jnp
import jax
import cv2
from PIL import Image
import numpy as np
import gradio as gr
import os

description = """
Our project is to use diffusion model to change the texture of our robotic arm simulation.

To do so, we first get our simulated images. After, we process these images to get Canny Edge maps. Finally, we can get brand new images by using ControlNet.

Therefore, we are able to change our simulation texture, and still keeping the image composition.


Our objectif for the sprint is to perform data augmentation using ControlNet. We then look for having a model that can augment an image quickly.
For now, we benchmarked our model on a node of 4 Titan RTX 24Go. We were able to generate a batch of 4 images in a average time of 1.3 seconds!
We also have access to nodes composed of 8 A100 80Go GPUs. The benchmark on one of these nodes will come soon.
 

"""





def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def load_controlnet(controlnet_version):
    controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
        "Baptlem/baptlem-controlnet",
        subfolder=controlnet_version,
        from_flax=True,
        dtype=jnp.float32,
    )
    return controlnet, controlnet_params


def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
    controlnet, controlnet_params = load_controlnet(controlnet_version)

    scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(
        sb_path,
        subfolder="scheduler"
    )
    
    pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
        sb_path,
        controlnet=controlnet, 
        revision="flax", 
        dtype=jnp.bfloat16
    )
        
    pipe.scheduler = scheduler
    params["controlnet"] = controlnet_params
    params["scheduler"] = scheduler_params
    return pipe, params  

    

controlnet_path = "Baptlem/baptlem-controlnet"
controlnet_version = "coyo-500k"

# Constants
low_threshold = 100
high_threshold = 200

print(os.path.abspath('.'))
print(os.listdir("."))
print("Gradio version:", gr.__version__)
# pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_model_cpu_offload()
# pipe.enable_attention_slicing()
print("Loaded models...")
def pipe_inference(
    image,
    prompt,
    is_canny=False,
    num_samples=4,
    resolution=128,
    num_inference_steps=50,
    guidance_scale=7.5,
    model="coyo-500k",
    seed=0,
    negative_prompt="",
    ):
    print("Loading pipe")
    pipe, params = load_sb_pipe(model)
        
    if not isinstance(image, np.ndarray):
        image = np.array(image) 

    processed_image = resize_image(image, resolution) #-> PIL
        
    if not is_canny:
        resized_image, processed_image = preprocess_canny(processed_image, resolution)

    rng = create_key(seed)
    rng = jax.random.split(rng, jax.device_count())

    prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
    processed_image = pipe.prepare_image_inputs([processed_image] * num_samples)
        
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    print("Inference...")
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images
    print("Finished inference...")
    # all_outputs = []
    # all_outputs.append(image)
    # if not is_canny:
    #     all_outputs.append(resized_image)
        
    # for image in output.images:
    #     all_outputs.append(image)

    all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    return all_outputs

def resize_image(image, resolution):  
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
    h, w = image.shape[:2]
    ratio = w/h
    if ratio > 1 :
        resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
    elif ratio < 1 :
        resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
    else:
        resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
    
    return Image.fromarray(resized_image)
    
    
def preprocess_canny(image, resolution=128):
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
        
    processed_image = cv2.Canny(image, low_threshold, high_threshold)
    processed_image = processed_image[:, :, None]
    processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)

    resized_image = Image.fromarray(image)
    processed_image = Image.fromarray(processed_image)
    return resized_image, processed_image


def create_demo(process, max_images=12, default_num_images=4):
    with gr.Blocks() as demo:
        with gr.Row():
            gr.Markdown('## Control Stable Diffusion with Canny Edge Maps')
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(source='upload', type='numpy')
                prompt = gr.Textbox(label='Prompt')
                run_button = gr.Button(label='Run')
                with gr.Accordion('Advanced options', open=False):
                    is_canny = gr.Checkbox(
                        label='Is canny', value=False)
                    num_samples = gr.Slider(label='Images',
                                            minimum=1,
                                            maximum=max_images,
                                            value=default_num_images,
                                            step=1)
                    """
                    canny_low_threshold = gr.Slider(
                        label='Canny low threshold',
                        minimum=1,
                        maximum=255,
                        value=100,
                        step=1)
                    canny_high_threshold = gr.Slider(
                        label='Canny high threshold',
                        minimum=1,
                        maximum=255,
                        value=200,
                        step=1)
                    """
                    resolution = gr.Slider(label='Resolution',
                                          minimum=128,
                                          maximum=128,
                                          value=128,
                                          step=1)
                    num_steps = gr.Slider(label='Steps',
                                          minimum=1,
                                          maximum=100,
                                          value=20,
                                          step=1)
                    guidance_scale = gr.Slider(label='Guidance Scale',
                                               minimum=0.1,
                                               maximum=30.0,
                                               value=7.5,
                                               step=0.1)
                    model = gr.Dropdown(choices=["coyo-500k", "bridge-2M", "coyo2M-bridge3M"],
                                        value="coyo-500k",
                                        label="Model used for inference", 
                                        info="Find every models at https://huggingface.co/Baptlem/baptlem-controlnet")
                    seed = gr.Slider(label='Seed',
                                     minimum=-1,
                                     maximum=2147483647,
                                     step=1,
                                     randomize=True)
                    n_prompt = gr.Textbox(
                        label='Negative Prompt',
                        value=
                        'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
                    )
            with gr.Column():
                result = gr.Gallery(label='Output',
                                    show_label=False,
                                    elem_id='gallery').style(grid=2,
                                                             height='auto')
        
        with gr.Row():
            gr.Markdown(description)

            gr.Video(value=".trajectory_hf/trajectory.avi",
                    format="avi",
                    interactive=False)
        
        
        
        
        
        
        inputs = [
            input_image,
            prompt,
            is_canny,
            num_samples,
            resolution,
            #canny_low_threshold,
            #canny_high_threshold,
            num_steps,
            guidance_scale,
            model,
            seed,
            n_prompt,
        ]
        prompt.submit(fn=process, inputs=inputs, outputs=result)
        run_button.click(fn=process,
                         inputs=inputs,
                         outputs=result,
                         api_name='canny')
    
    return demo

if __name__ == '__main__':

    pipe_inference
    demo = create_demo(pipe_inference)
    demo.queue().launch()
    # gr.Interface(create_demo).launch()