Skip to content

Commit c44df1f

Browse files
author
Dmitry Rogozhkin
committed
Fix ZeroPointDomain.NONE support & make it default for da8w8 weights
1 parent 7b3caa6 commit c44df1f

File tree

7 files changed

+116
-48
lines changed

7 files changed

+116
-48
lines changed

test/integration/test_integration.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
quantize_,
4646
)
4747
from torchao.quantization.quant_primitives import (
48+
MappingType,
49+
ZeroPointDomain,
4850
dequantize_affine,
4951
)
5052
from torchao.quantization.smoothquant import (
@@ -99,6 +101,10 @@
99101

100102
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
101103

104+
ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]
105+
106+
WEIGHT_ZERO_POINT_DOMAINS = [ZeroPointDomain.NONE, ZeroPointDomain.INT]
107+
102108
COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
103109

104110

@@ -118,9 +124,20 @@ def _int8wo_groupwise_api(mod):
118124
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)
119125

120126

121-
def _int8da_int8w_api(mod):
127+
def _int8da_int8w_api(
128+
mod,
129+
act_mapping_type=MappingType.SYMMETRIC,
130+
weight_zero_point_domain=ZeroPointDomain.INT,
131+
):
122132
if TORCH_VERSION_AT_LEAST_2_4:
123-
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
133+
quantize_(
134+
mod,
135+
int8_dynamic_activation_int8_weight(
136+
act_mapping_type=act_mapping_type,
137+
weight_zp_domain=weight_zero_point_domain,
138+
),
139+
set_inductor_config=False,
140+
)
124141
if not TORCH_VERSION_AT_LEAST_2_5:
125142
unwrap_tensor_subclass(mod)
126143
else:
@@ -959,25 +976,49 @@ def _test_lin_weight_subclass_api_impl(
959976
mod[0].weight.tensor_impl.get_plain()
960977

961978
test = mod(x)
979+
962980
self.assertGreater(
963981
SQNR(ref_f, test),
964982
min_sqnr,
965-
f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
983+
f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
966984
)
967985

968986
mod_qc = torch.compile(mod, mode="max-autotune")
969987
test_comp = mod_qc(x)
970988
self.assertGreater(
971989
SQNR(ref_f, test_comp),
972990
min_sqnr,
973-
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
991+
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
974992
)
975993

976-
@parameterized.expand(COMMON_DEVICE_DTYPE)
977-
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
978-
self._test_lin_weight_subclass_api_impl(
979-
_int8da_int8w_api, device, 35, test_dtype=dtype
994+
@parameterized.expand(
995+
list(
996+
itertools.product(
997+
COMMON_DEVICES,
998+
COMMON_DTYPES,
999+
ACT_MAPPING_TYPES,
1000+
WEIGHT_ZERO_POINT_DOMAINS,
1001+
)
1002+
)
1003+
)
1004+
def test_int8_dynamic_quant_subclass_api(
1005+
self, device, dtype, act_mapping, weight_zero_point_domain
1006+
):
1007+
from functools import partial
1008+
1009+
if (
1010+
not TORCH_VERSION_AT_LEAST_2_5
1011+
and dtype in (torch.float16, torch.bfloat16)
1012+
and act_mapping is MappingType.ASYMMETRIC
1013+
and device == "cpu"
1014+
):
1015+
self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5")
1016+
api = partial(
1017+
_int8da_int8w_api,
1018+
act_mapping_type=act_mapping,
1019+
weight_zero_point_domain=weight_zero_point_domain,
9801020
)
1021+
self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype)
9811022

9821023
@parameterized.expand(COMMON_DEVICE_DTYPE)
9831024
@unittest.skipIf(is_fbcode(), "broken in fbcode")

test/quantization/test_observer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from torchao.quantization.quant_primitives import (
2323
MappingType,
24+
ZeroPointDomain,
2425
)
2526

2627

@@ -74,7 +75,7 @@ def test_block_size_calc_success(self):
7475
eps=torch.finfo(torch.float32).eps,
7576
scale_dtype=torch.float,
7677
zero_point_dtype=torch.int,
77-
zero_point_domain=None,
78+
zero_point_domain=ZeroPointDomain.NONE,
7879
)
7980
example_inputs = [
8081
torch.randn(10, 2048),
@@ -93,7 +94,7 @@ def test_block_size_calc_success(self):
9394
eps=torch.finfo(torch.float32).eps,
9495
scale_dtype=torch.float,
9596
zero_point_dtype=torch.int,
96-
zero_point_domain=None,
97+
zero_point_domain=ZeroPointDomain.NONE,
9798
)
9899
for example_input in example_inputs:
99100
obs(example_input)
@@ -108,7 +109,7 @@ def test_block_size_row_errors(self):
108109
eps=torch.finfo(torch.float32).eps,
109110
scale_dtype=torch.float,
110111
zero_point_dtype=torch.int,
111-
zero_point_domain=None,
112+
zero_point_domain=ZeroPointDomain.NONE,
112113
)
113114
example_inputs = [
114115
torch.randn(10, 2048),
@@ -127,7 +128,7 @@ def test_block_size_row_errors(self):
127128
eps=torch.finfo(torch.float32).eps,
128129
scale_dtype=torch.float,
129130
zero_point_dtype=torch.int,
130-
zero_point_domain=None,
131+
zero_point_domain=ZeroPointDomain.NONE,
131132
)
132133
example_inputs = [
133134
torch.randn(10, 2048),
@@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
155156
eps=torch.finfo(torch.float32).eps,
156157
scale_dtype=torch.float,
157158
zero_point_dtype=torch.int,
158-
zero_point_domain=None,
159+
zero_point_domain=ZeroPointDomain.NONE,
159160
)
160161
if observe_weight:
161162
weight_observer = AffineQuantizedMinMaxObserver(
@@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
165166
eps=torch.finfo(torch.float32).eps,
166167
scale_dtype=torch.float,
167168
zero_point_dtype=torch.int,
168-
zero_point_domain=None,
169+
zero_point_domain=ZeroPointDomain.NONE,
169170
)
170171
else:
171172
weight_observer = None
@@ -199,7 +200,6 @@ def test_linear_observer_tensor(self, observe_weight: bool):
199200
input_scale.item(),
200201
max_val / max_fp8,
201202
)
202-
self.assertIsNotNone(input_zero_point)
203203

