Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
from utils import utils, tools, preprocess | |
# BASE_MODEL_PATH = "stablediffusionapi/neta-art-xl-v2" | |
VAE_PATH = "madebyollin/sdxl-vae-fp16-fix" | |
REPO_ID = "Pbihao/ControlNeXt" | |
UNET_FILENAME = "ControlAny-SDXL/anime_canny/unet.safetensors" | |
CONTROLNET_FILENAME = "ControlAny-SDXL/anime_canny/controlnet.safetensors" | |
CACHE_DIR = None | |
def ui(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_file = hf_hub_download( | |
repo_id='Lykon/AAM_XL_AnimeMix', | |
filename='AAM_XL_Anime_Mix.safetensors', | |
cache_dir=CACHE_DIR, | |
) | |
unet_file = hf_hub_download( | |
repo_id=REPO_ID, | |
filename=UNET_FILENAME, | |
cache_dir=CACHE_DIR, | |
) | |
controlnet_file = hf_hub_download( | |
repo_id=REPO_ID, | |
filename=CONTROLNET_FILENAME, | |
cache_dir=CACHE_DIR, | |
) | |
pipeline = tools.get_pipeline( | |
pretrained_model_name_or_path=model_file, | |
unet_model_name_or_path=unet_file, | |
controlnet_model_name_or_path=controlnet_file, | |
vae_model_name_or_path=VAE_PATH, | |
load_weight_increasement=True, | |
device=device, | |
hf_cache_dir=CACHE_DIR, | |
use_safetensors=True, | |
enable_xformers_memory_efficient_attention=True, | |
) | |
preprocessors = ['canny'] | |
schedulers = ['Euler A', 'UniPC', 'Euler', 'DDIM', 'DDPM'] | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 520px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(f""" | |
# [ControlNeXt](https://github.com/dvlab-research/ControlNeXt) Official Demo | |
""") | |
with gr.Row(): | |
with gr.Column(scale=9): | |
prompt = gr.Textbox(lines=3, placeholder='prompt', container=False) | |
negative_prompt = gr.Textbox(lines=3, placeholder='negative prompt', container=False) | |
with gr.Column(scale=1): | |
generate_button = gr.Button("Generate", variant='primary', min_width=96) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
control_image = gr.Image( | |
value=None, | |
label='Condition', | |
sources=['upload'], | |
type='pil', | |
height=512, | |
show_download_button=True, | |
show_share_button=True, | |
) | |
with gr.Row(): | |
scheduler = gr.Dropdown( | |
label='Scheduler', | |
choices=schedulers, | |
value='Euler A', | |
multiselect=False, | |
allow_custom_value=False, | |
filterable=True, | |
) | |
num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=20, label='Steps') | |
with gr.Row(): | |
cfg_scale = gr.Slider(minimum=1, maximum=30, step=1, value=7.5, label='CFG Scale') | |
controlnet_scale = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='ControlNet Scale') | |
with gr.Row(): | |
seed = gr.Number(label='Seed', step=1, precision=0, value=-1) | |
with gr.Row(): | |
processor = gr.Dropdown( | |
label='Image Preprocessor', | |
choices=preprocessors, | |
value='canny', | |
) | |
process_button = gr.Button("Process", variant='primary', min_width=96, scale=0) | |
with gr.Column(scale=1): | |
output = gr.Gallery( | |
label='Output', | |
value=None, | |
object_fit='scale-down', | |
columns=4, | |
height=512, | |
show_download_button=True, | |
show_share_button=True, | |
) | |
def generate( | |
prompt, | |
control_image, | |
negative_prompt, | |
cfg_scale, | |
controlnet_scale, | |
num_inference_steps, | |
scheduler, | |
seed, | |
): | |
pipeline.scheduler = tools.get_scheduler(scheduler, pipeline.scheduler.config) | |
generator = torch.Generator(device=device).manual_seed(max(0, min(seed, np.iinfo(np.int32).max))) if seed != -1 else None | |
if control_image is None: | |
raise gr.Error('Please upload an image.') | |
width, height = utils.around_reso(control_image.width, control_image.height, reso=1024, max_width=2048, max_height=2048, divisible=32) | |
control_image = control_image.resize((width, height)).convert('RGB') | |
with torch.autocast(device): | |
output_images = pipeline.__call__( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
controlnet_image=control_image, | |
controlnet_scale=controlnet_scale, | |
width=width, | |
height=height, | |
generator=generator, | |
guidance_scale=cfg_scale, | |
num_inference_steps=num_inference_steps, | |
).images | |
return output_images | |
def process( | |
image, | |
processor, | |
): | |
if image is None: | |
raise gr.Error('Please upload an image.') | |
processor = preprocess.get_extractor(processor) | |
image = processor(image) | |
return image | |
generate_button.click( | |
fn=generate, | |
inputs=[prompt, control_image, negative_prompt, cfg_scale, controlnet_scale, num_inference_steps, scheduler, seed], | |
outputs=[output], | |
) | |
process_button.click( | |
fn=process, | |
inputs=[control_image, processor], | |
outputs=[control_image], | |
) | |
return demo | |
if __name__ == '__main__': | |
demo = ui() | |
demo.queue().launch() | |