Skip to content

Commit 983215d

Browse files
committed
[Feat]: Add support for kleidiai quantization schemes
Description: Allow int8_dynamic_activation_intx_weight to work with aten _dyn_quant_matmul_4bit op Needs : pytorch/pytorch#134124 or Pytorch > 2.6.0 Signed-off-by: Nikhil Gupta <[email protected]>
1 parent 11333ba commit 983215d

File tree

4 files changed

+112
-21
lines changed

4 files changed

+112
-21
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def from_hp_to_intx(
206206
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
207207
_layout: Layout = PlainLayout(),
208208
use_hqq: bool = False,
209+
bias: Optional[torch.Tensor] = None
209210
):
210211
original_shape = input_float.shape
211212
input_float = _layout.pre_process(input_float)
@@ -278,7 +279,11 @@ def from_hp_to_intx(
278279

279280
data = _layout.post_process(data)
280281
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
281-
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
282+
args = [data, scale, zero_point, _layout]
283+
# Only PackedLinearInt8DynamicActivationIntxWeightLayout() with "aten" target supports bias
284+
if bias is not None:
285+
args.append(bias)
286+
tensor_impl = tensor_impl_ctr(*args)
282287
return cls(
283288
tensor_impl,
284289
block_size,

torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8+
from enum import Enum, auto
89
from typing import Optional, Tuple
910

1011
import torch
@@ -20,6 +21,9 @@
2021
from torchao.quantization.quant_primitives import (
2122
ZeroPointDomain,
2223
)
24+
from torchao.utils import (
25+
TORCH_VERSION_AT_LEAST_2_6,
26+
)
2327

2428
logger = logging.getLogger(__name__)
2529
logger.setLevel(logging.WARNING)
@@ -31,17 +35,33 @@
3135
handler.setFormatter(formatter)
3236
logger.addHandler(handler)
3337

38+
class Target(Enum):
39+
"""Enum that indicates the backend target"""
40+
41+
NATIVE = auto()
42+
ATEN = auto()
43+
44+
def target_from_str(target: str) -> Target:
45+
if target.lower() == "native":
46+
return Target.NATIVE
47+
elif target.lower() == "aten":
48+
return Target.ATEN
49+
else:
50+
raise ValueError(f"Invalid target: {target}")
3451

3552
class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
3653
bit_width: Optional[int]
3754
group_size: Optional[int]
3855
has_weight_zeros: Optional[bool]
56+
# The target platform for the layout, 'native' or 'aten'
57+
target: Optional[Target]
3958

4059
def __init__(
4160
self,
4261
bit_width: Optional[int] = None,
4362
group_size: Optional[int] = None,
4463
has_weight_zeros: Optional[bool] = None,
64+
target: Optional[str] = "native",
4565
):
4666
if bit_width is not None:
4767
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8"
@@ -51,6 +71,7 @@ def __init__(
5171
self.bit_width = bit_width
5272
self.group_size = group_size
5373
self.has_weight_zeros = has_weight_zeros
74+
self.target = target_from_str(target)
5475

5576
if not self.has_params_set():
5677
assert (
@@ -60,13 +81,14 @@ def __init__(
6081
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False"
6182

6283
def extra_repr(self):
63-
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}"
84+
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}"
6485

6586
def has_params_set(self) -> bool:
6687
return (
6788
(self.bit_width is not None)
6889
and (self.group_size is not None)
6990
and (self.has_weight_zeros is not None)
91+
and (self.target is not None)
7092
)
7193

7294

@@ -125,9 +147,11 @@ def from_plain(
125147
scale: torch.Tensor,
126148
zero_point: Optional[torch.Tensor],
127149
layout: Layout,
150+
bias: Optional[torch.Tensor] = None,
128151
):
129152
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
130153
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
154+
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}"
131155

132156
# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
133157
# when AOTI supports int
@@ -136,6 +160,13 @@ def from_plain(
136160
n_tensor = torch.empty(0, n, dtype=torch.int8)
137161
k_tensor = torch.empty(0, k, dtype=torch.int8)
138162

163+
if layout.target == Target.ATEN:
164+
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
165+
int_data = int_data.add(8)
166+
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8)
167+
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n)
168+
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor)
169+
139170
if layout.has_weight_zeros:
140171
args = [
141172
int_data.to(torch.int8),
@@ -211,16 +242,13 @@ def __tensor_unflatten__(
211242
def _linear_check(input_tensor, weight_tensor, bias):
212243
layout = weight_tensor.tensor_impl.get_layout()
213244
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and (
214-
bias is None
245+
bias is None or layout.target == Target.ATEN # Aten target allows bias
215246
)
216247

217248

218249
def _linear_impl(input_tensor, weight_tensor, bias):
219-
assert (
220-
bias is None
221-
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl"
222250

223-
def _impl_2d(input_tensor, weight_tensor):
251+
def _impl_2d_native(input_tensor, weight_tensor):
224252
assert input_tensor.dim() == 2
225253
assert weight_tensor.dim() == 2
226254

@@ -255,6 +283,31 @@ def _impl_2d(input_tensor, weight_tensor):
255283
torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight"
256284
)(*args)
257285

286+
def _impl_2d_aten(input_tensor, weight_tensor):
287+
assert input_tensor.dim() == 2
288+
assert weight_tensor.dim() == 2
289+
290+
m, k = input_tensor.shape
291+
n, k_ = weight_tensor.shape
292+
assert k_ == k
293+
group_size = weight_tensor.tensor_impl.get_layout().group_size
294+
packed_weight = weight_tensor.tensor_impl.packed_weight
295+
return torch.ops.aten._dyn_quant_matmul_4bit(
296+
input_tensor, packed_weight, group_size, k_, n)
297+
298+
target = weight_tensor.tensor_impl.get_layout().target
299+
300+
if target == Target.ATEN:
301+
assert (
302+
TORCH_VERSION_AT_LEAST_2_6 == 1
303+
), "Target.ATEN requires torch >= 2.6.0"
304+
_impl_2d = _impl_2d_aten
305+
elif target == Target.NATIVE:
306+
_impl_2d = _impl_2d_native
307+
assert (
308+
bias is None
309+
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' "
310+
258311
if input_tensor.dim() == 2:
259312
return _impl_2d(input_tensor, weight_tensor)
260313

@@ -268,7 +321,6 @@ def _impl_2d(input_tensor, weight_tensor):
268321
res = res.reshape(*lead_shape, m, n)
269322
return res
270323

271-
272324
register_aqt_quantized_linear_dispatch(
273325
_linear_check,
274326
_linear_impl,

torchao/experimental/quant_api.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import sys
78
import logging
89
from typing import Optional, Union
910

@@ -18,14 +19,18 @@
1819
PerGroup,
1920
PerRow,
2021
)
22+
from torchao.utils import (
23+
TORCH_VERSION_AT_LEAST_2_6,
24+
)
25+
from torchao.dtypes import PlainLayout
2126

2227
logger = logging.getLogger(__name__)
2328
logger.setLevel(logging.WARNING)
2429

25-
import sys
2630

2731
handler = logging.StreamHandler(sys.stdout)
28-
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
32+
formatter = logging.Formatter(
33+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s")
2934
handler.setFormatter(formatter)
3035
logger.addHandler(handler)
3136

@@ -506,6 +511,7 @@ def int8_dynamic_activation_intx_weight(
506511
weight_dtype: torch.dtype = torch.int4,
507512
granularity: Union[PerRow, PerGroup] = PerGroup(128),
508513
has_weight_zeros: bool = False,
514+
target: str = "native",
509515
weight_mapping_type=MappingType.ASYMMETRIC,
510516
act_mapping_type=MappingType.ASYMMETRIC,
511517
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() also works, but will be slow
@@ -531,13 +537,28 @@ def int8_dynamic_activation_intx_weight(
531537
- The weight tensor must have dtype=float32 (note that after applying quantization, the weights will no longer be float32)
532538
- act_mapping_type must be MappingType.ASYMMETRIC
533539
"""
534-
try:
535-
torch.ops.torchao._pack_8bit_act_4bit_weight
536-
except AttributeError:
537-
raise Exception(
538-
"TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU."
539-
+ " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance."
540-
)
540+
541+
if target == "aten":
542+
if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) or \
543+
weight_dtype != torch.int4 or \
544+
has_weight_zeros != True or \
545+
weight_mapping_type != MappingType.SYMMETRIC:
546+
raise NotImplementedError(
547+
f"target 'aten' requires:\n"
548+
f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
549+
f"- has_weight_zeros to be True,\n"
550+
f"- weight_dtype to be torch.int4,\n"
551+
f"- weight_mapping_type to be MappingType.SYMMETRIC"
552+
)
553+
elif not isinstance(layout, PlainLayout):
554+
try:
555+
torch.ops.torchao._pack_8bit_act_4bit_weight
556+
except AttributeError:
557+
raise Exception(
558+
"TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU."
559+
+ " You can also set target to 'aten' if you are using ARM CPU."
560+
+ " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance."
561+
)
541562

542563
dtype_to_bit_width = {
543564
torch.int1: 1,
@@ -556,7 +577,7 @@ def int8_dynamic_activation_intx_weight(
556577
bit_width = dtype_to_bit_width[weight_dtype]
557578
layout_arg = layout
558579

559-
def apply(weight):
580+
def apply(weight, bias: Optional[torch.Tensor] = None):
560581
if isinstance(granularity, PerGroup):
561582
group_size = granularity.group_size
562583
elif isinstance(granularity, PerRow):
@@ -569,6 +590,7 @@ def apply(weight):
569590
assert weight.shape[-1] % group_size == 0
570591

571592
layout = layout_arg
593+
scale_dtype = None
572594
if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
573595
assert (
574596
weight.device == torch.device("cpu")
@@ -584,7 +606,13 @@ def apply(weight):
584606
bit_width=bit_width,
585607
group_size=group_size,
586608
has_weight_zeros=has_weight_zeros,
609+
target=target,
587610
)
611+
if target == "aten":
612+
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
613+
if torch.backends.kleidiai.is_available():
614+
if isinstance(granularity, PerGroup):
615+
scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype
588616

589617
quant_min = -(1 << (bit_width - 1))
590618
quant_max = (1 << (bit_width - 1)) - 1
@@ -596,12 +624,14 @@ def apply(weight):
596624
quant_min=quant_min,
597625
quant_max=quant_max,
598626
eps=torch.finfo(torch.float32).eps,
627+
scale_dtype=scale_dtype,
599628
zero_point_dtype=torch.int8,
600629
preserve_zero=has_weight_zeros,
601630
zero_point_domain=ZeroPointDomain.INT
602631
if has_weight_zeros
603632
else ZeroPointDomain.NONE,
604633
_layout=layout,
634+
bias=bias
605635
)
606636

607637
# Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused
@@ -620,7 +650,8 @@ def apply(weight):
620650
weight = to_linear_activation_quantized(weight, activation_quant_func)
621651
return weight
622652

623-
return _get_linear_subclass_inserter(apply)
653+
propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten"
654+
return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias)
624655

625656

626657
class UIntxWeightOnlyQuantizedLinear(nn.Module):

torchao/quantization/quant_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,15 +450,18 @@ def _linear_extra_repr(self):
450450
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"
451451

452452

453-
def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, **kwargs):
453+
def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs):
454454
"""Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs)
455455
to the weight of linear module
456456
"""
457457

458458
def insert_subclass(lin):
459459
requires_grad = allow_requires_grad and lin.weight.requires_grad
460+
args = [lin.weight]
461+
if propagate_bias == True:
462+
args.append(lin.bias)
460463
lin.weight = torch.nn.Parameter(
461-
constructor(lin.weight, **kwargs), requires_grad=requires_grad
464+
constructor(*args, **kwargs), requires_grad=requires_grad
462465
)
463466
lin.extra_repr = types.MethodType(_linear_extra_repr, lin)
464467
return lin

0 commit comments

Comments
 (0)