Skip to content

Commit 2b69b98

Browse files
authored
Merge branch 'main' into titaiwang/fix_avgpool_output_shape
2 parents f97ca29 + e292b4a commit 2b69b98

File tree

11 files changed

+217
-227
lines changed

11 files changed

+217
-227
lines changed

docs/docsgen/source/api/helper.md

Lines changed: 20 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -6,136 +6,9 @@
66
.. currentmodule:: onnx.helper
77
```
88

9-
```{eval-rst}
10-
.. autosummary::
11-
12-
find_min_ir_version_for
13-
get_all_tensor_dtypes
14-
get_attribute_value
15-
get_node_attr_value
16-
set_metadata_props
17-
set_model_props
18-
float32_to_bfloat16
19-
float32_to_float8e4m3
20-
float32_to_float8e5m2
21-
make_attribute
22-
make_attribute_ref
23-
make_empty_tensor_value_info
24-
make_function
25-
make_graph
26-
make_map
27-
make_map_type_proto
28-
make_model
29-
make_node
30-
make_operatorsetid
31-
make_opsetid
32-
make_model_gen_version
33-
make_optional
34-
make_optional_type_proto
35-
make_sequence
36-
make_sequence_type_proto
37-
make_sparse_tensor
38-
make_sparse_tensor_type_proto
39-
make_sparse_tensor_value_info
40-
make_tensor
41-
make_tensor_sequence_value_info
42-
make_tensor_type_proto
43-
make_training_info
44-
make_tensor_value_info
45-
make_value_info
46-
np_dtype_to_tensor_dtype
47-
printable_attribute
48-
printable_dim
49-
printable_graph
50-
printable_node
51-
printable_tensor_proto
52-
printable_type
53-
printable_value_info
54-
split_complex_to_pairs
55-
create_op_set_id_version_map
56-
strip_doc_string
57-
pack_float32_to_4bit
58-
tensor_dtype_to_np_dtype
59-
tensor_dtype_to_storage_tensor_dtype
60-
tensor_dtype_to_string
61-
tensor_dtype_to_field
62-
```
63-
64-
## getter
65-
66-
```{eval-rst}
67-
.. autofunction:: onnx.helper.get_attribute_value
68-
```
69-
70-
```{eval-rst}
71-
.. autofunction:: onnx.helper.get_node_attr_value
72-
```
73-
74-
## setter
75-
76-
```{eval-rst}
77-
.. autofunction:: onnx.helper.set_metadata_props
78-
```
79-
80-
```{eval-rst}
81-
.. autofunction:: onnx.helper.set_model_props
82-
```
83-
84-
## print
85-
86-
```{eval-rst}
87-
.. autofunction:: onnx.helper.printable_attribute
88-
```
89-
90-
```{eval-rst}
91-
.. autofunction:: onnx.helper.printable_dim
92-
```
93-
94-
```{eval-rst}
95-
.. autofunction:: onnx.helper.printable_graph
96-
```
97-
98-
```{eval-rst}
99-
.. autofunction:: onnx.helper.printable_node
100-
```
101-
102-
```{eval-rst}
103-
.. autofunction:: onnx.helper.printable_tensor_proto
104-
```
105-
106-
```{eval-rst}
107-
.. autofunction:: onnx.helper.printable_type
108-
```
109-
110-
```{eval-rst}
111-
.. autofunction:: onnx.helper.printable_value_info
112-
```
113-
114-
## tools
115-
116-
```{eval-rst}
117-
.. autofunction:: onnx.helper.find_min_ir_version_for
118-
```
119-
120-
```{eval-rst}
121-
.. autofunction:: onnx.helper.split_complex_to_pairs
122-
```
123-
124-
```{eval-rst}
125-
.. autofunction:: onnx.helper.create_op_set_id_version_map
126-
```
127-
128-
```{eval-rst}
129-
.. autofunction:: onnx.helper.strip_doc_string
130-
```
131-
132-
```{eval-rst}
133-
.. autofunction:: onnx.helper.pack_float32_to_4bit
134-
```
135-
1369
(l-onnx-make-function)=
13710

138-
## make function
11+
## Helper functions to make ONNX graph components
13912

14013
All functions used to create an ONNX graph.
14114

@@ -239,7 +112,7 @@ All functions used to create an ONNX graph.
239112
.. autofunction:: onnx.helper.make_value_info
240113
```
241114

242-
## type mappings
115+
## Type Mappings
243116

244117
```{eval-rst}
245118
.. autofunction:: onnx.helper.get_all_tensor_dtypes
@@ -265,16 +138,30 @@ All functions used to create an ONNX graph.
265138
.. autofunction:: onnx.helper.tensor_dtype_to_string
266139
```
267140

268-
## cast
141+
## Tools
269142

270143
```{eval-rst}
271-
.. autofunction:: onnx.helper.float32_to_bfloat16
144+
.. autofunction:: onnx.helper.find_min_ir_version_for
272145
```
273146

274147
```{eval-rst}
275-
.. autofunction:: onnx.helper.float32_to_float8e4m3
148+
.. autofunction:: onnx.helper.create_op_set_id_version_map
276149
```
277150

