|
import gc |
|
import os |
|
from typing import Dict, List, Union |
|
|
|
from diffusers import ( |
|
StableDiffusionXLPipeline, |
|
UNet2DConditionModel, |
|
) |
|
from diffusers.loaders import LoraLoaderMixin |
|
from huggingface_hub import hf_hub_download |
|
import safetensors |
|
import torch |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer, CLIPTextModelWithProjection |
|
|
|
|
|
|
|
SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0" |
|
DPO_REPO = "mhdang/dpo-sdxl-text2image-v1" |
|
JN_REPO = "RunDiffusion/Juggernaut-XL-v9" |
|
JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl" |
|
|
|
|
|
UKIYOE_REPO = "SakanaAI/Evo-Ukiyoe-v1" |
|
|
|
|
|
def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"): |
|
file_extension = os.path.basename(checkpoint_file).split(".")[-1] |
|
if file_extension == "safetensors": |
|
return safetensors.torch.load_file(checkpoint_file, device=device) |
|
else: |
|
return torch.load(checkpoint_file, map_location=device) |
|
|
|
|
|
def load_from_pretrained( |
|
repo_id, |
|
filename="diffusion_pytorch_model.fp16.safetensors", |
|
subfolder="unet", |
|
device="cuda", |
|
) -> Dict[str, torch.Tensor]: |
|
return load_state_dict( |
|
hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
subfolder=subfolder, |
|
), |
|
device=device, |
|
) |
|
|
|
|
|
def reshape_weight_task_tensors(task_tensors, weights): |
|
""" |
|
Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions. |
|
|
|
Args: |
|
task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`. |
|
weights (`torch.Tensor`): The tensor to be reshaped. |
|
|
|
Returns: |
|
`torch.Tensor`: The reshaped tensor. |
|
""" |
|
new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim()) |
|
weights = weights.view(new_shape) |
|
return weights |
|
|
|
|
|
def linear(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Merge the task tensors using `linear`. |
|
|
|
Args: |
|
task_tensors(`List[torch.Tensor]`):The task tensors to merge. |
|
weights (`torch.Tensor`):The weights of the task tensors. |
|
|
|
Returns: |
|
`torch.Tensor`: The merged tensor. |
|
""" |
|
task_tensors = torch.stack(task_tensors, dim=0) |
|
|
|
weights = reshape_weight_task_tensors(task_tensors, weights) |
|
weighted_task_tensors = task_tensors * weights |
|
mixed_task_tensors = weighted_task_tensors.sum(dim=0) |
|
return mixed_task_tensors |
|
|
|
|
|
def merge_models(task_tensors, weights): |
|
keys = list(task_tensors[0].keys()) |
|
weights = torch.tensor(weights, device=task_tensors[0][keys[0]].device) |
|
state_dict = {} |
|
for key in tqdm(keys, desc="Merging"): |
|
w_list = [] |
|
for i, sd in enumerate(task_tensors): |
|
w = sd.pop(key) |
|
w_list.append(w) |
|
new_w = linear(task_tensors=w_list, weights=weights) |
|
state_dict[key] = new_w |
|
return state_dict |
|
|
|
|
|
def split_conv_attn(weights): |
|
attn_tensors = {} |
|
conv_tensors = {} |
|
for key in list(weights.keys()): |
|
if any(k in key for k in ["to_k", "to_q", "to_v", "to_out.0"]): |
|
attn_tensors[key] = weights.pop(key) |
|
else: |
|
conv_tensors[key] = weights.pop(key) |
|
return {"conv": conv_tensors, "attn": attn_tensors} |
|
|
|
|
|
def load_evoukiyoe(device="cuda") -> StableDiffusionXLPipeline: |
|
|
|
sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device)) |
|
dpo_weights = split_conv_attn( |
|
load_from_pretrained( |
|
DPO_REPO, "diffusion_pytorch_model.safetensors", device=device |
|
) |
|
) |
|
jn_weights = split_conv_attn(load_from_pretrained(JN_REPO, device=device)) |
|
jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device)) |
|
|
|
tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights] |
|
new_conv = merge_models( |
|
[sd["conv"] for sd in tensors], |
|
[ |
|
0.15928833971605916, |
|
0.1032449268871776, |
|
0.6503217149752791, |
|
0.08714501842148402, |
|
], |
|
) |
|
new_attn = merge_models( |
|
[sd["attn"] for sd in tensors], |
|
[ |
|
0.1877279276437178, |
|
0.20014114603909822, |
|
0.3922685507065275, |
|
0.2198623756106564, |
|
], |
|
) |
|
del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights |
|
gc.collect() |
|
if "cuda" in device: |
|
torch.cuda.empty_cache() |
|
|
|
unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet") |
|
unet = UNet2DConditionModel.from_config(unet_config).to(device=device) |
|
unet.load_state_dict({**new_conv, **new_attn}) |
|
|
|
|
|
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict( |
|
pretrained_model_name_or_path_or_dict=UKIYOE_REPO |
|
) |
|
LoraLoaderMixin.load_lora_into_unet(state_dict, network_alphas, unet) |
|
unet.fuse_lora(1.0) |
|
|
|
|
|
text_encoder = CLIPTextModelWithProjection.from_pretrained( |
|
JSDXL_REPO, |
|
subfolder="text_encoder", |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
JSDXL_REPO, |
|
subfolder="tokenizer", |
|
use_fast=False, |
|
) |
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
SDXL_REPO, |
|
unet=unet, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
) |
|
pipe = pipe.to(device, dtype=torch.float16) |
|
return pipe |