Skip to content

Commit 8f97916

Browse files
authored
Add 2 bit support to onnx (onnx#7446)
### Motivation and Context Add support for 2-bit data types in ONNX to enable emerging low-bit model formats. Recent research and frameworks are exploring 2-bit quantized models for improved efficiency, and ML dtypes already include 2-bit representations. This change ensures ONNX can represent and interoperate with these new models. - Add new data types INT2/UINT2 and related helper functions. - Update `Cast`, `CastLike`, `DequantizeLinear` and `QuantizeLinear` - Update non-compute operators `Constant`, `ConstantOfShape`, `Identity`, `Reshape`, `Shape`, `Size`, `If`, `Loop`, `Scan`, `Flatten`, `Pad`, `Squeeze`, `Unsqueeze`, `Transpose`. - Updater IR version to 13 and opset version to 25 ### Issue onnx#7159 --------- Signed-off-by: vraspar <vrajang@outlook.com>
1 parent 76981e3 commit 8f97916

File tree

402 files changed

+4316
-246
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

402 files changed

+4316
-246
lines changed

docs/Changelog.md

Lines changed: 1263 additions & 0 deletions
Large diffs are not rendered by default.

docs/IR.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,8 @@ It is common to represent a tensor as a nested list. This generally works fine,
422422
|Group|Types|Description|
423423
|---|---|---|
424424
Floating Point Types|float16, float32, float64, bfloat16, float8e4m3fn, float8e5m2, float8e4m3fnuz, float8e5m2fnuz, float4e2m1|Values adhering to the IEEE 754-2008 standard representation of floating-point data or defined in papers [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433), [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915), and the [Open Compute Project](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
425-
Signed Integer Types|int4, int8, int16, int32, int64|Signed integers are supported for 4-64 bit widths.
426-
Unsigned Integer Types|uint4, uint8, uint16, uint32, uint64|Unsigned integers are supported for 4-64 bit widths.
425+
Signed Integer Types|int2, int4, int8, int16, int32, int64|Signed integers are supported for 2-64 bit widths.
426+
Unsigned Integer Types|uint2, uint4, uint8, uint16, uint32, uint64|Unsigned integers are supported for 2-64 bit widths.
427427
Complex Types|complex64, complex128|A complex number with either 32- or 64-bit real and imaginary parts.
428428
Other|string|Strings represent textual data. All strings are encoded using UTF-8.
429429
Other|bool|Boolean values represent data with only two values, typically true and false.

docs/Operators.md

Lines changed: 264 additions & 81 deletions
Large diffs are not rendered by default.

docs/TestCoverage.md

Lines changed: 175 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4619,6 +4619,16 @@ test_cases = [
46194619
("FLOAT4E2M1", "FLOAT16"),
46204620
("FLOAT", "FLOAT4E2M1"),
46214621
("FLOAT16", "FLOAT4E2M1"),
4622+
("FLOAT", "UINT2"),
4623+
("FLOAT16", "UINT2"),
4624+
("FLOAT", "INT2"),
4625+
("FLOAT16", "INT2"),
4626+
("UINT2", "FLOAT"),
4627+
("UINT2", "FLOAT16"),
4628+
("UINT2", "UINT8"),
4629+
("INT2", "FLOAT"),
4630+
("INT2", "FLOAT16"),
4631+
("INT2", "INT8"),
46224632
]
46234633

46244634
for from_type, to_type in test_cases:
@@ -4675,6 +4685,9 @@ for from_type, to_type in test_cases:
46754685
elif from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4"):
46764686
np_fp32 = np.arange(-9, 16).astype(np.float32)
46774687
input_shape = (5, 5)
4688+
elif from_type in ("UINT2", "INT2") or to_type in ("UINT2", "INT2"):
4689+
np_fp32 = np.arange(-3, 4).astype(np.float32)
4690+
input_shape = (7, 1)
46784691
elif from_type == "FLOAT4E2M1" or to_type == "FLOAT4E2M1":
46794692
np_fp32 = np.array(
46804693
[
@@ -4735,6 +4748,12 @@ for from_type, to_type in test_cases:
47354748
input = make_tensor(
47364749
"input", from_dtype, input_shape, vals=packed.tobytes(), raw=True
47374750
)
4751+
elif from_type in TWO_BIT_TYPES:
4752+
np_from = np_fp32.astype(from_np_dtype)
4753+
packed = onnx.numpy_helper._pack_2bitx4(np_from)
4754+
input = make_tensor(
4755+
"input", from_dtype, input_shape, vals=packed.tobytes(), raw=True
4756+
)
47384757
else:
47394758
np_from = np_fp32.astype(from_np_dtype)
47404759
input = make_tensor(
@@ -4756,6 +4775,11 @@ for from_type, to_type in test_cases:
47564775
output = make_tensor(
47574776
"output", to_dtype, input_shape, vals=packed.tobytes(), raw=True
47584777
)
4778+
elif to_type in TWO_BIT_TYPES:
4779+
packed = onnx.numpy_helper._pack_2bitx4(np_from.astype(to_np_dtype))
4780+
output = make_tensor(
4781+
"output", to_dtype, input_shape, vals=packed.tobytes(), raw=True
4782+
)
47594783
else:
47604784
output = make_tensor(
47614785
"output",
@@ -4985,6 +5009,16 @@ test_cases = [
49855009
("FLOAT4E2M1", "FLOAT16"),
49865010
("FLOAT", "FLOAT4E2M1"),
49875011
("FLOAT16", "FLOAT4E2M1"),
5012+
("FLOAT", "UINT2"),
5013+
("FLOAT16", "UINT2"),
5014+
("FLOAT", "INT2"),
5015+
("FLOAT16", "INT2"),
5016+
("UINT2", "FLOAT"),
5017+
("UINT2", "FLOAT16"),
5018+
("UINT2", "UINT8"),
5019+
("INT2", "FLOAT"),
5020+
("INT2", "FLOAT16"),
5021+
("INT2", "INT8"),
49885022
]
49895023

49905024
f8_types = {"FLOAT8E4M3FN", "FLOAT8E4M3FNUZ", "FLOAT8E5M2", "FLOAT8E5M2FNUZ"}
@@ -5043,6 +5077,9 @@ for from_type, to_type in test_cases:
50435077
elif from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4"):
50445078
np_fp32 = np.arange(-9, 16).astype(np.float32)
50455079
input_shape = (5, 5)
5080+
elif from_type in ("UINT2", "INT2") or to_type in ("UINT2", "INT2"):
5081+
np_fp32 = np.arange(-3, 4).astype(np.float32)
5082+
input_shape = (7, 1)
50465083
elif from_type == "FLOAT4E2M1" or to_type == "FLOAT4E2M1":
50475084
np_fp32 = np.array(
50485085
[
@@ -5103,6 +5140,14 @@ for from_type, to_type in test_cases:
51035140
input = make_tensor(
51045141
"input", from_dtype, input_shape, vals=packed.tobytes(), raw=True
51055142
)
5143+
elif from_type in TWO_BIT_TYPES:
5144+
np_from = np_fp32.astype(from_np_dtype)
5145+
packed = onnx.numpy_helper._pack_2bitx4(np_from)
5146+
# No byteswap needed on big-endian machines as _pack_2bitx4()
5147+
# returns a numpy array with uint8 datatype.
5148+
input = make_tensor(
5149+
"input", from_dtype, input_shape, vals=packed.tobytes(), raw=True
5150+
)
51065151
else:
51075152
np_from = np_fp32.astype(from_np_dtype)
51085153
input = make_tensor(
@@ -5124,6 +5169,13 @@ for from_type, to_type in test_cases:
51245169
output = make_tensor(
51255170
"output", to_dtype, input_shape, vals=packed.tobytes(), raw=True
51265171
)
5172+
elif to_type in TWO_BIT_TYPES:
5173+
packed = onnx.numpy_helper._pack_2bitx4(np_from.astype(to_np_dtype))
5174+
# No byteswap needed on big-endian machines as _pack_2bitx4()
5175+
# returns a numpy array with uint8 datatype.
5176+
output = make_tensor(
5177+
"output", to_dtype, input_shape, vals=packed.tobytes(), raw=True
5178+
)
51275179
else:
51285180
output = make_tensor(
51295181
"output",
@@ -7663,7 +7715,7 @@ expect(node, inputs=[x], outputs=[y], name="test_depthtospace_example")
76637715

76647716

76657717
### DequantizeLinear
7666-
There are 12 test cases, listed as following:
7718+
There are 14 test cases, listed as following:
76677719
<details>
76687720
<summary>axis</summary>
76697721

@@ -7950,6 +8002,32 @@ expect(
79508002
)
79518003
```
79528004

8005+
</details>
8006+
<details>
8007+
<summary>int2</summary>
8008+
8009+
```python
8010+
node = onnx.helper.make_node(
8011+
"DequantizeLinear",
8012+
inputs=["x", "x_scale", "x_zero_point"],
8013+
outputs=["y"],
8014+
axis=0,
8015+
)
8016+
8017+
# scalar zero point and scale
8018+
x = make_tensor("x", TensorProto.INT2, [4], [0, 1, -1, -2])
8019+
x_scale = np.float32(2)
8020+
x_zero_point = make_tensor("x_zero_point", TensorProto.INT2, (1,), [1])
8021+
y = np.array([-2, 0, -4, -6], dtype=np.float32)
8022+
8023+
expect(
8024+
node,
8025+
inputs=[x, x_scale, x_zero_point],
8026+
outputs=[y],
8027+
name="test_dequantizelinear_int2",
8028+
)
8029+
```
8030+
79538031
</details>
79548032
<details>
79558033
<summary>int4</summary>
@@ -8000,6 +8078,32 @@ expect(
80008078
)
80018079
```
80028080

8081+
</details>
8082+
<details>
8083+
<summary>uint2</summary>
8084+
8085+
```python
8086+
node = onnx.helper.make_node(
8087+
"DequantizeLinear",
8088+
inputs=["x", "x_scale", "x_zero_point"],
8089+
outputs=["y"],
8090+
axis=0,
8091+
)
8092+
8093+
# scalar zero point and scale
8094+
x = make_tensor("x", TensorProto.UINT2, [4], [0, 1, 2, 3])
8095+
x_scale = np.float32(2)
8096+
x_zero_point = make_tensor("x_zero_point", TensorProto.UINT2, (1,), [1])
8097+
y = np.array([-2, 0, 2, 4], dtype=np.float32)
8098+
8099+
expect(
8100+
node,
8101+
inputs=[x, x_scale, x_zero_point],
8102+
outputs=[y],
8103+
name="test_dequantizelinear_uint2",
8104+
)
8105+
```
8106+
80038107
</details>
80048108
<details>
80058109
<summary>uint4</summary>
@@ -16674,7 +16778,7 @@ for quant_type_name in ["uint8", "int8"]:
1667416778

1667516779

1667616780
### QuantizeLinear
16677-
There are 11 test cases, listed as following:
16781+
There are 13 test cases, listed as following:
1667816782
<details>
1667916783
<summary>axis</summary>
1668016784

@@ -16992,6 +17096,40 @@ expect(
1699217096
)
1699317097
```
1699417098

17099+
</details>
17100+
<details>
17101+
<summary>int2</summary>
17102+
17103+
```python
17104+
node = onnx.helper.make_node(
17105+
"QuantizeLinear",
17106+
inputs=["x", "y_scale", "y_zero_point"],
17107+
outputs=["y"],
17108+
axis=0,
17109+
)
17110+
x = np.array(
17111+
[
17112+
[0.0, 2.5, 4.8, 8.6],
17113+
[-4.0, -3.0, 1.0, 2.0],
17114+
[-0.0, -2.5, -4.8, -8.6],
17115+
],
17116+
dtype=np.float32,
17117+
)
17118+
y_scale = np.asarray([2.0, 3.0, 4.0], dtype=np.float32)
17119+
y_zero_point = make_tensor(
17120+
"y_zero_point", TensorProto.INT2, y_scale.shape, np.zeros_like(y_scale)
17121+
)
17122+
y = make_tensor(
17123+
"y", TensorProto.INT2, x.shape, [0, 1, 1, 1, -1, -1, 0, 1, 0, -1, -1, -2]
17124+
)
17125+
expect(
17126+
node,
17127+
inputs=[x, y_scale, y_zero_point],
17128+
outputs=[y],
17129+
name="test_quantizelinear_int2",
17130+
)
17131+
```
17132+
1699517133
</details>
1699617134
<details>
1699717135
<summary>int4</summary>
@@ -17106,6 +17244,41 @@ expect(
1710617244
)
1710717245
```
1710817246

17247+
</details>
17248+
<details>
17249+
<summary>uint2</summary>
17250+
17251+
```python
17252+
node = onnx.helper.make_node(
17253+
"QuantizeLinear",
17254+
inputs=["x", "y_scale", "y_zero_point"],
17255+
outputs=["y"],
17256+
axis=0,
17257+
)
17258+
17259+
x = np.array(
17260+
[
17261+
[0.0, 2.5, 4.8, 8.6],
17262+
[-2.0, -1.0, 1.0, 3.0],
17263+
[4.0, 5.0, 6.0, 7.0],
17264+
],
17265+
dtype=np.float32,
17266+
)
17267+
y_scale = np.asarray([2.0, 3.0, 4.0], dtype=np.float32)
17268+
y_zero_point = make_tensor(
17269+
"y_zero_point", TensorProto.UINT2, y_scale.shape, np.zeros_like(y_scale)
17270+
)
17271+
y = make_tensor(
17272+
"y", TensorProto.UINT2, x.shape, [0, 1, 2, 3, 0, 0, 0, 1, 1, 1, 2, 2]
17273+
)
17274+
expect(
17275+
node,
17276+
inputs=[x, y_scale, y_zero_point],
17277+
outputs=[y],
17278+
name="test_quantizelinear_uint2",
17279+
)
17280+
```
17281+
1710917282
</details>
1711017283
<details>
1711117284
<summary>uint4</summary>

docs/Versioning.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ ONNX version|IR version|Opset version ai.onnx|Opset version ai.onnx.ml|Opset ver
197197
1.18.0|11|23|5|1
198198
1.19.0|12|24|5|1
199199
1.19.1|12|24|5|1
200-
1.20.0|12|24|5|1
200+
1.20.0|13|25|5|1
201201

202202
A programmatically accessible version of the above table is available [here](../onnx/helper.py). Limited version number
203203
information is also maintained in [version.h](../onnx/common/version.h) and [schema.h](../onnx/defs/schema.h).

docs/docsgen/source/technical/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ deeper than the code documentation.
1717
float8
1818
int4
1919
float4
20+
int2
2021
```
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<!--
2+
Copyright (c) ONNX Project Contributors
3+
4+
SPDX-License-Identifier: Apache-2.0
5+
-->
6+
(onnx-detail-int2) =
7+
8+
# 2 bit integer types
9+
10+
## Papers
11+
12+
[T-MAC: CPU Renaissance via Table Lookup for Low-Bit LLM Deployment on Edge](https://arxiv.org/abs/2407.00088)
13+
14+
T-MAC, an innovative lookup table(LUT)-based method designed for efficient low-bit LLM (i.e., weight-quantized LLM) inference on CPUs. T-MAC directly supports mpGEMM without dequantization, while simultaneously eliminating multiplications and reducing additions required. Specifically, T-MAC transforms the traditional data-type-centric multiplication to bit-wise table lookup, and enables a unified and scalable mpGEMM solution.
15+
16+
## Cast
17+
18+
Cast from 2 bit to any higher precision type is exact.
19+
Cast to a 2 bit type is done by rounding to the nearest-integer (with ties to even)
20+
nearest-even integer and truncating.
21+
22+
23+
## Packing and Unpacking (2-bit)
24+
All 2-bit types are stored as 4×2-bit values in a single byte. The elements are packed from least significant bits (LSB) to most significant bits (MSB). That is, for consecutive elements x0, x1, x2, x3 in the array:
25+
26+
Packing:
27+
```
28+
pack(x0, x1, x2, x3):
29+
(x0 & 0x03) |
30+
((x1 & 0x03) << 2) |
31+
((x2 & 0x03) << 4) |
32+
((x3 & 0x03) << 6)
33+
```
34+
35+
Unpacking:
36+
```
37+
x0 = z & 0x03
38+
x1 = (z >> 2) & 0x03
39+
x2 = (z >> 4) & 0x03
40+
x3 = (z >> 6) & 0x03
41+
```
42+
In case the total number of elements is not divisible by 4, zero-padding will be applied in the remaining higher bits of the final byte.
43+
The storage size of a 2-bit tensor of size N is: ceil(N / 4) bytes

onnx/_mapping.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,10 @@ class TensorDtypeMap(NamedTuple):
110110
int(TensorProto.INT32),
111111
"TensorProto.FLOAT8E8M0",
112112
),
113+
int(TensorProto.UINT2): TensorDtypeMap(
114+
np.dtype(ml_dtypes.uint2), int(TensorProto.INT32), "TensorProto.UINT2"
115+
),
116+
int(TensorProto.INT2): TensorDtypeMap(
117+
np.dtype(ml_dtypes.int2), int(TensorProto.INT32), "TensorProto.INT2"
118+
),
113119
}

0 commit comments

Comments
 (0)