File size: 3,491 Bytes
258d8c9 c74095e 975dc6e c74095e 258d8c9 e8b12f5 880828c c74095e c4964ee a15e3f7 c4964ee 3eed896 48c7266 258d8c9 c74095e 258d8c9 090c9fa c74095e 090c9fa c74095e febb26d 090c9fa 8118b09 c74095e febb26d 090c9fa 8118b09 c74095e 090c9fa c74095e 258d8c9 c131c56 880828c 806402e 880828c 466cd5e b6e7466 bf379ef 49e8691 466cd5e 493299e 13c70ff 49e8691 ce71c29 49e8691 880828c 49e8691 ce71c29 466cd5e a600f9f 466cd5e 5c5897d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2
with open("test.html") as f:
lines = f.readlines()
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def addp5sketch(url):
iframe = f'<iframe src ={url} style="border:none;height:525px;width:100%"/frame>'
return gr.HTML(iframe)
def wandb_report(url):
iframe = f'<iframe src ={url} style="border:none;height:1024px;width:100%"/frame>'
return gr.HTML(iframe)
report_url = 'https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5'
control_img = 'myimage.jpg'
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"JFoz/dog-cat-pose", dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)
def infer(prompts, negative_prompts, image):
params["controlnet"] = controlnet_params
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
image = Image.fromarray(image)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
with gr.Blocks(theme='kfahn/AnimalPose') as demo:
gr.Markdown(
"""
# Animal Pose Control Net
## This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
[Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)
[Diffusers model](https://huggingface.co/JFoz/dog-pose)
[Github](https://github.com/fi4cr/animalpose)
[Training Report](https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5)
""")
with gr.Row():
with gr.Column():
prompts = gr.Textbox(label="Prompt")
negative_prompts = gr.Textbox(label="Negative Prompt")
conditioning_image = gr.Image(label="Conditioning Image")
with gr.Column():
#keypoint_tool = addp5sketch(sketch_url)
keypoint_tool = gr.HTML(lines)
submit_btn = gr.Button(value="Submit")
submit_btn.click(fn=infer, inputs = [prompts, negative_prompts, conditioning_image], outputs = "gallery",
examples=[["a Labrador crossing the road", "low quality", "myimage.jpg"]])
#gr.Interface(fn=infer, inputs = [prompts, negative_prompts, conditioning_image], outputs = "gallery",
#examples=[["a Labrador crossing the road", "low quality", "myimage.jpg"]])
#with gr.Row():
# report = wandb_report(report_url)
if __name__ == "__main__":
demo.launch() |