File size: 4,731 Bytes
258d8c9 c74095e 975dc6e c74095e b2b2e3c 258d8c9 fae2a55 8cc5e29 fae2a55 c74095e c4964ee 430340e edee1cd 430340e c4964ee 3eed896 48c7266 258d8c9 d055ca5 2d7769b 936aa8d a74d502 c74095e d186b30 c74095e 936aa8d 258d8c9 090c9fa c74095e 090c9fa c74095e febb26d 090c9fa 8118b09 c74095e febb26d 090c9fa 8118b09 c74095e 090c9fa c74095e 3bcd02d c74095e 65e93e6 b2b2e3c e2eb33c 8a59f67 b2b2e3c 3bcd02d 6f00ba2 258d8c9 1bb6ad6 70c9759 6d16f98 edee1cd 1bb6ad6 edee1cd 087b403 16481ea edee1cd 6d16f98 466cd5e fae2a55 d1b0b6b 50c2795 a950a05 65e93e6 85f6b9e edee1cd 7b8dae3 087b403 7b8dae3 edee1cd 0771b33 bd9a3cd 65e93e6 b547a14 466cd5e 6303c40 280a8d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import gc
report_url = 'https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5'
sketch_url = 'https://editor.p5js.org/kfahn/full/OshQky7RS'
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def addp5sketch(url):
iframe = f'<iframe src ={url} style="border:none;height:495px;width:100%"/frame>'
return gr.HTML(iframe)
def wandb_report(url):
iframe = f'<iframe src ={url} style="border:none;height:1024px;width:100%"/frame>'
return gr.HTML(iframe)
control_img = 'myimage.jpg'
examples = [["a yellow dog in grass", "lowres, two heads, bad muzzle, bad anatomy, missing ears, missing paws", "example1.jpg"]]
#default_example = examples[0]
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"JFoz/dog-cat-pose", dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16, safety_checker=None,
)
def infer(prompts, negative_prompts, image):
params["controlnet"] = controlnet_params
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
image = Image.fromarray(image)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images[0,0]
#output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
del image
del prompt_ids
del negative_prompt_ids
gc.collect()
output=np.array(output, dtype=np.float32)
return output
with gr.Blocks(css=".gradio-container {background-color: #f8d0ab};") as demo:
gr.Markdown(
"""
<h1 style="text-align: center;">
πβπ¦Ί Animal Pose Control Net πββ¬
</h1>
<h3 style="text-align: left;"> This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with a new type of conditioning.</h3>
<h3 style="text-align: left;"> While this is definitely a work in progress, you can still try it out by using the p5 sketch to create a keypoint image and using it as the conditioning image.</h3>
<h3 style="text-align: left;"> The model was generated as part of the Hugging Face Jax Diffusers sprint. Thank you to both Hugging Face and Google Cloud who provided the TPUs for training!
<h3 style="text-align: left;"> The dataset was built using the OpenPifPaf Animalpose plugin.</h3>
</h3>
""")
with gr.Row():
with gr.Column():
prompts = gr.Textbox(label="Prompt", placeholder="animal standing, best quality, highres")
negative_prompts = gr.Textbox(label="Negative Prompt", value="lowres, two heads, bad muzzle, bad anatomy, missing ears, missing paws")
conditioning_image = gr.Image(label="Conditioning Image")
# conditioning_image = gr.Image(label="Conditioning Image", value=default_example[3])
run_btn = gr.Button("Run")
output = gr.Image(
label="Result",
)
#wandb = wandb_report(report_url)
with gr.Column():
keypoint_tool = addp5sketch(sketch_url)
gr.Markdown(
"""
[OpenPifPaf Animalpose](https://openpifpaf.github.io/plugins_animalpose.html).
[Dataset](https://huggingface.co/datasets/JFoz/dog-cat-pose)
[Diffusers model](https://huggingface.co/JFoz/dog-cat-pose)
[WANDB Training Report](https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5)
[Training Scripts](https://github.com/fi4cr/animalpose/tree/main/scripts)
[p5.js](https://p5js.org)
""")
run_btn.click(fn=infer, inputs = [prompts, negative_prompts, conditioning_image], outputs = output)
#gr.Interface(fn=infer, inputs = ["text", "text", "image"], outputs = output,
#examples=[["a Labrador crossing the road", "low quality", "myimage.jpg"]])
demo.launch(debug=True) |