-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
INT4 XPU enabling #1577
base: main
Are you sure you want to change the base?
INT4 XPU enabling #1577
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1577
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b096666 with merge base 3fb1665 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/integration/test_integration.py
Outdated
@@ -1079,6 +1084,8 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): | |||
layout_list = [] | |||
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: | |||
layout_list.append(Int4CPULayout()) | |||
elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_6: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here as well, 2_6 or 2_7?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2.7
|
||
__torch_function__ = torch._C._disabled_torch_function_impl | ||
|
||
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw for this one, we have some unpacking op for tensor core tiled layout that we should really be using:
ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu
Lines 311 to 312 in cf45336
m.impl("torchao::unpack_tensor_core_tiled_layout", &_unpack_tensor_core_tiled_layout); | |
m.impl("torchao::dequantize_tensor_core_tiled_layout", &_dequantize_tensor_core_tiled_layout); |
might be better to do the same instead of hacking with quantize ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure. I will give a check.
btw why the op is added in pytorch/pytorch#137566 instead of in torchao? any plans to move it to torchao? |
@mingfeima @EikanWang can you comment? |
The situation is different for XPU (the intel GPUs) from CPU and CUDA here. Not sure that whether providing sycl or oneDNN xpu ops in ao is a feasible solution. |
91067e2
to
895376f
Compare
@jerryzh168 pls review again· |
test/dtypes/test_affine_quantized.py
Outdated
_ = torch.load(f, weights_only=False) | ||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
# TODO(#1690): delete this once config migration is done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @vkuzo we can delete these now?
if self.scale_and_zero is not None: | ||
return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] | ||
else: | ||
return ["packed_weight", "scale", "zero"], [self.transposed, self._layout] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have two formats here? maybe should split into multiple layouts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
integer zp and floating zp
I don't split into 2 layout because from user side it will be confusing
current:
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT))
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.Float))
but if different layouts
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayoutIntZP(), zero_point_domain=ZeroPointDomain.INT))
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayoutFloatZP(), zero_point_domain=ZeroPointDomain.Float))
I think the current implementation is more straightforward for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but layout defines how we store the packed weights actually, using a single layout for multiple things is breaking this abstraction I feel
is the concern around specifying zero_point_domain multiple times? we could remove that and just infer the zero_point_domain from layout I think (the latter API)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since only XPU supports integer zp, can I move it in the next PR?
layout defines how we store the packed weights actually
it should include the layout of scales and zeros, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since only XPU supports integer zp, can I move it in the next PR?
what is this referring to?
layout defines how we store the packed weights actually it should include the layout of scales and zeros, right?
yeah that's correct, ideally I think we should not use layout to control whether we have packed weight / scale_and_zero / scale, zero, the duplication should actually happen in the tensor level (we create different tensor subclass tensors), not layout. feel free to go that route if want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I separate into different layouts, and bind the zero point domain into each layout in the next PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah sure
@@ -242,6 +247,11 @@ def matmul(self, x): | |||
c = torch.ops.aten._weight_int4pack_mm_for_cpu( | |||
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros | |||
) | |||
if is_device(x.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7 \ | |||
and not isinstance(self.scales_and_zeros, torch.Tensor): | |||
c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this supposed to match line 550 in GPTQ.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed hqq support in this PR to simply the logic
torchao/quantization/GPTQ.py
Outdated
@@ -546,6 +546,14 @@ def linear_forward_int4( | |||
groupsize, | |||
scales_and_zeros.to(scales_precision), | |||
).to(dtype=x.dtype) | |||
elif is_device(x.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7: | |||
c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also why do we have this function? can the slicing (scales_and_zeros[0] and scales_and_zeros[1]) be done in the _weight_int4pack_mm
itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed GPTQ support in this PR to simply the logic. I will open another PR after I seperate the layouts
@@ -166,6 +167,10 @@ def process_hqq_quants(self, W_q, meta): | |||
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( | |||
W_q_torch, self.inner_k_tiles | |||
) | |||
if is_device(W_q.device.type, "Xpu") and TORCH_VERSION_AT_LEAST_2_7: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "xpu"?
@@ -407,6 +407,8 @@ def _quantize_affine_no_dtype_cast( | |||
shape_after_reduction = shape_for_reduction | |||
for i in reduction_dims: | |||
shape_after_reduction[i] = 1 | |||
if shape_after_reduction[0] == 12288: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remve?
@@ -954,6 +956,7 @@ def _choose_qparams_affine( | |||
if preserve_zero: | |||
zero_point = quant_min - torch.round(min_val_neg / scale) | |||
zero_point = torch.clamp(zero_point, quant_min, quant_max) | |||
zero_point_dtype = torch.int32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this set here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact preserve_zero
and INT zero point domain couples here, I think it is duplicated someway
The reason for setting this parameter as an int is that many places calling this function use the default floating parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
preserve_zero talks about whether zero (in original floating point domain) is exactly representable or not, it's not coupled with zero point domain I think, even zero is exactly representation, we can still choose zero_point_domain to be in float
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, from the math side they are not related. but the code here implies this, see the condition dispatch from Line954 to 966. We need a refactor here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be better to do assert instead of changing condition if there is coupling?
torchao/quantization/utils.py
Outdated
@@ -315,7 +316,7 @@ def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float3 | |||
return dequantized | |||
|
|||
|
|||
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): | |||
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16, zero_point_domain_is_int=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'm wondering if we should just expose zero_point_domain as an arg directly here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which way do you prefer? how about
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, data_dtype=torch.bfloat16, scale_dtype=torch.bfloat16, zero_dtype=torch.bfloat16):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, this sounds good. what about preserve_zero and zero_point_domain? I don't think these can be fully inferred?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this going to be updated? I feel it's a bit weird to introduce a boolean flag for zero_point_domain when we can just pass zero_point_domain itself around
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
287d069 is it okay?
torchao/quantization/quant_api.py
Outdated
@@ -850,6 +860,7 @@ def _int4_weight_only_transform( | |||
zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] | |||
), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" | |||
|
|||
preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] if zero_point_domain!=ZeroPointDomain.INT else True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does zero_point_dtype
need to change after this is set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as the above. In fact preserve_zero and INT zero point domain couples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are three things, preserve_zero = {True, False}, zero_point_domain = {FLOAT, INT, NONE} and zero_point_dtype = {float, int, ...}
it's true that not all combinations are valid, but I don't think they are coupled, see
ao/torchao/quantization/quant_primitives.py
Lines 755 to 770 in 64bcf4c
preserve_zero (bool): a flag to indicate whether we need zero to be exactly | |
representable or not, this is typically required for ops that needs zero padding, like convolution | |
it's less important for ops that doesn't have zero padding in the op itself, like linear. | |
For example, given a floating point Tensor [1.2, 0.1, 3.0, 4.0, 0.4, 0], if `preserve_zero` is True, | |
we'll make sure there is a integer value corresponding to the floating point 0, e.g. [-3, -8, 3, 7, -7, -8], 0 will be mapped to `-8` without loss. But if `preserve_zero` is not True, there won't be such | |
gurantee. | |
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point | |
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float | |
if zero_point is in integer domain, zero point is added to the quantized integer value during | |
quantization | |
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) | |
value during quantization | |
default is ZeroPointDomain.INT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about 946f530, expose it as an independent argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks fine for now, but at some point we will probably add new tensor subclass tensors for these
The PR is a draft currently.The PR will add 2 kinds of INT4 support on XPU: floating zero points and integer zero points, following the discussion in #1264.
Integer zero points which has been natively supported via OneDNN pytorch/pytorch#137566
Floating zero points, the default behaviour in this repo, supported by intel/torch-xpu-ops#1130, more implementations on the way.