|
from typing import List |
|
import torch |
|
from safetensors import safe_open |
|
from diffusers import StableDiffusionPipeline |
|
from .lora import ( |
|
monkeypatch_or_replace_safeloras, |
|
apply_learned_embed_in_clip, |
|
set_lora_diag, |
|
parse_safeloras_embeds, |
|
) |
|
|
|
|
|
def lora_join(lora_safetenors: list): |
|
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors] |
|
_total_metadata = {} |
|
total_metadata = {} |
|
total_tensor = {} |
|
total_rank = 0 |
|
ranklist = [] |
|
for _metadata in metadatas: |
|
rankset = [] |
|
for k, v in _metadata.items(): |
|
if k.endswith("rank"): |
|
rankset.append(int(v)) |
|
|
|
assert len(set(rankset)) <= 1, "Rank should be the same per model" |
|
if len(rankset) == 0: |
|
rankset = [0] |
|
|
|
total_rank += rankset[0] |
|
_total_metadata.update(_metadata) |
|
ranklist.append(rankset[0]) |
|
|
|
|
|
for k, v in _total_metadata.items(): |
|
if v != "<embed>": |
|
total_metadata[k] = v |
|
|
|
tensorkeys = set() |
|
for safelora in lora_safetenors: |
|
tensorkeys.update(safelora.keys()) |
|
|
|
for keys in tensorkeys: |
|
if keys.startswith("text_encoder") or keys.startswith("unet"): |
|
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors] |
|
|
|
is_down = keys.endswith("down") |
|
|
|
if is_down: |
|
_tensor = torch.cat(tensorset, dim=0) |
|
assert _tensor.shape[0] == total_rank |
|
else: |
|
_tensor = torch.cat(tensorset, dim=1) |
|
assert _tensor.shape[1] == total_rank |
|
|
|
total_tensor[keys] = _tensor |
|
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank" |
|
total_metadata[keys_rank] = str(total_rank) |
|
token_size_list = [] |
|
for idx, safelora in enumerate(lora_safetenors): |
|
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"] |
|
for jdx, token in enumerate(sorted(tokens)): |
|
|
|
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token) |
|
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>" |
|
|
|
print(f"Embedding {token} replaced to <s{idx}-{jdx}>") |
|
|
|
token_size_list.append(len(tokens)) |
|
|
|
return total_tensor, total_metadata, ranklist, token_size_list |
|
|
|
|
|
class DummySafeTensorObject: |
|
def __init__(self, tensor: dict, metadata): |
|
self.tensor = tensor |
|
self._metadata = metadata |
|
|
|
def keys(self): |
|
return self.tensor.keys() |
|
|
|
def metadata(self): |
|
return self._metadata |
|
|
|
def get_tensor(self, key): |
|
return self.tensor[key] |
|
|
|
|
|
class LoRAManager: |
|
def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline): |
|
|
|
self.lora_paths_list = lora_paths_list |
|
self.pipe = pipe |
|
self._setup() |
|
|
|
def _setup(self): |
|
|
|
self._lora_safetenors = [ |
|
safe_open(path, framework="pt", device="cpu") |
|
for path in self.lora_paths_list |
|
] |
|
|
|
( |
|
total_tensor, |
|
total_metadata, |
|
self.ranklist, |
|
self.token_size_list, |
|
) = lora_join(self._lora_safetenors) |
|
|
|
self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata) |
|
|
|
monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora) |
|
tok_dict = parse_safeloras_embeds(self.total_safelora) |
|
|
|
apply_learned_embed_in_clip( |
|
tok_dict, |
|
self.pipe.text_encoder, |
|
self.pipe.tokenizer, |
|
token=None, |
|
idempotent=True, |
|
) |
|
|
|
def tune(self, scales): |
|
|
|
assert len(scales) == len( |
|
self.ranklist |
|
), "Scale list should be the same length as ranklist" |
|
|
|
diags = [] |
|
for scale, rank in zip(scales, self.ranklist): |
|
diags = diags + [scale] * rank |
|
|
|
set_lora_diag(self.pipe.unet, torch.tensor(diags)) |
|
|
|
def prompt(self, prompt): |
|
if prompt is not None: |
|
for idx, tok_size in enumerate(self.token_size_list): |
|
prompt = prompt.replace( |
|
f"<{idx + 1}>", |
|
"".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]), |
|
) |
|
|
|
|
|
return prompt |
|
|