File size: 5,160 Bytes
258d8c9
c74095e
975dc6e
c74095e
 
 
 
 
b2b2e3c
258d8c9
fae2a55
8cc5e29
fae2a55
c74095e
 
c4964ee
430340e
edee1cd
430340e
c4964ee
3eed896
 
 
 
48c7266
258d8c9
d055ca5
2d7769b
936aa8d
a74d502
c74095e
d186b30
c74095e
 
936aa8d
258d8c9
 
090c9fa
 
c74095e
 
 
 
 
090c9fa
c74095e
febb26d
090c9fa
8118b09
c74095e
 
febb26d
090c9fa
8118b09
c74095e
 
 
 
 
 
 
090c9fa
c74095e
3bcd02d
c74095e
65e93e6
b2b2e3c
e2eb33c
8a59f67
b2b2e3c
 
3bcd02d
6f00ba2
258d8c9
a9b5cc4
70c9759
6d16f98
a9b5cc4
 
edee1cd
087b403
 
 
16481ea
edee1cd
6d16f98
466cd5e
 
fae2a55
d1b0b6b
50c2795
 
a950a05
65e93e6
 
 
85f6b9e
 
 
edee1cd
 
a9b5cc4
 
 
 
 
 
 
edee1cd
0771b33
 
bd9a3cd
65e93e6
b547a14
466cd5e
6303c40
280a8d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import gc

report_url = 'https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5'
sketch_url = 'https://editor.p5js.org/kfahn/full/OshQky7RS'

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def addp5sketch(url):
   iframe = f'<iframe src ={url} style="border:none;height:495px;width:100%"/frame>'
   return gr.HTML(iframe)

def wandb_report(url):
    iframe = f'<iframe src ={url} style="border:none;height:1024px;width:100%"/frame>'
    return gr.HTML(iframe)

control_img = 'myimage.jpg'

examples = [["a yellow dog in grass", "lowres, two heads, bad muzzle, bad anatomy, missing ears, missing paws", "example1.jpg"]]

#default_example = examples[0]

controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    "JFoz/dog-cat-pose", dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16, safety_checker=None,
)

def infer(prompts, negative_prompts, image):

    params["controlnet"] = controlnet_params
    
    num_samples = 1 #jax.device_count()
    rng = create_key(0)
    rng = jax.random.split(rng, jax.device_count())
    image = Image.fromarray(image)
    
    prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
    processed_image = pipe.prepare_image_inputs([image] * num_samples)
    
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=50,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images[0,0]
    
    #output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    del image
    del prompt_ids
    del negative_prompt_ids
    gc.collect()
    
    output=np.array(output, dtype=np.float32)
    return output

with gr.Blocks(css=".gradio-container {background-image: linear-gradient(to bottom, #206dff 10%, #f8d0ab 90%)};") as demo:  
  gr.Markdown(
      """
      <h1 style="text-align: center; font-size: 30px; color: white">
      πŸ• Animal Pose Control Net 🐈
      </h1>
      <h3 style="text-align: left;"> This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with a new type of conditioning.</h3>  
      <h3 style="text-align: left;"> While this is definitely a work in progress, you can still try it out by using the p5 sketch to create a keypoint image and using it as the conditioning image.</h3>
      <h3 style="text-align: left;"> The model was generated as part of the Hugging Face Jax Diffusers sprint.  Thank you to both Hugging Face and Google Cloud who provided the TPUs for training!
      <h3 style="text-align: left;"> The dataset was built using the OpenPifPaf Animalpose plugin.</h3>
      </h3>
      """) 
  with gr.Row():
    with gr.Column():
      prompts  = gr.Textbox(label="Prompt", placeholder="animal standing, best quality, highres")
      negative_prompts  = gr.Textbox(label="Negative Prompt", value="lowres, two heads, bad muzzle, bad anatomy, missing ears, missing paws")
      conditioning_image = gr.Image(label="Conditioning Image")
      #  conditioning_image = gr.Image(label="Conditioning Image", value=default_example[3])
      run_btn = gr.Button("Run")
      output = gr.Image(
                label="Result",
            )
      #wandb = wandb_report(report_url)
    with gr.Column():
      keypoint_tool = addp5sketch(sketch_url)
      gr.Markdown(
        """
        <h3 style="text-align: left;">Additional Information</h3>
        <a style = "color: black; font-size: 20px" href="https://openpifpaf.github.io/plugins_animalpose.html">OpenPifPaf Animalpose</a>
        <a style = "color: black; font-size: 20px" href="https://huggingface.co/datasets/JFoz/dog-cat-pose">Dataset</a>
        <a style = "color: black; font-size: 20px" href="https://huggingface.co/JFoz/dog-cat-pose">Diffusers model</a>
        <a style = "color: black; font-size: 20px" href="https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5"> WANDB Training Report</a>
        <a style = "color: black; font-size: 20px" href="https://github.com/fi4cr/animalpose/tree/main/scripts">Training Scripts</a>
        <a style = "color: black; font-size: 20px" href="https://p5js.org">p5.js</a>
        """)     

  run_btn.click(fn=infer, inputs = [prompts, negative_prompts, conditioning_image], outputs = output)
    
#gr.Interface(fn=infer, inputs = ["text", "text", "image"], outputs = output,
            #examples=[["a Labrador crossing the road", "low quality", "myimage.jpg"]])   
    

demo.launch(debug=True)