mkshing commited on
Commit
ec826a6
1 Parent(s): 4d317d3

Create evoukiyoe_v1.py

Browse files
Files changed (1) hide show
  1. evoukiyoe_v1.py +179 -0
evoukiyoe_v1.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from typing import Dict, List, Union
4
+
5
+ from diffusers import (
6
+ StableDiffusionXLPipeline,
7
+ UNet2DConditionModel,
8
+ )
9
+ from diffusers.loaders import LoraLoaderMixin
10
+ from huggingface_hub import hf_hub_download
11
+ import safetensors
12
+ import torch
13
+ from tqdm import tqdm
14
+ from transformers import AutoTokenizer, CLIPTextModelWithProjection
15
+
16
+
17
+ # Base models (fine-tuned from SDXL-1.0)
18
+ SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
19
+ DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
20
+ JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
21
+ JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
22
+
23
+ # Evo-Ukiyoe
24
+ UKIYOE_REPO = "SakanaAI/Evo-Ukiyoe-v1"
25
+
26
+
27
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
28
+ file_extension = os.path.basename(checkpoint_file).split(".")[-1]
29
+ if file_extension == "safetensors":
30
+ return safetensors.torch.load_file(checkpoint_file, device=device)
31
+ else:
32
+ return torch.load(checkpoint_file, map_location=device)
33
+
34
+
35
+ def load_from_pretrained(
36
+ repo_id,
37
+ filename="diffusion_pytorch_model.fp16.safetensors",
38
+ subfolder="unet",
39
+ device="cuda",
40
+ ) -> Dict[str, torch.Tensor]:
41
+ return load_state_dict(
42
+ hf_hub_download(
43
+ repo_id=repo_id,
44
+ filename=filename,
45
+ subfolder=subfolder,
46
+ ),
47
+ device=device,
48
+ )
49
+
50
+
51
+ def reshape_weight_task_tensors(task_tensors, weights):
52
+ """
53
+ Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions.
54
+
55
+ Args:
56
+ task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`.
57
+ weights (`torch.Tensor`): The tensor to be reshaped.
58
+
59
+ Returns:
60
+ `torch.Tensor`: The reshaped tensor.
61
+ """
62
+ new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim())
63
+ weights = weights.view(new_shape)
64
+ return weights
65
+
66
+
67
+ def linear(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor:
68
+ """
69
+ Merge the task tensors using `linear`.
70
+
71
+ Args:
72
+ task_tensors(`List[torch.Tensor]`):The task tensors to merge.
73
+ weights (`torch.Tensor`):The weights of the task tensors.
74
+
75
+ Returns:
76
+ `torch.Tensor`: The merged tensor.
77
+ """
78
+ task_tensors = torch.stack(task_tensors, dim=0)
79
+ # weighted task tensors
80
+ weights = reshape_weight_task_tensors(task_tensors, weights)
81
+ weighted_task_tensors = task_tensors * weights
82
+ mixed_task_tensors = weighted_task_tensors.sum(dim=0)
83
+ return mixed_task_tensors
84
+
85
+
86
+ def merge_models(task_tensors, weights):
87
+ keys = list(task_tensors[0].keys())
88
+ weights = torch.tensor(weights, device=task_tensors[0][keys[0]].device)
89
+ state_dict = {}
90
+ for key in tqdm(keys, desc="Merging"):
91
+ w_list = []
92
+ for i, sd in enumerate(task_tensors):
93
+ w = sd.pop(key)
94
+ w_list.append(w)
95
+ new_w = linear(task_tensors=w_list, weights=weights)
96
+ state_dict[key] = new_w
97
+ return state_dict
98
+
99
+
100
+ def split_conv_attn(weights):
101
+ attn_tensors = {}
102
+ conv_tensors = {}
103
+ for key in list(weights.keys()):
104
+ if any(k in key for k in ["to_k", "to_q", "to_v", "to_out.0"]):
105
+ attn_tensors[key] = weights.pop(key)
106
+ else:
107
+ conv_tensors[key] = weights.pop(key)
108
+ return {"conv": conv_tensors, "attn": attn_tensors}
109
+
110
+
111
+ def load_evoukiyoe(device="cuda") -> StableDiffusionXLPipeline:
112
+ # Load base models
113
+ sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
114
+ dpo_weights = split_conv_attn(
115
+ load_from_pretrained(
116
+ DPO_REPO, "diffusion_pytorch_model.safetensors", device=device
117
+ )
118
+ )
119
+ jn_weights = split_conv_attn(load_from_pretrained(JN_REPO, device=device))
120
+ jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))
121
+ # Merge base models
122
+ tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
123
+ new_conv = merge_models(
124
+ [sd["conv"] for sd in tensors],
125
+ [
126
+ 0.15928833971605916,
127
+ 0.1032449268871776,
128
+ 0.6503217149752791,
129
+ 0.08714501842148402,
130
+ ],
131
+ )
132
+ new_attn = merge_models(
133
+ [sd["attn"] for sd in tensors],
134
+ [
135
+ 0.1877279276437178,
136
+ 0.20014114603909822,
137
+ 0.3922685507065275,
138
+ 0.2198623756106564,
139
+ ],
140
+ )
141
+ del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
142
+ gc.collect()
143
+ if "cuda" in device:
144
+ torch.cuda.empty_cache()
145
+
146
+ unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
147
+ unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
148
+ unet.load_state_dict({**new_conv, **new_attn})
149
+
150
+ # Load LoRA weights
151
+ state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
152
+ pretrained_model_name_or_path_or_dict=UKIYOE_REPO
153
+ )
154
+ LoraLoaderMixin.load_lora_into_unet(state_dict, network_alphas, unet)
155
+ unet.fuse_lora(1.0)
156
+
157
+ # Load other modules
158
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
159
+ JSDXL_REPO,
160
+ subfolder="text_encoder",
161
+ torch_dtype=torch.float16,
162
+ variant="fp16",
163
+ )
164
+ tokenizer = AutoTokenizer.from_pretrained(
165
+ JSDXL_REPO,
166
+ subfolder="tokenizer",
167
+ use_fast=False,
168
+ )
169
+
170
+ pipe = StableDiffusionXLPipeline.from_pretrained(
171
+ SDXL_REPO,
172
+ unet=unet,
173
+ text_encoder=text_encoder,
174
+ tokenizer=tokenizer,
175
+ torch_dtype=torch.float16,
176
+ variant="fp16",
177
+ )
178
+ pipe = pipe.to(device, dtype=torch.float16)
179
+ return pipe