John6666 commited on
Commit
e576798
1 Parent(s): c200633

Upload 12 files

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Convert diffusers SDXL repo to single Safetensors
3
  emoji: 🐶
4
  colorFrom: yellow
5
  colorTo: red
@@ -7,6 +7,7 @@ sdk: gradio
7
  sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Convert HF Diffusers repo to single safetensors file V2 (for SDXL / SD 1.5 / LoRA)
3
  emoji: 🐶
4
  colorFrom: yellow
5
  colorTo: red
 
7
  sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,30 +1,47 @@
1
  import gradio as gr
2
  import os
3
- from convert_repo_to_safetensors_gr import convert_repo_to_safetensors_multi
4
  os.environ['HF_OUTPUT_REPO'] = 'John6666/safetensors_converting_test'
5
 
6
- css = """"""
 
 
 
7
 
8
- with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
9
- gr.Markdown(
10
- f"""
11
- - [A CLI version of this tool is available here](https://huggingface.co/spaces/John6666/convert_repo_to_safetensors/tree/main/local).
12
- """)
13
  with gr.Column():
14
  repo_id = gr.Textbox(label="Repo ID", placeholder="author/model", value="", lines=1)
15
- is_half = gr.Checkbox(label="Half precision", value=True)
16
  is_upload = gr.Checkbox(label="Upload safetensors to HF Repo", info="Fast download, but files will be public.", value=False)
17
- uploaded_urls = gr.CheckboxGroup(visible=False, choices=[], value=None)
 
 
 
 
 
 
 
 
 
 
18
  run_button = gr.Button(value="Convert")
19
  st_file = gr.Files(label="Output", interactive=False)
20
  st_md = gr.Markdown()
 
 
 
 
 
 
 
21
 
22
  gr.on(
23
  triggers=[repo_id.submit, run_button.click],
24
  fn=convert_repo_to_safetensors_multi,
25
- inputs=[repo_id, st_file, is_upload, uploaded_urls, is_half],
26
  outputs=[st_file, uploaded_urls, st_md],
27
  )
 
28
 
29
  demo.queue()
30
  demo.launch()
 
1
  import gradio as gr
2
  import os
3
+ from convert_repo_to_safetensors_gr import convert_repo_to_safetensors_multi, clear_safetensors
4
  os.environ['HF_OUTPUT_REPO'] = 'John6666/safetensors_converting_test'
5
 
6
+ css = """
7
+ .title { text-align: center; !important; }
8
+ .footer { text-align: center; !important; }
9
+ """
10
 
11
+ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
12
+ gr.Markdown("# HF Diffusers repo to WebUI/ComfyUI single safetensors file converter (for SDXL / SD 1.5 / LoRA)", elem_classes="title")
 
 
 
13
  with gr.Column():
14
  repo_id = gr.Textbox(label="Repo ID", placeholder="author/model", value="", lines=1)
 
15
  is_upload = gr.Checkbox(label="Upload safetensors to HF Repo", info="Fast download, but files will be public.", value=False)
16
+ with gr.Accordion("Advanced", open=False):
17
+ dtype = gr.Radio(label="Output data type", choices=["fp16", "fp32", "bf16", "fp8", "default"], value="fp16")
18
+ with gr.Accordion("Upload to your repo", open=True):
19
+ with gr.Row():
20
+ hf_token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
21
+ gr.Markdown("Your token is available at [hf.co/settings/tokens](https://huggingface.co/settings/tokens).")
22
+ with gr.Row():
23
+ newrepo_id = gr.Textbox(label="Upload repo ID", placeholder="yourid/newrepo", value="", max_lines=1)
24
+ newrepo_type = gr.Radio(label="Upload repo type", choices=["model", "dataset"], value="model")
25
+ is_private = gr.Checkbox(label="Create private repo", value=True)
26
+ uploaded_urls = gr.CheckboxGroup(visible=False, choices=[], value=None) # hidden
27
  run_button = gr.Button(value="Convert")
28
  st_file = gr.Files(label="Output", interactive=False)
29
  st_md = gr.Markdown()
30
+ delete_button = gr.Button(value="Delete Safetensors")
31
+ gr.DuplicateButton(value="Duplicate Space")
32
+ gr.Markdown(
33
+ f"""
34
+ - Thanks to [xi0v](https://huggingface.co/xi0v)
35
+ - [A CLI version of this tool is available here](https://huggingface.co/spaces/John6666/convert_repo_to_safetensors/tree/main/local).
36
+ """, elem_classes="footer")
37
 
38
  gr.on(
39
  triggers=[repo_id.submit, run_button.click],
40
  fn=convert_repo_to_safetensors_multi,
41
+ inputs=[repo_id, hf_token, st_file, uploaded_urls, dtype, is_upload, newrepo_id, newrepo_type, is_private],
42
  outputs=[st_file, uploaded_urls, st_md],
43
  )
44
+ delete_button.click(clear_safetensors, None, [st_file], queue=False, show_api=False)
45
 
46
  demo.queue()
47
  demo.launch()
convert_repo_to_safetensors_gr.py CHANGED
@@ -10,6 +10,15 @@ import torch
10
  from safetensors.torch import load_file, save_file
11
  import gradio as gr
12
 
 
 
 
 
 
 
 
 
 
13
  # =================#
14
  # UNet Conversion #
15
  # =================#
@@ -269,8 +278,7 @@ def convert_openai_text_enc_state_dict(text_enc_dict):
269
  return text_enc_dict
270
 
271
 
272
- def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, progress=gr.Progress(track_tqdm=True)):
273
- progress(0, desc="Start converting...")
274
  # Path for safetensors
275
  unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
276
  vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
@@ -326,74 +334,109 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, p
326
  # Put together new checkpoint
327
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
328
 
329
- if half:
330
- state_dict = {k: v.half() for k, v in state_dict.items()}
 
 
331
 
332
  save_file(state_dict, checkpoint_path)
333
- progress(1, desc="Converted.")
334
 
335
 
 
336
  def download_repo(repo_id, dir_path, progress=gr.Progress(track_tqdm=True)):
337
- from huggingface_hub import snapshot_download
338
  try:
339
- snapshot_download(repo_id=repo_id, local_dir=dir_path)
 
340
  except Exception as e:
341
  print(f"Error: Failed to download {repo_id}. {e}")
 
342
  return
343
 
344
 
345
- def upload_safetensors_to_repo(filename, progress=gr.Progress(track_tqdm=True)):
346
- from huggingface_hub import HfApi, hf_hub_url
347
- import os
348
- from pathlib import Path
349
  output_filename = Path(filename).name
350
- hf_token = os.environ.get("HF_TOKEN")
351
- repo_id = os.environ.get("HF_OUTPUT_REPO")
352
- api = HfApi()
353
  try:
 
354
  progress(0, desc="Start uploading...")
355
- api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_id=repo_id, token=hf_token)
356
  progress(1, desc="Uploaded.")
