Skip to content

Enable AWQ on Intel GPU. #2248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed

Conversation

xiaowangintel
Copy link
Contributor

@xiaowangintel xiaowangintel commented May 23, 2025

Following pytorch/pytorch#153019 requests, we enable awq-uint4 for Intel GPU in pytorch/ao after RTN ready.

How to run awq quantization model:

cd torchao/prototype/awq

python example.py --device xpu  huggingface-model(such as meta-llama/Llama-3.1-8B-Instruct) awq-uint4-128

#Results of meta-llama/Llama-3.1-8B-Instruct on Intel GPU:
{'perplexity': {'perplexity': 10.099576950073242, 'prediction_time': 0.20489671968780787}}

#Results of meta-llama/Llama-3.1-8B-Instruct on NVIDIA-A100 GPU:
Results: {'perplexity': {'perplexity': 10.160041809082031, 'prediction_time': 0.4466673863672577}}

Copy link

pytorch-bot bot commented May 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2248

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f60041c with merge base d963a88 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 23, 2025
@xiaowangintel
Copy link
Contributor Author

@liangan1 Can you help to review this PR?

@liangan1
Copy link
Contributor

how about perplexity on cuda?

Copy link
Contributor

@liangan1 liangan1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@xiaowangintel xiaowangintel changed the title [WIP]Enable AWQ on Intel GPU. Enable AWQ on Intel GPU. May 23, 2025
@liangan1
Copy link
Contributor

@EikanWang

@@ -429,15 +428,14 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
scale = scale.reshape(scale.shape[:-1]).contiguous()
zero = zero.reshape(zero.shape[:-1]).contiguous()
int_data = quantize_affine(
int_data = quantize_affine_float_zero_point(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is actually specific to fbgemm I think, we'd need to rename in a future PR cc @jainapurva

Comment on lines 156 to 159
if "xpu" in device.type:
_layout = Int4XPULayout()
else:
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can layout be explicitly passed in instead of inferred from device?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be OK. We should follow the Int4WeightOnlyConfig to let user to specify the layout information.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, modified done.

@@ -114,6 +116,7 @@ class AWQUIntXConfig(AOBaseConfig):
group_size: int = 64
use_hqq: bool = False
set_inductor_config: bool = True
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.FLOAT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be removed if we have layout?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I agree with you. Following the logic of #2149, preserve_zero and zero_point_domain is too complex to be used in the user UX. It is better way to use layout to decide the zero_point_domain information.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, modified done.

from torchao.dtypes import Int4XPULayout


zero_point_domain_dict = {"float":ZeroPointDomain.FLOAT, "int":ZeroPointDomain.INT, "none":ZeroPointDomain.NONE}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, we used to use this for distinguish between different types of kernels, but now we are keeping the default path of integer zero point and preserve zero for the common path, and split out the other q/dq ops for specific kernels like tinygemm: #2149

I think it's just different ways to implement things and we not necessarily need to have these categorizations like zero_point_domain and preserve_zero since it might complicate the UX.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, modified done.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we can use layout as a user facing interface

@@ -473,6 +459,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
not (check_xpu_version(int_data.device))
):
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
if check_xpu_version(int_data.device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably encapsulate these better when we have a better design for layout conversions: #2249

@@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import types
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Tuple, Union
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Optional

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@xiaowangintel xiaowangintel requested a review from jerryzh168 May 29, 2025 09:22
@liangan1
Copy link
Contributor

@pytorchbot label topic: new feature

Copy link

pytorch-bot bot commented May 29, 2025

Didn't find following labels among repository labels: topic:,new,feature

@jerryzh168 jerryzh168 added the topic: new feature Use this tag if this PR adds a new feature label May 29, 2025
@liangan1
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants