File size: 2,969 Bytes
258d8c9 c74095e 975dc6e c74095e 258d8c9 880828c c74095e 258d8c9 c74095e 258d8c9 090c9fa c74095e 090c9fa c74095e febb26d 090c9fa 8118b09 c74095e febb26d 090c9fa 8118b09 c74095e 090c9fa c74095e 258d8c9 880828c f9183eb 880828c c74095e 880828c 258d8c9 880828c 258d8c9 880828c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2
with open("test.html") as f:
lines = f.readlines()
def create_key(seed=0):
return jax.random.PRNGKey(seed)
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"JFoz/dog-cat-pose", dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)
def infer(prompts, negative_prompts, image):
params["controlnet"] = controlnet_params
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
image = Image.fromarray(image)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
mytheme = gr.themes.Default(primary_hue="slate")
control_image = "https://huggingface.co/spaces/kfahn/Animal_Pose_Control_Net/blob/main/image_control.png"
with gr.Blocks(theme = mytheme) as demo:
gr.Markdown(
"""
# Animal Pose Control Net
## This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
""")
with gr.Column():
with gr.Row():
pos_prompts = gr.Textbox(label="Prompt")
neg_prompts = gr.Textbox(label="Negative Prompt")
image = gr.Image()
with gr.Column():
with gr.Row():
explain = gr.Textbox("Keypoint Tool: Use mouse to move joints")
with gr.Row():
keypoint_tool = gr.HTML(lines)
gr.Markdown(
"""
* [Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)
* [Diffusers model](), [Web UI model](https://huggingface.co/JFoz/dog-pose)
* [Training Report](https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5)
""")
btn = gr.Button("Run")
#btn.click(inputs = ["text", "text", "image"])
btn.click(fn=infer, inputs = ["text", "text", "image"], outputs = "gallery",
examples=[["a Labrador crossing the road", "low quality", control_image]])
demo.launch()
|