Skip to content

Commit

Permalink
[Feat]: Add support for kleidiai quantization schemes
Browse files Browse the repository at this point in the history
Description:

Allow int8_dynamic_activation_intx_weight to work with aten  _dyn_quant_matmul_4bit op

Needs : pytorch/pytorch#134124 or Pytorch > 2.6.0

Signed-off-by: Nikhil Gupta <[email protected]>
  • Loading branch information
ng-05 committed Jan 15, 2025
1 parent 11333ba commit 983215d
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 21 deletions.
7 changes: 6 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def from_hp_to_intx(
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
use_hqq: bool = False,
bias: Optional[torch.Tensor] = None
):
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
Expand Down Expand Up @@ -278,7 +279,11 @@ def from_hp_to_intx(

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
args = [data, scale, zero_point, _layout]
# Only PackedLinearInt8DynamicActivationIntxWeightLayout() with "aten" target supports bias
if bias is not None:
args.append(bias)
tensor_impl = tensor_impl_ctr(*args)
return cls(
tensor_impl,
block_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from enum import Enum, auto
from typing import Optional, Tuple

import torch
Expand All @@ -20,6 +21,9 @@
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
Expand All @@ -31,17 +35,33 @@
handler.setFormatter(formatter)
logger.addHandler(handler)

class Target(Enum):
"""Enum that indicates the backend target"""

NATIVE = auto()
ATEN = auto()

def target_from_str(target: str) -> Target:
if target.lower() == "native":
return Target.NATIVE
elif target.lower() == "aten":
return Target.ATEN
else:
raise ValueError(f"Invalid target: {target}")

class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
bit_width: Optional[int]
group_size: Optional[int]
has_weight_zeros: Optional[bool]
# The target platform for the layout, 'native' or 'aten'
target: Optional[Target]

def __init__(
self,
bit_width: Optional[int] = None,
group_size: Optional[int] = None,
has_weight_zeros: Optional[bool] = None,
target: Optional[str] = "native",
):
if bit_width is not None:
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8"
Expand All @@ -51,6 +71,7 @@ def __init__(
self.bit_width = bit_width
self.group_size = group_size
self.has_weight_zeros = has_weight_zeros
self.target = target_from_str(target)

if not self.has_params_set():
assert (
Expand All @@ -60,13 +81,14 @@ def __init__(
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False"

def extra_repr(self):
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}"
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}"

def has_params_set(self) -> bool:
return (
(self.bit_width is not None)
and (self.group_size is not None)
and (self.has_weight_zeros is not None)
and (self.target is not None)
)


Expand Down Expand Up @@ -125,9 +147,11 @@ def from_plain(
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
layout: Layout,
bias: Optional[torch.Tensor] = None,
):
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}"

# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
# when AOTI supports int
Expand All @@ -136,6 +160,13 @@ def from_plain(
n_tensor = torch.empty(0, n, dtype=torch.int8)
k_tensor = torch.empty(0, k, dtype=torch.int8)

if layout.target == Target.ATEN:
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
int_data = int_data.add(8)
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8)
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n)
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor)

