T2V-Turbo / utils /lora_handler.py
Ji4chenLi
initial test
9f200a2
raw
history blame
4.41 kB
import torch
from types import SimpleNamespace
from .lora import (
extract_lora_ups_down,
inject_trainable_lora_extended,
monkeypatch_or_replace_lora_extended,
)
CLONE_OF_SIMO_KEYS = ["model", "loras", "target_replace_module", "r"]
lora_versions = dict(stable_lora="stable_lora", cloneofsimo="cloneofsimo")
lora_func_types = dict(loader="loader", injector="injector")
lora_args = dict(
model=None,
loras=None,
target_replace_module=[],
target_module=[],
r=4,
search_class=[torch.nn.Linear],
dropout=0,
lora_bias="none",
)
LoraVersions = SimpleNamespace(**lora_versions)
LoraFuncTypes = SimpleNamespace(**lora_func_types)
LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]
def filter_dict(_dict, keys=[]):
if len(keys) == 0:
assert "Keys cannot empty for filtering return dict."
for k in keys:
if k not in lora_args.keys():
assert f"{k} does not exist in available LoRA arguments"
return {k: v for k, v in _dict.items() if k in keys}
class LoraHandler(object):
def __init__(
self,
version: str = LoraVersions.cloneofsimo,
use_unet_lora: bool = False,
use_text_lora: bool = False,
save_for_webui: bool = False,
only_for_webui: bool = False,
lora_bias: str = "none",
unet_replace_modules: list = ["UNet3DConditionModel"],
):
self.version = version
assert self.is_cloneofsimo_lora()
self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
self.lora_bias = lora_bias
self.use_unet_lora = use_unet_lora
self.use_text_lora = use_text_lora
self.save_for_webui = save_for_webui
self.only_for_webui = only_for_webui
self.unet_replace_modules = unet_replace_modules
self.use_lora = any([use_text_lora, use_unet_lora])
if self.use_lora:
print(f"Using LoRA Version: {self.version}")
def is_cloneofsimo_lora(self):
return self.version == LoraVersions.cloneofsimo
def get_lora_func(self, func_type: str = LoraFuncTypes.loader):
if func_type == LoraFuncTypes.loader:
return monkeypatch_or_replace_lora_extended
if func_type == LoraFuncTypes.injector:
return inject_trainable_lora_extended
assert "LoRA Version does not exist."
def get_lora_func_args(
self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias
):
return_dict = lora_args.copy()
return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
return_dict.update(
{
"model": model,
"loras": lora_path,
"target_replace_module": replace_modules,
"r": r,
}
)
return return_dict
def do_lora_injection(
self,
model,
replace_modules,
bias="none",
dropout=0,
r=4,
lora_loader_args=None,
):
REPLACE_MODULES = replace_modules
params = None
negation = None
injector_args = lora_loader_args
params, negation = self.lora_injector(**injector_args)
for _up, _down in extract_lora_ups_down(
model, target_replace_module=REPLACE_MODULES
):
if all(x is not None for x in [_up, _down]):
print(
f"Lora successfully injected into {model.__class__.__name__}."
)
break
return params, negation
def add_lora_to_model(
self, use_lora, model, replace_modules, dropout=0.0, lora_path=None, r=16
):
params = None
negation = None
lora_loader_args = self.get_lora_func_args(
lora_path, use_lora, model, replace_modules, r, dropout, self.lora_bias
)
if use_lora:
params, negation = self.do_lora_injection(
model,
replace_modules,
bias=self.lora_bias,
lora_loader_args=lora_loader_args,
dropout=dropout,
r=r,
)
params = model if params is None else params
return params, negation