File size: 3,105 Bytes
258d8c9 c74095e 975dc6e c74095e 258d8c9 c74095e 258d8c9 c74095e 258d8c9 8118b09 c74095e 8118b09 c74095e febb26d ae56ea3 8118b09 c74095e febb26d ae56ea3 8118b09 c74095e ae56ea3 c74095e 258d8c9 c74095e 258d8c9 9dc506e c74095e ae56ea3 258d8c9 febb26d 258d8c9 febb26d e62de2d cf41876 8118b09 c74095e 8118b09 c74095e 258d8c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2
def create_key(seed=0):
return jax.random.PRNGKey(seed)
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"JFoz/dog-cat-pose", dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)
def infer(prompt, image):
#def infer(prompt):
params["controlnet"] = controlnet_params
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
im = image
image = Image.fromarray(im)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
#negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
#negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
#neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
#gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery").launch()
title = "Animal Pose Control Net"
description = "This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning."
#with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo:
#gr.Markdown(
# """
# Animal Pose Control Net
# This is a demo of Animal Pose Control Net, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
#""")
#theme = gr.themes.Default(primary_hue="green").set(
# button_primary_background_fill="*primary_200",
# button_primary_background_fill_hover="*primary_300",
#)
#gr.Interface(fn = infer, inputs = ["text"], outputs = "image",
# title = title, description = description, theme='gradio/soft').launch()
gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "gallery",
title = title, description = description, theme='gradio/soft',
examples=[["a Labrador crossing the road", "low quality", "image_control.png"]]
).launch()
gr.Markdown(
"""
* [Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)
* [Diffusers model](), [Web UI model](https://huggingface.co/JFoz/dog-pose)
* [Training Report](https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5))
""")
|