Skip to content

Commit 329a798

Browse files
Copilotjustinchuby
andauthored
Add support for FLOAT8E8M0 data type (onnx#128)
This PR adds comprehensive support for the FLOAT8E8M0 data type that was added to ONNX in onnx/onnx#7030. ## Changes Made - **Added FLOAT8E8M0 enum value**: Set to 24 (next available value after FLOAT4E2M1=23) - **Updated numpy type mapping**: Added support for `ml_dtypes.float8_e8m0fnu` - **Added type properties**: Configured as 8-bit floating point, signed type - **Added short name**: "f8e8m0" for compact representation - **Updated serialization**: Added FLOAT8E8M0 to appropriate sets in `serde.py` for proper tensor serialization/deserialization - **Added tests**: Included parameterized test case and conditional ONNX compatibility check ## Testing The implementation includes comprehensive testing: ```python import onnx_ir._enums as enums import ml_dtypes import numpy as np # Create tensor with FLOAT8E8M0 type data = np.array([1.0, 2.0, 3.0], dtype=ml_dtypes.float8_e8m0fnu) tensor = ir_core.Tensor(data) assert tensor.dtype == enums.DataType.FLOAT8E8M0 # Test properties assert enums.DataType.FLOAT8E8M0.is_floating_point() == True assert enums.DataType.FLOAT8E8M0.bitwidth == 8 assert enums.DataType.FLOAT8E8M0.short_name() == 'f8e8m0' # Test serialization round-trip tensor_proto = serde.serialize_tensor(tensor) assert tensor_proto.data_type == 24 ``` All existing tests continue to pass, ensuring no regression in functionality. Fixes onnx#127. <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent d619b1b commit 329a798

File tree

8 files changed

+67
-8
lines changed

8 files changed

+67
-8
lines changed

src/onnx_ir/_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
_enums.DataType.FLOAT8E4M3FNUZ,
7979
_enums.DataType.FLOAT8E5M2,
8080
_enums.DataType.FLOAT8E5M2FNUZ,
81+
_enums.DataType.FLOAT8E8M0,
8182
_enums.DataType.INT4,
8283
_enums.DataType.UINT4,
8384
_enums.DataType.FLOAT4E2M1,
@@ -261,6 +262,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
261262
ml_dtypes.float8_e4m3fn,
262263
ml_dtypes.float8_e5m2fnuz,
263264
ml_dtypes.float8_e5m2,
265+
ml_dtypes.float8_e8m0fnu,
264266
):
265267
raise TypeError(
266268
f"The numpy array dtype must be uint8 or ml_dtypes.float8* (not {array.dtype}) for IR data type {dtype}."
@@ -319,6 +321,8 @@ def _maybe_view_np_array_with_ml_dtypes(
319321
return array.view(ml_dtypes.float8_e5m2)
320322
if dtype == _enums.DataType.FLOAT8E5M2FNUZ:
321323
return array.view(ml_dtypes.float8_e5m2fnuz)
324+
if dtype == _enums.DataType.FLOAT8E8M0:
325+
return array.view(ml_dtypes.float8_e8m0fnu)
322326
if dtype == _enums.DataType.INT4:
323327
return array.view(ml_dtypes.int4)
324328
if dtype == _enums.DataType.UINT4:

src/onnx_ir/_core_test.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def test_init_requires_type_when_value_is_not_np_array(self):
5252
("float8e4m3fnuz", np.uint8, ir.DataType.FLOAT8E4M3FNUZ),
5353
("float8e5m2", np.uint8, ir.DataType.FLOAT8E5M2),
5454
("float8e5m2fnuz", np.uint8, ir.DataType.FLOAT8E5M2FNUZ),
55+
("float8e8m0", np.uint8, ir.DataType.FLOAT8E8M0),
5556
("int4", np.int8, ir.DataType.INT4),
5657
("int4_uint8", np.uint8, ir.DataType.INT4),
5758
("uint4", np.uint8, ir.DataType.UINT4),
@@ -396,15 +397,28 @@ def test_external_tensor_bfloat16(self):
396397
ir.DataType.FLOAT8E5M2FNUZ,
397398
ml_dtypes.float8_e5m2fnuz,
398399
),
400+
(
401+
"FLOAT8E8M0",
402+
ir.DataType.FLOAT8E8M0,
403+
ml_dtypes.float8_e8m0fnu,
404+
),
399405
]
400406
)
401407
def test_external_tensor_float8(self, _: str, dtype: ir.DataType, np_dtype):
402-
expected_array = np.array(
403-
[[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 40.0, 2.0]]
404-
).astype(np_dtype)
405-
tensor_proto = ir.serde.serialize_tensor(
406-
ir.Tensor(expected_array.view(np.uint8), dtype=dtype)
407-
)
408+
# FLOAT8E8M0 has different precision characteristics (8 exponent bits, 0 mantissa bits)
409+
# It can only represent powers of 2 and special values
410+
if dtype == ir.DataType.FLOAT8E8M0:
411+
expected_array = np.array(
412+
[[0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0]]
413+
).astype(np_dtype)
414+
tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array, dtype=dtype))
415+
else:
416+
expected_array = np.array(
417+
[[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 40.0, 2.0]]
418+
).astype(np_dtype)
419+
tensor_proto = ir.serde.serialize_tensor(
420+
ir.Tensor(expected_array.view(np.uint8), dtype=dtype)
421+
)
408422
with tempfile.TemporaryDirectory() as temp_dir:
409423
_to_external_tensor(tensor_proto, temp_dir, "tensor.bin")
410424
tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir)

