Update evo_nishikie_v1.py
Browse files- evo_nishikie_v1.py +12 -25
evo_nishikie_v1.py
CHANGED
@@ -3,8 +3,8 @@ from io import BytesIO
|
|
3 |
import os
|
4 |
from typing import Dict, List, Union
|
5 |
|
6 |
-
from PIL import Image
|
7 |
-
from controlnet_aux import
|
8 |
from diffusers import (
|
9 |
ControlNetModel,
|
10 |
StableDiffusionXLControlNetPipeline,
|
@@ -17,7 +17,8 @@ import torch
|
|
17 |
from tqdm import tqdm
|
18 |
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
19 |
|
20 |
-
|
|
|
21 |
SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
|
22 |
DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
|
23 |
JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
|
@@ -29,6 +30,9 @@ UKIYOE_REPO = "SakanaAI/Evo-Ukiyoe-v1"
|
|
29 |
# Evo-Nishikie
|
30 |
NISHIKIE_REPO = "SakanaAI/Evo-Nishikie-v1"
|
31 |
|
|
|
|
|
|
|
32 |
|
33 |
def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
|
34 |
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
@@ -124,6 +128,7 @@ def load_evo_nishikie(device="cuda") -> StableDiffusionXLControlNetPipeline:
|
|
124 |
)
|
125 |
jn_weights = split_conv_attn(load_from_pretrained(JN_REPO, device=device))
|
126 |
jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))
|
|
|
127 |
# Merge base models
|
128 |
tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
|
129 |
new_conv = merge_models(
|
@@ -144,11 +149,14 @@ def load_evo_nishikie(device="cuda") -> StableDiffusionXLControlNetPipeline:
|
|
144 |
0.2198623756106564,
|
145 |
],
|
146 |
)
|
|
|
|
|
147 |
del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
|
148 |
gc.collect()
|
149 |
if "cuda" in device:
|
150 |
torch.cuda.empty_cache()
|
151 |
|
|
|
152 |
unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
|
153 |
unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
|
154 |
unet.load_state_dict({**new_conv, **new_attn})
|
@@ -181,26 +189,5 @@ def load_evo_nishikie(device="cuda") -> StableDiffusionXLControlNetPipeline:
|
|
181 |
# Load Evo-Ukiyoe weights
|
182 |
pipe.load_lora_weights(UKIYOE_REPO)
|
183 |
pipe.fuse_lora(lora_scale=1.0)
|
184 |
-
return pipe
|
185 |
-
|
186 |
|
187 |
-
|
188 |
-
url = "https://sakana.ai/assets/nedo-grant/nedo_grant.jpeg"
|
189 |
-
original_image = Image.open(
|
190 |
-
BytesIO(requests.get(url).content)
|
191 |
-
).resize((1024, 1024), Image.Resampling.LANCZOS)
|
192 |
-
canny_detector = CannyDetector()
|
193 |
-
canny_image = canny_detector(original_image, image_resolution=1024)
|
194 |
-
pipe: StableDiffusionXLControlNetPipeline = load_evo_nishikie()
|
195 |
-
images = pipe(
|
196 |
-
prompt="銀杏が色づく。草木が生えた地面と青空の富士山。最高品質の輻の浮世絵。",
|
197 |
-
negative_prompt="暗い。",
|
198 |
-
image=canny_image,
|
199 |
-
guidance_scale=8.0,
|
200 |
-
controlnet_conditioning_scale=0.6,
|
201 |
-
num_inference_steps=50,
|
202 |
-
generator=torch.Generator().manual_seed(0),
|
203 |
-
num_images_per_prompt=1,
|
204 |
-
output_type="pil",
|
205 |
-
).images
|
206 |
-
images[0].save("out.png")
|
|
|
3 |
import os
|
4 |
from typing import Dict, List, Union
|
5 |
|
6 |
+
from PIL import Image, ImageFilter
|
7 |
+
from controlnet_aux import LineartDetector
|
8 |
from diffusers import (
|
9 |
ControlNetModel,
|
10 |
StableDiffusionXLControlNetPipeline,
|
|
|
17 |
from tqdm import tqdm
|
18 |
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
19 |
|
20 |
+
|
21 |
+
# Base models
|
22 |
SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
|
23 |
DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
|
24 |
JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
|
|
|
30 |
# Evo-Nishikie
|
31 |
NISHIKIE_REPO = "SakanaAI/Evo-Nishikie-v1"
|
32 |
|
33 |
+
# Threshold for image binarization
|
34 |
+
BINARY_THRESHOLD = 40
|
35 |
+
|
36 |
|
37 |
def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
|
38 |
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
|
|
128 |
)
|
129 |
jn_weights = split_conv_attn(load_from_pretrained(JN_REPO, device=device))
|
130 |
jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))
|
131 |
+
|
132 |
# Merge base models
|
133 |
tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
|
134 |
new_conv = merge_models(
|
|
|
149 |
0.2198623756106564,
|
150 |
],
|
151 |
)
|
152 |
+
|
153 |
+
# Delete no longer needed variables to free
|
154 |
del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
|
155 |
gc.collect()
|
156 |
if "cuda" in device:
|
157 |
torch.cuda.empty_cache()
|
158 |
|
159 |
+
# Instantiate UNet
|
160 |
unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
|
161 |
unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
|
162 |
unet.load_state_dict({**new_conv, **new_attn})
|
|
|
189 |
# Load Evo-Ukiyoe weights
|
190 |
pipe.load_lora_weights(UKIYOE_REPO)
|
191 |
pipe.fuse_lora(lora_scale=1.0)
|
|
|
|
|
192 |
|
193 |
+
return pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|