Skip to content

Commit 0aa8dbd

Browse files
authored
Enable fp16+int4 mixed precission path for int4 xpu path with int zero point (#2240)
* Enable fp16 path for int4 xpu path with int zero point * Update int4_xpu_layout.py * Fix typo
1 parent 4d5f657 commit 0aa8dbd

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
from torchao.dtypes.uintx.int4_xpu_layout import (
4747
_linear_bf16_act_uint4_weight_float_zero_check,
4848
_linear_bf16_act_uint4_weight_float_zero_impl,
49-
_linear_bf16_act_uint4_weight_int8_zero_check,
50-
_linear_bf16_act_uint4_weight_int8_zero_impl,
49+
_linear_fp_act_uint4_weight_int8_zero_check,
50+
_linear_fp_act_uint4_weight_int8_zero_impl,
5151
)
5252
from torchao.dtypes.uintx.marlin_qqq_tensor import (
5353
_linear_int8_act_int4_weight_marlin_qqq_check,
@@ -240,8 +240,8 @@ def _register_aqt_quantized_linear_dispatches():
240240
_linear_q_dq_impl,
241241
),
242242
(
243-
_linear_bf16_act_uint4_weight_int8_zero_check,
244-
_linear_bf16_act_uint4_weight_int8_zero_impl,
243+
_linear_fp_act_uint4_weight_int8_zero_check,
244+
_linear_fp_act_uint4_weight_int8_zero_impl,
245245
),
246246
(
247247
_linear_bf16_act_uint4_weight_float_zero_check,
@@ -267,7 +267,6 @@ def _(func, types, args, kwargs):
267267
raise NotImplementedError(
268268
f"{func} is not implemented for non floating point input"
269269
)
270-
271270
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
272271
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
273272
# make the branches easier to understand in `_quantized_linear_op`

torchao/dtypes/uintx/int4_xpu_layout.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,26 +89,22 @@ def _linear_bf16_act_uint4_weight_float_zero_impl(input_tensor, weight_tensor, b
8989
return y.to(orig_dtype)
9090

9191

92-
def _linear_bf16_act_uint4_weight_int8_zero_check(input_tensor, weight_tensor, bias):
92+
def _linear_fp_act_uint4_weight_int8_zero_check(input_tensor, weight_tensor, bias):
9393
return (
94-
# input is native bfloat16 tensor
9594
not is_traceable_wrapper_subclass(input_tensor)
96-
and input_tensor.dtype == torch.bfloat16
9795
and
9896
# weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor
9997
isinstance(weight_tensor, AffineQuantizedTensor)
10098
and _aqt_is_xpu_layout_uint4(weight_tensor)
101-
and weight_tensor.dtype == torch.bfloat16
10299
and len(weight_tensor.shape) == 2
103100
and weight_tensor.zero_point_domain == ZeroPointDomain.INT
104101
and weight_tensor.tensor_impl.scale_and_zero is None
105-
and weight_tensor.tensor_impl.scale.dtype == torch.bfloat16
106102
and weight_tensor.tensor_impl.zero.dtype == torch.int8
107103
and isinstance(weight_tensor._layout, Int4XPULayout)
108104
)
109105

110106

111-
def _linear_bf16_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bias):
107+
def _linear_fp_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bias):
112108
assert weight_tensor.block_size[0] == 1, (
113109
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
114110
)
@@ -129,7 +125,7 @@ def _linear_bf16_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bi
129125
orig_act_size = act_mat.size()
130126
orig_dtype = act_mat.dtype
131127

132-
act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16)
128+
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
133129

134130
# groupwise int4 quantization
135131
groupsize = weight_tensor.block_size[1]

0 commit comments

Comments
 (0)