357
- url = hf_hub_url(repo_id=repo_id, filename=output_filename)
358
  except Exception as e:
359
- print(f"Error: Failed to upload to {repo_id}. ")
 
360
  return None
361
  return url
362
 
363
 
364
- def convert_repo_to_safetensors(repo_id, half=True, progress=gr.Progress(track_tqdm=True)):
365
  download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
366
  output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
 
367
  download_repo(repo_id, download_dir)
368
- convert_diffusers_to_safetensors(download_dir, output_filename, half)
 
 
 
369
  return output_filename
370
 
371
 
372
- def convert_repo_to_safetensors_multi(repo_id, files, is_upload, urls, half=True, progress=gr.Progress(track_tqdm=True)):
373
- file = convert_repo_to_safetensors(repo_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  if not urls: urls = []
 
375
  url = ""
376
  if is_upload:
377
- url = upload_safetensors_to_repo(file, half)
378
- if url: urls.append(url)
 
 
 
 
 
 
379
  md = ""
380
  for u in urls:
381
  md += f"[Download {str(u).split('/')[-1]}]({str(u)})<br>"
382
- if not files: files = []
383
- files.append(file)
384
  return gr.update(value=files), gr.update(value=urls, choices=urls), gr.update(value=md)
385
 
386
 
 
 
 
 
 
 
 
 
387
  if __name__ == "__main__":
388
  parser = argparse.ArgumentParser()
389
 
390
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
391
- parser.add_argument("--half", default=True, help="Save weights in half precision.")
392
 
393
  args = parser.parse_args()
394
  assert args.repo_id is not None, "Must provide a Repo ID!"
395
 
396
- convert_repo_to_safetensors(args.repo_id, args.half)
397
 
398
 
399
  # Usage: python convert_repo_to_safetensors.py --repo_id GraydientPlatformAPI/goodfit-pony41-xl
 
10
  from safetensors.torch import load_file, save_file
11
  import gradio as gr
12
 
13
+ from huggingface_hub import HfApi, HfFolder, hf_hub_url, snapshot_download
14
+ import os
15
+ from pathlib import Path
16
+ import shutil
17
+ import gc
18
+ from utils import get_token, set_token, is_repo_exists, get_model_type
19
+ from convert_repo_to_safetensors_sd_gr import convert_repo_to_safetensors as convert_repo_to_safetensors_sd
20
+ from convert_repo_to_safetensors_sdxl_lora_gr import convert_repo_to_safetensors_sdxl_lora
21
+
22
  # =================#
23
  # UNet Conversion #
24
  # =================#
 
278
  return text_enc_dict
279
 
280
 
281
+ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
 
282
  # Path for safetensors
283
  unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
284
  vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
 
334
  # Put together new checkpoint
335
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
336
 
337
+ if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
338
+ elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
339
+ elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
340
+ elif dtype == "fp8": state_dict = {k: v.to(torch.float8_e4m3fn) for k, v in state_dict.items()}
341
 
342
  save_file(state_dict, checkpoint_path)
 
343
 
344
 
345
+ # https://huggingface.co/docs/huggingface_hub/v0.25.1/en/package_reference/file_download#huggingface_hub.snapshot_download
346
  def download_repo(repo_id, dir_path, progress=gr.Progress(track_tqdm=True)):
347
+ hf_token = get_token()
348
  try:
349
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"],
350
+ ignore_patterns=["*.fp16.*", "/*.safetensors", "/*.bin"], force_download=True)
351
  except Exception as e:
352
  print(f"Error: Failed to download {repo_id}. {e}")
353
+ gr.Warning(f"Error: Failed to download {repo_id}. {e}")
354
  return
355
 
356
 
357
+ def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
 
 
 
358
  output_filename = Path(filename).name
359
+ hf_token = get_token()
360
+ api = HfApi(token=hf_token)
 
361
  try:
362
+ if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
363
  progress(0, desc="Start uploading...")
364
+ api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
365
  progress(1, desc="Uploaded.")
366
+ url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
367
  except Exception as e:
368
+ print(f"Error: Failed to upload to {repo_id}. {e}")
369
+ gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
370
  return None
371
  return url
372
 
373
 
374
+ def convert_repo_to_safetensors(repo_id, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
375
  download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
376
  output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
377
+ progress(0, desc="Start downloading...")
378
  download_repo(repo_id, download_dir)
379
+ progress(0, desc="Start converting...")
380
+ convert_diffusers_to_safetensors(download_dir, output_filename, dtype)
381
+ progress(1, desc="Converted.")
382
+ shutil.rmtree(download_dir)
383
  return output_filename
384
 
385
 
386
+ def convert_repo_to_safetensors_multi(repo_id, hf_token, files, urls, dtype="fp16", is_upload=False,
387
+ newrepo_id="", repo_type="model", is_private=True, progress=gr.Progress(track_tqdm=True)):
388
+ if hf_token: set_token(hf_token)
389
+ else: set_token(os.environ.get("HF_TOKEN"))
390
+ if is_upload and newrepo_id and not hf_token: raise gr.Error("HF write token is required for this process.")
391
+ if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
392
+ model_type = get_model_type(repo_id)
393
+ if model_type == "SDXL":
394
+ gr.Info(f"Converting {model_type} model.")
395
+ file = convert_repo_to_safetensors(repo_id, dtype)
396
+ elif model_type == "SD 1.5":
397
+ gr.Info(f"Converting {model_type} model.")
398
+ file = convert_repo_to_safetensors_sd(repo_id, dtype)
399
+ elif model_type == "LoRA":
400
+ gr.Info(f"Converting {model_type}.")
401
+ file = convert_repo_to_safetensors_sdxl_lora(repo_id)
402
+ else: raise gr.Error(f"Unsupported model type: {model_type}")
403
  if not urls: urls = []
404
+ if not files: files = []
405
  url = ""
406
  if is_upload:
407
+ url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
408
+ if url:
409
+ urls.append(url)
410
+ Path(file).unlink()
411
+ else: files.append(file)
412
+ else:
413
+ files.append(file)
414
+ progress(1, desc="Processing...")
415
  md = ""
416
  for u in urls:
417
  md += f"[Download {str(u).split('/')[-1]}]({str(u)})<br>"
418
+ gc.collect()
 
419
  return gr.update(value=files), gr.update(value=urls, choices=urls), gr.update(value=md)
420
 
421
 
422
+ def clear_safetensors():
423
+ for p in Path('.').glob('*.safetensors'):
424
+ p.unlink()
425
+ print("Deleted.")
426
+ gc.collect()
427
+ return gr.update(value=[])
428
+
429
+
430
  if __name__ == "__main__":
431
  parser = argparse.ArgumentParser()
432
 
433
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
434
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "fp8", "default"], help='Output data type. (Default: "fp16")')
435
 
436
  args = parser.parse_args()
