diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 3d51ed048f..cc6ffa8f0e 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -45,6 +45,8 @@ quantize_, ) from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, dequantize_affine, ) from torchao.quantization.smoothquant import ( @@ -99,6 +101,10 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] +ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC] + +WEIGHT_ZERO_POINT_DOMAINS = [ZeroPointDomain.NONE, ZeroPointDomain.INT] + COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() @@ -118,9 +124,20 @@ def _int8wo_groupwise_api(mod): quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False) -def _int8da_int8w_api(mod): +def _int8da_int8w_api( + mod, + act_mapping_type=MappingType.SYMMETRIC, + weight_zero_point_domain=ZeroPointDomain.INT, +): if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) + quantize_( + mod, + int8_dynamic_activation_int8_weight( + act_mapping_type=act_mapping_type, + weight_zp_domain=weight_zero_point_domain, + ), + set_inductor_config=False, + ) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) else: @@ -959,10 +976,11 @@ def _test_lin_weight_subclass_api_impl( mod[0].weight.tensor_impl.get_plain() test = mod(x) + self.assertGreater( SQNR(ref_f, test), min_sqnr, - f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}", + f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}", ) mod_qc = torch.compile(mod, mode="max-autotune") @@ -970,14 +988,37 @@ def _test_lin_weight_subclass_api_impl( self.assertGreater( SQNR(ref_f, test_comp), min_sqnr, - f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}", + f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}", ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_int8_dynamic_quant_subclass_api(self, device, dtype): - self._test_lin_weight_subclass_api_impl( - _int8da_int8w_api, device, 35, test_dtype=dtype + @parameterized.expand( + list( + itertools.product( + COMMON_DEVICES, + COMMON_DTYPES, + ACT_MAPPING_TYPES, + WEIGHT_ZERO_POINT_DOMAINS, + ) + ) + ) + def test_int8_dynamic_quant_subclass_api( + self, device, dtype, act_mapping, weight_zero_point_domain + ): + from functools import partial + + if ( + not TORCH_VERSION_AT_LEAST_2_5 + and dtype in (torch.float16, torch.bfloat16) + and act_mapping is MappingType.ASYMMETRIC + and device == "cpu" + ): + self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5") + api = partial( + _int8da_int8w_api, + act_mapping_type=act_mapping, + weight_zero_point_domain=weight_zero_point_domain, ) + self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 0526ee01b2..8ec15eb201 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -21,6 +21,7 @@ ) from torchao.quantization.quant_primitives import ( MappingType, + ZeroPointDomain, ) @@ -74,7 +75,7 @@ def test_block_size_calc_success(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -93,7 +94,7 @@ def test_block_size_calc_success(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) for example_input in example_inputs: obs(example_input) @@ -108,7 +109,7 @@ def test_block_size_row_errors(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -127,7 +128,7 @@ def test_block_size_row_errors(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) if observe_weight: weight_observer = AffineQuantizedMinMaxObserver( @@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) else: weight_observer = None @@ -199,7 +200,6 @@ def test_linear_observer_tensor(self, observe_weight: bool): input_scale.item(), max_val / max_fp8, ) - self.assertIsNotNone(input_zero_point) if observe_weight: weight_observer = linear.weight.weight_observer @@ -210,7 +210,6 @@ def test_linear_observer_tensor(self, observe_weight: bool): atol=5e-5, rtol=0.0, ) - self.assertIsNotNone(weight_zero_point) else: self.assertIsNone(linear.weight.weight_observer) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 102e76cb1a..00fe300864 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -838,6 +838,32 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + # ZeroPointDomain.NONE should work + def test_none_zero_point_domain(self): + input = torch.randn(10, 256) + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (1, 128) + quant_min = None + quant_max = None + eps = 1e-6 + scale_dtype = torch.float32 + zero_point_dtype = torch.int64 + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.NONE, + ) + self.assertTrue(zero_point is None) + if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e7aca34c5f..c476ece97c 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -262,7 +262,7 @@ def from_hp_to_intx( ) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None # TODO should probably consolidate ZeroPointDomain.NONE and None - if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: + if zero_point_domain == ZeroPointDomain.NONE: zero_point = None data = quantize_affine( input_float, @@ -360,7 +360,7 @@ def from_hp_to_floatx( scale_dtype=scale_dtype, zero_point_dtype=None, preserve_zero=True, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, _layout=_layout, use_hqq=False, ) @@ -387,7 +387,7 @@ def from_hp_to_floatx_static( target_dtype=target_dtype, quant_min=math.ceil(torch.finfo(target_dtype).min), quant_max=math.ceil(torch.finfo(target_dtype).max), - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, _layout=_layout, ) else: diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index 502e3c13e9..b61dee8ba4 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -38,7 +38,7 @@ def __new__( cls, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): kwargs = {} @@ -55,7 +55,7 @@ def __init__( self, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): self.int_data = int_data @@ -64,7 +64,10 @@ def __init__( self._layout = _layout def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [self._layout] + if self.zero_point is not None: + return ["int_data", "scale", "zero_point"], [self._layout] + else: + return ["int_data", "scale"], [self._layout] @classmethod def __tensor_unflatten__( @@ -73,7 +76,7 @@ def __tensor_unflatten__( int_data, scale, zero_point = ( tensor_data_dict["int_data"], tensor_data_dict["scale"], - tensor_data_dict["zero_point"], + tensor_data_dict.get("zero_point", None), ) (_layout,) = tensor_attributes return cls(int_data, scale, zero_point, _layout) @@ -83,7 +86,9 @@ def to(self, *args, **kwargs): return self.__class__( self.int_data.to(kwargs["device"]), self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]) + if self.zero_point is not None + else None, self._layout, ) @@ -91,7 +96,7 @@ def _apply_fn_to_data(self, fn): return self.__class__( fn(self.int_data), fn(self.scale), - fn(self.zero_point), + fn(self.zero_point) if self.zero_point is not None else None, self._layout, ) @@ -134,7 +139,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return PlainAQTTensorImpl( aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), - self.zero_point.view(-1), + self.zero_point.view(-1) if self.zero_point is not None else None, self._layout, ) else: @@ -148,7 +153,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): __torch_function__ = torch._C._disabled_torch_function_impl - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return self.int_data, self.scale, self.zero_point def get_layout(self) -> Layout: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..184b96334c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -387,7 +387,7 @@ def insert_observers_( eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) # Create a linear module @@ -688,7 +688,7 @@ def int4_weight_only( group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using @@ -731,7 +731,7 @@ def apply_int4_weight_only_quant(weight): assert ( type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" - if zero_point_domain is None: + if zero_point_domain == ZeroPointDomain.NONE: # the first value is the default one zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] else: @@ -857,6 +857,7 @@ def int8_dynamic_activation_int8_weight( layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC, weight_only_decode=False, + weight_zp_domain=ZeroPointDomain.NONE, ): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight @@ -901,6 +902,7 @@ def get_weight_block_size(x): eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout, + zero_point_domain=weight_zp_domain, ) weight = to_linear_activation_quantized(weight, input_quant_func) return weight diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index fddd21c43e..2a12bacc91 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -334,6 +334,7 @@ def _quantize_affine( zero_point, quant_min, quant_max, + output_dtype, zero_point_domain, ).to(output_dtype) @@ -345,6 +346,7 @@ def _quantize_affine_no_dtype_cast( zero_point: Optional[torch.Tensor], quant_min: Union[int, float], quant_max: Union[int, float], + quant_dtype: Optional[torch.dtype], zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, ) -> torch.Tensor: """ @@ -389,13 +391,12 @@ def _quantize_affine_no_dtype_cast( assert ( zero_point is None ), "zero_point should be None when zero_point_domain is NONE" - quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) - elif zero_point_domain is None: - # This case handles quantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" - quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + if _is_float8_type(quant_dtype): + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + else: + quant = torch.clamp( + torch.round(input * (1.0 / scale)), quant_min, quant_max + ) else: assert zero_point_domain == ZeroPointDomain.FLOAT.name mid_point = (quant_max + quant_min + 1) / 2 @@ -538,16 +539,6 @@ def _dequantize_affine_no_dtype_check( ), "zero_point should be None when zero_point_domain is NONE" dequant = input.to(output_dtype) dequant = dequant * scale - elif zero_point_domain is None: - # This case handles dequantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" - assert _is_float8_type( - input.dtype - ), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" - dequant = input.to(output_dtype) - dequant = dequant * scale else: assert ( zero_point_domain == ZeroPointDomain.FLOAT.name @@ -674,6 +665,7 @@ def _do_fake_quantize_affine( zero_point, quant_min, quant_max, + quant_dtype, zero_point_domain.name, ) dq = _dequantize_affine_no_dtype_check( @@ -901,8 +893,11 @@ def _choose_qparams_affine( raise ValueError( "zero_point_domain should be ZeroPointDomain.INT or ZeroPointDomain.NONE for symmetric quantization" ) + if zero_point_domain == ZeroPointDomain.NONE.name: + zero_point = None + else: + zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) scale = torch.clamp(scale, min=eps) - zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)