Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INT4 XPU enabling #1577

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open

INT4 XPU enabling #1577

wants to merge 23 commits into from

Conversation

airMeng
Copy link
Collaborator

@airMeng airMeng commented Jan 17, 2025

The PR is a draft currently.

The PR will add 2 kinds of INT4 support on XPU: floating zero points and integer zero points, following the discussion in #1264.

Integer zero points which has been natively supported via OneDNN pytorch/pytorch#137566

Floating zero points, the default behaviour in this repo, supported by intel/torch-xpu-ops#1130, more implementations on the way.

Copy link

pytorch-bot bot commented Jan 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1577

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b096666 with merge base 3fb1665 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@airMeng airMeng marked this pull request as draft January 17, 2025 03:20
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 17, 2025
@@ -1079,6 +1084,8 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
layout_list = []
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
layout_list.append(Int4CPULayout())
elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_6:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here as well, 2_6 or 2_7?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2.7


__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Copy link
Contributor

@jerryzh168 jerryzh168 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw for this one, we have some unpacking op for tensor core tiled layout that we should really be using:

m.impl("torchao::unpack_tensor_core_tiled_layout", &_unpack_tensor_core_tiled_layout);
m.impl("torchao::dequantize_tensor_core_tiled_layout", &_dequantize_tensor_core_tiled_layout);

might be better to do the same instead of hacking with quantize ops

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. I will give a check.

@jerryzh168
Copy link
Contributor

btw why the op is added in pytorch/pytorch#137566 instead of in torchao? any plans to move it to torchao?

@airMeng
Copy link
Collaborator Author

airMeng commented Jan 17, 2025

btw why the op is added in pytorch/pytorch#137566 instead of in torchao? any plans to move it to torchao?

@mingfeima @EikanWang can you comment?

@mingfeima
Copy link

btw why the op is added in pytorch/pytorch#137566 instead of in torchao? any plans to move it to torchao?

@mingfeima @EikanWang can you comment?

The situation is different for XPU (the intel GPUs) from CPU and CUDA here. Not sure that whether providing sycl or oneDNN xpu ops in ao is a feasible solution.

@airMeng airMeng force-pushed the xpu_int4 branch 2 times, most recently from 91067e2 to 895376f Compare February 24, 2025 01:40
@airMeng airMeng marked this pull request as ready for review February 26, 2025 09:53
@sunjiweiswift
Copy link
Contributor

@jerryzh168 pls review again·

_ = torch.load(f, weights_only=False)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
# TODO(#1690): delete this once config migration is done
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @vkuzo we can delete these now?

Comment on lines +211 to +214
if self.scale_and_zero is not None:
return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout]
else:
return ["packed_weight", "scale", "zero"], [self.transposed, self._layout]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have two formats here? maybe should split into multiple layouts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

integer zp and floating zp
I don't split into 2 layout because from user side it will be confusing

current:

quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT))
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.Float))

but if different layouts

quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayoutIntZP(), zero_point_domain=ZeroPointDomain.INT))
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayoutFloatZP(), zero_point_domain=ZeroPointDomain.Float))

I think the current implementation is more straightforward for users.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but layout defines how we store the packed weights actually, using a single layout for multiple things is breaking this abstraction I feel

is the concern around specifying zero_point_domain multiple times? we could remove that and just infer the zero_point_domain from layout I think (the latter API)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since only XPU supports integer zp, can I move it in the next PR?
layout defines how we store the packed weights actually it should include the layout of scales and zeros, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since only XPU supports integer zp, can I move it in the next PR?

what is this referring to?

layout defines how we store the packed weights actually it should include the layout of scales and zeros, right?

yeah that's correct, ideally I think we should not use layout to control whether we have packed weight / scale_and_zero / scale, zero, the duplication should actually happen in the tensor level (we create different tensor subclass tensors), not layout. feel free to go that route if want.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I separate into different layouts, and bind the zero point domain into each layout in the next PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sure

