|
import os |
|
|
|
import fire |
|
import torch |
|
from lora_diffusion import ( |
|
DEFAULT_TARGET_REPLACE, |
|
TEXT_ENCODER_DEFAULT_TARGET_REPLACE, |
|
UNET_DEFAULT_TARGET_REPLACE, |
|
convert_loras_to_safeloras_with_embeds, |
|
safetensors_available, |
|
) |
|
|
|
_target_by_name = { |
|
"unet": UNET_DEFAULT_TARGET_REPLACE, |
|
"text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE, |
|
} |
|
|
|
|
|
def convert(*paths, outpath, overwrite=False, **settings): |
|
""" |
|
Converts one or more pytorch Lora and/or Textual Embedding pytorch files |
|
into a safetensor file. |
|
|
|
Pass all the input paths as arguments. Whether they are Textual Embedding |
|
or Lora models will be auto-detected. |
|
|
|
For Lora models, their name will be taken from the path, i.e. |
|
"lora_weight.pt" => unet |
|
"lora_weight.text_encoder.pt" => text_encoder |
|
|
|
You can also set target_modules and/or rank by providing an argument prefixed |
|
by the name. |
|
|
|
So a complete example might be something like: |
|
|
|
``` |
|
python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8 |
|
``` |
|
""" |
|
modelmap = {} |
|
embeds = {} |
|
|
|
if os.path.exists(outpath) and not overwrite: |
|
raise ValueError( |
|
f"Output path {outpath} already exists, and overwrite is not True" |
|
) |
|
|
|
for path in paths: |
|
data = torch.load(path) |
|
|
|
if isinstance(data, dict): |
|
print(f"Loading textual inversion embeds {data.keys()} from {path}") |
|
embeds.update(data) |
|
|
|
else: |
|
name_parts = os.path.split(path)[1].split(".") |
|
name = name_parts[-2] if len(name_parts) > 2 else "unet" |
|
|
|
model_settings = { |
|
"target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE), |
|
"rank": 4, |
|
} |
|
|
|
prefix = f"{name}." |
|
|
|
arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) } |
|
model_settings = { **model_settings, **arg_settings } |
|
|
|
print(f"Loading Lora for {name} from {path} with settings {model_settings}") |
|
|
|
modelmap[name] = ( |
|
path, |
|
model_settings["target_modules"], |
|
model_settings["rank"], |
|
) |
|
|
|
convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath) |
|
|
|
|
|
def main(): |
|
fire.Fire(convert) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|