Spaces:
Runtime error
Runtime error
File size: 1,738 Bytes
a4d7b31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
"""
Pure python version of Safetensors safe_open
From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282
"""
import json
import mmap
import os
import torch
class SafetensorsWrapper:
def __init__(self, metadata, tensors):
self._metadata = metadata
self._tensors = tensors
def metadata(self):
return self._metadata
def keys(self):
return self._tensors.keys()
def get_tensor(self, k):
return self._tensors[k]
DTYPES = {
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
}
def create_tensor(storage, info, offset):
dtype = DTYPES[info["dtype"]]
shape = info["shape"]
start, stop = info["data_offsets"]
return (
torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8)
.view(dtype=dtype)
.reshape(shape)
)
def safe_open(filename, framework="pt", device="cpu"):
if framework != "pt":
raise ValueError("`framework` must be 'pt'")
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
metadata_bytes = m.read(n)
metadata = json.loads(metadata_bytes)
size = os.stat(filename).st_size
storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
offset = n + 8
return SafetensorsWrapper(
metadata=metadata.get("__metadata__", {}),
tensors={
name: create_tensor(storage, info, offset).to(device)
for name, info in metadata.items()
if name != "__metadata__"
},
)
|