File size: 3,711 Bytes
258d8c9 c74095e 975dc6e c74095e b2b2e3c 258d8c9 fae2a55 8cc5e29 fae2a55 c74095e c4964ee 430340e edee1cd 430340e c4964ee 3eed896 48c7266 258d8c9 c74095e d186b30 c74095e d186b30 258d8c9 090c9fa c74095e 090c9fa c74095e febb26d 090c9fa 8118b09 c74095e febb26d 090c9fa 8118b09 c74095e 090c9fa c74095e 3bcd02d c74095e 65e93e6 b2b2e3c 3bcd02d 6f00ba2 258d8c9 c131c56 70c9759 6d16f98 edee1cd 6d16f98 466cd5e fae2a55 d1b0b6b bf379ef a950a05 edee1cd 466cd5e 430340e 65e93e6 edee1cd 0771b33 bd9a3cd 65e93e6 b547a14 466cd5e 6303c40 280a8d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import gc
report_url = 'https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5'
sketch_url = 'https://editor.p5js.org/kfahn/full/OshQky7RS'
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def addp5sketch(url):
iframe = f'<iframe src ={url} style="border:none;height:495px;width:100%"/frame>'
return gr.HTML(iframe)
def wandb_report(url):
iframe = f'<iframe src ={url} style="border:none;height:1024px;width:100%"/frame>'
return gr.HTML(iframe)
control_img = 'myimage.jpg'
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"JFoz/dog-cat-pose", dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)
def infer(prompts, negative_prompts, image):
params["controlnet"] = controlnet_params
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
image = Image.fromarray(image)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images[0,0]
#output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
del image
gc.collect()
output=np.array(output, dtype=np.float32)
return output
with gr.Blocks(theme='kfahn/AnimalPose') as demo:
gr.Markdown(
"""
<h1 style="text-align: center;">
Animal Pose Control Net
</h1>
<h3 style="text-align: center;"> This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
</h3>
""")
with gr.Row():
with gr.Column():
prompts = gr.Textbox(label="Prompt", placeholder="animal standing, best quality, highres")
negative_prompts = gr.Textbox(label="Negative Prompt", value="lowres, two heads, bad muzzle, bad anatomy, missing ears, missing paws")
conditioning_image = gr.Image(label="Conditioning Image")
run_btn = gr.Button("Run")
wandb = wandb_report(report_url)
with gr.Column():
keypoint_tool = addp5sketch(sketch_url)
output = gr.Image(
label="Result",
)
gr.Markdown(
"""
[Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)
[Diffusers model](https://huggingface.co/JFoz/dog-pose)
[Github](https://github.com/fi4cr/animalpose)
[Training Report](https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5)
""")
run_btn.click(fn=infer, inputs = [prompts, negative_prompts, conditioning_image], outputs = output)
#gr.Interface(fn=infer, inputs = ["text", "text", "image"], outputs = output,
#examples=[["a Labrador crossing the road", "low quality", "myimage.jpg"]])
demo.launch(debug=True) |