if layout.has_weight_zeros:
args = [
int_data.to(torch.int8),
Expand Down Expand Up @@ -211,16 +242,13 @@ def __tensor_unflatten__(
def _linear_check(input_tensor, weight_tensor, bias):
layout = weight_tensor.tensor_impl.get_layout()
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and (
bias is None
bias is None or layout.target == Target.ATEN # Aten target allows bias
)


def _linear_impl(input_tensor, weight_tensor, bias):
assert (
bias is None
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl"

def _impl_2d(input_tensor, weight_tensor):
def _impl_2d_native(input_tensor, weight_tensor):
assert input_tensor.dim() == 2
assert weight_tensor.dim() == 2

Expand Down Expand Up @@ -255,6 +283,31 @@ def _impl_2d(input_tensor, weight_tensor):
torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight"
)(*args)

def _impl_2d_aten(input_tensor, weight_tensor):
assert input_tensor.dim() == 2
assert weight_tensor.dim() == 2

m, k = input_tensor.shape
n, k_ = weight_tensor.shape
assert k_ == k
group_size = weight_tensor.tensor_impl.get_layout().group_size
packed_weight = weight_tensor.tensor_impl.packed_weight
return torch.ops.aten._dyn_quant_matmul_4bit(
input_tensor, packed_weight, group_size, k_, n)

target = weight_tensor.tensor_impl.get_layout().target

if target == Target.ATEN:
assert (
TORCH_VERSION_AT_LEAST_2_6 == 1
), "Target.ATEN requires torch >= 2.6.0"
_impl_2d = _impl_2d_aten
elif target == Target.NATIVE:
_impl_2d = _impl_2d_native
assert (
bias is None
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' "

if input_tensor.dim() == 2:
return _impl_2d(input_tensor, weight_tensor)

Expand All @@ -268,7 +321,6 @@ def _impl_2d(input_tensor, weight_tensor):
res = res.reshape(*lead_shape, m, n)
return res


register_aqt_quantized_linear_dispatch(
_linear_check,
_linear_impl,
Expand Down
53 changes: 42 additions & 11 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import sys
import logging
from typing import Optional, Union

Expand All @@ -18,14 +19,18 @@
PerGroup,
PerRow,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)
from torchao.dtypes import PlainLayout

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

import sys

handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

Expand Down Expand Up @@ -506,6 +511,7 @@ def int8_dynamic_activation_intx_weight(
weight_dtype: torch.dtype = torch.int4,
granularity: Union[PerRow, PerGroup] = PerGroup(128),
has_weight_zeros: bool = False,
target: str = "native",
weight_mapping_type=MappingType.ASYMMETRIC,
act_mapping_type=MappingType.ASYMMETRIC,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() also works, but will be slow
Expand All @@ -531,13 +537,28 @@ def int8_dynamic_activation_intx_weight(
- The weight tensor must have dtype=float32 (note that after applying quantization, the weights will no longer be float32)
- act_mapping_type must be MappingType.ASYMMETRIC
"""
try:
torch.ops.torchao._pack_8bit_act_4bit_weight
except AttributeError:
raise Exception(
"TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU."
+ " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance."
)

if target == "aten":
if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) or \
weight_dtype != torch.int4 or \
has_weight_zeros != True or \
weight_mapping_type != MappingType.SYMMETRIC:
raise NotImplementedError(
f"target 'aten' requires:\n"
f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
f"- has_weight_zeros to be True,\n"
f"- weight_dtype to be torch.int4,\n"
f"- weight_mapping_type to be MappingType.SYMMETRIC"
)
elif not isinstance(layout, PlainLayout):
try:
torch.ops.torchao._pack_8bit_act_4bit_weight
except AttributeError:
raise Exception(
"TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU."
+ " You can also set target to 'aten' if you are using ARM CPU."
+ " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance."
)

dtype_to_bit_width = {
torch.int1: 1,
Expand All @@ -556,7 +577,7 @@ def int8_dynamic_activation_intx_weight(
bit_width = dtype_to_bit_width[weight_dtype]
layout_arg = layout

def apply(weight):
def apply(weight, bias: Optional[torch.Tensor] = None):
if isinstance(granularity, PerGroup):
group_size = granularity.group_size
elif isinstance(granularity, PerRow):
Expand All @@ -569,6 +590,7 @@ def apply(weight):
assert weight.shape[-1] % group_size == 0

layout = layout_arg
scale_dtype = None
if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
assert (
weight.device == torch.device("cpu")
Expand All @@ -584,7 +606,13 @@ def apply(weight):
bit_width=bit_width,
group_size=group_size,
has_weight_zeros=has_weight_zeros,
target=target,
)
if target == "aten":
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
if torch.backends.kleidiai.is_available():
if isinstance(granularity, PerGroup):
scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype

quant_min = -(1 << (bit_width - 1))
quant_max = (1 << (bit_width - 1)) - 1
Expand All @@ -596,12 +624,14 @@ def apply(weight):
quant_min=quant_min,
quant_max=quant_max,
eps=torch.finfo(torch.float32).eps,
scale_dtype=scale_dtype,
zero_point_dtype=torch.int8,
preserve_zero=has_weight_zeros,
zero_point_domain=ZeroPointDomain.INT
if has_weight_zeros
else ZeroPointDomain.NONE,
_layout=layout,
bias=bias
)

# Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused
Expand All @@ -620,7 +650,8 @@ def apply(weight):
weight = to_linear_activation_quantized(weight, activation_quant_func)
return weight

return _get_linear_subclass_inserter(apply)
propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten"
return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias)


class UIntxWeightOnlyQuantizedLinear(nn.Module):
Expand Down
7 changes: 5 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,15 +450,18 @@ def _linear_extra_repr(self):
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"


def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, **kwargs):
def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs):
"""Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs)
to the weight of linear module
"""

def insert_subclass(lin):
requires_grad = allow_requires_grad and lin.weight.requires_grad
args = [lin.weight]
if propagate_bias == True:
args.append(lin.bias)
lin.weight = torch.nn.Parameter(
constructor(lin.weight, **kwargs), requires_grad=requires_grad
constructor(*args, **kwargs), requires_grad=requires_grad
)
lin.extra_repr = types.MethodType(_linear_extra_repr, lin)
return lin
Expand Down

0 comments on commit 983215d

Please sign in to comment.