yuki-imajuku commited on
Commit
f8c0350
1 Parent(s): 03a7b59

Rename evoukiyoe_v1.py to evo_ukiyoe_v1.py

Browse files
Files changed (1) hide show
  1. evoukiyoe_v1.py → evo_ukiyoe_v1.py +27 -20
evoukiyoe_v1.py → evo_ukiyoe_v1.py RENAMED
@@ -6,22 +6,20 @@ from diffusers import (
6
  StableDiffusionXLPipeline,
7
  UNet2DConditionModel,
8
  )
9
- from diffusers.loaders import LoraLoaderMixin
10
  from huggingface_hub import hf_hub_download
11
  import safetensors
12
  import torch
13
  from tqdm import tqdm
14
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
15
 
16
-
17
  # Base models (fine-tuned from SDXL-1.0)
18
  SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
19
  DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
20
  JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
21
  JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
22
 
23
- # Evo-Ukiyoe
24
- UKIYOE_REPO = "SakanaAI/Evo-Ukiyoe-v1"
25
 
26
 
27
  def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
@@ -108,7 +106,7 @@ def split_conv_attn(weights):
108
  return {"conv": conv_tensors, "attn": attn_tensors}
109
 
110
 
111
- def load_evoukiyoe(device="cuda") -> StableDiffusionXLPipeline:
112
  # Load base models
113
  sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
114
  dpo_weights = split_conv_attn(
@@ -147,26 +145,15 @@ def load_evoukiyoe(device="cuda") -> StableDiffusionXLPipeline:
147
  unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
148
  unet.load_state_dict({**new_conv, **new_attn})
149
 
150
- # Load LoRA weights
151
- state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
152
- pretrained_model_name_or_path_or_dict=UKIYOE_REPO
153
- )
154
- LoraLoaderMixin.load_lora_into_unet(state_dict, network_alphas, unet)
155
- unet.fuse_lora(1.0)
156
-
157
  # Load other modules
158
  text_encoder = CLIPTextModelWithProjection.from_pretrained(
159
- JSDXL_REPO,
160
- subfolder="text_encoder",
161
- torch_dtype=torch.float16,
162
- variant="fp16",
163
  )
164
  tokenizer = AutoTokenizer.from_pretrained(
165
- JSDXL_REPO,
166
- subfolder="tokenizer",
167
- use_fast=False,
168
  )
169
 
 
170
  pipe = StableDiffusionXLPipeline.from_pretrained(
171
  SDXL_REPO,
172
  unet=unet,
@@ -176,4 +163,24 @@ def load_evoukiyoe(device="cuda") -> StableDiffusionXLPipeline:
176
  variant="fp16",
177
  )
178
  pipe = pipe.to(device, dtype=torch.float16)
179
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  StableDiffusionXLPipeline,
7
  UNet2DConditionModel,
8
  )
 
9
  from huggingface_hub import hf_hub_download
10
  import safetensors
11
  import torch
12
  from tqdm import tqdm
13
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
14
 
 
15
  # Base models (fine-tuned from SDXL-1.0)
16
  SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
17
  DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
18
  JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
19
  JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
20
 
21
+ # LoRA weights
22
+ LORA_REPO = "SakanaAI/Evo-Ukiyoe-v1"
23
 
24
 
25
  def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
 
106
  return {"conv": conv_tensors, "attn": attn_tensors}
107
 
108
 
109
+ def load_evo_ukiyoe(device="cuda") -> StableDiffusionXLPipeline:
110
  # Load base models
111
  sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
112
  dpo_weights = split_conv_attn(
 
145
  unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
146
  unet.load_state_dict({**new_conv, **new_attn})
147
 
 
 
 
 
 
 
 
148
  # Load other modules
149
  text_encoder = CLIPTextModelWithProjection.from_pretrained(
150
+ JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16",
 
 
 
151
  )
152
  tokenizer = AutoTokenizer.from_pretrained(
153
+ JSDXL_REPO, subfolder="tokenizer", use_fast=False,
 
 
154
  )
155
 
156
+ # Load pipeline
157
  pipe = StableDiffusionXLPipeline.from_pretrained(
158
  SDXL_REPO,
159
  unet=unet,
 
163
  variant="fp16",
164
  )
165
  pipe = pipe.to(device, dtype=torch.float16)
166
+
167
+ # Load LoRA module
168
+ pipe.load_lora_weights(LORA_REPO)
169
+ pipe.fuse_lora(lora_scale=1.0)
170
+ return pipe
171
+
172
+
173
+ if __name__ == "__main__":
174
+ pipe: StableDiffusionXLPipeline = load_evo_ukiyoe()
175
+ images = pipe(
176
+ prompt="鶴が庭に立っている。雪が降っている。最高品質の輻の浮世絵。",
177
+ negative_prompt="",
178
+ width=1024,
179
+ height=1024,
180
+ guidance_scale=8.0,
181
+ num_inference_steps=50,
182
+ generator=torch.Generator().manual_seed(0),
183
+ num_images_per_prompt=1,
184
+ output_type="pil",
185
+ ).images
186
+ images[0].save("out.png")