from typing import Literal, Union, Dict import os import shutil import fire from diffusers import StableDiffusionPipeline import torch from .lora import tune_lora_scale, weight_apply_lora from .to_ckpt_v2 import convert_to_ckpt def _text_lora_path(path: str) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) def add( path_1: str, path_2: str, output_path: str, alpha: float = 0.5, mode: Literal[ "lpl", "upl", "upl-ckpt-v2", ] = "lpl", with_text_lora: bool = False, ): print("Lora Add, mode " + mode) if mode == "lpl": for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + ( [(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")] if with_text_lora else [] ): print("Loading", _path_1, _path_2) out_list = [] if opt == "text_encoder": if not os.path.exists(_path_1): print(f"No text encoder found in {_path_1}, skipping...") continue if not os.path.exists(_path_2): print(f"No text encoder found in {_path_1}, skipping...") continue l1 = torch.load(_path_1) l2 = torch.load(_path_2) l1pairs = zip(l1[::2], l1[1::2]) l2pairs = zip(l2[::2], l2[1::2]) for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs): # print("Merging", x1.shape, y1.shape, x2.shape, y2.shape) x1.data = alpha * x1.data + (1 - alpha) * x2.data y1.data = alpha * y1.data + (1 - alpha) * y2.data out_list.append(x1) out_list.append(y1) if opt == "unet": print("Saving merged UNET to", output_path) torch.save(out_list, output_path) elif opt == "text_encoder": print("Saving merged text encoder to", _text_lora_path(output_path)) torch.save( out_list, _text_lora_path(output_path), ) elif mode == "upl": loaded_pipeline = StableDiffusionPipeline.from_pretrained( path_1, ).to("cpu") weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha) if with_text_lora: weight_apply_lora( loaded_pipeline.text_encoder, torch.load(_text_lora_path(path_2)), alpha=alpha, target_replace_module=["CLIPAttention"], ) loaded_pipeline.save_pretrained(output_path) elif mode == "upl-ckpt-v2": loaded_pipeline = StableDiffusionPipeline.from_pretrained( path_1, ).to("cpu") weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha) if with_text_lora: weight_apply_lora( loaded_pipeline.text_encoder, torch.load(_text_lora_path(path_2)), alpha=alpha, target_replace_module=["CLIPAttention"], ) _tmp_output = output_path + ".tmp" loaded_pipeline.save_pretrained(_tmp_output) convert_to_ckpt(_tmp_output, output_path, as_half=True) # remove the tmp_output folder shutil.rmtree(_tmp_output) else: print("Unknown mode", mode) raise ValueError(f"Unknown mode {mode}") def main(): fire.Fire(add)