|
import argparse
|
|
from pathlib import Path
|
|
import json
|
|
import re
|
|
import gc
|
|
from safetensors.torch import load_file, save_file
|
|
import torch
|
|
|
|
|
|
SDXL_KEYS_FILE = "keys/sdxl_keys.txt"
|
|
|
|
|
|
def list_uniq(l):
|
|
return sorted(set(l), key=l.index)
|
|
|
|
|
|
def read_safetensors_metadata(path: str):
|
|
with open(path, 'rb') as f:
|
|
header_size = int.from_bytes(f.read(8), 'little')
|
|
header_json = f.read(header_size).decode('utf-8')
|
|
header = json.loads(header_json)
|
|
metadata = header.get('__metadata__', {})
|
|
return metadata
|
|
|
|
|
|
def keys_from_file(path: str):
|
|
keys = []
|
|
try:
|
|
with open(str(Path(path)), encoding='utf-8', mode='r') as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
keys.append(line.strip())
|
|
except Exception as e:
|
|
print(e)
|
|
finally:
|
|
return keys
|
|
|
|
|
|
def validate_keys(keys: list[str], rfile: str=SDXL_KEYS_FILE):
|
|
missing = []
|
|
added = []
|
|
try:
|
|
rkeys = keys_from_file(rfile)
|
|
all_keys = list_uniq(keys + rkeys)
|
|
for key in all_keys:
|
|
if key in set(rkeys) and key not in set(keys): missing.append(key)
|
|
if key in set(keys) and key not in set(rkeys): added.append(key)
|
|
except Exception as e:
|
|
print(e)
|
|
finally:
|
|
return missing, added
|
|
|
|
|
|
def read_safetensors_key(path: str):
|
|
try:
|
|
keys = []
|
|
state_dict = load_file(str(Path(path)))
|
|
for k in list(state_dict.keys()):
|
|
keys.append(k)
|
|
state_dict.pop(k)
|
|
except Exception as e:
|
|
print(e)
|
|
finally:
|
|
del state_dict
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
return keys
|
|
|
|
|
|
def write_safetensors_key(keys: list[str], path: str, is_validate: bool=True, rpath: str=SDXL_KEYS_FILE):
|
|
if len(keys) == 0: return False
|
|
try:
|
|
with open(str(Path(path)), encoding='utf-8', mode='w') as f:
|
|
f.write("\n".join(keys))
|
|
if is_validate:
|
|
missing, added = validate_keys(keys, rpath)
|
|
with open(str(Path(path).stem + "_missing.txt"), encoding='utf-8', mode='w') as f:
|
|
f.write("\n".join(missing))
|
|
with open(str(Path(path).stem + "_added.txt"), encoding='utf-8', mode='w') as f:
|
|
f.write("\n".join(added))
|
|
return True
|
|
except Exception as e:
|
|
print(e)
|
|
return False
|
|
|
|
|
|
def stkey(input: str, out_filename: str="", is_validate: bool=True, rfile: str=SDXL_KEYS_FILE):
|
|
keys = read_safetensors_key(input)
|
|
if len(keys) != 0 and out_filename: write_safetensors_key(keys, out_filename, is_validate, rfile)
|
|
if len(keys) != 0:
|
|
print("Metadata:")
|
|
print(read_safetensors_metadata(input))
|
|
print("\nKeys:")
|
|
print("\n".join(keys))
|
|
if is_validate:
|
|
missing, added = validate_keys(keys, rfile)
|
|
print("\nMissing Keys:")
|
|
print("\n".join(missing))
|
|
print("\nAdded Keys:")
|
|
print("\n".join(added))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("input", type=str, help="Input safetensors file.")
|
|
parser.add_argument("-s", "--save", action="store_true", default=False, help="Output to text file.")
|
|
parser.add_argument("-o", "--output", default="", type=str, help="Output to specific text file.")
|
|
parser.add_argument("-v", "--val", action="store_false", default=True, help="Disable key validation.")
|
|
parser.add_argument("-r", "--rfile", default=SDXL_KEYS_FILE, type=str, help="Specify reference file to validate keys.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.save: out_filename = Path(args.input).stem + ".txt"
|
|
out_filename = args.output if args.output else out_filename
|
|
|
|
stkey(args.input, out_filename, args.val, args.rfile)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|