437
  assert args.repo_id is not None, "Must provide a Repo ID!"
438
 
439
+ convert_repo_to_safetensors(args.repo_id, args.dtype)
440
 
441
 
442
  # Usage: python convert_repo_to_safetensors.py --repo_id GraydientPlatformAPI/goodfit-pony41-xl
convert_repo_to_safetensors_sd_gr.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, VAE, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+
5
+ import argparse
6
+ import os.path as osp
7
+ import re
8
+
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+ import gradio as gr
12
+
13
+ from huggingface_hub import HfApi, HfFolder, hf_hub_url, snapshot_download
14
+ import os
15
+ from pathlib import Path
16
+ import shutil
17
+ import gc
18
+ from utils import get_token, set_token, is_repo_exists
19
+
20
+ # =================#
21
+ # UNet Conversion #
22
+ # =================#
23
+
24
+ unet_conversion_map = [
25
+ # (stable-diffusion, HF Diffusers)
26
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
27
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
28
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
29
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
30
+ ("input_blocks.0.0.weight", "conv_in.weight"),
31
+ ("input_blocks.0.0.bias", "conv_in.bias"),
32
+ ("out.0.weight", "conv_norm_out.weight"),
33
+ ("out.0.bias", "conv_norm_out.bias"),
34
+ ("out.2.weight", "conv_out.weight"),
35
+ ("out.2.bias", "conv_out.bias"),
36
+ ]
37
+
38
+ unet_conversion_map_resnet = [
39
+ # (stable-diffusion, HF Diffusers)
40
+ ("in_layers.0", "norm1"),
41
+ ("in_layers.2", "conv1"),
42
+ ("out_layers.0", "norm2"),
43
+ ("out_layers.3", "conv2"),
44
+ ("emb_layers.1", "time_emb_proj"),
45
+ ("skip_connection", "conv_shortcut"),
46
+ ]
47
+
48
+ unet_conversion_map_layer = []
49
+ # hardcoded number of downblocks and resnets/attentions...
50
+ # would need smarter logic for other networks.
51
+ for i in range(4):
52
+ # loop over downblocks/upblocks
53
+
54
+ for j in range(2):
55
+ # loop over resnets/attentions for downblocks
56
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
57
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
58
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
59
+
60
+ if i < 3:
61
+ # no attention layers in down_blocks.3
62
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
63
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
64
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
65
+
66
+ for j in range(3):
67
+ # loop over resnets/attentions for upblocks
68
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
69
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
70
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
71
+
72
+ if i > 0:
73
+ # no attention layers in up_blocks.0
74
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
75
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
76
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
77
+
78
+ if i < 3:
79
+ # no downsample in down_blocks.3
80
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
81
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
82
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
83
+
84
+ # no upsample in up_blocks.3
85
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
86
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
87
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
88
+
89
+ hf_mid_atn_prefix = "mid_block.attentions.0."
90
+ sd_mid_atn_prefix = "middle_block.1."
91
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
92
+
93
+ for j in range(2):
94
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
95
+ sd_mid_res_prefix = f"middle_block.{2*j}."
96
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
97
+
98
+
99
+ def convert_unet_state_dict(unet_state_dict):
100
+ # buyer beware: this is a *brittle* function,
101
+ # and correct output requires that all of these pieces interact in
102
+ # the exact order in which I have arranged them.
103
+ mapping = {k: k for k in unet_state_dict.keys()}
104
+ for sd_name, hf_name in unet_conversion_map:
105
+ mapping[hf_name] = sd_name
106
+ for k, v in mapping.items():
107
+ if "resnets" in k:
108
+ for sd_part, hf_part in unet_conversion_map_resnet:
109
+ v = v.replace(hf_part, sd_part)
110
+ mapping[k] = v
111
+ for k, v in mapping.items():
112
+ for sd_part, hf_part in unet_conversion_map_layer:
113
+ v = v.replace(hf_part, sd_part)
114
+ mapping[k] = v
115
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
116
+ return new_state_dict
117
+
118
+
119
+ # ================#
120
+ # VAE Conversion #
121
+ # ================#
122
+
123
+ vae_conversion_map = [
124
+ # (stable-diffusion, HF Diffusers)
125
+ ("nin_shortcut", "conv_shortcut"),
126
+ ("norm_out", "conv_norm_out"),
127
+ ("mid.attn_1.", "mid_block.attentions.0."),
128
+ ]
129
+
130
+ for i in range(4):
131
+ # down_blocks have two resnets
132
+ for j in range(2):
133
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
134
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
135
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
136
+
137
+ if i < 3:
138
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
139
+ sd_downsample_prefix = f"down.{i}.downsample."
140
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
141
+
142
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
143
+ sd_upsample_prefix = f"up.{3-i}.upsample."
144
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
145
+
146
+ # up_blocks have three resnets
147
+ # also, up blocks in hf are numbered in reverse from sd
148
+ for j in range(3):
149
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
150
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
151
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
152
+
153
+ # this part accounts for mid blocks in both the encoder and the decoder
154
+ for i in range(2):
155
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
156
+ sd_mid_res_prefix = f"mid.block_{i+1}."
157
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
158
+
159
+
160
+ vae_conversion_map_attn = [
161
+ # (stable-diffusion, HF Diffusers)
162
+ ("norm.", "group_norm."),
163
+ ("q.", "query."),
164
+ ("k.", "key."),
165
+ ("v.", "value."),
166
+ ("proj_out.", "proj_attn."),
167
+ ]
168
+
169
+ # This is probably not the most ideal solution, but it does work.
170
+ vae_extra_conversion_map = [
171
+ ("to_q", "q"),
172
+ ("to_k", "k"),
173
+ ("to_v", "v"),
174
+ ("to_out.0", "proj_out"),
175
+ ]
176
+
177
+
178
+ def reshape_weight_for_sd(w):
179
+ # convert HF linear weights to SD conv2d weights
180
+ if not w.ndim == 1:
181
+ return w.reshape(*w.shape, 1, 1)
182
+ else:
183
+ return w
184
+
185
+
186
+ def convert_vae_state_dict(vae_state_dict):
187
+ mapping = {k: k for k in vae_state_dict.keys()}
188
+ for k, v in mapping.items():
189
+ for sd_part, hf_part in vae_conversion_map:
190
+ v = v.replace(hf_part, sd_part)
191
+ mapping[k] = v
192
+ for k, v in mapping.items():
193
+ if "attentions" in k:
194
+ for sd_part, hf_part in vae_conversion_map_attn:
195
+ v = v.replace(hf_part, sd_part)
196
+ mapping[k] = v
197
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
198
+ weights_to_convert = ["q", "k", "v", "proj_out"]
199
+ keys_to_rename = {}
200
+ for k, v in new_state_dict.items():
201
+ for weight_name in weights_to_convert:
202
+ if f"mid.attn_1.{weight_name}.weight" in k:
203
+ print(f"Reshaping {k} for SD format")
204
+ new_state_dict[k] = reshape_weight_for_sd(v)
205
+ for weight_name, real_weight_name in vae_extra_conversion_map:
206
+ if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
207
+ keys_to_rename[k] = k.replace(weight_name, real_weight_name)
208
+ for k, v in keys_to_rename.items():
209
+ if k in new_state_dict:
210
+ print(f"Renaming {k} to {v}")
211
+ new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
212
+ del new_state_dict[k]
213
+ return new_state_dict
214
+
215
+
216
+ # =========================#
217
+ # Text Encoder Conversion #
218
+ # =========================#
219
+
220
+
221
+ textenc_conversion_lst = [
222
+ # (stable-diffusion, HF Diffusers)
223
+ ("resblocks.", "text_model.encoder.layers."),
224
+ ("ln_1", "layer_norm1"),
225
+ ("ln_2", "layer_norm2"),
226
+ (".c_fc.", ".fc1."),
227
+ (".c_proj.", ".fc2."),
228
+ (".attn", ".self_attn"),
229
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
230
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
231
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
232
+ ]
233
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
234
+ textenc_pattern = re.compile("|".join(protected.keys()))
235
+
236
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
237
+ code2idx = {"q": 0, "k": 1, "v": 2}
238
+
239
+
240
+ def convert_text_enc_state_dict_v20(text_enc_dict):
241
+ new_state_dict = {}
242
+ capture_qkv_weight = {}
243
+ capture_qkv_bias = {}
244
+ for k, v in text_enc_dict.items():
245
+ if (
246
+ k.endswith(".self_attn.q_proj.weight")
247
+ or k.endswith(".self_attn.k_proj.weight")
248
+ or k.endswith(".self_attn.v_proj.weight")
249
+ ):
250
+ k_pre = k[: -len(".q_proj.weight")]
251
+ k_code = k[-len("q_proj.weight")]
252
+ if k_pre not in capture_qkv_weight:
253
+ capture_qkv_weight[k_pre] = [None, None, None]
254
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
255
+ continue
256
+
257
+ if (
258
+ k.endswith(".self_attn.q_proj.bias")
259
+ or k.endswith(".self_attn.k_proj.bias")
260
+ or k.endswith(".self_attn.v_proj.bias")
261
+ ):
262
+ k_pre = k[: -len(".q_proj.bias")]
263
+ k_code = k[-len("q_proj.bias")]
264
+ if k_pre not in capture_qkv_bias:
265
+ capture_qkv_bias[k_pre] = [None, None, None]
266
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
267
+ continue
268
+
269
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
270
+ new_state_dict[relabelled_key] = v
271
+
272
+ for k_pre, tensors in capture_qkv_weight.items():
273
+ if None in tensors:
274
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
275
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
276
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
277
+
278
+ for k_pre, tensors in capture_qkv_bias.items():
279
+ if None in tensors:
280
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
281
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
282
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
283
+
284
+ return new_state_dict
285
+
286
+
287
+ def convert_text_enc_state_dict(text_enc_dict):
288
+ return text_enc_dict
289
+
290
+
291
+ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
292
+ # Path for safetensors
293
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
294
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
295
+ text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
296
+
297
+ # Load models from safetensors if it exists, if it doesn't pytorch
298
+ if osp.exists(unet_path):
299
+ unet_state_dict = load_file(unet_path, device="cpu")
300
+ else:
301
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
302
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
303
+
304
+ if osp.exists(vae_path):
305
+ vae_state_dict = load_file(vae_path, device="cpu")
306
+ else:
307
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
308
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
309
+
310
+ if osp.exists(text_enc_path):
311
+ text_enc_dict = load_file(text_enc_path, device="cpu")
312
+ else:
313
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
314
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
315
+
316
+ # Convert the UNet model
317
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
318
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
319
+
320
+ # Convert the VAE model
321
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
322
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
323
+
324
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
325
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
326
+
327
+ if is_v20_model:
328
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
329
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
330
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
331
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
332
+ else:
333
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
334
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
335
+
336
+ # Put together new checkpoint
337
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
338
+
339
+ if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
340
+ elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
341
+ elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
342
+ elif dtype == "fp8": state_dict = {k: v.to(torch.float8_e4m3fn) for k, v in state_dict.items()}
343
+
344
+ save_file(state_dict, checkpoint_path)
345
+
346
+
347
+ # https://huggingface.co/docs/huggingface_hub/v0.25.1/en/package_reference/file_download#huggingface_hub.snapshot_download
348
+ def download_repo(repo_id, dir_path):
349
+ hf_token = get_token()
350
+ try:
351
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"],
352
+ ignore_patterns=["*.fp16.*", "/*.safetensors", "/*.bin"])
353
+ except Exception as e:
354
+ print(f"Error: Failed to download {repo_id}. {e}")
355
+ gr.Warning(f"Error: Failed to download {repo_id}. {e}")
356
+ return
357
+
358
+
359
+ def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
360
+ output_filename = Path(filename).name
361
+ hf_token = get_token()
362
+ api = HfApi(token=hf_token)
363
+ try:
364
+ if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
365
+ progress(0, desc="Start uploading...")
366
+ api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
367
+ progress(1, desc="Uploaded.")
368
+ url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
369
+ except Exception as e:
370
+ print(f"Error: Failed to upload to {repo_id}. {e}")
371
+ gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
372
+ return None
373
+ return url
374
+
375
+
376
+ def convert_repo_to_safetensors(repo_id, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
377
+ download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
378
+ output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
379
+ progress(0, desc="Start downloading...")
380
+ download_repo(repo_id, download_dir)
381
+ progress(0, desc="Start converting...")
382
+ convert_diffusers_to_safetensors(download_dir, output_filename, dtype)
383
+ progress(1, desc="Converted.")
384
+ shutil.rmtree(download_dir)
385
+ return output_filename
386
+
387
+
388
+ def convert_repo_to_safetensors_multi_sd(repo_id, hf_token, files, urls, dtype="fp16", is_upload=False,
389
+ newrepo_id="", repo_type="model", is_private=True, progress=gr.Progress(track_tqdm=True)):
390
+ if hf_token: set_token(hf_token)
391
+ else: set_token(os.environ.get("HF_TOKEN"))
392
+ if is_upload and newrepo_id and not hf_token: raise gr.Error("HF write token is required for this process.")
393
+ if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
394
+ file = convert_repo_to_safetensors(repo_id, dtype)
395
+ if not urls: urls = []
396
+ url = ""
397
+ if is_upload:
398
+ url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
399
+ if url: urls.append(url)
400
+ progress(1, desc="Processing...")
401
+ md = ""
402
+ for u in urls:
403
+ md += f"[Download {str(u).split('/')[-1]}]({str(u)})<br>"
404
+ if not files: files = []
405
+ files.append(file)
406
+ return gr.update(value=files), gr.update(value=urls, choices=urls), gr.update(value=md)
407
+
408
+
409
+ def clear_safetensors():
410
+ for p in Path('.').glob('*.safetensors'):
411
+ p.unlink()
412
+ print("Deleted.")
413
+ gc.collect()
414
+ return gr.update(value=[])
415
+
416
+
417
+ if __name__ == "__main__":
418
+ parser = argparse.ArgumentParser()
419
+
420
+ parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
421
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "fp8", "default"], help='Output data type. (Default: "fp16")')
422
+
423
+ args = parser.parse_args()
424
+ assert args.repo_id is not None, "Must provide a Repo ID!"
425
+
426
+ convert_repo_to_safetensors(args.repo_id, args.dtype)
convert_repo_to_safetensors_sdxl_lora_gr.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
2
+ # This means that you can input your diffusers-trained LoRAs and
3
+ # Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
4
+
5
+ # To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
6
+ # https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
7
+ # and run the script:
8
+ # python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
9
+ # now you can use corgy.safetensors in your WebUI of choice!
10
+
11
+ # To train your own, here are some diffusers training scripts and utils that you can use and then convert:
12
+ # LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
13
+ # Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
14
+ # - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
15
+ # - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
16
+ # Canonical diffusers training scripts:
17
+ # - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
18
+ # - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
19
+
20
+ import argparse
21
+ import os
22
+
23
+ from safetensors.torch import load_file, save_file
24
+ from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
25
+ from pathlib import Path
26
+ import gradio as gr
27
+
28
+ from huggingface_hub import hf_hub_download, HfApi
29
+ from huggingface_hub import HfApi, HfFolder, hf_hub_url, snapshot_download
30
+ import os
31
+ from pathlib import Path
32
+ import shutil
33
+ import gc
34
+ from utils import get_token, set_token, is_repo_exists, get_model_type
35
+
36
+ def convert_and_save(input_lora, output_lora=None):
37
+ if output_lora is None:
38
+ base_name = os.path.splitext(input_lora)[0]
39
+ output_lora = f"{base_name}_webui.safetensors"
40
+
41
+ diffusers_state_dict = load_file(input_lora)
42
+ try:
43
+ peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
44
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
45
+ except Exception: # skipped
46
+ kohya_state_dict = diffusers_state_dict
47
+ save_file(kohya_state_dict, output_lora)
48
+
49
+
50
+ def download_repo_lora(repo_id, local_file, progress=gr.Progress(track_tqdm=True)):
51
+ hf_token = get_token()
52
+ lora_filename = "pytorch_lora_weights.safetensors"
53
+ lora_path = Path(lora_filename)
54
+ api = HfApi(token=hf_token)
55
+ try:
56
+ if not api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token):
57
+ print(f"Error: This repo isn't diffusers LoRA repo: {repo_id}.")
58
+ return None
59
+ if lora_path.exists():
60
+ print(f"Error: Download file already exists: {lora_filename}.")
61
+ return None
62
+ hf_hub_download(repo_id=repo_id, filename=lora_filename, local_dir=".")
63
+ if lora_path.exists(): lora_path.rename(Path(local_file))
64
+ except Exception as e:
65
+ print(f"Error: Failed to download from {repo_id}. {e}")
66
+ return local_file
67
+
68
+
69
+ def convert_repo_to_safetensors_sdxl_lora(repo_id, progress=gr.Progress(track_tqdm=True)):
70
+ download_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}_diffusers.safetensors"
71
+ output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}_webui.safetensors"
72
+ progress(0, desc="Start downloading...")
73
+ download_repo_lora(repo_id, download_filename)
74
+ progress(0, desc="Start converting...")
75
+ convert_and_save(download_filename, output_filename)
76
+ progress(1, desc="Converted.")
77
+ Path(download_filename).unlink()
78
+ return output_filename
79
+
80
+
81
+ def convert_repo_to_safetensors_sdxl_lora_multi(repo_id, files, progress=gr.Progress(track_tqdm=True)):
82
+ file = convert_repo_to_safetensors_sdxl_lora(repo_id)
83
+ if not files: files = []
84
+ files.append(file)
85
+ return gr.update(value=files)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format from Repo.")
90
+ parser.add_argument("--repo_id", type=str, required=True, help="URL to the Repo of input LoRA model in the diffusers format.")
91
+
92
+ args = parser.parse_args()
93
+
94
+ convert_repo_to_safetensors_sdxl_lora(args.repo_id)
95
+
96
+
97
+ # Usage: python convert_repo_to_safetensors_sdxl_lora.py --repo_id nroggendorff/zelda-lora
local/convert_repo_to_safetensors.py CHANGED
@@ -269,7 +269,7 @@ def convert_openai_text_enc_state_dict(text_enc_dict):
269
  return text_enc_dict
