Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gc | |
import torch | |
from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel | |
from safetensors.torch import load_file | |
from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline | |
from models.unet import UNet2DConditionModel | |
from models.controlnet import ControlNetModel | |
from . import utils | |
UNET_CONFIG = { | |
"act_fn": "silu", | |
"addition_embed_type": "text_time", | |
"addition_embed_type_num_heads": 64, | |
"addition_time_embed_dim": 256, | |
"attention_head_dim": [ | |
5, | |
10, | |
20 | |
], | |
"block_out_channels": [ | |
320, | |
640, | |
1280 | |
], | |
"center_input_sample": False, | |
"class_embed_type": None, | |
"class_embeddings_concat": False, | |
"conv_in_kernel": 3, | |
"conv_out_kernel": 3, | |
"cross_attention_dim": 2048, | |
"cross_attention_norm": None, | |
"down_block_types": [ | |
"DownBlock2D", | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D" | |
], | |
"downsample_padding": 1, | |
"dual_cross_attention": False, | |
"encoder_hid_dim": None, | |
"encoder_hid_dim_type": None, | |
"flip_sin_to_cos": True, | |
"freq_shift": 0, | |
"in_channels": 4, | |
"layers_per_block": 2, | |
"mid_block_only_cross_attention": None, | |
"mid_block_scale_factor": 1, | |
"mid_block_type": "UNetMidBlock2DCrossAttn", | |
"norm_eps": 1e-05, | |
"norm_num_groups": 32, | |
"num_attention_heads": None, | |
"num_class_embeds": None, | |
"only_cross_attention": False, | |
"out_channels": 4, | |
"projection_class_embeddings_input_dim": 2816, | |
"resnet_out_scale_factor": 1.0, | |
"resnet_skip_time_act": False, | |
"resnet_time_scale_shift": "default", | |
"sample_size": 128, | |
"time_cond_proj_dim": None, | |
"time_embedding_act_fn": None, | |
"time_embedding_dim": None, | |
"time_embedding_type": "positional", | |
"timestep_post_act": None, | |
"transformer_layers_per_block": [ | |
1, | |
2, | |
10 | |
], | |
"up_block_types": [ | |
"CrossAttnUpBlock2D", | |
"CrossAttnUpBlock2D", | |
"UpBlock2D" | |
], | |
"upcast_attention": None, | |
"use_linear_projection": True | |
} | |
CONTROLNET_CONFIG = { | |
'in_channels': [128, 128], | |
'out_channels': [128, 256], | |
'groups': [4, 8], | |
'time_embed_dim': 256, | |
'final_out_channels': 320, | |
'_use_default_values': ['time_embed_dim', 'groups', 'in_channels', 'final_out_channels', 'out_channels'] | |
} | |
def get_pipeline( | |
pretrained_model_name_or_path, | |
unet_model_name_or_path, | |
controlnet_model_name_or_path, | |
vae_model_name_or_path=None, | |
lora_path=None, | |
load_weight_increasement=False, | |
enable_xformers_memory_efficient_attention=False, | |
revision=None, | |
variant=None, | |
hf_cache_dir=None, | |
use_safetensors=True, | |
device=None, | |
): | |
pipeline_init_kwargs = {} | |
print(f"loading unet from {pretrained_model_name_or_path}") | |
if os.path.isfile(pretrained_model_name_or_path): | |
# load unet from local checkpoint | |
unet_sd = load_file(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".safetensors") else torch.load(pretrained_model_name_or_path) | |
unet_sd = utils.extract_unet_state_dict(unet_sd) | |
unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd) | |
unet = UNet2DConditionModel.from_config(UNET_CONFIG) | |
unet.load_state_dict(unet_sd, strict=True) | |
else: | |
unet = UNet2DConditionModel.from_pretrained( | |
pretrained_model_name_or_path, | |
cache_dir=hf_cache_dir, | |
variant=variant, | |
torch_dtype=torch.float16, | |
use_safetensors=use_safetensors, | |
subfolder="unet", | |
) | |
unet = unet.to(dtype=torch.float16) | |
pipeline_init_kwargs["unet"] = unet | |
if vae_model_name_or_path is not None: | |
print(f"loading vae from {vae_model_name_or_path}") | |
vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device) | |
pipeline_init_kwargs["vae"] = vae | |
if controlnet_model_name_or_path is not None: | |
pipeline_init_kwargs["controlnet"] = ControlNetModel.from_config(CONTROLNET_CONFIG).to(device, dtype=torch.float32) # init | |
print(f"loading pipeline from {pretrained_model_name_or_path}") | |
if os.path.isfile(pretrained_model_name_or_path): | |
pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file( | |
pretrained_model_name_or_path, | |
use_safetensors=pretrained_model_name_or_path.endswith(".safetensors"), | |
local_files_only=True, | |
cache_dir=hf_cache_dir, | |
**pipeline_init_kwargs, | |
) | |
else: | |
pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_pretrained( | |
pretrained_model_name_or_path, | |
revision=revision, | |
variant=variant, | |
use_safetensors=use_safetensors, | |
cache_dir=hf_cache_dir, | |
**pipeline_init_kwargs, | |
) | |
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) | |
if unet_model_name_or_path is not None: | |
print(f"loading controlnext unet from {unet_model_name_or_path}") | |
pipeline.load_controlnext_unet_weights( | |
unet_model_name_or_path, | |
load_weight_increasement=load_weight_increasement, | |
use_safetensors=True, | |
torch_dtype=torch.float16, | |
cache_dir=hf_cache_dir, | |
) | |
if controlnet_model_name_or_path is not None: | |
print(f"loading controlnext controlnet from {controlnet_model_name_or_path}") | |
pipeline.load_controlnext_controlnet_weights( | |
controlnet_model_name_or_path, | |
use_safetensors=True, | |
torch_dtype=torch.float32, | |
cache_dir=hf_cache_dir, | |
) | |
pipeline.set_progress_bar_config() | |
pipeline = pipeline.to(device, dtype=torch.float16) | |
if lora_path is not None: | |
pipeline.load_lora_weights(lora_path) | |
if enable_xformers_memory_efficient_attention: | |
pipeline.enable_xformers_memory_efficient_attention() | |
gc.collect() | |
if str(device) == 'cuda' and torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return pipeline | |
def get_scheduler( | |
scheduler_name, | |
scheduler_config, | |
): | |
if scheduler_name == 'Euler A': | |
from diffusers.schedulers import EulerAncestralDiscreteScheduler | |
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) | |
elif scheduler_name == 'UniPC': | |
from diffusers.schedulers import UniPCMultistepScheduler | |
scheduler = UniPCMultistepScheduler.from_config(scheduler_config) | |
elif scheduler_name == 'Euler': | |
from diffusers.schedulers import EulerDiscreteScheduler | |
scheduler = EulerDiscreteScheduler.from_config(scheduler_config) | |
elif scheduler_name == 'DDIM': | |
from diffusers.schedulers import DDIMScheduler | |
scheduler = DDIMScheduler.from_config(scheduler_config) | |
elif scheduler_name == 'DDPM': | |
from diffusers.schedulers import DDPMScheduler | |
scheduler = DDPMScheduler.from_config(scheduler_config) | |
else: | |
raise ValueError(f"Unknown scheduler: {scheduler_name}") | |
return scheduler | |