ControlNeXt / utils /tools.py
Eugeoter's picture
update
76be739
import os
import gc
import torch
from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel
from safetensors.torch import load_file
from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline
from models.unet import UNet2DConditionModel
from models.controlnet import ControlNetModel
from . import utils
UNET_CONFIG = {
"act_fn": "silu",
"addition_embed_type": "text_time",
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": 256,
"attention_head_dim": [
5,
10,
20
],
"block_out_channels": [
320,
640,
1280
],
"center_input_sample": False,
"class_embed_type": None,
"class_embeddings_concat": False,
"conv_in_kernel": 3,
"conv_out_kernel": 3,
"cross_attention_dim": 2048,
"cross_attention_norm": None,
"down_block_types": [
"DownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D"
],
"downsample_padding": 1,
"dual_cross_attention": False,
"encoder_hid_dim": None,
"encoder_hid_dim_type": None,
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_only_cross_attention": None,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": None,
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 4,
"projection_class_embeddings_input_dim": 2816,
"resnet_out_scale_factor": 1.0,
"resnet_skip_time_act": False,
"resnet_time_scale_shift": "default",
"sample_size": 128,
"time_cond_proj_dim": None,
"time_embedding_act_fn": None,
"time_embedding_dim": None,
"time_embedding_type": "positional",
"timestep_post_act": None,
"transformer_layers_per_block": [
1,
2,
10
],
"up_block_types": [
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"UpBlock2D"
],
"upcast_attention": None,
"use_linear_projection": True
}
CONTROLNET_CONFIG = {
'in_channels': [128, 128],
'out_channels': [128, 256],
'groups': [4, 8],
'time_embed_dim': 256,
'final_out_channels': 320,
'_use_default_values': ['time_embed_dim', 'groups', 'in_channels', 'final_out_channels', 'out_channels']
}
def get_pipeline(
pretrained_model_name_or_path,
unet_model_name_or_path,
controlnet_model_name_or_path,
vae_model_name_or_path=None,
lora_path=None,
load_weight_increasement=False,
enable_xformers_memory_efficient_attention=False,
revision=None,
variant=None,
hf_cache_dir=None,
use_safetensors=True,
device=None,
):
pipeline_init_kwargs = {}
print(f"loading unet from {pretrained_model_name_or_path}")
if os.path.isfile(pretrained_model_name_or_path):
# load unet from local checkpoint
unet_sd = load_file(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".safetensors") else torch.load(pretrained_model_name_or_path)
unet_sd = utils.extract_unet_state_dict(unet_sd)
unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
unet = UNet2DConditionModel.from_config(UNET_CONFIG)
unet.load_state_dict(unet_sd, strict=True)
else:
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path,
cache_dir=hf_cache_dir,
variant=variant,
torch_dtype=torch.float16,
use_safetensors=use_safetensors,
subfolder="unet",
)
unet = unet.to(dtype=torch.float16)
pipeline_init_kwargs["unet"] = unet
if vae_model_name_or_path is not None:
print(f"loading vae from {vae_model_name_or_path}")
vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device)
pipeline_init_kwargs["vae"] = vae
if controlnet_model_name_or_path is not None:
pipeline_init_kwargs["controlnet"] = ControlNetModel.from_config(CONTROLNET_CONFIG).to(device, dtype=torch.float32) # init
print(f"loading pipeline from {pretrained_model_name_or_path}")
if os.path.isfile(pretrained_model_name_or_path):
pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file(
pretrained_model_name_or_path,
use_safetensors=pretrained_model_name_or_path.endswith(".safetensors"),
local_files_only=True,
cache_dir=hf_cache_dir,
**pipeline_init_kwargs,
)
else:
pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_pretrained(
pretrained_model_name_or_path,
revision=revision,
variant=variant,
use_safetensors=use_safetensors,
cache_dir=hf_cache_dir,
**pipeline_init_kwargs,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
if unet_model_name_or_path is not None:
print(f"loading controlnext unet from {unet_model_name_or_path}")
pipeline.load_controlnext_unet_weights(
unet_model_name_or_path,
load_weight_increasement=load_weight_increasement,
use_safetensors=True,
torch_dtype=torch.float16,
cache_dir=hf_cache_dir,
)
if controlnet_model_name_or_path is not None:
print(f"loading controlnext controlnet from {controlnet_model_name_or_path}")
pipeline.load_controlnext_controlnet_weights(
controlnet_model_name_or_path,
use_safetensors=True,
torch_dtype=torch.float32,
cache_dir=hf_cache_dir,
)
pipeline.set_progress_bar_config()
pipeline = pipeline.to(device, dtype=torch.float16)
if lora_path is not None:
pipeline.load_lora_weights(lora_path)
if enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
gc.collect()
if str(device) == 'cuda' and torch.cuda.is_available():
torch.cuda.empty_cache()
return pipeline
def get_scheduler(
scheduler_name,
scheduler_config,
):
if scheduler_name == 'Euler A':
from diffusers.schedulers import EulerAncestralDiscreteScheduler
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
elif scheduler_name == 'UniPC':
from diffusers.schedulers import UniPCMultistepScheduler
scheduler = UniPCMultistepScheduler.from_config(scheduler_config)
elif scheduler_name == 'Euler':
from diffusers.schedulers import EulerDiscreteScheduler
scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
elif scheduler_name == 'DDIM':
from diffusers.schedulers import DDIMScheduler
scheduler = DDIMScheduler.from_config(scheduler_config)
elif scheduler_name == 'DDPM':
from diffusers.schedulers import DDPMScheduler
scheduler = DDPMScheduler.from_config(scheduler_config)
else:
raise ValueError(f"Unknown scheduler: {scheduler_name}")
return scheduler