270
 
271
 
272
- def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True):
273
  # Path for safetensors
274
  unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
275
  vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
@@ -325,8 +325,9 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True):
325
  # Put together new checkpoint
326
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
327
 
328
- if half:
329
- state_dict = {k: v.half() for k, v in state_dict.items()}
 
330
 
331
  save_file(state_dict, checkpoint_path)
332
 
@@ -336,15 +337,15 @@ def download_repo(repo_id, dir_path):
336
  try:
337
  snapshot_download(repo_id=repo_id, local_dir=dir_path)
338
  except Exception as e:
339
- print(f"Error: Failed to download {repo_id}. ")
340
  return
341
 
342
 
343
- def convert_repo_to_safetensors(repo_id, half=True):
344
  download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
345
  output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
346
  download_repo(repo_id, download_dir)
347
- convert_diffusers_to_safetensors(download_dir, output_filename, half)
348
  return output_filename
349
 
350
 
@@ -352,12 +353,12 @@ if __name__ == "__main__":
352
  parser = argparse.ArgumentParser()
353
 
354
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
355
- parser.add_argument("--half", default=True, help="Save weights in half precision.")
356
 
357
  args = parser.parse_args()
358
  assert args.repo_id is not None, "Must provide a Repo ID!"
359
 
360
- convert_repo_to_safetensors(args.repo_id, args.half)
361
 
