ControlNeXt / app.py
Eugeoter's picture
test
9892334
raw
history blame
6.14 kB
import gradio as gr
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from utils import utils, tools, preprocess
# BASE_MODEL_PATH = "stablediffusionapi/neta-art-xl-v2"
VAE_PATH = "madebyollin/sdxl-vae-fp16-fix"
REPO_ID = "Pbihao/ControlNeXt"
UNET_FILENAME = "ControlAny-SDXL/anime_canny/unet.safetensors"
CONTROLNET_FILENAME = "ControlAny-SDXL/anime_canny/controlnet.safetensors"
CACHE_DIR = None
def ui():
device = "cuda" if torch.cuda.is_available() else "cpu"
model_file = hf_hub_download(
repo_id='Lykon/AAM_XL_AnimeMix',
filename='AAM_XL_Anime_Mix.safetensors',
cache_dir=CACHE_DIR,
)
unet_file = hf_hub_download(
repo_id=REPO_ID,
filename=UNET_FILENAME,
cache_dir=CACHE_DIR,
)
controlnet_file = hf_hub_download(
repo_id=REPO_ID,
filename=CONTROLNET_FILENAME,
cache_dir=CACHE_DIR,
)
pipeline = tools.get_pipeline(
pretrained_model_name_or_path=model_file,
unet_model_name_or_path=unet_file,
controlnet_model_name_or_path=controlnet_file,
vae_model_name_or_path=VAE_PATH,
load_weight_increasement=True,
device=device,
hf_cache_dir=CACHE_DIR,
use_safetensors=True,
enable_xformers_memory_efficient_attention=True,
)
preprocessors = ['canny']
schedulers = ['Euler A', 'UniPC', 'Euler', 'DDIM', 'DDPM']
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(f"""
# [ControlNeXt](https://github.com/dvlab-research/ControlNeXt) Official Demo
""")
with gr.Row():
with gr.Column(scale=9):
prompt = gr.Textbox(lines=3, placeholder='prompt', container=False)
negative_prompt = gr.Textbox(lines=3, placeholder='negative prompt', container=False)
with gr.Column(scale=1):
generate_button = gr.Button("Generate", variant='primary', min_width=96)
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
control_image = gr.Image(
value=None,
label='Condition',
sources=['upload'],
type='pil',
height=512,
show_download_button=True,
show_share_button=True,
)
with gr.Row():
scheduler = gr.Dropdown(
label='Scheduler',
choices=schedulers,
value='Euler A',
multiselect=False,
allow_custom_value=False,
filterable=True,
)
num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=20, label='Steps')
with gr.Row():
cfg_scale = gr.Slider(minimum=1, maximum=30, step=1, value=7.5, label='CFG Scale')
controlnet_scale = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='ControlNet Scale')
with gr.Row():
seed = gr.Number(label='Seed', step=1, precision=0, value=-1)
with gr.Row():
processor = gr.Dropdown(
label='Image Preprocessor',
choices=preprocessors,
value='canny',
)
process_button = gr.Button("Process", variant='primary', min_width=96, scale=0)
with gr.Column(scale=1):
output = gr.Gallery(
label='Output',
value=None,
object_fit='scale-down',
columns=4,
height=512,
show_download_button=True,
show_share_button=True,
)
def generate(
prompt,
control_image,
negative_prompt,
cfg_scale,
controlnet_scale,
num_inference_steps,
scheduler,
seed,
):
pipeline.scheduler = tools.get_scheduler(scheduler, pipeline.scheduler.config)
generator = torch.Generator(device=device).manual_seed(max(0, min(seed, np.iinfo(np.int32).max))) if seed != -1 else None
if control_image is None:
raise gr.Error('Please upload an image.')
width, height = utils.around_reso(control_image.width, control_image.height, reso=1024, max_width=2048, max_height=2048, divisible=32)
control_image = control_image.resize((width, height)).convert('RGB')
with torch.autocast(device):
output_images = pipeline.__call__(
prompt=prompt,
negative_prompt=negative_prompt,
controlnet_image=control_image,
controlnet_scale=controlnet_scale,
width=width,
height=height,
generator=generator,
guidance_scale=cfg_scale,
num_inference_steps=num_inference_steps,
).images
return output_images
def process(
image,
processor,
):
if image is None:
raise gr.Error('Please upload an image.')
processor = preprocess.get_extractor(processor)
image = processor(image)
return image
generate_button.click(
fn=generate,
inputs=[prompt, control_image, negative_prompt, cfg_scale, controlnet_scale, num_inference_steps, scheduler, seed],
outputs=[output],
)
process_button.click(
fn=process,
inputs=[control_image, processor],
outputs=[control_image],
)
return demo
if __name__ == '__main__':
demo = ui()
demo.queue().launch()