Skip to content

Commit

Permalink
Fix ZeroPointDomain.NONE support & make it default for da8w8 weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Rogozhkin committed Jan 15, 2025
1 parent 7b3caa6 commit c44df1f
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 48 deletions.
57 changes: 49 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
dequantize_affine,
)
from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -99,6 +101,10 @@

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]

WEIGHT_ZERO_POINT_DOMAINS = [ZeroPointDomain.NONE, ZeroPointDomain.INT]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()


Expand All @@ -118,9 +124,20 @@ def _int8wo_groupwise_api(mod):
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)


def _int8da_int8w_api(mod):
def _int8da_int8w_api(
mod,
act_mapping_type=MappingType.SYMMETRIC,
weight_zero_point_domain=ZeroPointDomain.INT,
):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(
mod,
int8_dynamic_activation_int8_weight(
act_mapping_type=act_mapping_type,
weight_zp_domain=weight_zero_point_domain,
),
set_inductor_config=False,
)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
Expand Down Expand Up @@ -959,25 +976,49 @@ def _test_lin_weight_subclass_api_impl(
mod[0].weight.tensor_impl.get_plain()

test = mod(x)

self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
)

mod_qc = torch.compile(mod, mode="max-autotune")
test_comp = mod_qc(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
_int8da_int8w_api, device, 35, test_dtype=dtype
@parameterized.expand(
list(
itertools.product(
COMMON_DEVICES,
COMMON_DTYPES,
ACT_MAPPING_TYPES,
WEIGHT_ZERO_POINT_DOMAINS,
)
)
)
def test_int8_dynamic_quant_subclass_api(
self, device, dtype, act_mapping, weight_zero_point_domain
):
from functools import partial

if (
not TORCH_VERSION_AT_LEAST_2_5
and dtype in (torch.float16, torch.bfloat16)
and act_mapping is MappingType.ASYMMETRIC
and device == "cpu"
):
self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5")
api = partial(
_int8da_int8w_api,
act_mapping_type=act_mapping,
weight_zero_point_domain=weight_zero_point_domain,
)
self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
Expand Down
15 changes: 7 additions & 8 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)


Expand Down Expand Up @@ -74,7 +75,7 @@ def test_block_size_calc_success(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand All @@ -93,7 +94,7 @@ def test_block_size_calc_success(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
for example_input in example_inputs:
obs(example_input)
Expand All @@ -108,7 +109,7 @@ def test_block_size_row_errors(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand All @@ -127,7 +128,7 @@ def test_block_size_row_errors(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand Down Expand Up @@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
if observe_weight:
weight_observer = AffineQuantizedMinMaxObserver(
Expand All @@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
else:
weight_observer = None
Expand Down Expand Up @@ -199,7 +200,6 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNotNone(input_zero_point)

if observe_weight:
weight_observer = linear.weight.weight_observer
Expand All @@ -210,7 +210,6 @@ def test_linear_observer_tensor(self, observe_weight: bool):
atol=5e-5,
rtol=0.0,
)
self.assertIsNotNone(weight_zero_point)
else:
self.assertIsNone(linear.weight.weight_observer)

Expand Down
26 changes: 26 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,32 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

# ZeroPointDomain.NONE should work
def test_none_zero_point_domain(self):
input = torch.randn(10, 256)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = None
quant_max = None
eps = 1e-6
scale_dtype = torch.float32
zero_point_dtype = torch.int64
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
self.assertTrue(zero_point is None)


if __name__ == "__main__":
unittest.main()
6 changes: 3 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def from_hp_to_intx(
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
if zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
Expand Down Expand Up @@ -360,7 +360,7 @@ def from_hp_to_floatx(
scale_dtype=scale_dtype,
zero_point_dtype=None,
preserve_zero=True,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
_layout=_layout,
use_hqq=False,
)
Expand All @@ -387,7 +387,7 @@ def from_hp_to_floatx_static(
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
_layout=_layout,
)
else:
Expand Down
21 changes: 13 additions & 8 deletions torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
kwargs = {}
Expand All @@ -55,7 +55,7 @@ def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
self.int_data = int_data
Expand All @@ -64,7 +64,10 @@ def __init__(
self._layout = _layout

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point"], [self._layout]
if self.zero_point is not None:
return ["int_data", "scale", "zero_point"], [self._layout]
else:
return ["int_data", "scale"], [self._layout]

@classmethod
def __tensor_unflatten__(
Expand All @@ -73,7 +76,7 @@ def __tensor_unflatten__(
int_data, scale, zero_point = (
tensor_data_dict["int_data"],
tensor_data_dict["scale"],
tensor_data_dict["zero_point"],
tensor_data_dict.get("zero_point", None),
)
(_layout,) = tensor_attributes
return cls(int_data, scale, zero_point, _layout)
Expand All @@ -83,15 +86,17 @@ def to(self, *args, **kwargs):
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
self.zero_point.to(kwargs["device"])
if self.zero_point is not None
else None,
self._layout,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
fn(self.zero_point) if self.zero_point is not None else None,
self._layout,
)

Expand Down Expand Up @@ -134,7 +139,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return PlainAQTTensorImpl(
aten.slice.Tensor(self.int_data, dim, start, end, step),
self.scale.view(-1),
self.zero_point.view(-1),
self.zero_point.view(-1) if self.zero_point is not None else None,
self._layout,
)
else:
Expand All @@ -148,7 +153,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.int_data, self.scale, self.zero_point

def get_layout(self) -> Layout:
Expand Down
8 changes: 5 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def insert_observers_(
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
# Create a linear module
Expand Down Expand Up @@ -688,7 +688,7 @@ def int4_weight_only(
group_size=128,
layout=TensorCoreTiledLayout(inner_k_tiles=8),
use_hqq=False,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
):
"""
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
Expand Down Expand Up @@ -731,7 +731,7 @@ def apply_int4_weight_only_quant(weight):
assert (
type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys()
), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
if zero_point_domain is None:
if zero_point_domain == ZeroPointDomain.NONE:
# the first value is the default one
zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0]
else:
Expand Down Expand Up @@ -857,6 +857,7 @@ def int8_dynamic_activation_int8_weight(
layout=PlainLayout(),
act_mapping_type=MappingType.SYMMETRIC,
weight_only_decode=False,
weight_zp_domain=ZeroPointDomain.NONE,
):
"""
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
Expand Down Expand Up @@ -901,6 +902,7 @@ def get_weight_block_size(x):
eps=eps,
zero_point_dtype=zero_point_dtype,
_layout=layout,
zero_point_domain=weight_zp_domain,
)
weight = to_linear_activation_quantized(weight, input_quant_func)
return weight
Expand Down
Loading

0 comments on commit c44df1f

Please sign in to comment.