src/onnx_ir/_enums.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class DataType(enum.IntEnum):
6565
UINT4 = 21
6666
INT4 = 22
6767
FLOAT4E2M1 = 23
68+
FLOAT8E8M0 = 24
6869

6970
@classmethod
7071
def from_numpy(cls, dtype: np.dtype) -> DataType:
@@ -167,6 +168,7 @@ def is_floating_point(self) -> bool:
167168
DataType.FLOAT8E5M2,
168169
DataType.FLOAT8E5M2FNUZ,
169170
DataType.FLOAT4E2M1,
171+
DataType.FLOAT8E8M0,
170172
}
171173

172174
def is_integer(self) -> bool:
@@ -209,6 +211,7 @@ def is_signed(self) -> bool:
209211
DataType.FLOAT8E5M2FNUZ,
210212
DataType.INT4,
211213
DataType.FLOAT4E2M1,
214+
DataType.FLOAT8E8M0,
212215
}
213216

214217
def __repr__(self) -> str:
@@ -241,6 +244,7 @@ def __str__(self) -> str:
241244
DataType.UINT4: 4,
242245
DataType.INT4: 4,
243246
DataType.FLOAT4E2M1: 4,
247+
DataType.FLOAT8E8M0: 8,
244248
}
245249

246250

@@ -266,6 +270,7 @@ def __str__(self) -> str:
266270
np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ,
267271
np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2,
268272
np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ,
273+
np.dtype(ml_dtypes.float8_e8m0fnu): DataType.FLOAT8E8M0,
269274
np.dtype(ml_dtypes.int4): DataType.INT4,
270275
np.dtype(ml_dtypes.uint4): DataType.UINT4,
271276
}
@@ -290,6 +295,7 @@ def __str__(self) -> str:
290295
DataType.FLOAT8E5M2: "f8e5m2",
291296
DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
292297
DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
298+
DataType.FLOAT8E8M0: "f8e8m0",
293299
DataType.FLOAT4E2M1: "f4e2m1",
294300
DataType.COMPLEX64: "c64",
295301
DataType.COMPLEX128: "c128",

src/onnx_ir/_enums_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def test_enums_are_the_same_as_spec(self):
3838
self.assertEqual(_enums.DataType.INT4, onnx.TensorProto.INT4)
3939
if hasattr(onnx.TensorProto, "FLOAT4E2M1"):
4040
self.assertEqual(_enums.DataType.FLOAT4E2M1, onnx.TensorProto.FLOAT4E2M1)
41+
if hasattr(onnx.TensorProto, "FLOAT8E8M0"):
42+
self.assertEqual(_enums.DataType.FLOAT8E8M0, onnx.TensorProto.FLOAT8E8M0)
4143
self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED)
4244