204204
if observe_weight:
205205
weight_observer = linear.weight.weight_observer
@@ -210,7 +210,6 @@ def test_linear_observer_tensor(self, observe_weight: bool):
210210
atol=5e-5,
211211
rtol=0.0,
212212
)
213-
self.assertIsNotNone(weight_zero_point)
214213
else:
215214
self.assertIsNone(linear.weight.weight_observer)
216215

test/quantization/test_quant_primitives.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,32 @@ def test_fake_quantize_affine_cachemask(self):
838838
torch.testing.assert_close(dequantized, fake_quantized)
839839
torch.testing.assert_close(expected_mask, mask)
840840

841+
# ZeroPointDomain.NONE should work
842+
def test_none_zero_point_domain(self):
843+
input = torch.randn(10, 256)
844+
mapping_type = MappingType.SYMMETRIC
845+
dtype = torch.int8
846+
block_size = (1, 128)
847+
quant_min = None
848+
quant_max = None
849+
eps = 1e-6
850+
scale_dtype = torch.float32
851+
zero_point_dtype = torch.int64
852+
_, zero_point = choose_qparams_affine(
853+
input,
854+
mapping_type,
855+
block_size,
856+
dtype,
857+
quant_min,
858+
quant_max,
859+
eps,
860+
scale_dtype=scale_dtype,
861+
zero_point_dtype=zero_point_dtype,
862+
preserve_zero=True,
863+
zero_point_domain=ZeroPointDomain.NONE,
864+
)
865+
self.assertTrue(zero_point is None)
866+
841867