362
 
363
  # Usage: python convert_repo_to_safetensors.py --repo_id GraydientPlatformAPI/goodfit-pony41-xl
 
269
  return text_enc_dict
270
 
271
 
272
+ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16"):
273
  # Path for safetensors
274
  unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
275
  vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
 
325
  # Put together new checkpoint
326
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
327
 
328
+ if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
329
+ elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
330
+ elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
331
 
332
  save_file(state_dict, checkpoint_path)
333
 
 
337
  try:
338
  snapshot_download(repo_id=repo_id, local_dir=dir_path)
339
  except Exception as e:
340
+ print(f"Error: Failed to download {repo_id}. {e}")
341
  return
342
 
343
 
344
+ def convert_repo_to_safetensors(repo_id, dtype="fp16"):
345
  download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
346
  output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
347
  download_repo(repo_id, download_dir)
348
+ convert_diffusers_to_safetensors(download_dir, output_filename, dtype)
349
  return output_filename
350
 
351
 
 
353
  parser = argparse.ArgumentParser()
354
 
355
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
356
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "default"], help='Output data type. (Default: "fp16")')
357
 
358
  args = parser.parse_args()
359
  assert args.repo_id is not None, "Must provide a Repo ID!"
360
 
361
+ convert_repo_to_safetensors(args.repo_id, args.dtype)
362
 
