import gradio as gr import torch import numpy as np from huggingface_hub import hf_hub_download from utils import utils, tools, preprocess # BASE_MODEL_PATH = "stablediffusionapi/neta-art-xl-v2" VAE_PATH = "madebyollin/sdxl-vae-fp16-fix" REPO_ID = "Pbihao/ControlNeXt" UNET_FILENAME = "ControlAny-SDXL/anime_canny/unet.safetensors" CONTROLNET_FILENAME = "ControlAny-SDXL/anime_canny/controlnet.safetensors" CACHE_DIR = None def ui(): device = "cuda" if torch.cuda.is_available() else "cpu" model_file = hf_hub_download( repo_id='Lykon/AAM_XL_AnimeMix', filename='AAM_XL_Anime_Mix.safetensors', cache_dir=CACHE_DIR, ) unet_file = hf_hub_download( repo_id=REPO_ID, filename=UNET_FILENAME, cache_dir=CACHE_DIR, ) controlnet_file = hf_hub_download( repo_id=REPO_ID, filename=CONTROLNET_FILENAME, cache_dir=CACHE_DIR, ) pipeline = tools.get_pipeline( pretrained_model_name_or_path=model_file, unet_model_name_or_path=unet_file, controlnet_model_name_or_path=controlnet_file, vae_model_name_or_path=VAE_PATH, load_weight_increasement=True, device=device, hf_cache_dir=CACHE_DIR, use_safetensors=True, enable_xformers_memory_efficient_attention=True, ) preprocessors = ['canny'] schedulers = ['Euler A', 'UniPC', 'Euler', 'DDIM', 'DDPM'] css = """ #col-container { margin: 0 auto; max-width: 520px; } """ with gr.Blocks(css=css) as demo: gr.Markdown(f""" # [ControlNeXt](https://github.com/dvlab-research/ControlNeXt) Official Demo """) with gr.Row(): with gr.Column(scale=9): prompt = gr.Textbox(lines=3, placeholder='prompt', container=False) negative_prompt = gr.Textbox(lines=3, placeholder='negative prompt', container=False) with gr.Column(scale=1): generate_button = gr.Button("Generate", variant='primary', min_width=96) with gr.Row(): with gr.Column(scale=1): with gr.Row(): control_image = gr.Image( value=None, label='Condition', sources=['upload'], type='pil', height=512, show_download_button=True, show_share_button=True, ) with gr.Row(): scheduler = gr.Dropdown( label='Scheduler', choices=schedulers, value='Euler A', multiselect=False, allow_custom_value=False, filterable=True, ) num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=20, label='Steps') with gr.Row(): cfg_scale = gr.Slider(minimum=1, maximum=30, step=1, value=7.5, label='CFG Scale') controlnet_scale = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='ControlNet Scale') with gr.Row(): seed = gr.Number(label='Seed', step=1, precision=0, value=-1) with gr.Row(): processor = gr.Dropdown( label='Image Preprocessor', choices=preprocessors, value='canny', ) process_button = gr.Button("Process", variant='primary', min_width=96, scale=0) with gr.Column(scale=1): output = gr.Gallery( label='Output', value=None, object_fit='scale-down', columns=4, height=512, show_download_button=True, show_share_button=True, ) def generate( prompt, control_image, negative_prompt, cfg_scale, controlnet_scale, num_inference_steps, scheduler, seed, ): pipeline.scheduler = tools.get_scheduler(scheduler, pipeline.scheduler.config) generator = torch.Generator(device=device).manual_seed(max(0, min(seed, np.iinfo(np.int32).max))) if seed != -1 else None if control_image is None: raise gr.Error('Please upload an image.') width, height = utils.around_reso(control_image.width, control_image.height, reso=1024, max_width=2048, max_height=2048, divisible=32) control_image = control_image.resize((width, height)).convert('RGB') with torch.autocast(device): output_images = pipeline.__call__( prompt=prompt, negative_prompt=negative_prompt, controlnet_image=control_image, controlnet_scale=controlnet_scale, width=width, height=height, generator=generator, guidance_scale=cfg_scale, num_inference_steps=num_inference_steps, ).images return output_images def process( image, processor, ): if image is None: raise gr.Error('Please upload an image.') processor = preprocess.get_extractor(processor) image = processor(image) return image generate_button.click( fn=generate, inputs=[prompt, control_image, negative_prompt, cfg_scale, controlnet_scale, num_inference_steps, scheduler, seed], outputs=[output], ) process_button.click( fn=process, inputs=[control_image, processor], outputs=[control_image], ) return demo if __name__ == '__main__': demo = ui() demo.queue().launch()