Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat]: Add support for kleidiai quantization schemes #1447

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def from_hp_to_intx(
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
use_hqq: bool = False,
bias: Optional[torch.Tensor] = None
ng-05 marked this conversation as resolved.
Show resolved Hide resolved
):
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
Expand Down Expand Up @@ -278,7 +279,11 @@ def from_hp_to_intx(

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
args = [data, scale, zero_point, _layout]
# Only PackedLinearInt8DynamicActivationIntxWeightLayout() with "aten" target supports bias
if bias is not None:
args.append(bias)
tensor_impl = tensor_impl_ctr(*args)
return cls(
tensor_impl,
block_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from enum import Enum, auto
from typing import Optional, Tuple

import torch
Expand All @@ -20,6 +21,9 @@
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
Expand All @@ -31,17 +35,33 @@
handler.setFormatter(formatter)
logger.addHandler(handler)

class Target(Enum):
"""Enum that indicates the backend target"""

NATIVE = auto()
ATEN = auto()

def target_from_str(target: str) -> Target:
if target.lower() == "native":
return Target.NATIVE
elif target.lower() == "aten":
return Target.ATEN
else:
raise ValueError(f"Invalid target: {target}")

class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
bit_width: Optional[int]
group_size: Optional[int]
has_weight_zeros: Optional[bool]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The packed weights from Kleidi have bias packed with them, right? If so, let's add has_bias: Optional[bool] here to layout.

# The target platform for the layout, 'native' or 'aten'
target: Optional[Target]

def __init__(
self,
bit_width: Optional[int] = None,
group_size: Optional[int] = None,
has_weight_zeros: Optional[bool] = None,
target: Optional[str] = "native",
):
if bit_width is not None:
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8"
Expand All @@ -51,6 +71,7 @@ def __init__(
self.bit_width = bit_width
self.group_size = group_size
self.has_weight_zeros = has_weight_zeros
self.target = target_from_str(target)

if not self.has_params_set():
assert (
Expand All @@ -60,13 +81,14 @@ def __init__(
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False"

def extra_repr(self):
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}"
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}"

def has_params_set(self) -> bool:
return (
(self.bit_width is not None)
and (self.group_size is not None)
and (self.has_weight_zeros is not None)
and (self.target is not None)
)


Expand Down Expand Up @@ -125,9 +147,11 @@ def from_plain(
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
layout: Layout,
bias: Optional[torch.Tensor] = None,
):
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}"

# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
# when AOTI supports int
Expand All @@ -136,6 +160,13 @@ def from_plain(
n_tensor = torch.empty(0, n, dtype=torch.int8)
k_tensor = torch.empty(0, k, dtype=torch.int8)

if layout.target == Target.ATEN:
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
int_data = int_data.add(8)
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8)
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n)
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor)

