import math from typing import Callable, Dict, List, Optional, Tuple import numpy as np import PIL import torch import torch.nn.functional as F import torch.nn as nn class LoraInjectedLinear(nn.Module): def __init__(self, in_features, out_features, bias=False, r=4): super().__init__() if r > min(in_features, out_features): raise ValueError( f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" ) self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) self.lora_up = nn.Linear(r, out_features, bias=False) self.scale = 1.0 nn.init.normal_(self.lora_down.weight, std=1 / r**2) nn.init.zeros_(self.lora_up.weight) def forward(self, input): return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale def inject_trainable_lora( model: nn.Module, target_replace_module: List[str] = ["CrossAttention", "Attention"], r: int = 4, loras=None, # path to lora .pt ): """ inject lora into model, and returns lora parameter groups. """ require_grad_params = [] names = [] if loras != None: loras = torch.load(loras) for _module in model.modules(): if _module.__class__.__name__ in target_replace_module: for name, _child_module in _module.named_modules(): if _child_module.__class__.__name__ == "Linear": weight = _child_module.weight bias = _child_module.bias _tmp = LoraInjectedLinear( _child_module.in_features, _child_module.out_features, _child_module.bias is not None, r, ) _tmp.linear.weight = weight if bias is not None: _tmp.linear.bias = bias # switch the module _module._modules[name] = _tmp require_grad_params.append( _module._modules[name].lora_up.parameters() ) require_grad_params.append( _module._modules[name].lora_down.parameters() ) if loras != None: _module._modules[name].lora_up.weight = loras.pop(0) _module._modules[name].lora_down.weight = loras.pop(0) _module._modules[name].lora_up.weight.requires_grad = True _module._modules[name].lora_down.weight.requires_grad = True names.append(name) return require_grad_params, names def extract_lora_ups_down(model, target_replace_module=["CrossAttention", "Attention"]): loras = [] for _module in model.modules(): if _module.__class__.__name__ in target_replace_module: for _child_module in _module.modules(): if _child_module.__class__.__name__ == "LoraInjectedLinear": loras.append((_child_module.lora_up, _child_module.lora_down)) if len(loras) == 0: raise ValueError("No lora injected.") return loras def save_lora_weight( model, path="./lora.pt", target_replace_module=["CrossAttention", "Attention"] ): weights = [] for _up, _down in extract_lora_ups_down( model, target_replace_module=target_replace_module ): weights.append(_up.weight) weights.append(_down.weight) torch.save(weights, path) def save_lora_as_json(model, path="./lora.json"): weights = [] for _up, _down in extract_lora_ups_down(model): weights.append(_up.weight.detach().cpu().numpy().tolist()) weights.append(_down.weight.detach().cpu().numpy().tolist()) import json with open(path, "w") as f: json.dump(weights, f) def weight_apply_lora( model, loras, target_replace_module=["CrossAttention", "Attention"], alpha=1.0 ): for _module in model.modules(): if _module.__class__.__name__ in target_replace_module: for _child_module in _module.modules(): if _child_module.__class__.__name__ == "Linear": weight = _child_module.weight up_weight = loras.pop(0).detach().to(weight.device) down_weight = loras.pop(0).detach().to(weight.device) # W <- W + U * D weight = weight + alpha * (up_weight @ down_weight).type( weight.dtype ) _child_module.weight = nn.Parameter(weight) def monkeypatch_lora( model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4 ): for _module in model.modules(): if _module.__class__.__name__ in target_replace_module: for name, _child_module in _module.named_modules(): if _child_module.__class__.__name__ == "Linear": weight = _child_module.weight bias = _child_module.bias _tmp = LoraInjectedLinear( _child_module.in_features, _child_module.out_features, _child_module.bias is not None, r=r, ) _tmp.linear.weight = weight if bias is not None: _tmp.linear.bias = bias # switch the module _module._modules[name] = _tmp up_weight = loras.pop(0) down_weight = loras.pop(0) _module._modules[name].lora_up.weight = nn.Parameter( up_weight.type(weight.dtype) ) _module._modules[name].lora_down.weight = nn.Parameter( down_weight.type(weight.dtype) ) _module._modules[name].to(weight.device) def monkeypatch_replace_lora( model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4 ): for _module in model.modules(): if _module.__class__.__name__ in target_replace_module: for name, _child_module in _module.named_modules(): if _child_module.__class__.__name__ == "LoraInjectedLinear": weight = _child_module.linear.weight bias = _child_module.linear.bias _tmp = LoraInjectedLinear( _child_module.linear.in_features, _child_module.linear.out_features, _child_module.linear.bias is not None, r=r, ) _tmp.linear.weight = weight if bias is not None: _tmp.linear.bias = bias # switch the module _module._modules[name] = _tmp up_weight = loras.pop(0) down_weight = loras.pop(0) _module._modules[name].lora_up.weight = nn.Parameter( up_weight.type(weight.dtype) ) _module._modules[name].lora_down.weight = nn.Parameter( down_weight.type(weight.dtype) ) _module._modules[name].to(weight.device) def monkeypatch_add_lora( model, loras, target_replace_module=["CrossAttention", "Attention"], alpha: float = 1.0, beta: float = 1.0, ): for _module in model.modules(): if _module.__class__.__name__ in target_replace_module: for name, _child_module in _module.named_modules(): if _child_module.__class__.__name__ == "LoraInjectedLinear": weight = _child_module.linear.weight up_weight = loras.pop(0) down_weight = loras.pop(0) _module._modules[name].lora_up.weight = nn.Parameter( up_weight.type(weight.dtype).to(weight.device) * alpha + _module._modules[name].lora_up.weight.to(weight.device) * beta ) _module._modules[name].lora_down.weight = nn.Parameter( down_weight.type(weight.dtype).to(weight.device) * alpha + _module._modules[name].lora_down.weight.to(weight.device) * beta ) _module._modules[name].to(weight.device) def tune_lora_scale(model, alpha: float = 1.0): for _module in model.modules(): if _module.__class__.__name__ == "LoraInjectedLinear": _module.scale = alpha def _text_lora_path(path: str) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) def _ti_lora_path(path: str) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["ti", "pt"]) def load_learned_embed_in_clip( learned_embeds_path, text_encoder, tokenizer, token=None, idempotent=False ): loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") # separate token and the embeds trained_token = list(loaded_learned_embeds.keys())[0] embeds = loaded_learned_embeds[trained_token] # cast to dtype of text_encoder dtype = text_encoder.get_input_embeddings().weight.dtype # add the token in tokenizer token = token if token is not None else trained_token num_added_tokens = tokenizer.add_tokens(token) i = 1 if num_added_tokens == 0 and idempotent: return token while num_added_tokens == 0: print(f"The tokenizer already contains the token {token}.") token = f"{token[:-1]}-{i}>" print(f"Attempting to add the token {token}.") num_added_tokens = tokenizer.add_tokens(token) i += 1 # resize the token embeddings text_encoder.resize_token_embeddings(len(tokenizer)) # get the id for the token and assign the embeds token_id = tokenizer.convert_tokens_to_ids(token) text_encoder.get_input_embeddings().weight.data[token_id] = embeds return token def patch_pipe( pipe, unet_path, token, alpha: float = 1.0, r: int = 4, patch_text=False, patch_ti=False, idempotent_token=True, ): ti_path = _ti_lora_path(unet_path) text_path = _text_lora_path(unet_path) unet_has_lora = False text_encoder_has_lora = False for _module in pipe.unet.modules(): if _module.__class__.__name__ == "LoraInjectedLinear": unet_has_lora = True for _module in pipe.text_encoder.modules(): if _module.__class__.__name__ == "LoraInjectedLinear": text_encoder_has_lora = True if not unet_has_lora: monkeypatch_lora(pipe.unet, torch.load(unet_path), r=r) else: monkeypatch_replace_lora(pipe.unet, torch.load(unet_path), r=r) if patch_text: if not text_encoder_has_lora: monkeypatch_lora( pipe.text_encoder, torch.load(text_path), target_replace_module=["CLIPAttention"], r=r, ) else: monkeypatch_replace_lora( pipe.text_encoder, torch.load(text_path), target_replace_module=["CLIPAttention"], r=r, ) if patch_ti: token = load_learned_embed_in_clip( ti_path, pipe.text_encoder, pipe.tokenizer, token, idempotent=idempotent_token, )