from safetensors.torch import load_file | |
import gguf | |
import sys | |
def load_transformer_by_original_checkpoint(ckpt_path): | |
sd3 = load_file(ckpt_path) | |
sd = {} | |
for key in sd3.keys(): | |
if key.startswith('model.diffusion_model.'): | |
sd[key] = sd3[key] | |
return sd | |
if __name__ == '__main__': | |
filepath = sys.argv[1] # GGUF filepath. | |
writer = gguf.GGUFWriter(filepath, arch='sd3') # Arch is a fake value, sd3, sdxl. | |
target_quant = gguf.GGMLQuantizationType.Q8_0 | |
sd_fp16 = load_transformer_by_original_checkpoint(sys.argv[2]) # Safetensors filepath, downloaded from civit. | |
writer.add_quantization_version(gguf.GGML_QUANT_VERSION) | |
writer.add_file_type(target_quant) | |
sd = {} | |
for key in sd_fp16.keys(): | |
tensor = sd_fp16[key] | |
if len(tensor.shape) == 1 or len(tensor.shape) == 4: | |
q = gguf.GGMLQuantizationType.F16 | |
else: | |
q = target_quant | |
sd[key] = gguf.quants.quantize(tensor.numpy(), q) | |
writer.add_tensor(key, sd[key], raw_dtype=q) | |
writer.write_header_to_file(filepath) | |
writer.write_kv_data_to_file() | |
writer.write_tensors_to_file() | |
writer.close() |