363
 
364
  # Usage: python convert_repo_to_safetensors.py --repo_id GraydientPlatformAPI/goodfit-pony41-xl
local/convert_repo_to_safetensors_sd.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, VAE, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+
5
+ import argparse
6
+ import os.path as osp
7
+ import re
8
+
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+
12
+
13
+ # =================#
14
+ # UNet Conversion #
15
+ # =================#
16
+
17
+ unet_conversion_map = [
18
+ # (stable-diffusion, HF Diffusers)
19
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
20
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
21
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
22
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
23
+ ("input_blocks.0.0.weight", "conv_in.weight"),
24
+ ("input_blocks.0.0.bias", "conv_in.bias"),
25
+ ("out.0.weight", "conv_norm_out.weight"),
26
+ ("out.0.bias", "conv_norm_out.bias"),
27
+ ("out.2.weight", "conv_out.weight"),
28
+ ("out.2.bias", "conv_out.bias"),
29
+ ]
30
+
31
+ unet_conversion_map_resnet = [
32
+ # (stable-diffusion, HF Diffusers)
33
+ ("in_layers.0", "norm1"),
34
+ ("in_layers.2", "conv1"),
35
+ ("out_layers.0", "norm2"),
36
+ ("out_layers.3", "conv2"),
37
+ ("emb_layers.1", "time_emb_proj"),
38
+ ("skip_connection", "conv_shortcut"),
39
+ ]
40
+
41
+ unet_conversion_map_layer = []
42
+ # hardcoded number of downblocks and resnets/attentions...
43
+ # would need smarter logic for other networks.
44
+ for i in range(4):
45
+ # loop over downblocks/upblocks
46
+
47
+ for j in range(2):
48
+ # loop over resnets/attentions for downblocks
49
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
50
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
51
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
52
+
53
+ if i < 3:
54
+ # no attention layers in down_blocks.3
55
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
56
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
57
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
58
+
59
+ for j in range(3):
60
+ # loop over resnets/attentions for upblocks
61
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
62
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
63
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
64
+
65
+ if i > 0:
66
+ # no attention layers in up_blocks.0
67
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
68
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
69
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
70
+
71
+ if i < 3:
72
+ # no downsample in down_blocks.3
73
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
74
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
75
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
76
+
77
+ # no upsample in up_blocks.3
78
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
79
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
80
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
81
+
82
+ hf_mid_atn_prefix = "mid_block.attentions.0."
83
+ sd_mid_atn_prefix = "middle_block.1."
84
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
85
+
86
+ for j in range(2):
87
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
88
+ sd_mid_res_prefix = f"middle_block.{2*j}."
89
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
90
+
91
+
92
+ def convert_unet_state_dict(unet_state_dict):
93
+ # buyer beware: this is a *brittle* function,
94
+ # and correct output requires that all of these pieces interact in
95
+ # the exact order in which I have arranged them.
96
+ mapping = {k: k for k in unet_state_dict.keys()}
97
+ for sd_name, hf_name in unet_conversion_map:
98
+ mapping[hf_name] = sd_name
99
+ for k, v in mapping.items():
100
+ if "resnets" in k:
101
+ for sd_part, hf_part in unet_conversion_map_resnet:
102
+ v = v.replace(hf_part, sd_part)
103
+ mapping[k] = v
104
+ for k, v in mapping.items():
105
+ for sd_part, hf_part in unet_conversion_map_layer:
106
+ v = v.replace(hf_part, sd_part)
107
+ mapping[k] = v
108
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
109
+ return new_state_dict
110
+
111
+
112
+ # ================#
113
+ # VAE Conversion #
114
+ # ================#
115
+
116
+ vae_conversion_map = [
117
+ # (stable-diffusion, HF Diffusers)
118
+ ("nin_shortcut", "conv_shortcut"),
119
+ ("norm_out", "conv_norm_out"),
120
+ ("mid.attn_1.", "mid_block.attentions.0."),
121
+ ]
122
+
123
+ for i in range(4):
124
+ # down_blocks have two resnets
125
+ for j in range(2):
126
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
127
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
128
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
129
+
130
+ if i < 3:
131
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
132
+ sd_downsample_prefix = f"down.{i}.downsample."
133
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
134
+
135
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
136
+ sd_upsample_prefix = f"up.{3-i}.upsample."
137
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
138
+
139
+ # up_blocks have three resnets
140
+ # also, up blocks in hf are numbered in reverse from sd
141
+ for j in range(3):
142
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
143
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
144
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
145
+
146
+ # this part accounts for mid blocks in both the encoder and the decoder
147
+ for i in range(2):
148
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
149
+ sd_mid_res_prefix = f"mid.block_{i+1}."
150
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
151
+
152
+
153
+ vae_conversion_map_attn = [
154
+ # (stable-diffusion, HF Diffusers)
155
+ ("norm.", "group_norm."),
156
+ ("q.", "query."),
157
+ ("k.", "key."),
158
+ ("v.", "value."),
159
+ ("proj_out.", "proj_attn."),
160
+ ]
161
+
162
+ # This is probably not the most ideal solution, but it does work.
163
+ vae_extra_conversion_map = [
164
+ ("to_q", "q"),
165
+ ("to_k", "k"),
166
+ ("to_v", "v"),
167
+ ("to_out.0", "proj_out"),
168
+ ]
169
+
170
+
171
+ def reshape_weight_for_sd(w):
172
+ # convert HF linear weights to SD conv2d weights
173
+ if not w.ndim == 1:
174
+ return w.reshape(*w.shape, 1, 1)
175
+ else:
176
+ return w
177
+
178
+
179
+ def convert_vae_state_dict(vae_state_dict):
180
+ mapping = {k: k for k in vae_state_dict.keys()}
181
+ for k, v in mapping.items():
182
+ for sd_part, hf_part in vae_conversion_map:
183
+ v = v.replace(hf_part, sd_part)
184
+ mapping[k] = v
185
+ for k, v in mapping.items():
186
+ if "attentions" in k:
187
+ for sd_part, hf_part in vae_conversion_map_attn:
188
+ v = v.replace(hf_part, sd_part)
189
+ mapping[k] = v
190
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
191
+ weights_to_convert = ["q", "k", "v", "proj_out"]
192
+ keys_to_rename = {}
193
+ for k, v in new_state_dict.items():
194
+ for weight_name in weights_to_convert:
195
+ if f"mid.attn_1.{weight_name}.weight" in k:
196
+ print(f"Reshaping {k} for SD format")
197
+ new_state_dict[k] = reshape_weight_for_sd(v)
198
+ for weight_name, real_weight_name in vae_extra_conversion_map:
199
+ if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
200
+ keys_to_rename[k] = k.replace(weight_name, real_weight_name)
201
+ for k, v in keys_to_rename.items():
202
+ if k in new_state_dict:
203
+ print(f"Renaming {k} to {v}")
204
+ new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
205
+ del new_state_dict[k]
206
+ return new_state_dict
207
+
208
+
209
+ # =========================#
210
+ # Text Encoder Conversion #
211
+ # =========================#
212
+
213
+
214
+ textenc_conversion_lst = [
215
+ # (stable-diffusion, HF Diffusers)
216
+ ("resblocks.", "text_model.encoder.layers."),
217
+ ("ln_1", "layer_norm1"),
218
+ ("ln_2", "layer_norm2"),
219
+ (".c_fc.", ".fc1."),
220
+ (".c_proj.", ".fc2."),
221
+ (".attn", ".self_attn"),
222
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
223
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
224
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
225
+ ]
226
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
227
+ textenc_pattern = re.compile("|".join(protected.keys()))
228
+
229
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
230
+ code2idx = {"q": 0, "k": 1, "v": 2}
231
+
232
+
233
+ def convert_text_enc_state_dict_v20(text_enc_dict):
234
+ new_state_dict = {}
235
+ capture_qkv_weight = {}
236
+ capture_qkv_bias = {}
237
+ for k, v in text_enc_dict.items():
238
+ if (
239
+ k.endswith(".self_attn.q_proj.weight")
240
+ or k.endswith(".self_attn.k_proj.weight")
241
+ or k.endswith(".self_attn.v_proj.weight")
242
+ ):
243
+ k_pre = k[: -len(".q_proj.weight")]
244
+ k_code = k[-len("q_proj.weight")]
245
+ if k_pre not in capture_qkv_weight:
246
+ capture_qkv_weight[k_pre] = [None, None, None]
247
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
248
+ continue
249
+
250
+ if (
251
+ k.endswith(".self_attn.q_proj.bias")
252
+ or k.endswith(".self_attn.k_proj.bias")
253
+ or k.endswith(".self_attn.v_proj.bias")
254
+ ):
255
+ k_pre = k[: -len(".q_proj.bias")]
256
+ k_code = k[-len("q_proj.bias")]
257
+ if k_pre not in capture_qkv_bias:
258
+ capture_qkv_bias[k_pre] = [None, None, None]
259
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
260
+ continue
261
+
262
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
263
+ new_state_dict[relabelled_key] = v
264
+
265
+ for k_pre, tensors in capture_qkv_weight.items():
266
+ if None in tensors:
267
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
268
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
269
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
270
+
271
+ for k_pre, tensors in capture_qkv_bias.items():
272
+ if None in tensors:
273
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
274
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
275
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
276
+
277
+ return new_state_dict
278
+
279
+
280
+ def convert_text_enc_state_dict(text_enc_dict):
281
+ return text_enc_dict
282
+
283
+
284
+ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16"):
285
+ # Path for safetensors
286
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
287
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
288
+ text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
289
+
290
+ # Load models from safetensors if it exists, if it doesn't pytorch
291
+ if osp.exists(unet_path):
292
+ unet_state_dict = load_file(unet_path, device="cpu")
293
+ else:
294
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
295
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
296
+
297
+ if osp.exists(vae_path):
298
+ vae_state_dict = load_file(vae_path, device="cpu")
299
+ else:
300
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
301
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
302
+
303
+ if osp.exists(text_enc_path):
304
+ text_enc_dict = load_file(text_enc_path, device="cpu")
305
+ else:
306
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
307
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
308
+
309
+ # Convert the UNet model
310
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
311
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
312
+
313
+ # Convert the VAE model
314
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
315
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
316
+
317
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
318
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
319
+
320
+ if is_v20_model:
321
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
322
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
323
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
324
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
325
+ else:
326
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
327
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
328
+
329
+ # Put together new checkpoint
330
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
331
+
332
+ if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
333
+ elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
334
+ elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
335
+
336
+ save_file(state_dict, checkpoint_path)
337
+
338
+
339
+ def download_repo(repo_id, dir_path):
340
+ from huggingface_hub import snapshot_download
341
+ try:
342
+ snapshot_download(repo_id=repo_id, local_dir=dir_path)
343
+ except Exception as e:
344
+ print(f"Error: Failed to download {repo_id}. ")
345
+ return
346
+
347
+
348
+ def convert_repo_to_safetensors(repo_id, dtype="fp16"):
349
+ download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
350
+ output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
351
+ download_repo(repo_id, download_dir)
352
+ convert_diffusers_to_safetensors(download_dir, output_filename, dtype)
353
+ return output_filename
354
+
355
+
356
+ if __name__ == "__main__":
357
+ parser = argparse.ArgumentParser()
358
+
359
+ parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
360
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "default"], help='Output data type. (Default: "fp16")')
361
+
362
+ args = parser.parse_args()
363
+ assert args.repo_id is not None, "Must provide a Repo ID!"
364
+
365
+ convert_repo_to_safetensors(args.repo_id, args.dtype)
local/convert_repo_to_safetensors_sdxl_lora.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
2
+ # This means that you can input your diffusers-trained LoRAs and
3
+ # Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
4
+
5
+ # To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
6
+ # https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
7
+ # and run the script:
8
+ # python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
9
+ # now you can use corgy.safetensors in your WebUI of choice!
10
+
11
+ # To train your own, here are some diffusers training scripts and utils that you can use and then convert:
12
+ # LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
13
+ # Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
14
+ # - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
15
+ # - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
16
+ # Canonical diffusers training scripts:
17
+ # - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
18
+ # - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
19
+
20
+ import argparse
21
+ import os
22
+
23
+ from safetensors.torch import load_file, save_file
24
+ from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
25
+ from pathlib import Path
26
+
27
+ def convert_and_save(input_lora, output_lora=None):
28
+ if output_lora is None:
29
+ base_name = os.path.splitext(input_lora)[0]
30
+ output_lora = f"{base_name}_webui.safetensors"
31
+
32
+ diffusers_state_dict = load_file(input_lora)
33
+ try:
34
+ peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
35
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
36
+ except Exception: # skipped
37
+ kohya_state_dict = diffusers_state_dict
38
+ save_file(kohya_state_dict, output_lora)
39
+
40
+
41
+ def download_repo_lora(repo_id, local_file):
42
+ from huggingface_hub import hf_hub_download, HfApi
43
+ lora_filename = "pytorch_lora_weights.safetensors"
44
+ lora_path = Path(lora_filename)
45
+ api = HfApi()
46
+ try:
47
+ if not api.file_exists(repo_id=repo_id, filename=lora_filename):
48
+ print(f"Error: This repo isn't diffusers LoRA repo: {repo_id}. ")
49
+ return None
50
+ if lora_path.exists():
51
+ print(f"Error: Download file already exists: {lora_filename}. ")
52
+ return None
53
+ hf_hub_download(repo_id=repo_id, filename=lora_filename, local_dir=".")
54
+ if lora_path.exists(): lora_path.rename(Path(local_file))
55
+ except Exception as e:
56
+ print(f"Error: Failed to download from {repo_id}. {e}")
57
+ return local_file
58
+
59
+
60
+ def convert_repo_to_safetensors_sdxl_lora(repo_id):
61
+ download_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}_diffusers.safetensors"
62
+ output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}_webui.safetensors"
63
+ download_repo_lora(repo_id, download_filename)
64
+ convert_and_save(download_filename, output_filename)
65
+ return output_filename
66
+
67
+
68
+ if __name__ == "__main__":
69
+ parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format from Repo.")
70
+ parser.add_argument("--repo_id", type=str, required=True, help="URL to the Repo of input LoRA model in the diffusers format.")
71
+
72
+ args = parser.parse_args()
73
+
74
+ convert_repo_to_safetensors_sdxl_lora(args.repo_id)
75
+
76
+
77
+ # Usage: python convert_repo_to_safetensors_sdxl_lora.py --repo_id nroggendorff/zelda-lora
local/requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
  torch