4345
@parameterized.parameterized.expand(
@@ -73,6 +75,7 @@ def test_enums_are_the_same_as_spec(self):
7375
("uint4", np.dtype(ml_dtypes.uint4), _enums.DataType.UINT4),
7476
("int4", np.dtype(ml_dtypes.int4), _enums.DataType.INT4),
7577
("float4e2m1", np.dtype(ml_dtypes.float4_e2m1fn), _enums.DataType.FLOAT4E2M1),
78+
("float8e8m0", np.dtype(ml_dtypes.float8_e8m0fnu), _enums.DataType.FLOAT8E8M0),
7679
(
7780
"onnx_ref_bfloat16",
7881
onnx._custom_element_types.bfloat16,

src/onnx_ir/serde.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def numpy(self) -> np.ndarray:
405405
_enums.DataType.FLOAT8E4M3FNUZ,
406406
_enums.DataType.FLOAT8E5M2,
407407
_enums.DataType.FLOAT8E5M2FNUZ,
408+
_enums.DataType.FLOAT8E8M0,
408409
_enums.DataType.INT16,
409410
_enums.DataType.INT32,
410411
_enums.DataType.INT4,
@@ -505,6 +506,7 @@ def tobytes(self) -> bytes:
505506
_enums.DataType.FLOAT8E4M3FNUZ,
506507
_enums.DataType.FLOAT8E5M2,
507508
_enums.DataType.FLOAT8E5M2FNUZ,
509+
_enums.DataType.FLOAT8E8M0,
508510
_enums.DataType.INT4,
509511
_enums.DataType.UINT4,
510512
_enums.DataType.FLOAT4E2M1,

src/onnx_ir/serde_test.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,34 @@ def test_tensor_proto_tensor_bfloat16(self):
224224
onnx.TensorProto.FLOAT8E5M2FNUZ,
225225
ml_dtypes.float8_e5m2fnuz,
226226
),
227+
(
228+
"FLOAT8E8M0",
229+
24, # FLOAT8E8M0 value from the enum
230+
ml_dtypes.float8_e8m0fnu,
231+
),
227232
]
228233
)
229234
def test_tensor_proto_tensor_float8(self, _: str, dtype: int, np_dtype):
230-
expected_array = np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 40.0, 2.0]])
231-
tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 9], expected_array)
235+
# FLOAT8E8M0 has different precision characteristics (8 exponent bits, 0 mantissa bits)
236+
# It can only represent powers of 2 and special values
237+
if dtype == 24: # FLOAT8E8M0
238+
expected_array = np.array([[0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0]])
239+
else:
240+
expected_array = np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 40.0, 2.0]])
241+
242+
# Handle the case where ONNX doesn't support FLOAT8E8M0 yet (value 24)
243+
if dtype == 24: # FLOAT8E8M0
244+
# Create tensor proto manually since ONNX helper might not support this type yet
245+
tensor_proto = onnx.TensorProto()
246+
tensor_proto.name = "test_tensor"
247+
tensor_proto.data_type = dtype
248+
tensor_proto.dims[:] = [1, 9]
249+
tensor_proto.raw_data = expected_array.astype(np_dtype).tobytes()
250+
else:
251+
tensor_proto = onnx.helper.make_tensor(
252+
"test_tensor", dtype, [1, 9], expected_array
253+
)
254+
232255
tensor = serde.TensorProtoTensor(tensor_proto)
233256
np.testing.assert_array_equal(
234257
tensor.numpy().view(np_dtype).astype(np.float32), expected_array
@@ -371,6 +394,7 @@ def test_tensor_proto_tensor_empty_tensor(self):
371394
("FLOAT8E4M3FNUZ", ir.DataType.FLOAT8E4M3FNUZ),
372395
("FLOAT8E5M2", ir.DataType.FLOAT8E5M2),
373396
("FLOAT8E5M2FNUZ", ir.DataType.FLOAT8E5M2FNUZ),
397+
("FLOAT8E8M0", ir.DataType.FLOAT8E8M0),
374398
("UINT4", ir.DataType.UINT4),
375399
("INT4", ir.DataType.INT4),
376400
("FLOAT4E2M1", ir.DataType.FLOAT4E2M1),
@@ -406,6 +430,7 @@ def test_round_trip_numpy_conversion_from_raw_data(
406430
ir.DataType.FLOAT8E5M2,
407431
ir.DataType.FLOAT8E4M3FN,
408432
ir.DataType.BFLOAT16,
433+
ir.DataType.FLOAT8E8M0,
409434
}:
410435
# There is a bug in ml_dtypes that causes equality checks to fail for these dtypes
411436
# See https://github.com/jax-ml/ml_dtypes/issues/301

src/onnx_ir/tensor_adapters.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
6868
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
6969
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
7070
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
71+
torch.float8_e8m0fnu: ir.DataType.FLOAT8E8M0,
7172
torch.int16: ir.DataType.INT16,
7273
torch.int32: ir.DataType.INT32,
7374
torch.int64: ir.DataType.INT64,
@@ -104,6 +105,7 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
104105
ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
105106
ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
106107
ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
108+
ir.DataType.FLOAT8E8M0: torch.float8_e8m0fnu,
107109
ir.DataType.INT16: torch.int16,
108110
ir.DataType.INT32: torch.int32,
109111
ir.DataType.INT64: torch.int64,
@@ -142,6 +144,7 @@ def numpy(self) -> npt.NDArray:
142144
ir.DataType.FLOAT8E4M3FNUZ,
143145
ir.DataType.FLOAT8E5M2,
144146
ir.DataType.FLOAT8E5M2FNUZ,
147+
ir.DataType.FLOAT8E8M0,
145148
}:
146149
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
147150

src/onnx_ir/tensor_adapters_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class TorchTensorTest(unittest.TestCase):
3737
(torch.float8_e4m3fnuz, ml_dtypes.float8_e4m3fnuz),
3838
(torch.float8_e5m2, ml_dtypes.float8_e5m2),
3939
(torch.float8_e5m2fnuz, ml_dtypes.float8_e5m2fnuz),
40+
(torch.float8_e8m0fnu, ml_dtypes.float8_e8m0fnu),
4041
(torch.int16, np.int16),
4142
(torch.int32, np.int32),
4243
(torch.int64, np.int64),
@@ -66,6 +67,7 @@ def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):
6667
(torch.float8_e4m3fnuz,),
6768
(torch.float8_e5m2,),
6869
(torch.float8_e5m2fnuz,),
70+
(torch.float8_e8m0fnu,),
6971
(torch.int16,),
7072
(torch.int32,),
7173
(torch.int64,),

0 commit comments

Comments
 (0)