5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import logging
8
+ from enum import Enum , auto
8
9
from typing import Optional , Tuple
9
10
10
11
import torch
20
21
from torchao .quantization .quant_primitives import (
21
22
ZeroPointDomain ,
22
23
)
24
+ from torchao .utils import (
25
+ TORCH_VERSION_AT_LEAST_2_6 ,
26
+ )
23
27
24
28
logger = logging .getLogger (__name__ )
25
29
logger .setLevel (logging .WARNING )
31
35
handler .setFormatter (formatter )
32
36
logger .addHandler (handler )
33
37
38
+ class Target (Enum ):
39
+ """Enum that indicates the backend target"""
40
+
41
+ NATIVE = auto ()
42
+ ATEN = auto ()
43
+
44
+ def target_from_str (target : str ) -> Target :
45
+ if target .lower () == "native" :
46
+ return Target .NATIVE
47
+ elif target .lower () == "aten" :
48
+ return Target .ATEN
49
+ else :
50
+ raise ValueError (f"Invalid target: { target } " )
34
51
35
52
class PackedLinearInt8DynamicActivationIntxWeightLayout (Layout ):
36
53
bit_width : Optional [int ]
37
54
group_size : Optional [int ]
38
55
has_weight_zeros : Optional [bool ]
56
+ # The target platform for the layout, 'native' or 'aten'
57
+ target : Optional [Target ]
39
58
40
59
def __init__ (
41
60
self ,
42
61
bit_width : Optional [int ] = None ,
43
62
group_size : Optional [int ] = None ,
44
63
has_weight_zeros : Optional [bool ] = None ,
64
+ target : Optional [str ] = "native" ,
45
65
):
46
66
if bit_width is not None :
47
67
assert bit_width >= 1 and bit_width <= 8 , "bit_width must be 1 to 8"
@@ -51,6 +71,7 @@ def __init__(
51
71
self .bit_width = bit_width
52
72
self .group_size = group_size
53
73
self .has_weight_zeros = has_weight_zeros
74
+ self .target = target_from_str (target )
54
75
55
76
if not self .has_params_set ():
56
77
assert (
@@ -60,13 +81,14 @@ def __init__(
60
81
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False"
61
82
62
83
def extra_repr (self ):
63
- return f"group_size={ self .group_size } , bit_width={ self .bit_width } , has_weight_zeros={ self .has_weight_zeros } "
84
+ return f"group_size={ self .group_size } , bit_width={ self .bit_width } , has_weight_zeros={ self .has_weight_zeros } , target= { self . target } "
64
85
65
86
def has_params_set (self ) -> bool :
66
87
return (
67
88
(self .bit_width is not None )
68
89
and (self .group_size is not None )
69
90
and (self .has_weight_zeros is not None )
91
+ and (self .target is not None )
70
92
)
71
93
72
94
@@ -125,9 +147,11 @@ def from_plain(
125
147
scale : torch .Tensor ,
126
148
zero_point : Optional [torch .Tensor ],
127
149
layout : Layout ,
150
+ bias : Optional [torch .Tensor ] = None ,
128
151
):
129
152
assert isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout )
130
153
assert layout .has_params_set (), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
154
+ assert layout .target in {Target .NATIVE , Target .ATEN }, f"Unexpected target: { layout .target } "
131
155
132
156
# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
133
157
# when AOTI supports int
@@ -136,6 +160,13 @@ def from_plain(
136
160
n_tensor = torch .empty (0 , n , dtype = torch .int8 )
137
161
k_tensor = torch .empty (0 , k , dtype = torch .int8 )
138
162
163
+ if layout .target == Target .ATEN :
164
+ assert TORCH_VERSION_AT_LEAST_2_6 , f"aten target is requires torch version > 2.6.0"
165
+ int_data = int_data .add (8 )
166
+ int_data = (int_data [::,1 ::2 ] << 4 | int_data [::,::2 ] ).to (torch .uint8 )
167
+ packed_weight = torch .ops .aten ._dyn_quant_pack_4bit_weight (int_data , scale , bias , layout .group_size , k , n )
168
+ return cls (packed_weight , layout , group_size_tensor , n_tensor , k_tensor )
169
+
139
170
if layout .has_weight_zeros :
140
171
args = [
141
172
int_data .to (torch .int8 ),
@@ -211,16 +242,13 @@ def __tensor_unflatten__(
211
242
def _linear_check (input_tensor , weight_tensor , bias ):
212
243
layout = weight_tensor .tensor_impl .get_layout ()
213
244
return isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout ) and (
214
- bias is None
245
+ bias is None or layout . target == Target . ATEN # Aten target allows bias
215
246
)
216
247
217
248
218
249
def _linear_impl (input_tensor , weight_tensor , bias ):
219
- assert (
220
- bias is None
221
- ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl"
222
250
223
- def _impl_2d (input_tensor , weight_tensor ):
251
+ def _impl_2d_native (input_tensor , weight_tensor ):
224
252
assert input_tensor .dim () == 2
225
253
assert weight_tensor .dim () == 2
226
254
@@ -255,6 +283,31 @@ def _impl_2d(input_tensor, weight_tensor):
255
283
torch .ops .torchao , f"_linear_8bit_act_{ bit_width } bit{ wzp_suffix } _weight"
256
284
)(* args )
257
285
286
+ def _impl_2d_aten (input_tensor , weight_tensor ):
287
+ assert input_tensor .dim () == 2
288
+ assert weight_tensor .dim () == 2
289
+
290
+ m , k = input_tensor .shape
291
+ n , k_ = weight_tensor .shape
292
+ assert k_ == k
293
+ group_size = weight_tensor .tensor_impl .get_layout ().group_size
294
+ packed_weight = weight_tensor .tensor_impl .packed_weight
295
+ return torch .ops .aten ._dyn_quant_matmul_4bit (
296
+ input_tensor , packed_weight , group_size , k_ , n )
297
+
298
+ target = weight_tensor .tensor_impl .get_layout ().target
299
+
300
+ if target == Target .ATEN :
301
+ assert (
302
+ TORCH_VERSION_AT_LEAST_2_6 == 1
303
+ ), "Target.ATEN requires torch >= 2.6.0"
304
+ _impl_2d = _impl_2d_aten
305
+ elif target == Target .NATIVE :
306
+ _impl_2d = _impl_2d_native
307
+ assert (
308
+ bias is None
309
+ ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' "
310
+
258
311
if input_tensor .dim () == 2 :
259
312
return _impl_2d (input_tensor , weight_tensor )
260
313
@@ -268,7 +321,6 @@ def _impl_2d(input_tensor, weight_tensor):
268
321
res = res .reshape (* lead_shape , m , n )
269
322
return res
270
323
271
-
272
324
register_aqt_quantized_linear_dispatch (
273
325
_linear_check ,
274
326
_linear_impl ,
0 commit comments