Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 156 additions & 1 deletion test/quantization/test_da8w4_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,159 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

# Tests removed: Int8DynamicActivationInt4WeightConfig has been removed.
import copy
import unittest

import torch
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao import quantize_
from torchao.prototype.int4_opaque_tensor import (
Int8DynamicActInt4WeightOpaqueTensorConfig,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.utils import compute_error
from torchao.utils import torch_version_at_least


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64, bias=False):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float)

def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (
torch.randn(
batch_size, self.linear1.in_features, dtype=dtype, device=device
),
)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


class TestDa8w4Cpu(TestCase):
@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not torch_version_at_least("2.7.0"), "Test only enabled for 2.7+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [1, 160])
@common_utils.parametrize("sym_quant_a", [True, False])
def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a):
if sym_quant_a and not torch_version_at_least("2.8.0"):
# symmetric int8 activation not supported until PT 2.8
return
device = "cpu"
m = ToyLinearModel(bias=bias).eval().to(dtype).to(device)
m_ref = copy.deepcopy(m)
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)

act_mapping_type = (
MappingType.SYMMETRIC if sym_quant_a else MappingType.ASYMMETRIC
)
with torch.no_grad():
quantize_(
m,
Int8DynamicActInt4WeightOpaqueTensorConfig(
group_size=32,
act_mapping_type=act_mapping_type,
),
)
y, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0]

# compare against float reference: result should have reasonable accuracy
torch._dynamo.reset()
y_ref = m_ref(*example_inputs)
sqnr = compute_error(y_ref, y)
if dtype == torch.float:
assert sqnr > 20, f"SQNR too low: {sqnr}"
else:
assert sqnr > 15, f"SQNR too low: {sqnr}"

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Test only enabled for 2.8+")
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
def test_8da4w_concat_linear_cpu(self, x_dim, bias):
from torchao.prototype.inductor.fx_passes import (
register_da8w4_concat_linear_cpu_pass,
)

register_da8w4_concat_linear_cpu_pass()
N, K = 64, 128

class Mod(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear1 = torch.nn.Linear(K, N, bias=bias)
self.linear2 = torch.nn.Linear(K, N, bias=bias)
self.linear3 = torch.nn.Linear(K, N, bias=bias)

def forward(self, x):
a = self.linear1(x)
b = self.linear2(x)
c = self.linear3(x)
return a + b + c

dtype = torch.bfloat16
device = "cpu"
m = Mod(bias).eval().to(dtype).to(device)
x_shape = [2] * x_dim
x_shape[-1] = K
x = torch.rand(x_shape, dtype=dtype, device=device)
with torch.no_grad():
quantize_(
m,
Int8DynamicActInt4WeightOpaqueTensorConfig(
group_size=32,
act_mapping_type=MappingType.SYMMETRIC,
),
)
# Need to turn on freezing to get the pattern
# set enable_concat_linear to true to enable the fusion
with torch._inductor.config.patch(
{"freezing": True, "cpp.enable_concat_linear": True}
):
y, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
x,
)
# ensure the expected op occurs only once in the code after fusion
# The trailing "(" is to avoid matching the op in the comment
assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1
with torch._inductor.config.patch(
{"freezing": True, "cpp.enable_concat_linear": False}
):
y_ref, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
x,
)
assert torch.allclose(y, y_ref)


common_utils.instantiate_parametrized_tests(TestDa8w4Cpu)


if __name__ == "__main__":
run_tests()
6 changes: 5 additions & 1 deletion torchao/prototype/int4_opaque_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from .inference_workflow import Int4WeightOnlyOpaqueTensorConfig
from .inference_workflow import (
Int4WeightOnlyOpaqueTensorConfig,
Int8DynamicActInt4WeightOpaqueTensorConfig,
)
from .int4_opaque_tensor import Int4OpaqueTensor

__all__ = [
"Int4OpaqueTensor",
"Int4WeightOnlyOpaqueTensorConfig",
"Int8DynamicActInt4WeightOpaqueTensorConfig",
]
71 changes: 69 additions & 2 deletions torchao/prototype/int4_opaque_tensor/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
# LICENSE file in the root directory of this source tree.

import logging
from dataclasses import dataclass
import types
from dataclasses import dataclass, field

import torch

import torchao
from torchao.core.config import AOBaseConfig

logger = logging.getLogger(__name__)
import types

from torchao.quantization.quant_api import _linear_extra_repr
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quantize_.workflows import (
Int4ChooseQParamsAlgorithm,
)
Expand Down Expand Up @@ -86,3 +87,69 @@ def _int4_weight_only_transform(
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


@dataclass
class Int8DynamicActInt4WeightOpaqueTensorConfig(AOBaseConfig):
"""
Configuration for int8 dynamic activation + int4 weight quantization on CPU,
using Int4OpaqueTensor (tensor subclassing) with the da8w4_linear_cpu backend.

Weights are quantized per-group (asymmetric int4) and prepacked at quantization time.
Activations are quantized dynamically per-token at runtime.

Args:
`group_size`: quantization group size for weights; K must be divisible by group_size
`act_mapping_type`: activation quantization type:
- MappingType.ASYMMETRIC (default): uint8 activation quantization
- MappingType.SYMMETRIC: int8 activation quantization (requires PyTorch >= 2.8)
"""

group_size: int = 32
act_mapping_type: MappingType = field(
default_factory=lambda: MappingType.ASYMMETRIC
)
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.prototype.int4_opaque_tensor.Int8DynamicActInt4WeightOpaqueTensorConfig"
)


@register_quantize_module_handler(Int8DynamicActInt4WeightOpaqueTensorConfig)
def _int8_dynamic_act_int4_weight_transform(
module: torch.nn.Module, config: Int8DynamicActInt4WeightOpaqueTensorConfig
) -> torch.nn.Module:
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

assert hasattr(module, "weight"), (
"applying DA8W4 quant requires module to have weight attribute"
+ f" but {module} does not have one"
)
Comment on lines +124 to +130
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DA8W4 module transform quantizes/prepacks weights unconditionally. If the DA8W4 CPU kernels aren’t built/registered (or if running on an older PyTorch that doesn’t support the needed path), this will still replace module.weight with an Int4OpaqueTensor and the first linear() call will fail at runtime. Consider adding an early guard here (similar to the unit test) that checks kernel availability via torch._C._dispatch_dump("torchao::da8w4_linear_cpu") and a torch_version_at_least("2.7.0") (and 2.8.0 for symmetric) before applying the transform; otherwise log and return the original module.

Copilot uses AI. Check for mistakes.
assert "CPU" in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), (
"DA8W4 on CPU requires the da8w4_linear_cpu kernel to be built and available"
)
weight = module.weight
if weight.shape[-1] % config.group_size != 0:
logger.info(
f"Skipping DA8W4 quantization: weight shape {weight.shape} is not compatible "
f"with group_size {config.group_size}"
)
return module
if weight.shape[0] % 32 != 0 or weight.shape[-1] % 2 != 0:
logger.info(
f"Skipping DA8W4 quantization: weight shape {weight.shape} requires "
"N divisible by 32 and K divisible by 2"
)
return module

new_weight = Int4OpaqueTensor.from_hp_da8w4(
weight,
group_size=config.group_size,
act_mapping_type=config.act_mapping_type,
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module
Loading
Loading