comfy-gguf-unet-loader / convert_sd3_to_kohya_unet.py
twodgirl's picture
Create converter scripts.
c3bcb08 verified
raw
history blame
No virus
1.18 kB
from safetensors.torch import load_file
import gguf
import sys
def load_transformer_by_original_checkpoint(ckpt_path):
sd3 = load_file(ckpt_path)
sd = {}
for key in sd3.keys():
if key.startswith('model.diffusion_model.'):
sd[key] = sd3[key]
return sd
if __name__ == '__main__':
filepath = sys.argv[1] # GGUF filepath.
writer = gguf.GGUFWriter(filepath, arch='sd3') # Arch is a fake value, sd3, sdxl.
target_quant = gguf.GGMLQuantizationType.Q8_0
sd_fp16 = load_transformer_by_original_checkpoint(sys.argv[2]) # Safetensors filepath, downloaded from civit.
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
writer.add_file_type(target_quant)
sd = {}
for key in sd_fp16.keys():
tensor = sd_fp16[key]
if len(tensor.shape) == 1 or len(tensor.shape) == 4:
q = gguf.GGMLQuantizationType.F16
else:
q = target_quant
sd[key] = gguf.quants.quantize(tensor.numpy(), q)
writer.add_tensor(key, sd[key], raw_dtype=q)
writer.write_header_to_file(filepath)
writer.write_kv_data_to_file()
writer.write_tensors_to_file()
writer.close()