MotionInversion / utils /lora_handler.py
ziyangmai's picture
page demo
113884e
import os
from logging import warnings
import torch
from typing import Union
from types import SimpleNamespace
from models.unet.unet_3d_condition import UNet3DConditionModel
from transformers import CLIPTextModel
from .convert_diffusers_to_original_ms_text_to_video import convert_unet_state_dict, convert_text_enc_state_dict_v20
from .lora import (
extract_lora_ups_down,
inject_trainable_lora_extended,
save_lora_weight,
train_patch_pipe,
monkeypatch_or_replace_lora,
monkeypatch_or_replace_lora_extended
)
FILE_BASENAMES = ['unet', 'text_encoder']
LORA_FILE_TYPES = ['.pt', '.safetensors']
CLONE_OF_SIMO_KEYS = ['model', 'loras', 'target_replace_module', 'r']
STABLE_LORA_KEYS = ['model', 'target_module', 'search_class', 'r', 'dropout', 'lora_bias']
lora_versions = dict(
stable_lora = "stable_lora",
cloneofsimo = "cloneofsimo"
)
lora_func_types = dict(
loader = "loader",
injector = "injector"
)
lora_args = dict(
model = None,
loras = None,
target_replace_module = [],
target_module = [],
r = 4,
search_class = [torch.nn.Linear],
dropout = 0,
lora_bias = 'none'
)
LoraVersions = SimpleNamespace(**lora_versions)
LoraFuncTypes = SimpleNamespace(**lora_func_types)
LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]
def filter_dict(_dict, keys=[]):
if len(keys) == 0:
assert "Keys cannot empty for filtering return dict."
for k in keys:
if k not in lora_args.keys():
assert f"{k} does not exist in available LoRA arguments"
return {k: v for k, v in _dict.items() if k in keys}
class LoraHandler(object):
def __init__(
self,
version: LORA_VERSIONS = LoraVersions.cloneofsimo,
use_unet_lora: bool = False,
use_text_lora: bool = False,
save_for_webui: bool = False,
only_for_webui: bool = False,
lora_bias: str = 'none',
unet_replace_modules: list = None,
text_encoder_replace_modules: list = None
):
self.version = version
self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
self.lora_bias = lora_bias
self.use_unet_lora = use_unet_lora
self.use_text_lora = use_text_lora
self.save_for_webui = save_for_webui
self.only_for_webui = only_for_webui
self.unet_replace_modules = unet_replace_modules
self.text_encoder_replace_modules = text_encoder_replace_modules
self.use_lora = any([use_text_lora, use_unet_lora])
def is_cloneofsimo_lora(self):
return self.version == LoraVersions.cloneofsimo
def get_lora_func(self, func_type: LORA_FUNC_TYPES = LoraFuncTypes.loader):
if self.is_cloneofsimo_lora():
if func_type == LoraFuncTypes.loader:
return monkeypatch_or_replace_lora_extended
if func_type == LoraFuncTypes.injector:
return inject_trainable_lora_extended
assert "LoRA Version does not exist."
def check_lora_ext(self, lora_file: str):
return lora_file.endswith(tuple(LORA_FILE_TYPES))
def get_lora_file_path(
self,
lora_path: str,
model: Union[UNet3DConditionModel, CLIPTextModel]
):
if os.path.exists(lora_path):
lora_filenames = [fns for fns in os.listdir(lora_path)]
is_lora = self.check_lora_ext(lora_path)
is_unet = isinstance(model, UNet3DConditionModel)
is_text = isinstance(model, CLIPTextModel)
idx = 0 if is_unet else 1
base_name = FILE_BASENAMES[idx]
for lora_filename in lora_filenames:
is_lora = self.check_lora_ext(lora_filename)
if not is_lora:
continue
if base_name in lora_filename:
return os.path.join(lora_path, lora_filename)
return None
def handle_lora_load(self, file_name:str, lora_loader_args: dict = None):
self.lora_loader(**lora_loader_args)
print(f"Successfully loaded LoRA from: {file_name}")
def load_lora(self, model, lora_path: str = '', lora_loader_args: dict = None,):
try:
lora_file = self.get_lora_file_path(lora_path, model)
if lora_file is not None:
lora_loader_args.update({"lora_path": lora_file})
self.handle_lora_load(lora_file, lora_loader_args)
else:
print(f"Could not load LoRAs for {model.__class__.__name__}. Injecting new ones instead...")
except Exception as e:
print(f"An error occurred while loading a LoRA file: {e}")
def get_lora_func_args(self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias, scale):
return_dict = lora_args.copy()
if self.is_cloneofsimo_lora():
return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
return_dict.update({
"model": model,
"loras": self.get_lora_file_path(lora_path, model),
"target_replace_module": replace_modules,
"r": r,
"scale": scale,
"dropout_p": dropout,
})
return return_dict
def do_lora_injection(
self,
model,
replace_modules,
bias='none',
dropout=0,
r=4,
lora_loader_args=None,
):
REPLACE_MODULES = replace_modules
params = None
negation = None
is_injection_hybrid = False
if self.is_cloneofsimo_lora():
is_injection_hybrid = True
injector_args = lora_loader_args
params, negation = self.lora_injector(**injector_args) # inject_trainable_lora_extended
for _up, _down in extract_lora_ups_down(
model,
target_replace_module=REPLACE_MODULES):
if all(x is not None for x in [_up, _down]):
print(f"Lora successfully injected into {model.__class__.__name__}.")
break
return params, negation, is_injection_hybrid
return params, negation, is_injection_hybrid
def add_lora_to_model(self, use_lora, model, replace_modules, dropout=0.0, lora_path='', r=16, scale=1.0):
params = None
negation = None
lora_loader_args = self.get_lora_func_args(
lora_path,
use_lora,
model,
replace_modules,
r,
dropout,
self.lora_bias,
scale
)
if use_lora:
params, negation, is_injection_hybrid = self.do_lora_injection(
model,
replace_modules,
bias=self.lora_bias,
lora_loader_args=lora_loader_args,
dropout=dropout,
r=r
)
if not is_injection_hybrid:
self.load_lora(model, lora_path=lora_path, lora_loader_args=lora_loader_args)
params = model if params is None else params
return params, negation
def save_cloneofsimo_lora(self, model, save_path, step, flag):
def save_lora(model, name, condition, replace_modules, step, save_path, flag=None):
if condition and replace_modules is not None:
save_path = f"{save_path}/{step}_{name}.pt"
save_lora_weight(model, save_path, replace_modules, flag)
save_lora(
model.unet,
FILE_BASENAMES[0],
self.use_unet_lora,
self.unet_replace_modules,
step,
save_path,
flag
)
save_lora(
model.text_encoder,
FILE_BASENAMES[1],
self.use_text_lora,
self.text_encoder_replace_modules,
step,
save_path,
flag
)
# train_patch_pipe(model, self.use_unet_lora, self.use_text_lora)
def save_lora_weights(self, model: None, save_path: str ='',step: str = '', flag=None):
save_path = f"{save_path}/lora"
os.makedirs(save_path, exist_ok=True)
if self.is_cloneofsimo_lora():
if any([self.save_for_webui, self.only_for_webui]):
warnings.warn(
"""
You have 'save_for_webui' enabled, but are using cloneofsimo's LoRA implemention.
Only 'stable_lora' is supported for saving to a compatible webui file.
"""
)
self.save_cloneofsimo_lora(model, save_path, step, flag)
def inject_spatial_loras(unet, use_unet_lora, lora_unet_dropout, lora_path, lora_rank, spatial_lora_num):
lora_managers_spatial = []
unet_lora_params_spatial_list = []
for i in range(spatial_lora_num):
lora_manager_spatial = LoraHandler(
use_unet_lora=use_unet_lora,
unet_replace_modules=["Transformer2DModel"]
)
lora_managers_spatial.append(lora_manager_spatial)
unet_lora_params_spatial, unet_negation_spatial = lora_manager_spatial.add_lora_to_model(
use_unet_lora,
unet,
lora_manager_spatial.unet_replace_modules,
lora_unet_dropout,
lora_path + '/spatial/lora/',
r=lora_rank
)
unet_lora_params_spatial_list.append(unet_lora_params_spatial)
return lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_spatial