mkshing commited on
Commit
7277a01
1 Parent(s): 7d23de5

Update evo_nishikie_v1.py

Browse files
Files changed (1) hide show
  1. evo_nishikie_v1.py +12 -25
evo_nishikie_v1.py CHANGED
@@ -3,8 +3,8 @@ from io import BytesIO
3
  import os
4
  from typing import Dict, List, Union
5
 
6
- from PIL import Image
7
- from controlnet_aux import CannyDetector
8
  from diffusers import (
9
  ControlNetModel,
10
  StableDiffusionXLControlNetPipeline,
@@ -17,7 +17,8 @@ import torch
17
  from tqdm import tqdm
18
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
19
 
20
- # Base models (fine-tuned from SDXL-1.0)
 
21
  SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
22
  DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
23
  JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
@@ -29,6 +30,9 @@ UKIYOE_REPO = "SakanaAI/Evo-Ukiyoe-v1"
29
  # Evo-Nishikie
30
  NISHIKIE_REPO = "SakanaAI/Evo-Nishikie-v1"
31
 
 
 
 
32
 
33
  def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
34
  file_extension = os.path.basename(checkpoint_file).split(".")[-1]
@@ -124,6 +128,7 @@ def load_evo_nishikie(device="cuda") -> StableDiffusionXLControlNetPipeline:
124
  )
125
  jn_weights = split_conv_attn(load_from_pretrained(JN_REPO, device=device))
126
  jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))
 
127
  # Merge base models
128
  tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
129
  new_conv = merge_models(
@@ -144,11 +149,14 @@ def load_evo_nishikie(device="cuda") -> StableDiffusionXLControlNetPipeline:
144
  0.2198623756106564,
145
  ],
146
  )
 
 
147
  del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
148
  gc.collect()
149
  if "cuda" in device:
150
  torch.cuda.empty_cache()
151
 
 
152
  unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
153
  unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
154
  unet.load_state_dict({**new_conv, **new_attn})
@@ -181,26 +189,5 @@ def load_evo_nishikie(device="cuda") -> StableDiffusionXLControlNetPipeline:
181
  # Load Evo-Ukiyoe weights
182
  pipe.load_lora_weights(UKIYOE_REPO)
183
  pipe.fuse_lora(lora_scale=1.0)
184
- return pipe
185
-
186
 
187
- if __name__ == "__main__":
188
- url = "https://sakana.ai/assets/nedo-grant/nedo_grant.jpeg"
189
- original_image = Image.open(
190
- BytesIO(requests.get(url).content)
191
- ).resize((1024, 1024), Image.Resampling.LANCZOS)
192
- canny_detector = CannyDetector()
193
- canny_image = canny_detector(original_image, image_resolution=1024)
194
- pipe: StableDiffusionXLControlNetPipeline = load_evo_nishikie()
195
- images = pipe(
196
- prompt="銀杏が色づく。草木が生えた地面と青空の富士山。最高品質の輻の浮世絵。",
197
- negative_prompt="暗い。",
198
- image=canny_image,
199
- guidance_scale=8.0,
200
- controlnet_conditioning_scale=0.6,
201
- num_inference_steps=50,
202
- generator=torch.Generator().manual_seed(0),
203
- num_images_per_prompt=1,
204
- output_type="pil",
205
- ).images
206
- images[0].save("out.png")
 
3
  import os
4
  from typing import Dict, List, Union
5
 
6
+ from PIL import Image, ImageFilter
7
+ from controlnet_aux import LineartDetector
8
  from diffusers import (
9
  ControlNetModel,
10
  StableDiffusionXLControlNetPipeline,
 
17
  from tqdm import tqdm
18
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
19
 
20
+
21
+ # Base models
22
  SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
23
  DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
24
  JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
 
30
  # Evo-Nishikie
31
  NISHIKIE_REPO = "SakanaAI/Evo-Nishikie-v1"
32
 
33
+ # Threshold for image binarization
34
+ BINARY_THRESHOLD = 40
35
+
36
 
37
  def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
38
  file_extension = os.path.basename(checkpoint_file).split(".")[-1]
 
128
  )
129
  jn_weights = split_conv_attn(load_from_pretrained(JN_REPO, device=device))
130
  jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))
131
+
132
  # Merge base models
133
  tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
134
  new_conv = merge_models(
 
149
  0.2198623756106564,
150
  ],
151
  )
152
+
153
+ # Delete no longer needed variables to free
154
  del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
155
  gc.collect()
156
  if "cuda" in device:
157
  torch.cuda.empty_cache()
158
 
159
+ # Instantiate UNet
160
  unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
161
  unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
162
  unet.load_state_dict({**new_conv, **new_attn})
 
189
  # Load Evo-Ukiyoe weights
190
  pipe.load_lora_weights(UKIYOE_REPO)
191
  pipe.fuse_lora(lora_scale=1.0)
 
 
192
 
193
+ return pipe