842868
if __name__ == "__main__":
843869
unittest.main()

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def from_hp_to_intx(
262262
)
263263
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
264264
# TODO should probably consolidate ZeroPointDomain.NONE and None
265-
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
265+
if zero_point_domain == ZeroPointDomain.NONE:
266266
zero_point = None
267267
data = quantize_affine(
268268
input_float,
@@ -360,7 +360,7 @@ def from_hp_to_floatx(
360360
scale_dtype=scale_dtype,
361361
zero_point_dtype=None,
362362
preserve_zero=True,
363-
zero_point_domain=None,
363+
zero_point_domain=ZeroPointDomain.NONE,
364364
_layout=_layout,
365365
use_hqq=False,
366366
)
@@ -387,7 +387,7 @@ def from_hp_to_floatx_static(
387387
target_dtype=target_dtype,
388388
quant_min=math.ceil(torch.finfo(target_dtype).min),
389389
quant_max=math.ceil(torch.finfo(target_dtype).max),
390-
zero_point_domain=None,
390+
zero_point_domain=ZeroPointDomain.NONE,
391391
_layout=_layout,
392392
)
393393
else:

torchao/dtypes/uintx/plain_layout.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __new__(
3838
cls,
3939
int_data: torch.Tensor,
4040
scale: torch.Tensor,
41-
zero_point: torch.Tensor,
41+
zero_point: Optional[torch.Tensor],
4242
_layout: Layout,
4343
):
4444
kwargs = {}
@@ -55,7 +55,7 @@ def __init__(
5555
self,
5656
int_data: torch.Tensor,
5757
scale: torch.Tensor,
58-
zero_point: torch.Tensor,
58+
zero_point: Optional[torch.Tensor],
5959
_layout: Layout,
6060
):
6161
self.int_data = int_data
@@ -64,7 +64,10 @@ def __init__(
6464
self._layout = _layout
6565

6666
def __tensor_flatten__(self):
67-
return ["int_data", "scale", "zero_point"], [self._layout]
67+
if self.zero_point is not None:
68+
return ["int_data", "scale", "zero_point"], [self._layout]
69+
else:
70+
return ["int_data", "scale"], [self._layout]
6871

6972
@classmethod
7073
def __tensor_unflatten__(
@@ -73,7 +76,7 @@ def __tensor_unflatten__(
7376
int_data, scale, zero_point = (
7477
tensor_data_dict["int_data"],
7578
tensor_data_dict["scale"],
76-
tensor_data_dict["zero_point"],
79+
tensor_data_dict.get("zero_point", None),
7780
)
7881
(_layout,) = tensor_attributes
7982
return cls(int_data, scale, zero_point, _layout)
@@ -83,15 +86,17 @@ def to(self, *args, **kwargs):
8386
return self.__class__(
8487
self.int_data.to(kwargs["device"]),
8588
self.scale.to(kwargs["device"]),
86-
self.zero_point.to(kwargs["device"]),
89+
self.zero_point.to(kwargs["device"])
90+
if self.zero_point is not None
91+
else None,
8792
self._layout,
8893
)
8994

9095
def _apply_fn_to_data(self, fn):
9196
return self.__class__(
9297
fn(self.int_data),
9398
fn(self.scale),
94-
fn(self.zero_point),
99+
fn(self.zero_point) if self.zero_point is not None else None,
95100
self._layout,
96101
)
97102

@@ -134,7 +139,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
134139
return PlainAQTTensorImpl(
135140
aten.slice.Tensor(self.int_data, dim, start, end, step),
136141
self.scale.view(-1),
137-
self.zero_point.view(-1),
142+
self.zero_point.view(-1) if self.zero_point is not None else None,
138143
self._layout,
139144
)
140145
else:
@@ -148,7 +153,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
148153

149154
__torch_function__ = torch._C._disabled_torch_function_impl
150155

151-
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
156+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
152157
return self.int_data, self.scale, self.zero_point
153158

154159
def get_layout(self) -> Layout:

torchao/quantization/quant_api.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def insert_observers_(
387387
eps=torch.finfo(torch.float32).eps,
388388
scale_dtype=torch.float,
389389
zero_point_dtype=torch.int,
390-
zero_point_domain=None,
390+
zero_point_domain=ZeroPointDomain.NONE,
391391
)
392392
393393
# Create a linear module
@@ -688,7 +688,7 @@ def int4_weight_only(
688688
group_size=128,
689689
layout=TensorCoreTiledLayout(inner_k_tiles=8),
690690
use_hqq=False,
691-
zero_point_domain=None,
691+
zero_point_domain=ZeroPointDomain.NONE,
692692
):
693693
"""
694694
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
@@ -731,7 +731,7 @@ def apply_int4_weight_only_quant(weight):
731731
assert (
732732
type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys()
733733
), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
734-
if zero_point_domain is None:
734+
if zero_point_domain == ZeroPointDomain.NONE:
735735
# the first value is the default one
736736
zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0]
737737
else:
@@ -857,6 +857,7 @@ def int8_dynamic_activation_int8_weight(
857857
layout=PlainLayout(),
858858
act_mapping_type=MappingType.SYMMETRIC,
859859
weight_only_decode=False,
860+
weight_zp_domain=ZeroPointDomain.NONE,
860861
):
861862
"""
862863
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
@@ -901,6 +902,7 @@ def get_weight_block_size(x):
901902
eps=eps,
902903
zero_point_dtype=zero_point_dtype,
903904
_layout=layout,
905+
zero_point_domain=weight_zp_domain,
904906
)
905907
weight = to_linear_activation_quantized(weight, input_quant_func)
906908
return weight

0 commit comments

Comments
 (0)