import gradio as gr import jax import jax.numpy as jnp import numpy as np from flax.jax_utils import replicate from flax.training.common_utils import shard from PIL import Image from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel import gc report_url = 'https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5' sketch_url = 'https://editor.p5js.org/kfahn/full/OshQky7RS' def create_key(seed=0): return jax.random.PRNGKey(seed) def addp5sketch(url): iframe = f'