from safetensors.torch import load_file import gguf import sys def load_transformer_by_original_checkpoint(ckpt_path): sd3 = load_file(ckpt_path) sd = {} for key in sd3.keys(): if key.startswith('model.diffusion_model.'): sd[key] = sd3[key] return sd if __name__ == '__main__': filepath = sys.argv[1] # GGUF filepath. writer = gguf.GGUFWriter(filepath, arch='sd3') # Arch is a fake value, sd3, sdxl. target_quant = gguf.GGMLQuantizationType.Q8_0 sd_fp16 = load_transformer_by_original_checkpoint(sys.argv[2]) # Safetensors filepath, downloaded from civit. writer.add_quantization_version(gguf.GGML_QUANT_VERSION) writer.add_file_type(target_quant) sd = {} for key in sd_fp16.keys(): tensor = sd_fp16[key] if len(tensor.shape) == 1 or len(tensor.shape) == 4: q = gguf.GGMLQuantizationType.F16 else: q = target_quant sd[key] = gguf.quants.quantize(tensor.numpy(), q) writer.add_tensor(key, sd[key], raw_dtype=q) writer.write_header_to_file(filepath) writer.write_kv_data_to_file() writer.write_tensors_to_file() writer.close()