Skip to content

Commit b1db94e

Browse files
justinchubyCopilot
andauthored
Support saving external data as safetensors (onnx#306)
Supports saving to safetensors as external data file with `ir.save_safetensors()`. The function follows closely the standard `ir.save` api and added **sharding support** compatible with Huggingface Transformers convension. Initializers in the subgraph are also handled, like ir.load. Bug fixes: - Fixed an error when converting an external 2bit tensor to numpy. - Fixed an error in ir.load() where the base dir is not set for initializers in the subgraphs. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 9be787c commit b1db94e

File tree

9 files changed

+1175
-10
lines changed

9 files changed

+1175
-10
lines changed

docs/api/core.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
1515
onnx_ir.load
1616
onnx_ir.save
17+
onnx_ir.save_safetensors
1718
onnx_ir.from_proto
1819
onnx_ir.from_onnx_text
1920
onnx_ir.to_proto

noxfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"typing_extensions>=4.10",
3030
"ml-dtypes",
3131
"onnxruntime",
32+
"safetensors",
3233
)
3334
ONNX = "onnx==1.18"
3435
ONNXSCRIPT = "onnxscript"

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pyyaml
2828
torch>=2.3
2929
torchvision>=0.18.0
3030
transformers>=4.37.2
31+
safetensors
3132

3233
# Lint
3334
lintrunner>=0.10.7

src/onnx_ir/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
# IO
8585
"load",
8686
"save",
87+
"save_safetensors",
8788
# Flags
8889
"DEBUG",
8990
# Others
@@ -156,6 +157,7 @@
156157
TypeProtocol,
157158
ValueProtocol,
158159
)
160+
from onnx_ir._safetensors import save_safetensors
159161
from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_onnx_text, to_proto
160162

161163
DEBUG = False

src/onnx_ir/_core.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,8 @@ def _load(self):
757757
_enums.DataType.UINT2,
758758
}:
759759
# Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
760-
dt = np.dtype(np.uint8).newbyteorder("<")
760+
# No need to set endianness for uint8
761+
dt = np.dtype(np.uint8)
761762
count = self.size // 2 + self.size % 2
762763
else:
763764
# Handle the byte order correctly by always using little endian
@@ -772,6 +773,11 @@ def _load(self):
772773
self._array = _type_casting.unpack_4bitx2(self._array, shape).view(
773774
self.dtype.numpy()
774775
)
776+
elif self.dtype.bitwidth == 2:
777+
# Unpack the 2bit arrays
778+
self._array = _type_casting.unpack_2bitx4(self._array, shape).view(
779+
self.dtype.numpy()
780+
)
775781
else:
776782
self._array = self._array.reshape(shape)
777783

src/onnx_ir/_io.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo)
106106
base_dir = os.path.dirname(path)
107107

108108
# Store the original initializer values so they can be restored if modify_model=False
109-
initializer_values: list[_core.Value] = []
109+
initialized_values: list[_core.Value] = []
110110
for graph in model.graphs():
111111
# Collect from all subgraphs as well
112-
initializer_values.extend(graph.initializers.values())
113-
tensors = [v.const_value for v in initializer_values]
112+
initialized_values.extend(graph.initializers.values())
113+
tensors = [v.const_value for v in initialized_values]
114114

115115
try:
116116
model = _external_data.unload_from_model(
@@ -125,7 +125,7 @@ def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo)
125125

126126
finally:
127127
# Restore the original initializer values so the model is unchanged
128-
for initializer, tensor in zip(initializer_values, tensors, strict=True):
128+
for initializer, tensor in zip(initialized_values, tensors, strict=True):
129129
initializer.const_value = tensor
130130

131131
else:

0 commit comments

Comments
 (0)