151+
## Other functions
152+
278153
```{eval-rst}
279-
.. autofunction:: onnx.helper.float32_to_float8e5m2
154+
.. autosummary::
155+
156+
get_attribute_value
157+
get_node_attr_value
158+
set_metadata_props
159+
set_model_props
160+
printable_attribute
161+
printable_dim
162+
printable_graph
163+
printable_node
164+
printable_tensor_proto
165+
printable_type
166+
printable_value_info
280167
```

docs/docsgen/source/api/numpy_helper.md

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
```{eval-rst}
88
.. autosummary::
99
10-
bfloat16_to_float32
11-
float8e4m3_to_float32
12-
float8e5m2_to_float32
1310
from_array
1411
from_dict
1512
from_list
@@ -65,35 +62,3 @@ these two functions use a custom dtype defined in :mod:`onnx._custom_element_typ
6562
```{eval-rst}
6663
.. autofunction:: onnx.numpy_helper.from_optional
6764
```
68-
69-
## tools
70-
71-
```{eval-rst}
72-
.. autofunction:: onnx.numpy_helper.convert_endian
73-
```
74-
75-
```{eval-rst}
76-
.. autofunction:: onnx.numpy_helper.combine_pairs_to_complex
77-
```
78-
79-
```{eval-rst}
80-
.. autofunction:: onnx.numpy_helper.create_random_int
81-
```
82-
83-
```{eval-rst}
84-
.. autofunction:: onnx.numpy_helper.unpack_int4
85-
```
86-
87-
## cast
88-
89-
```{eval-rst}
90-
.. autofunction:: onnx.numpy_helper.bfloat16_to_float32
91-
```
92-
93-
```{eval-rst}
94-
.. autofunction:: onnx.numpy_helper.float8e4m3_to_float32
95-
```
96-
97-
```{eval-rst}
98-
.. autofunction:: onnx.numpy_helper.float8e5m2_to_float32
99-
```

onnx/helper.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import google.protobuf.message
2020
import numpy as np
21+
import typing_extensions
2122