@@ -242,6 +247,11 @@ def matmul(self, x):
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
)
if is_device(x.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7 \
and not isinstance(self.scales_and_zeros, torch.Tensor):
c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
Copy link
Contributor

@jerryzh168 jerryzh168 Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this supposed to match line 550 in GPTQ.py?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed hqq support in this PR to simply the logic

@@ -546,6 +546,14 @@ def linear_forward_int4(
groupsize,
scales_and_zeros.to(scales_precision),
).to(dtype=x.dtype)
elif is_device(x.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7:
c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also why do we have this function? can the slicing (scales_and_zeros[0] and scales_and_zeros[1]) be done in the _weight_int4pack_mm itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed GPTQ support in this PR to simply the logic. I will open another PR after I seperate the layouts

@airMeng airMeng requested a review from jerryzh168 March 19, 2025 06:46
@@ -166,6 +167,10 @@ def process_hqq_quants(self, W_q, meta):
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
W_q_torch, self.inner_k_tiles
)
if is_device(W_q.device.type, "Xpu") and TORCH_VERSION_AT_LEAST_2_7:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "xpu"?

@@ -407,6 +407,8 @@ def _quantize_affine_no_dtype_cast(
shape_after_reduction = shape_for_reduction
for i in reduction_dims:
shape_after_reduction[i] = 1
if shape_after_reduction[0] == 12288:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remve?

@@ -954,6 +956,7 @@ def _choose_qparams_affine(
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
zero_point_dtype = torch.int32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this set here?

Copy link
Collaborator Author

@airMeng airMeng Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact preserve_zero and INT zero point domain couples here, I think it is duplicated someway
The reason for setting this parameter as an int is that many places calling this function use the default floating parameter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preserve_zero talks about whether zero (in original floating point domain) is exactly representable or not, it's not coupled with zero point domain I think, even zero is exactly representation, we can still choose zero_point_domain to be in float

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, from the math side they are not related. but the code here implies this, see the condition dispatch from Line954 to 966. We need a refactor here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to do assert instead of changing condition if there is coupling?

@@ -315,7 +316,7 @@ def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float3
return dequantized


def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16, zero_point_domain_is_int=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'm wondering if we should just expose zero_point_domain as an arg directly here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which way do you prefer? how about

def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, data_dtype=torch.bfloat16, scale_dtype=torch.bfloat16, zero_dtype=torch.bfloat16):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, this sounds good. what about preserve_zero and zero_point_domain? I don't think these can be fully inferred?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this going to be updated? I feel it's a bit weird to introduce a boolean flag for zero_point_domain when we can just pass zero_point_domain itself around

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

287d069 is it okay?

@@ -850,6 +860,7 @@ def _int4_weight_only_transform(
zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)]
), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"

preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] if zero_point_domain!=ZeroPointDomain.INT else True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does zero_point_dtype need to change after this is set

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the above. In fact preserve_zero and INT zero point domain couples.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are three things, preserve_zero = {True, False}, zero_point_domain = {FLOAT, INT, NONE} and zero_point_dtype = {float, int, ...}

it's true that not all combinations are valid, but I don't think they are coupled, see

preserve_zero (bool): a flag to indicate whether we need zero to be exactly
representable or not, this is typically required for ops that needs zero padding, like convolution
it's less important for ops that doesn't have zero padding in the op itself, like linear.
For example, given a floating point Tensor [1.2, 0.1, 3.0, 4.0, 0.4, 0], if `preserve_zero` is True,
we'll make sure there is a integer value corresponding to the floating point 0, e.g. [-3, -8, 3, 7, -7, -8], 0 will be mapped to `-8` without loss. But if `preserve_zero` is not True, there won't be such
gurantee.
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
, zero_point_domain might be something we can remove though. cc @jainapurva

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about 946f530, expose it as an independent argument?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this makes sense

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine for now, but at some point we will probably add new tensor subclass tensors for these

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants