Skip to content

Enable fp16+int4 mixed precission path for int4 xpu path with int zero point #2240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
from torchao.dtypes.uintx.int4_xpu_layout import (
_linear_bf16_act_uint4_weight_float_zero_check,
_linear_bf16_act_uint4_weight_float_zero_impl,
_linear_bf16_act_uint4_weight_int8_zero_check,
_linear_bf16_act_uint4_weight_int8_zero_impl,
_linear_fp_act_uint4_weight_int8_zero_check,
_linear_fp_act_uint4_weight_int8_zero_impl,
)
from torchao.dtypes.uintx.marlin_qqq_tensor import (
_linear_int8_act_int4_weight_marlin_qqq_check,
Expand Down Expand Up @@ -235,8 +235,8 @@ def _register_aqt_quantized_linear_dispatches():
_linear_q_dq_impl,
),
(
_linear_bf16_act_uint4_weight_int8_zero_check,
_linear_bf16_act_uint4_weight_int8_zero_impl,
_linear_fp_act_uint4_weight_int8_zero_check,
_linear_fp_act_uint4_weight_int8_zero_impl,
),
(
_linear_bf16_act_uint4_weight_float_zero_check,
Expand All @@ -262,7 +262,6 @@ def _(func, types, args, kwargs):
raise NotImplementedError(
f"{func} is not implemented for non floating point input"
)

# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
Expand Down
10 changes: 3 additions & 7 deletions torchao/dtypes/uintx/int4_xpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,22 @@ def _linear_bf16_act_uint4_weight_float_zero_impl(input_tensor, weight_tensor, b
return y.to(orig_dtype)


def _linear_bf16_act_uint4_weight_int8_zero_check(input_tensor, weight_tensor, bias):
def _linear_fp_act_uint4_weight_int8_zero_check(input_tensor, weight_tensor, bias):
return (
# input is native bfloat16 tensor
not is_traceable_wrapper_subclass(input_tensor)
and input_tensor.dtype == torch.bfloat16
and
# weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor
isinstance(weight_tensor, AffineQuantizedTensor)
and _aqt_is_xpu_layout_uint4(weight_tensor)
and weight_tensor.dtype == torch.bfloat16
and len(weight_tensor.shape) == 2
and weight_tensor.zero_point_domain == ZeroPointDomain.INT
and weight_tensor.tensor_impl.scale_and_zero is None
and weight_tensor.tensor_impl.scale.dtype == torch.bfloat16
and weight_tensor.tensor_impl.zero.dtype == torch.int8
and isinstance(weight_tensor._layout, Int4XPULayout)
)


def _linear_bf16_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bias):
def _linear_fp_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bias):
assert weight_tensor.block_size[0] == 1, (
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
)
Expand All @@ -129,7 +125,7 @@ def _linear_bf16_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bi
orig_act_size = act_mat.size()
orig_dtype = act_mat.dtype

act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16)
act_mat = act_mat.reshape(-1, act_mat.shape[-1])

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]
Expand Down
Loading