diff --git a/test/quantization/test_da8w4_cpu.py b/test/quantization/test_da8w4_cpu.py index 8b0edb718e..ed93442807 100644 --- a/test/quantization/test_da8w4_cpu.py +++ b/test/quantization/test_da8w4_cpu.py @@ -4,4 +4,159 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -# Tests removed: Int8DynamicActivationInt4WeightConfig has been removed. +import copy +import unittest + +import torch +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao import quantize_ +from torchao.prototype.int4_opaque_tensor import ( + Int8DynamicActInt4WeightOpaqueTensorConfig, +) +from torchao.quantization.quant_primitives import MappingType +from torchao.quantization.utils import compute_error +from torchao.utils import torch_version_at_least + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=64, n=32, k=64, bias=False): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float) + self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float) + + def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestDa8w4Cpu(TestCase): + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.7.0"), "Test only enabled for 2.7+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + @common_utils.parametrize("sym_quant_a", [True, False]) + def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): + if sym_quant_a and not torch_version_at_least("2.8.0"): + # symmetric int8 activation not supported until PT 2.8 + return + device = "cpu" + m = ToyLinearModel(bias=bias).eval().to(dtype).to(device) + m_ref = copy.deepcopy(m) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + + act_mapping_type = ( + MappingType.SYMMETRIC if sym_quant_a else MappingType.ASYMMETRIC + ) + with torch.no_grad(): + quantize_( + m, + Int8DynamicActInt4WeightOpaqueTensorConfig( + group_size=32, + act_mapping_type=act_mapping_type, + ), + ) + y, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0] + + # compare against float reference: result should have reasonable accuracy + torch._dynamo.reset() + y_ref = m_ref(*example_inputs) + sqnr = compute_error(y_ref, y) + if dtype == torch.float: + assert sqnr > 20, f"SQNR too low: {sqnr}" + else: + assert sqnr > 15, f"SQNR too low: {sqnr}" + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.8.0"), "Test only enabled for 2.8+") + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + def test_8da4w_concat_linear_cpu(self, x_dim, bias): + from torchao.prototype.inductor.fx_passes import ( + register_da8w4_concat_linear_cpu_pass, + ) + + register_da8w4_concat_linear_cpu_pass() + N, K = 64, 128 + + class Mod(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.linear1 = torch.nn.Linear(K, N, bias=bias) + self.linear2 = torch.nn.Linear(K, N, bias=bias) + self.linear3 = torch.nn.Linear(K, N, bias=bias) + + def forward(self, x): + a = self.linear1(x) + b = self.linear2(x) + c = self.linear3(x) + return a + b + c + + dtype = torch.bfloat16 + device = "cpu" + m = Mod(bias).eval().to(dtype).to(device) + x_shape = [2] * x_dim + x_shape[-1] = K + x = torch.rand(x_shape, dtype=dtype, device=device) + with torch.no_grad(): + quantize_( + m, + Int8DynamicActInt4WeightOpaqueTensorConfig( + group_size=32, + act_mapping_type=MappingType.SYMMETRIC, + ), + ) + # Need to turn on freezing to get the pattern + # set enable_concat_linear to true to enable the fusion + with torch._inductor.config.patch( + {"freezing": True, "cpp.enable_concat_linear": True} + ): + y, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + x, + ) + # ensure the expected op occurs only once in the code after fusion + # The trailing "(" is to avoid matching the op in the comment + assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1 + with torch._inductor.config.patch( + {"freezing": True, "cpp.enable_concat_linear": False} + ): + y_ref, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + x, + ) + assert torch.allclose(y, y_ref) + + +common_utils.instantiate_parametrized_tests(TestDa8w4Cpu) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/prototype/int4_opaque_tensor/__init__.py b/torchao/prototype/int4_opaque_tensor/__init__.py index 29352457cb..f7e5e824d6 100644 --- a/torchao/prototype/int4_opaque_tensor/__init__.py +++ b/torchao/prototype/int4_opaque_tensor/__init__.py @@ -1,7 +1,11 @@ -from .inference_workflow import Int4WeightOnlyOpaqueTensorConfig +from .inference_workflow import ( + Int4WeightOnlyOpaqueTensorConfig, + Int8DynamicActInt4WeightOpaqueTensorConfig, +) from .int4_opaque_tensor import Int4OpaqueTensor __all__ = [ "Int4OpaqueTensor", "Int4WeightOnlyOpaqueTensorConfig", + "Int8DynamicActInt4WeightOpaqueTensorConfig", ] diff --git a/torchao/prototype/int4_opaque_tensor/inference_workflow.py b/torchao/prototype/int4_opaque_tensor/inference_workflow.py index d1a92d2b2f..3a9e440932 100644 --- a/torchao/prototype/int4_opaque_tensor/inference_workflow.py +++ b/torchao/prototype/int4_opaque_tensor/inference_workflow.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. import logging -from dataclasses import dataclass +import types +from dataclasses import dataclass, field import torch @@ -13,9 +14,9 @@ from torchao.core.config import AOBaseConfig logger = logging.getLogger(__name__) -import types from torchao.quantization.quant_api import _linear_extra_repr +from torchao.quantization.quant_primitives import MappingType from torchao.quantization.quantize_.workflows import ( Int4ChooseQParamsAlgorithm, ) @@ -86,3 +87,69 @@ def _int4_weight_only_transform( module.weight = torch.nn.Parameter(new_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module + + +@dataclass +class Int8DynamicActInt4WeightOpaqueTensorConfig(AOBaseConfig): + """ + Configuration for int8 dynamic activation + int4 weight quantization on CPU, + using Int4OpaqueTensor (tensor subclassing) with the da8w4_linear_cpu backend. + + Weights are quantized per-group (asymmetric int4) and prepacked at quantization time. + Activations are quantized dynamically per-token at runtime. + + Args: + `group_size`: quantization group size for weights; K must be divisible by group_size + `act_mapping_type`: activation quantization type: + - MappingType.ASYMMETRIC (default): uint8 activation quantization + - MappingType.SYMMETRIC: int8 activation quantization (requires PyTorch >= 2.8) + """ + + group_size: int = 32 + act_mapping_type: MappingType = field( + default_factory=lambda: MappingType.ASYMMETRIC + ) + set_inductor_config: bool = True + + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.prototype.int4_opaque_tensor.Int8DynamicActInt4WeightOpaqueTensorConfig" + ) + + +@register_quantize_module_handler(Int8DynamicActInt4WeightOpaqueTensorConfig) +def _int8_dynamic_act_int4_weight_transform( + module: torch.nn.Module, config: Int8DynamicActInt4WeightOpaqueTensorConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying DA8W4 quant requires module to have weight attribute" + + f" but {module} does not have one" + ) + assert "CPU" in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), ( + "DA8W4 on CPU requires the da8w4_linear_cpu kernel to be built and available" + ) + weight = module.weight + if weight.shape[-1] % config.group_size != 0: + logger.info( + f"Skipping DA8W4 quantization: weight shape {weight.shape} is not compatible " + f"with group_size {config.group_size}" + ) + return module + if weight.shape[0] % 32 != 0 or weight.shape[-1] % 2 != 0: + logger.info( + f"Skipping DA8W4 quantization: weight shape {weight.shape} requires " + "N divisible by 32 and K divisible by 2" + ) + return module + + new_weight = Int4OpaqueTensor.from_hp_da8w4( + weight, + group_size=config.group_size, + act_mapping_type=config.act_mapping_type, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module diff --git a/torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py b/torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py index 976f219167..015d529281 100644 --- a/torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py +++ b/torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py @@ -15,13 +15,19 @@ _choose_qparams_affine_tinygemm, _choose_qparams_and_quantize_affine_hqq, _quantize_affine_tinygemm, + choose_qparams_affine, + quantize_affine, ) from torchao.quantization.quantize_.workflows import ( Int4ChooseQParamsAlgorithm, ) -from torchao.quantization.utils import pack_tinygemm_scales_and_zeros +from torchao.quantization.utils import ( + _get_per_token_block_size, + pack_tinygemm_scales_and_zeros, +) from torchao.utils import ( TorchAOBaseTensor, + torch_version_at_least, ) __all__ = [ @@ -33,34 +39,44 @@ class Int4OpaqueTensor(TorchAOBaseTensor): """ - int4 weight-only quantization on CPU with tinygemm (groupwise quantization only). The packing format is determined on ISA and shape. - This is an opaque tensor subclass, the packing format is not exposed to the rest of the system. See the note below for more details. + int4 weight quantization on CPU with two supported paths: - Tensor Attributes: - qdata: preshuffled and packed int4 weight for CPU tinygemm kernel, always viewed as a 2D (N, K/2) tensor, last dimension is packed - preshuffling is specific to CPU kernels based on ISA and shape, see Note below. - scale_and_zero: (K/group_size, N, 2), dtype is the same as the original Tensor dtype + 1. A16W4 (weight-only): float16/bfloat16/float32 activation + int4 weight, + using tinygemm kernel (_weight_int4pack_mm_for_cpu). - Non-Tensor Attributes: - block_size: the block size for quantization, representing the granularity, for groupwise quantization, will have block_size (1, group_size). - we only support group_size = 32/64/128. - shape: shape of the original Tensor + 2. DA8W4 (dynamic activation): int8 dynamic activation + int4 weight, + using da8w4 kernel (da8w4_linear_cpu). Activation is quantized + per-token dynamically at runtime. - Optional Tensor Data Attributes: - act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present, - we'll multiply activation Tensor with act_pre_scale before applying dynamic - quantization to activation or running quantized mm op + The path is selected based on `act_mapping_type`: + - None → A16W4 tinygemm path + - "asymmetric" or "symmetric" → DA8W4 path - Note on Details for data layout for CPU tinygemm kernel: + Mandatory Tensor Attributes (A16W4): + qdata: packed int4 weight. + A16W4: preshuffled for tinygemm, shape (N, K/2) + DA8W4: prepacked for da8w4_linear_cpu (4D) + scale_and_zero: weight quantization params. + A16W4: packed scales+zeros for tinygemm, shape (K/group_size, N, 2) + DA8W4: packed scales, shape (N/block_n, G, block_n) - We use AVX512 to compute TINYGEMM on CPU. We can also leverage AVX512_VNNI and AMX instructions with torch.compile and max-autotune. - For data locality, we preshuffle the data in plain layout (N, K/2) to (N/block_n, K, block_n/2), where block_n = 64/32/16. - See https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 for more details. + Mandatory Non-Tensor Attributes: + block_size: quantization block size, e.g. [1, group_size] + shape: original weight shape [N, K] + + Optional Tensor Data Attributes: + act_pre_scale: activation pre-scale (A16W4 only) + qzeros: packed weight zero-points for DA8W4, shape (N/block_n, G, block_n) + compensation: weight compensation for DA8W4, shape (N/block_n, K/block_k, block_n) + + Optional Non-Tensor Attributes: + act_mapping_type: None for A16W4; "asymmetric" or "symmetric" for DA8W4 """ tensor_data_names = ["qdata", "scale_and_zero"] tensor_attribute_names = ["block_size", "shape"] - optional_tensor_data_names = ["act_pre_scale"] + optional_tensor_data_names = ["act_pre_scale", "qzeros", "compensation"] + optional_tensor_attribute_names = ["act_mapping_type"] def __new__( cls, @@ -69,6 +85,9 @@ def __new__( block_size, shape, act_pre_scale: Optional[torch.Tensor] = None, + qzeros: Optional[torch.Tensor] = None, + compensation: Optional[torch.Tensor] = None, + act_mapping_type: Optional[str] = None, ): kwargs = {} kwargs["device"] = qdata.device @@ -83,17 +102,26 @@ def __init__( block_size: List[int], shape: torch.Size, act_pre_scale: Optional[torch.Tensor] = None, + qzeros: Optional[torch.Tensor] = None, + compensation: Optional[torch.Tensor] = None, + act_mapping_type: Optional[str] = None, ): super().__init__() self.qdata = qdata self.scale_and_zero = scale_and_zero self.block_size = block_size self.act_pre_scale = act_pre_scale + self.qzeros = qzeros + self.compensation = compensation + self.act_mapping_type = act_mapping_type def _quantization_type(self): - s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}" - if self.act_pre_scale is not None: - s += f", act_pre_scale.shape={self.act_pre_scale.shape}" + if self.act_mapping_type is not None: + s = f"da8w4, shape={self.shape}, block_size={self.block_size}, act={self.act_mapping_type}, device={self.device}" + else: + s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + if self.act_pre_scale is not None: + s += f", act_pre_scale.shape={self.act_pre_scale.shape}" return s @classmethod @@ -186,11 +214,168 @@ def from_hp( act_pre_scale=None, ) + @classmethod + def from_hp_da8w4( + cls, + w: torch.Tensor, + group_size: int = 32, + act_mapping_type: MappingType = MappingType.ASYMMETRIC, + ): + """ + Quantize a float weight tensor for DA8W4 (int8 dynamic activation + int4 weight) on CPU. + + The weight is quantized per-group (asymmetric int4), then prepacked via + torch.ops.torchao.da8w4_linear_prepack_cpu for the CPU kernel. + + Args: + w: float weight tensor, shape [N, K], must be on CPU + group_size: quantization group size, K must be divisible by group_size + act_mapping_type: MappingType.ASYMMETRIC (uint8 activation, default) or + MappingType.SYMMETRIC (int8 activation, requires PyTorch >= 2.8) + """ + assert w.ndim == 2 and w.device.type == "cpu", ( + f"Expecting 2D tensor on CPU, but got: {w.shape} on {w.device.type}" + ) + assert w.shape[1] % group_size == 0, ( + f"K={w.shape[1]} must be divisible by group_size={group_size}" + ) + assert w.shape[0] % 32 == 0 and w.shape[1] % 2 == 0, ( + f"N={w.shape[0]} must be divisible by 32 and K={w.shape[1]} must be even for DA8W4" + ) + original_shape = w.shape + block_size = [1, group_size] + + # Quantize weight: asymmetric int4 per-group → uint8 [N, K], values in [0, 15] + scale, zero_point = choose_qparams_affine( + w, + MappingType.ASYMMETRIC, + block_size, + torch.uint8, + quant_min=0, + quant_max=15, + eps=1e-6, + scale_dtype=torch.float32, + zero_point_dtype=torch.int32, + ) + int4_weight = quantize_affine( + w, + block_size, + scale, + zero_point, + torch.uint8, + quant_min=0, + quant_max=15, + ).to(torch.uint8) + + # Prepack for da8w4_linear_cpu + packed_weight, packed_scales, packed_qzeros, compensation = ( + torch.ops.torchao.da8w4_linear_prepack_cpu( + int4_weight, + scale, + zero_point.to(torch.int8), + ) + ) + + act_str = ( + "symmetric" if act_mapping_type == MappingType.SYMMETRIC else "asymmetric" + ) + return cls( + qdata=packed_weight, + scale_and_zero=packed_scales, + block_size=block_size, + shape=original_shape, + act_pre_scale=None, + qzeros=packed_qzeros, + compensation=compensation, + act_mapping_type=act_str, + ) + implements = Int4OpaqueTensor.implements implements_torch_function = Int4OpaqueTensor.implements_torch_function +def _da8w4_linear(input_tensor, weight_tensor, bias): + """DA8W4 linear: dynamically quantize activation per-token, then call da8w4_linear_cpu.""" + orig_act_size = input_tensor.size() + orig_dtype = input_tensor.dtype + + # Reshape activation to 2D + act_fp = input_tensor.reshape(-1, input_tensor.shape[-1]) + per_token_block_size = _get_per_token_block_size(act_fp) + + if weight_tensor.act_mapping_type == "symmetric": + assert torch_version_at_least("2.8.0"), ( + "Symmetric int8 activation quantization requires PyTorch 2.8+" + ) + # Symmetric int8 quantization: values in [-127, 127] + act_scale, act_zero_point = choose_qparams_affine( + act_fp, + MappingType.SYMMETRIC, + per_token_block_size, + torch.int8, + quant_min=-127, + quant_max=127, + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float32, + zero_point_dtype=torch.int8, + ) + act_int = quantize_affine( + act_fp, + per_token_block_size, + act_scale, + act_zero_point, + torch.int8, + quant_min=-127, + quant_max=127, + ) + else: + assert torch_version_at_least("2.7.0"), ( + "Asymmetric uint8 activation quantization requires PyTorch 2.7+" + ) + # Asymmetric uint8 quantization: values in [0, 255] + act_scale, act_zero_point = choose_qparams_affine( + act_fp, + MappingType.ASYMMETRIC, + per_token_block_size, + torch.uint8, + quant_min=0, + quant_max=255, + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float32, + zero_point_dtype=torch.int32, + ) + act_int = quantize_affine( + act_fp, + per_token_block_size, + act_scale, + act_zero_point, + torch.uint8, + quant_min=0, + quant_max=255, + ) + + act_scale_1d = act_scale.reshape(-1) + act_qzeros_1d = act_zero_point.reshape(-1).to(torch.int32) + + y = torch.ops.torchao.da8w4_linear_cpu.default( + act_int.contiguous(), + act_scale_1d, + act_qzeros_1d, + weight_tensor.qdata, + weight_tensor.scale_and_zero, + weight_tensor.qzeros, + weight_tensor.compensation, + bias.float() if bias is not None else bias, + orig_dtype, + ) + + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + return y.to(orig_dtype) + + @implements(aten.linear.default) @implements_torch_function(torch.nn.functional.linear) def _(func, types, args, kwargs): @@ -212,6 +397,19 @@ def _(func, types, args, kwargs): f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" ) + # DA8W4 path: dynamic int8 activation + int4 weight + if weight_tensor.act_mapping_type is not None: + if weight_tensor.act_mapping_type == "symmetric": + assert torch_version_at_least("2.8.0"), ( + "Symmetric int8 activation quantization requires PyTorch 2.8+" + ) + else: + assert torch_version_at_least("2.7.0"), ( + "Asymmetric uint8 activation quantization requires PyTorch 2.7+" + ) + return _da8w4_linear(input_tensor, weight_tensor, bias) + + # A16W4 path: float activation + int4 weight (tinygemm) if weight_tensor.act_pre_scale is not None: input_tensor = input_tensor * weight_tensor.act_pre_scale