1818
1919import google .protobuf .message
2020import numpy as np
21+ import typing_extensions
2122
2223import onnx ._custom_element_types as custom_np_types
2324from onnx import (
@@ -356,20 +357,28 @@ def set_model_props(model: ModelProto, dict_value: dict[str, str]) -> None:
356357 set_metadata_props (model , dict_value )
357358
358359
359- def split_complex_to_pairs (ca : Sequence [np .complex64 ]) -> Sequence [int ]:
360+ def _split_complex_to_pairs (ca : Sequence [np .complex64 ]) -> Sequence [int ]:
360361 return [
361362 (ca [i // 2 ].real if (i % 2 == 0 ) else ca [i // 2 ].imag ) # type: ignore[misc]
362363 for i in range (len (ca ) * 2 )
363364 ]
364365
365366
367+ @typing_extensions .deprecated (
368+ "Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion" ,
369+ category = FutureWarning ,
370+ )
371+ def float32_to_bfloat16 (* args , ** kwargs ) -> int :
372+ return _float32_to_bfloat16 (* args , ** kwargs )
373+
374+
366375# convert a float32 value to a bfloat16 (as int)
367376# By default, this conversion rounds-to-nearest-even and supports NaN
368377# Setting `truncate` to True enables a simpler conversion. In this mode the
369378# conversion is performed by simply dropping the 2 least significant bytes of
370379# the significand. In this mode an error of up to 1 bit may be introduced and
371380# preservation of NaN values is not be guaranteed.
372- def float32_to_bfloat16 (fval : float , truncate : bool = False ) -> int :
381+ def _float32_to_bfloat16 (fval : float , truncate : bool = False ) -> int :
373382 ival = int .from_bytes (struct .pack ("<f" , fval ), "little" )
374383 if truncate :
375384 return ival >> 16
@@ -382,7 +391,15 @@ def float32_to_bfloat16(fval: float, truncate: bool = False) -> int:
382391 return (ival + rounded ) >> 16
383392
384393
385- def float32_to_float8e4m3 ( # noqa: PLR0911
394+ @typing_extensions .deprecated (
395+ "Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion" ,
396+ category = FutureWarning ,
397+ )
398+ def float32_to_float8e4m3 (* args , ** kwargs ) -> int :
399+ return _float32_to_float8e4m3 (* args , ** kwargs )
400+
401+
402+ def _float32_to_float8e4m3 ( # noqa: PLR0911
386403 fval : float ,
387404 scale : float = 1.0 ,
388405 fn : bool = True ,
@@ -516,7 +533,15 @@ def float32_to_float8e4m3( # noqa: PLR0911
516533 return int (ret )
517534
518535
519- def float32_to_float8e5m2 ( # noqa: PLR0911
536+ @typing_extensions .deprecated (
537+ "Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion" ,
538+ category = FutureWarning ,
539+ )
540+ def float32_to_float8e5m2 (* args : Any , ** kwargs : Any ) -> int :
541+ return _float32_to_float8e5m2 (* args , ** kwargs )
542+
543+
544+ def _float32_to_float8e5m2 ( # noqa: PLR0911
520545 fval : float ,
521546 scale : float = 1.0 ,
522547 fn : bool = False ,
@@ -642,7 +667,15 @@ def float32_to_float8e5m2( # noqa: PLR0911
642667 raise NotImplementedError ("fn and uz must be both False or True." )
643668
644669
670+ @typing_extensions .deprecated (
671+ "Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion" ,
672+ category = FutureWarning ,
673+ )
645674def pack_float32_to_4bit (array : np .ndarray | Sequence , signed : bool ) -> np .ndarray :
675+ return _pack_float32_to_4bit (array , signed )
676+
677+
678+ def _pack_float32_to_4bit (array : np .ndarray | Sequence , signed : bool ) -> np .ndarray :
646679 """Convert an array of float32 value to a 4bit data-type and pack every two concecutive elements in a byte.
647680 See :ref:`onnx-detail-int4` for technical details.
648681
@@ -662,15 +695,23 @@ def pack_float32_to_4bit(array: np.ndarray | Sequence, signed: bool) -> np.ndarr
662695 array_flat = np .append (array_flat , np .array ([0 ]))
663696
664697 def single_func (x , y ) -> np .ndarray :
665- return subbyte .float32x2_to_4bitx2 (x , y , signed )
698+ return subbyte ._float32x2_to_4bitx2 (x , y , signed )
666699
667700 func = np .frompyfunc (single_func , 2 , 1 )
668701
669702 arr : np .ndarray = func (array_flat [0 ::2 ], array_flat [1 ::2 ])
670703 return arr .astype (np .uint8 )
671704
672705
706+ @typing_extensions .deprecated (
707+ "Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion" ,
708+ category = FutureWarning ,
709+ )
673710def pack_float32_to_float4e2m1 (array : np .ndarray | Sequence ) -> np .ndarray :
711+ return _pack_float32_to_float4e2m1 (array )
712+
713+
714+ def _pack_float32_to_float4e2m1 (array : np .ndarray | Sequence ) -> np .ndarray :
674715 """Convert an array of float32 value to float4e2m1 and pack every two concecutive elements in a byte.
675716 See :ref:`onnx-detail-float4` for technical details.
676717
@@ -688,7 +729,7 @@ def pack_float32_to_float4e2m1(array: np.ndarray | Sequence) -> np.ndarray:
688729 if is_odd_volume :
689730 array_flat = np .append (array_flat , np .array ([0 ]))
690731
691- arr = subbyte .float32x2_to_float4e2m1x2 (array_flat [0 ::2 ], array_flat [1 ::2 ])
732+ arr = subbyte ._float32x2_to_float4e2m1x2 (array_flat [0 ::2 ], array_flat [1 ::2 ])
692733 return arr .astype (np .uint8 )
693734
694735
@@ -759,7 +800,7 @@ def make_tensor(
759800 tensor .raw_data = vals
760801 else :
761802 if data_type in (TensorProto .COMPLEX64 , TensorProto .COMPLEX128 ):
762- vals = split_complex_to_pairs (vals )
803+ vals = _split_complex_to_pairs (vals )
763804 elif data_type == TensorProto .FLOAT16 :
764805 vals = (
765806 np .array (vals ).astype (np_dtype ).view (dtype = np .uint16 ).flatten ().tolist ()
@@ -772,13 +813,13 @@ def make_tensor(
772813 TensorProto .FLOAT8E5M2FNUZ ,
773814 ):
774815 fcast = {
775- TensorProto .BFLOAT16 : float32_to_bfloat16 ,
776- TensorProto .FLOAT8E4M3FN : float32_to_float8e4m3 ,
777- TensorProto .FLOAT8E4M3FNUZ : lambda * args : float32_to_float8e4m3 ( # type: ignore[misc]
816+ TensorProto .BFLOAT16 : _float32_to_bfloat16 ,
817+ TensorProto .FLOAT8E4M3FN : _float32_to_float8e4m3 ,
818+ TensorProto .FLOAT8E4M3FNUZ : lambda * args : _float32_to_float8e4m3 ( # type: ignore[misc]
778819 * args , uz = True
779820 ),
780- TensorProto .FLOAT8E5M2 : float32_to_float8e5m2 ,
781- TensorProto .FLOAT8E5M2FNUZ : lambda * args : float32_to_float8e5m2 ( # type: ignore[misc]
821+ TensorProto .FLOAT8E5M2 : _float32_to_float8e5m2 ,
822+ TensorProto .FLOAT8E5M2FNUZ : lambda * args : _float32_to_float8e5m2 ( # type: ignore[misc]
782823 * args , fn = True , uz = True
783824 ),
784825 }[
@@ -801,9 +842,9 @@ def make_tensor(
801842 # to uint8 regardless of the value of 'signed'. Using int8 would cause
802843 # the size of int4 tensors to increase ~5x if the tensor contains negative values (due to
803844 # the way negative values are serialized by protobuf).
804- vals = pack_float32_to_4bit (vals , signed = signed ).flatten ().tolist ()
845+ vals = _pack_float32_to_4bit (vals , signed = signed ).flatten ().tolist ()
805846 elif data_type == TensorProto .FLOAT4E2M1 :
806- vals = pack_float32_to_float4e2m1 (vals ).flatten ().tolist ()
847+ vals = _pack_float32_to_float4e2m1 (vals ).flatten ().tolist ()
807848 elif data_type == TensorProto .BOOL :
808849 vals = np .array (vals ).astype (int )
809850 elif data_type == TensorProto .STRING :
0 commit comments