yuki-imajuku
commited on
Commit
•
f8c0350
1
Parent(s):
03a7b59
Rename evoukiyoe_v1.py to evo_ukiyoe_v1.py
Browse files
evoukiyoe_v1.py → evo_ukiyoe_v1.py
RENAMED
@@ -6,22 +6,20 @@ from diffusers import (
|
|
6 |
StableDiffusionXLPipeline,
|
7 |
UNet2DConditionModel,
|
8 |
)
|
9 |
-
from diffusers.loaders import LoraLoaderMixin
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import safetensors
|
12 |
import torch
|
13 |
from tqdm import tqdm
|
14 |
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
15 |
|
16 |
-
|
17 |
# Base models (fine-tuned from SDXL-1.0)
|
18 |
SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
|
19 |
DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
|
20 |
JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
|
21 |
JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
|
26 |
|
27 |
def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
|
@@ -108,7 +106,7 @@ def split_conv_attn(weights):
|
|
108 |
return {"conv": conv_tensors, "attn": attn_tensors}
|
109 |
|
110 |
|
111 |
-
def
|
112 |
# Load base models
|
113 |
sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
|
114 |
dpo_weights = split_conv_attn(
|
@@ -147,26 +145,15 @@ def load_evoukiyoe(device="cuda") -> StableDiffusionXLPipeline:
|
|
147 |
unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
|
148 |
unet.load_state_dict({**new_conv, **new_attn})
|
149 |
|
150 |
-
# Load LoRA weights
|
151 |
-
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
|
152 |
-
pretrained_model_name_or_path_or_dict=UKIYOE_REPO
|
153 |
-
)
|
154 |
-
LoraLoaderMixin.load_lora_into_unet(state_dict, network_alphas, unet)
|
155 |
-
unet.fuse_lora(1.0)
|
156 |
-
|
157 |
# Load other modules
|
158 |
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
159 |
-
JSDXL_REPO,
|
160 |
-
subfolder="text_encoder",
|
161 |
-
torch_dtype=torch.float16,
|
162 |
-
variant="fp16",
|
163 |
)
|
164 |
tokenizer = AutoTokenizer.from_pretrained(
|
165 |
-
JSDXL_REPO,
|
166 |
-
subfolder="tokenizer",
|
167 |
-
use_fast=False,
|
168 |
)
|
169 |
|
|
|
170 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
171 |
SDXL_REPO,
|
172 |
unet=unet,
|
@@ -176,4 +163,24 @@ def load_evoukiyoe(device="cuda") -> StableDiffusionXLPipeline:
|
|
176 |
variant="fp16",
|
177 |
)
|
178 |
pipe = pipe.to(device, dtype=torch.float16)
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
StableDiffusionXLPipeline,
|
7 |
UNet2DConditionModel,
|
8 |
)
|
|
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
import safetensors
|
11 |
import torch
|
12 |
from tqdm import tqdm
|
13 |
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
14 |
|
|
|
15 |
# Base models (fine-tuned from SDXL-1.0)
|
16 |
SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
|
17 |
DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
|
18 |
JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
|
19 |
JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
|
20 |
|
21 |
+
# LoRA weights
|
22 |
+
LORA_REPO = "SakanaAI/Evo-Ukiyoe-v1"
|
23 |
|
24 |
|
25 |
def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
|
|
|
106 |
return {"conv": conv_tensors, "attn": attn_tensors}
|
107 |
|
108 |
|
109 |
+
def load_evo_ukiyoe(device="cuda") -> StableDiffusionXLPipeline:
|
110 |
# Load base models
|
111 |
sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
|
112 |
dpo_weights = split_conv_attn(
|
|
|
145 |
unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
|
146 |
unet.load_state_dict({**new_conv, **new_attn})
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
# Load other modules
|
149 |
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
150 |
+
JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16",
|
|
|
|
|
|
|
151 |
)
|
152 |
tokenizer = AutoTokenizer.from_pretrained(
|
153 |
+
JSDXL_REPO, subfolder="tokenizer", use_fast=False,
|
|
|
|
|
154 |
)
|
155 |
|
156 |
+
# Load pipeline
|
157 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
158 |
SDXL_REPO,
|
159 |
unet=unet,
|
|
|
163 |
variant="fp16",
|
164 |
)
|
165 |
pipe = pipe.to(device, dtype=torch.float16)
|
166 |
+
|
167 |
+
# Load LoRA module
|
168 |
+
pipe.load_lora_weights(LORA_REPO)
|
169 |
+
pipe.fuse_lora(lora_scale=1.0)
|
170 |
+
return pipe
|
171 |
+
|
172 |
+
|
173 |
+
if __name__ == "__main__":
|
174 |
+
pipe: StableDiffusionXLPipeline = load_evo_ukiyoe()
|
175 |
+
images = pipe(
|
176 |
+
prompt="鶴が庭に立っている。雪が降っている。最高品質の輻の浮世絵。",
|
177 |
+
negative_prompt="",
|
178 |
+
width=1024,
|
179 |
+
height=1024,
|
180 |
+
guidance_scale=8.0,
|
181 |
+
num_inference_steps=50,
|
182 |
+
generator=torch.Generator().manual_seed(0),
|
183 |
+
num_images_per_prompt=1,
|
184 |
+
output_type="pil",
|
185 |
+
).images
|
186 |
+
images[0].save("out.png")
|