2
  safetensors
3
- huggingface-hub
 
 
 
 
 
1
  torch
2
  safetensors
3
+ huggingface-hub
4
+ accelerate
5
+ diffusers
6
+ transformers
7
+ peft
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git-lfs aria2
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
  torch
2
  safetensors
3
- huggingface-hub
 
 
 
 
 
1
  torch
2
  safetensors
3
+ huggingface-hub
4
+ accelerate
5
+ diffusers
6
+ transformers
7
+ peft
utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download
3
+ import os
4
+ from pathlib import Path
5
+ import shutil
6
+ import gc
7
+ import re
8
+ import urllib.parse
9
+
10
+
11
+ def get_token():
12
+ try:
13
+ token = HfFolder.get_token()
14
+ except Exception:
15
+ token = ""
16
+ return token
17
+
18
+
19
+ def set_token(token):
20
+ try:
21
+ HfFolder.save_token(token)
22
+ except Exception:
23
+ print(f"Error: Failed to save token.")
24
+
25
+
26
+ def get_user_agent():
27
+ return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
28
+
29
+
30
+ def is_repo_exists(repo_id: str, repo_type: str="model"):
31
+ hf_token = get_token()
32
+ api = HfApi(token=hf_token)
33
+ try:
34
+ if api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return True
35
+ else: return False
36
+ except Exception as e:
37
+ print(f"Error: Failed to connect {repo_id} ({repo_type}). {e}")
38
+ return True # for safe
39
+
40
+
41
+ MODEL_TYPE_CLASS = {
42
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
43
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
44
+ "diffusers:FluxPipeline": "FLUX",
45
+ }
46
+
47
+
48
+ def get_model_type(repo_id: str):
49
+ hf_token = get_token()
50
+ api = HfApi(token=hf_token)
51
+ lora_filename = "pytorch_lora_weights.safetensors"
52
+ diffusers_filename = "model_index.json"
53
+ default = "SDXL"
54
+ try:
55
+ if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA"
56
+ if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None"
57
+ model = api.model_info(repo_id=repo_id, token=hf_token)
58
+ tags = model.tags
59
+ for tag in tags:
60
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
61
+ except Exception:
62
+ return default
63
+ return default
64
+
65
+
66
+ def list_sub(a, b):
67
+ return [e for e in a if e not in b]
68
+
69
+
70
+ def is_repo_name(s):
71
+ return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
72
+
73
+
74
+ def split_hf_url(url: str):
75
+ try:
76
+ s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.safetensors)(?:\?download=true)?$', url)[0])
77
+ if len(s) < 4: return "", "", "", ""
78
+ repo_id = s[1]
79
+ repo_type = "dataset" if s[0] == "datasets" else "model"
80
+ subfolder = urllib.parse.unquote(s[2]) if s[2] else None
81
+ filename = urllib.parse.unquote(s[3])
82
+ return repo_id, filename, subfolder, repo_type
83
+ except Exception as e:
84
+ print(e)
85
+
86
+
87
+ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
88
+ hf_token = get_token()
89
+ repo_id, filename, subfolder, repo_type = split_hf_url(url)
90
+ try:
91
+ if subfolder is not None: hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
92
+ else: hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
93
+ except Exception as e:
94
+ print(f"Failed to download: {e}")
95
+
96
+
97
+ def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
98
+ hf_token = get_token()
99
+ url = url.strip()
100
+ if "drive.google.com" in url:
101
+ original_dir = os.getcwd()
102
+ os.chdir(directory)
103
+ os.system(f"gdown --fuzzy {url}")
104
+ os.chdir(original_dir)
105
+ elif "huggingface.co" in url:
106
+ url = url.replace("?download=true", "")
107
+ if "/blob/" in url:
108
+ url = url.replace("/blob/", "/resolve/")
109
+ #user_header = f'"Authorization: Bearer {hf_token}"'
110
+ if hf_token:
111
+ download_hf_file(directory, url)
112
+ #os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
113
+ else:
114
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
115
+ elif "civitai.com" in url:
116
+ if "?" in url:
117
+ url = url.split("?")[0]
118
+ if civitai_api_key:
119
+ url = url + f"?token={civitai_api_key}"
120
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
121
+ else:
122
+ print("You need an API key to download Civitai models.")
123
+ else:
124
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
125
+
126
+
127
+ def get_local_model_list(dir_path):
128
+ model_list = []
129
+ valid_extensions = ('.safetensors')
130
+ for file in Path(dir_path).glob("**/*.*"):
131
+ if file.is_file() and file.suffix in valid_extensions:
132
+ file_path = str(file)
133
+ model_list.append(file_path)
134
+ return model_list
135
+
136
+
137
+ def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
138
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
139
+ print(f"Use HF Repo: {url}")
140
+ new_file = url
141
+ elif not "http" in url and Path(url).exists():
142
+ print(f"Use local file: {url}")
143
+ new_file = url
144
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
145
+ print(f"File to download alreday exists: {url}")
146
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
147
+ else:
148
+ print(f"Start downloading: {url}")
149
+ before = get_local_model_list(temp_dir)
150
+ try:
151
+ download_thing(temp_dir, url.strip(), civitai_key)
152
+ except Exception:
153
+ print(f"Download failed: {url}")
154
+ return ""
155
+ after = get_local_model_list(temp_dir)
156
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
157
+ if not new_file:
158
+ print(f"Download failed: {url}")
159
+ return ""
160
+ print(f"Download completed: {url}")
161
+ return new_file