|
from typing import Literal, Union, Dict |
|
import os |
|
import shutil |
|
import fire |
|
from diffusers import StableDiffusionPipeline |
|
from safetensors.torch import safe_open, save_file |
|
|
|
import torch |
|
from .lora import ( |
|
tune_lora_scale, |
|
patch_pipe, |
|
collapse_lora, |
|
monkeypatch_remove_lora, |
|
) |
|
from .lora_manager import lora_join |
|
from .to_ckpt_v2 import convert_to_ckpt |
|
|
|
|
|
def _text_lora_path(path: str) -> str: |
|
assert path.endswith(".pt"), "Only .pt files are supported" |
|
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) |
|
|
|
|
|
def add( |
|
path_1: str, |
|
path_2: str, |
|
output_path: str, |
|
alpha_1: float = 0.5, |
|
alpha_2: float = 0.5, |
|
mode: Literal[ |
|
"lpl", |
|
"upl", |
|
"upl-ckpt-v2", |
|
] = "lpl", |
|
with_text_lora: bool = False, |
|
): |
|
print("Lora Add, mode " + mode) |
|
if mode == "lpl": |
|
if path_1.endswith(".pt") and path_2.endswith(".pt"): |
|
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + ( |
|
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")] |
|
if with_text_lora |
|
else [] |
|
): |
|
print("Loading", _path_1, _path_2) |
|
out_list = [] |
|
if opt == "text_encoder": |
|
if not os.path.exists(_path_1): |
|
print(f"No text encoder found in {_path_1}, skipping...") |
|
continue |
|
if not os.path.exists(_path_2): |
|
print(f"No text encoder found in {_path_1}, skipping...") |
|
continue |
|
|
|
l1 = torch.load(_path_1) |
|
l2 = torch.load(_path_2) |
|
|
|
l1pairs = zip(l1[::2], l1[1::2]) |
|
l2pairs = zip(l2[::2], l2[1::2]) |
|
|
|
for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs): |
|
|
|
x1.data = alpha_1 * x1.data + alpha_2 * x2.data |
|
y1.data = alpha_1 * y1.data + alpha_2 * y2.data |
|
|
|
out_list.append(x1) |
|
out_list.append(y1) |
|
|
|
if opt == "unet": |
|
|
|
print("Saving merged UNET to", output_path) |
|
torch.save(out_list, output_path) |
|
|
|
elif opt == "text_encoder": |
|
print("Saving merged text encoder to", _text_lora_path(output_path)) |
|
torch.save( |
|
out_list, |
|
_text_lora_path(output_path), |
|
) |
|
|
|
elif path_1.endswith(".safetensors") and path_2.endswith(".safetensors"): |
|
safeloras_1 = safe_open(path_1, framework="pt", device="cpu") |
|
safeloras_2 = safe_open(path_2, framework="pt", device="cpu") |
|
|
|
metadata = dict(safeloras_1.metadata()) |
|
metadata.update(dict(safeloras_2.metadata())) |
|
|
|
ret_tensor = {} |
|
|
|
for keys in set(list(safeloras_1.keys()) + list(safeloras_2.keys())): |
|
if keys.startswith("text_encoder") or keys.startswith("unet"): |
|
|
|
tens1 = safeloras_1.get_tensor(keys) |
|
tens2 = safeloras_2.get_tensor(keys) |
|
|
|
tens = alpha_1 * tens1 + alpha_2 * tens2 |
|
ret_tensor[keys] = tens |
|
else: |
|
if keys in safeloras_1.keys(): |
|
|
|
tens1 = safeloras_1.get_tensor(keys) |
|
else: |
|
tens1 = safeloras_2.get_tensor(keys) |
|
|
|
ret_tensor[keys] = tens1 |
|
|
|
save_file(ret_tensor, output_path, metadata) |
|
|
|
elif mode == "upl": |
|
|
|
print( |
|
f"Merging UNET/CLIP from {path_1} with LoRA from {path_2} to {output_path}. Merging ratio : {alpha_1}." |
|
) |
|
|
|
loaded_pipeline = StableDiffusionPipeline.from_pretrained( |
|
path_1, |
|
).to("cpu") |
|
|
|
patch_pipe(loaded_pipeline, path_2) |
|
|
|
collapse_lora(loaded_pipeline.unet, alpha_1) |
|
collapse_lora(loaded_pipeline.text_encoder, alpha_1) |
|
|
|
monkeypatch_remove_lora(loaded_pipeline.unet) |
|
monkeypatch_remove_lora(loaded_pipeline.text_encoder) |
|
|
|
loaded_pipeline.save_pretrained(output_path) |
|
|
|
elif mode == "upl-ckpt-v2": |
|
|
|
assert output_path.endswith(".ckpt"), "Only .ckpt files are supported" |
|
name = os.path.basename(output_path)[0:-5] |
|
|
|
print( |
|
f"You will be using {name} as the token in A1111 webui. Make sure {name} is unique enough token." |
|
) |
|
|
|
loaded_pipeline = StableDiffusionPipeline.from_pretrained( |
|
path_1, |
|
).to("cpu") |
|
|
|
tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False) |
|
|
|
collapse_lora(loaded_pipeline.unet, alpha_1) |
|
collapse_lora(loaded_pipeline.text_encoder, alpha_1) |
|
|
|
monkeypatch_remove_lora(loaded_pipeline.unet) |
|
monkeypatch_remove_lora(loaded_pipeline.text_encoder) |
|
|
|
_tmp_output = output_path + ".tmp" |
|
|
|
loaded_pipeline.save_pretrained(_tmp_output) |
|
convert_to_ckpt(_tmp_output, output_path, as_half=True) |
|
|
|
shutil.rmtree(_tmp_output) |
|
|
|
keys = sorted(tok_dict.keys()) |
|
tok_catted = torch.stack([tok_dict[k] for k in keys]) |
|
ret = { |
|
"string_to_token": {"*": torch.tensor(265)}, |
|
"string_to_param": {"*": tok_catted}, |
|
"name": name, |
|
} |
|
|
|
torch.save(ret, output_path[:-5] + ".pt") |
|
print( |
|
f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, " |
|
) |
|
elif mode == "ljl": |
|
print("Using Join mode : alpha will not have an effect here.") |
|
assert path_1.endswith(".safetensors") and path_2.endswith( |
|
".safetensors" |
|
), "Only .safetensors files are supported" |
|
|
|
safeloras_1 = safe_open(path_1, framework="pt", device="cpu") |
|
safeloras_2 = safe_open(path_2, framework="pt", device="cpu") |
|
|
|
total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2]) |
|
save_file(total_tensor, output_path, total_metadata) |
|
|
|
else: |
|
print("Unknown mode", mode) |
|
raise ValueError(f"Unknown mode {mode}") |
|
|
|
|
|
def main(): |
|
fire.Fire(add) |
|
|