if layout.has_weight_zeros:
args = [
int_data.to(torch.int8),
Expand Down Expand Up @@ -211,16 +242,13 @@ def __tensor_unflatten__(
def _linear_check(input_tensor, weight_tensor, bias):
layout = weight_tensor.tensor_impl.get_layout()
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and (
bias is None
bias is None or layout.target == Target.ATEN # Aten target allows bias
)


def _linear_impl(input_tensor, weight_tensor, bias):
assert (
bias is None
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl"

def _impl_2d(input_tensor, weight_tensor):
def _impl_2d_native(input_tensor, weight_tensor):
assert input_tensor.dim() == 2
assert weight_tensor.dim() == 2

Expand Down Expand Up @@ -255,6 +283,31 @@ def _impl_2d(input_tensor, weight_tensor):
torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight"
)(*args)

def _impl_2d_aten(input_tensor, weight_tensor):
assert input_tensor.dim() == 2
assert weight_tensor.dim() == 2

m, k = input_tensor.shape
n, k_ = weight_tensor.shape
assert k_ == k
group_size = weight_tensor.tensor_impl.get_layout().group_size
packed_weight = weight_tensor.tensor_impl.packed_weight
return torch.ops.aten._dyn_quant_matmul_4bit(
input_tensor, packed_weight, group_size, k_, n)

target = weight_tensor.tensor_impl.get_layout().target

if target == Target.ATEN:
assert (
TORCH_VERSION_AT_LEAST_2_6 == 1
), "Target.ATEN requires torch >= 2.6.0"
_impl_2d = _impl_2d_aten
elif target == Target.NATIVE:
_impl_2d = _impl_2d_native
assert (
bias is None
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' "

if input_tensor.dim() == 2:
return _impl_2d(input_tensor, weight_tensor)

Expand All @@ -268,7 +321,6 @@ def _impl_2d(input_tensor, weight_tensor):
res = res.reshape(*lead_shape, m, n)
return res


register_aqt_quantized_linear_dispatch(
_linear_check,
_linear_impl,
Expand Down
53 changes: 42 additions & 11 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import sys
import logging
from typing import Optional, Union

Expand All @@ -18,14 +19,18 @@
PerGroup,
PerRow,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)
from torchao.dtypes import PlainLayout

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

import sys

handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

Expand Down Expand Up @@ -506,6 +511,7 @@ def int8_dynamic_activation_intx_weight(
weight_dtype: torch.dtype = torch.int4,
granularity: Union[PerRow, PerGroup] = PerGroup(128),
has_weight_zeros: bool = False,
target: str = "native",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to pass this in layout's constructor because it isn't related to the quantization intent, but packing format/kernel selection e.g., layout= PackedLinearInt8DynamicActivationIntxWeightLayout(target="native")

weight_mapping_type=MappingType.ASYMMETRIC,
act_mapping_type=MappingType.ASYMMETRIC,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() also works, but will be slow
Expand All @@ -531,13 +537,28 @@ def int8_dynamic_activation_intx_weight(
- The weight tensor must have dtype=float32 (note that after applying quantization, the weights will no longer be float32)
- act_mapping_type must be MappingType.ASYMMETRIC
"""
try:
torch.ops.torchao._pack_8bit_act_4bit_weight
except AttributeError:
raise Exception(
"TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU."
+ " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance."
)

if target == "aten":
Copy link
Contributor

@metascroy metascroy Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be something like:

if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
   assert (act_mapping_type == MappingType.ASYMMETRIC), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC"

    if taget == "aten":
        # Do KleidiAI specific checks
    
    if target == "native":
        # Do try/except import logic

if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) or \
weight_dtype != torch.int4 or \
has_weight_zeros != True or \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the KleidiAI op does not take the zero points during packing (scale-only quantization)? So shouldn't has_weight_zeros be false?

weight_mapping_type != MappingType.SYMMETRIC:
raise NotImplementedError(
f"target 'aten' requires:\n"
f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
f"- has_weight_zeros to be True,\n"
f"- weight_dtype to be torch.int4,\n"
f"- weight_mapping_type to be MappingType.SYMMETRIC"
)
elif not isinstance(layout, PlainLayout):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard this try/except on if isisntance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) instead. In case other layout is added in future, guarding on not PlainLayout is too broad.

try:
torch.ops.torchao._pack_8bit_act_4bit_weight
except AttributeError:
raise Exception(
"TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU."
+ " You can also set target to 'aten' if you are using ARM CPU."
+ " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance."
)

dtype_to_bit_width = {
torch.int1: 1,
Expand All @@ -556,7 +577,7 @@ def int8_dynamic_activation_intx_weight(
bit_width = dtype_to_bit_width[weight_dtype]
layout_arg = layout

def apply(weight):
def apply(weight, bias: Optional[torch.Tensor] = None):
if isinstance(granularity, PerGroup):
group_size = granularity.group_size
elif isinstance(granularity, PerRow):
Expand All @@ -569,6 +590,7 @@ def apply(weight):
assert weight.shape[-1] % group_size == 0

layout = layout_arg
scale_dtype = None
if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
assert (
weight.device == torch.device("cpu")
Expand All @@ -584,7 +606,13 @@ def apply(weight):
bit_width=bit_width,
group_size=group_size,
has_weight_zeros=has_weight_zeros,
target=target,
)
if target == "aten":
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
if torch.backends.kleidiai.is_available():
if isinstance(granularity, PerGroup):
scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only bfloat16 on PerGroup, but not on PerRow?


quant_min = -(1 << (bit_width - 1))
quant_max = (1 << (bit_width - 1)) - 1
Expand All @@ -596,12 +624,14 @@ def apply(weight):
quant_min=quant_min,
quant_max=quant_max,
eps=torch.finfo(torch.float32).eps,
scale_dtype=scale_dtype,
zero_point_dtype=torch.int8,
preserve_zero=has_weight_zeros,
zero_point_domain=ZeroPointDomain.INT
if has_weight_zeros
else ZeroPointDomain.NONE,
_layout=layout,
bias=bias
)

# Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused
Expand All @@ -620,7 +650,8 @@ def apply(weight):
weight = to_linear_activation_quantized(weight, activation_quant_func)
return weight

return _get_linear_subclass_inserter(apply)
propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten"
return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias)


class UIntxWeightOnlyQuantizedLinear(nn.Module):
Expand Down
7 changes: 5 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,15 +450,18 @@ def _linear_extra_repr(self):
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"


def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, **kwargs):
def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs):
"""Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs)
to the weight of linear module
"""

def insert_subclass(lin):
requires_grad = allow_requires_grad and lin.weight.requires_grad
args = [lin.weight]
ng-05 marked this conversation as resolved.
Show resolved Hide resolved
if propagate_bias == True:
args.append(lin.bias)
lin.weight = torch.nn.Parameter(
constructor(lin.weight, **kwargs), requires_grad=requires_grad
constructor(*args, **kwargs), requires_grad=requires_grad
)
lin.extra_repr = types.MethodType(_linear_extra_repr, lin)
return lin
Expand Down