File size: 3,531 Bytes
258d8c9 c74095e 975dc6e c74095e 258d8c9 e8b12f5 880828c c74095e c4964ee 2e09612 c4964ee 3eed896 48c7266 258d8c9 c74095e 258d8c9 090c9fa c74095e 090c9fa c74095e febb26d 090c9fa 8118b09 c74095e febb26d 090c9fa 8118b09 c74095e 090c9fa c74095e 65e93e6 c74095e 65e93e6 6f00ba2 258d8c9 c131c56 880828c 806402e c33c1c5 880828c 466cd5e 8df7d86 a950a05 bf379ef a950a05 466cd5e 493299e 65e93e6 a950a05 34e4365 bd9a3cd 65e93e6 b547a14 466cd5e a600f9f 466cd5e 6303c40 280a8d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2
with open("test.html") as f:
lines = f.readlines()
def create_key(seed=0):
return jax.random.PRNGKey(seed)
#def addp5sketch(url):
# iframe = f'<iframe src ={url} style="border:none;height:525px;width:100%"/frame>'
# return gr.HTML(iframe)
def wandb_report(url):
iframe = f'<iframe src ={url} style="border:none;height:1024px;width:100%"/frame>'
return gr.HTML(iframe)
report_url = 'https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5'
control_img = 'myimage.jpg'
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"JFoz/dog-cat-pose", dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)
def infer(prompts, negative_prompts, image):
params["controlnet"] = controlnet_params
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
image = Image.fromarray(image)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images[0]
#output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output
with gr.Blocks(theme='kfahn/AnimalPose') as demo:
gr.Markdown(
"""
# Animal Pose Control Net
## This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
[Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)
[Diffusers model](https://huggingface.co/JFoz/dog-pose)
[Github](https://github.com/fi4cr/animalpose)
[Training Report](https://wandb.ai/john-fozard/AP10K-pose/runs/wn89ezaw)
""")
with gr.Row():
with gr.Column():
prompts = gr.Textbox(label="Prompt", placeholder="black cocker spaniel sitting on a lawn, best quality")
negative_prompts = gr.Textbox(label="Negative Prompt", value="lowres, bad anatomy, missing ears, missing paws")
conditioning_image = gr.Image(label="Conditioning Image")
run_btn = gr.Button("Run")
with gr.Column():
#keypoint_tool = addp5sketch(sketch_url)
keypoint_tool = gr.HTML(lines)
output = gr.Image(
label="Result",
)
run_btn.click(fn=infer, inputs = [prompts, negative_prompts, conditioning_image], outputs = output)
#gr.Interface(fn=infer, inputs = ["text", "text", "image"], outputs = output,
#examples=[["a Labrador crossing the road", "low quality", "myimage.jpg"]])
#with gr.Row():
# report = wandb_report(report_url)
demo.launch(debug=True) |