2223
import onnx._custom_element_types as custom_np_types
2324
from onnx import (
@@ -356,20 +357,28 @@ def set_model_props(model: ModelProto, dict_value: dict[str, str]) -> None:
356357
set_metadata_props(model, dict_value)
357358

358359

359-
def split_complex_to_pairs(ca: Sequence[np.complex64]) -> Sequence[int]:
360+
def _split_complex_to_pairs(ca: Sequence[np.complex64]) -> Sequence[int]:
360361
return [
361362
(ca[i // 2].real if (i % 2 == 0) else ca[i // 2].imag) # type: ignore[misc]
362363
for i in range(len(ca) * 2)
363364
]
364365

365366

367+
@typing_extensions.deprecated(
368+
"Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion",
369+
category=FutureWarning,
370+
)
371+
def float32_to_bfloat16(*args, **kwargs) -> int:
372+
return _float32_to_bfloat16(*args, **kwargs)
373+
374+
366375
# convert a float32 value to a bfloat16 (as int)
367376
# By default, this conversion rounds-to-nearest-even and supports NaN
368377
# Setting `truncate` to True enables a simpler conversion. In this mode the
369378
# conversion is performed by simply dropping the 2 least significant bytes of
370379
# the significand. In this mode an error of up to 1 bit may be introduced and
371380
# preservation of NaN values is not be guaranteed.
372-
def float32_to_bfloat16(fval: float, truncate: bool = False) -> int:
381+
def _float32_to_bfloat16(fval: float, truncate: bool = False) -> int:
373382
ival = int.from_bytes(struct.pack("<f", fval), "little")
374383
if truncate:
375384
return ival >> 16
@@ -382,7 +391,15 @@ def float32_to_bfloat16(fval: float, truncate: bool = False) -> int:
382391
return (ival + rounded) >> 16
383392

384393

385-
def float32_to_float8e4m3( # noqa: PLR0911
394+
@typing_extensions.deprecated(
395+
"Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion",
396+
category=FutureWarning,
397+
)
398+
def float32_to_float8e4m3(*args, **kwargs) -> int:
399+
return _float32_to_float8e4m3(*args, **kwargs)
400+
401+
402+
def _float32_to_float8e4m3( # noqa: PLR0911
386403
fval: float,
387404
scale: float = 1.0,
388405
fn: bool = True,
@@ -516,7 +533,15 @@ def float32_to_float8e4m3( # noqa: PLR0911
516533
return int(ret)
517534

518535

519-
def float32_to_float8e5m2( # noqa: PLR0911
536+
@typing_extensions.deprecated(
537+
"Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion",
538+
category=FutureWarning,
539+
)
540+
def float32_to_float8e5m2(*args: Any, **kwargs: Any) -> int:
541+
return _float32_to_float8e5m2(*args, **kwargs)
542+
543+
544+
def _float32_to_float8e5m2( # noqa: PLR0911
520545
fval: float,
521546
scale: float = 1.0,
522547
fn: bool = False,
@@ -642,7 +667,15 @@ def float32_to_float8e5m2( # noqa: PLR0911
642667
raise NotImplementedError("fn and uz must be both False or True.")
643668

644669

670+
@typing_extensions.deprecated(
671+
"Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion",
672+
category=FutureWarning,
673+
)
645674
def pack_float32_to_4bit(array: np.ndarray | Sequence, signed: bool) -> np.ndarray:
675+
return _pack_float32_to_4bit(array, signed)
676+
677+
678+
def _pack_float32_to_4bit(array: np.ndarray | Sequence, signed: bool) -> np.ndarray:
646679
"""Convert an array of float32 value to a 4bit data-type and pack every two concecutive elements in a byte.
647680
See :ref:`onnx-detail-int4` for technical details.
648681
@@ -662,15 +695,23 @@ def pack_float32_to_4bit(array: np.ndarray | Sequence, signed: bool) -> np.ndarr
662695
array_flat = np.append(array_flat, np.array([0]))
663696

664697
def single_func(x, y) -> np.ndarray:
665-
return subbyte.float32x2_to_4bitx2(x, y, signed)
698+
return subbyte._float32x2_to_4bitx2(x, y, signed)
666699

667700
func = np.frompyfunc(single_func, 2, 1)
668701

669702
arr: np.ndarray = func(array_flat[0::2], array_flat[1::2])
670703
return arr.astype(np.uint8)
671704

672705

706+
@typing_extensions.deprecated(
707+
"Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion",
708+
category=FutureWarning,
709+
)
673710
def pack_float32_to_float4e2m1(array: np.ndarray | Sequence) -> np.ndarray:
711+
return _pack_float32_to_float4e2m1(array)
712+
713+
714+
def _pack_float32_to_float4e2m1(array: np.ndarray | Sequence) -> np.ndarray:
674715
"""Convert an array of float32 value to float4e2m1 and pack every two concecutive elements in a byte.
675716
See :ref:`onnx-detail-float4` for technical details.
676717
@@ -688,7 +729,7 @@ def pack_float32_to_float4e2m1(array: np.ndarray | Sequence) -> np.ndarray:
688729
if is_odd_volume:
689730
array_flat = np.append(array_flat, np.array([0]))
690731

691-
arr = subbyte.float32x2_to_float4e2m1x2(array_flat[0::2], array_flat[1::2])
732+
arr = subbyte._float32x2_to_float4e2m1x2(array_flat[0::2], array_flat[1::2])
692733
return arr.astype(np.uint8)
693734

694735

@@ -759,7 +800,7 @@ def make_tensor(
759800
tensor.raw_data = vals
760801
else:
761802
if data_type in (TensorProto.COMPLEX64, TensorProto.COMPLEX128):
762-
vals = split_complex_to_pairs(vals)
803+
vals = _split_complex_to_pairs(vals)
763804
elif data_type == TensorProto.FLOAT16:
764805
vals = (
765806
np.array(vals).astype(np_dtype).view(dtype=np.uint16).flatten().tolist()
@@ -772,13 +813,13 @@ def make_tensor(
772813
TensorProto.FLOAT8E5M2FNUZ,
773814
):
774815
fcast = {
775-
TensorProto.BFLOAT16: float32_to_bfloat16,
776-
TensorProto.FLOAT8E4M3FN: float32_to_float8e4m3,
777-
TensorProto.FLOAT8E4M3FNUZ: lambda *args: float32_to_float8e4m3( # type: ignore[misc]
816+
TensorProto.BFLOAT16: _float32_to_bfloat16,
817+
TensorProto.FLOAT8E4M3FN: _float32_to_float8e4m3,
818+
TensorProto.FLOAT8E4M3FNUZ: lambda *args: _float32_to_float8e4m3( # type: ignore[misc]
778819
*args, uz=True
779820
),
780-
TensorProto.FLOAT8E5M2: float32_to_float8e5m2,
781-
TensorProto.FLOAT8E5M2FNUZ: lambda *args: float32_to_float8e5m2( # type: ignore[misc]
821+
TensorProto.FLOAT8E5M2: _float32_to_float8e5m2,
822+
TensorProto.FLOAT8E5M2FNUZ: lambda *args: _float32_to_float8e5m2( # type: ignore[misc]
782823
*args, fn=True, uz=True
783824
),
784825
}[
@@ -801,9 +842,9 @@ def make_tensor(
801842
# to uint8 regardless of the value of 'signed'. Using int8 would cause
802843
# the size of int4 tensors to increase ~5x if the tensor contains negative values (due to
803844
# the way negative values are serialized by protobuf).
804-
vals = pack_float32_to_4bit(vals, signed=signed).flatten().tolist()
845+
vals = _pack_float32_to_4bit(vals, signed=signed).flatten().tolist()
805846
elif data_type == TensorProto.FLOAT4E2M1:
806-
vals = pack_float32_to_float4e2m1(vals).flatten().tolist()
847+
vals = _pack_float32_to_float4e2m1(vals).flatten().tolist()
807848
elif data_type == TensorProto.BOOL:
808849
vals = np.array(vals).astype(int)
809850
elif data_type == TensorProto.STRING:

0 commit comments

Comments
 (0)