Skip to content

Enable AWQ on Intel GPU. #2248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
if TORCH_VERSION_AT_LEAST_2_5:
if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
if check_xpu_version(w.device):
w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8)

return w_int4x8

Expand Down Expand Up @@ -730,6 +732,8 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
not (check_xpu_version(input.device))
):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
if check_xpu_version(input.device):
input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain
)
Expand Down
17 changes: 7 additions & 10 deletions torchao/dtypes/uintx/int4_xpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,15 @@ def from_plain(
):
assert isinstance(_layout, Int4XPULayout)

from torchao.quantization.utils import convert_weight_to_int4pack_xpu

if TORCH_VERSION_AT_LEAST_2_8:
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
)
packed_weight = convert_weight_to_int4pack_xpu(
int_data, zero_point.dtype != scale.dtype
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(
torch.uint8
)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8
)
else:
assert False, "INT4 not supported on XPU until 2.8"
Expand Down Expand Up @@ -370,8 +371,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
quantize_affine,
quantize_affine_float_zero_point,
)
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros

Expand All @@ -394,7 +395,6 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
quant_max = 15
assert len(block_size) == 2 and block_size[0] == 1
if self.scale_and_zero is None:
zero_point_domain = ZeroPointDomain.INT
dequantized = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
torch.eye(eye_shape, device=device, dtype=original_dtype),
self.packed_weight,
Expand All @@ -411,10 +411,8 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
else:
zero_point_domain = ZeroPointDomain.FLOAT
dequantized = torch.ops.aten._weight_int4pack_mm(
torch.eye(eye_shape, device=device, dtype=original_dtype),
self.packed_weight,
Expand All @@ -425,15 +423,14 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
scale = scale.reshape(scale.shape[:-1]).contiguous()
zero = zero.reshape(zero.shape[:-1]).contiguous()
int_data = quantize_affine(
int_data = quantize_affine_float_zero_point(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is actually specific to fbgemm I think, we'd need to rename in a future PR cc @jainapurva

dequantized,
block_size,
scale,
zero,
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
return int_data, scale, zero

Expand Down
15 changes: 12 additions & 3 deletions torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
# LICENSE file in the root directory of this source tree.
import types
from dataclasses import dataclass
from typing import Optional

import torch

import torchao
from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
Int4XPULayout,
Layout,
TensorCoreTiledLayout,
to_affine_quantized_intx,
)
Expand Down Expand Up @@ -105,12 +108,14 @@ class AWQUIntXConfig(AOBaseConfig):

Args:
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
group_size: Quantization granularity. Use -1 for channel wise quantization
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
"""

quant_dtype: torch.dtype = torch.uint4
layout: Optional[Layout] = TensorCoreTiledLayout(inner_k_tiles=8)
group_size: int = 64
use_hqq: bool = False
set_inductor_config: bool = True
Expand Down Expand Up @@ -142,9 +147,13 @@ def _awq_uintx_transform(
target_dtype = torch.int32
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
_layout = config.layout
if isinstance(_layout, Int4XPULayout):
zero_point_dtype = torch.int8
zero_point_domain = ZeroPointDomain.INT
else:
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
else:
target_dtype = torch.uint8
eps = torch.finfo(torch.float32).eps
Expand Down
17 changes: 15 additions & 2 deletions torchao/prototype/awq/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from torchao.dtypes import Int4XPULayout
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
from torchao.quantization import int4_weight_only, quantize_

Expand Down Expand Up @@ -71,6 +72,8 @@ def wiki2_eval(
log_likelihood = model(input_ids, labels=target_ids).loss * trg_len
if device.startswith("cuda"):
torch.cuda.synchronize()
if device.startswith("xpu"):
torch.xpu.synchronize()
t2 = time.time()
t.append((t2 - t1))
lls.append(log_likelihood)
Expand Down Expand Up @@ -229,9 +232,14 @@ def wikitext2_ppl(
use_hqq = "hqq" in quant
print(f"running {quant_dtype} quantization")
t0 = time.time()
awq_uintx_config = awq_uintx(
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
)
if "xpu" in device:
awq_uintx_config.layout = Int4XPULayout()
quantize_(
model,
awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq),
awq_uintx_config,
is_observed_linear,
)
print(f"time for quantization: {time.time() - t0:.02f} seconds")
Expand All @@ -242,7 +250,12 @@ def wikitext2_ppl(
group_size = int(quant.split("-")[1])
use_hqq = "hqq" in quant
print(f"running {quant} quantization with group size {group_size}")
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
int4_weight_only_config = int4_weight_only(
group_size=group_size, use_hqq=use_hqq
)
if "xpu" in device:
int4_weight_only_config.layout = Int4XPULayout()
quantize_(model, int4_weight_only_config)
if compile:
model = torch.compile(model)

Expand Down
7 changes: 0 additions & 7 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,13 +697,6 @@ def to_qtensor_components(
int_data = aten._convert_weight_to_int4pack_for_cpu(
input_int4x8, inner_k_tiles
)
if check_xpu_version(input_float.device):
from torchao.quantization.utils import convert_weight_to_int4pack_xpu

int_data = convert_weight_to_int4pack_xpu(
input_int4x8,
zero_point_domain_is_int=zero_point_domain == ZeroPointDomain.INT,
)
else:
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles
35 changes: 13 additions & 22 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def cuda(self):
val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values
]

def xpu(self):
self.values = [
val.xpu() if isinstance(val, torch.Tensor) else val for val in self.values
]


def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
if dtype is not None and tensor_arg.dtype != dtype:
Expand Down Expand Up @@ -415,25 +420,6 @@ def unpack_tinygemm_scales_and_zeros(scales_and_zeros):
return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1)


def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False):
assert weight.device.type == "xpu"

if zero_point_domain_is_int:
# int_data = weight.to(dtype=torch.uint8)
int_data = (weight[::, 1::2] << 4 | weight[::, ::2]).to(torch.uint8)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
int_data,
8, # TODO:remove
)
else:
out = weight.to(dtype=torch.uint8)
out = (out[::, 1::2] << 4 | out[::, ::2]).to(torch.uint8)
packed_weight = out.view(torch.int32)

# Second, N * K/2 uint8 -> N * K/8 int32
return packed_weight


def groupwise_affine_quantize_tensor_from_qparams(
w, scales, zeros, n_bit=4, groupsize=128, zero_point_domain=ZeroPointDomain.FLOAT
):
Expand Down Expand Up @@ -473,6 +459,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
not (check_xpu_version(int_data.device))
):
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
if check_xpu_version(int_data.device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably encapsulate these better when we have a better design for layout conversions: #2249

int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
return int_data


Expand All @@ -491,7 +479,6 @@ def groupwise_affine_dequantize_tensor_from_qparams(
TORCH_VERSION_AT_LEAST_2_5
and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1)
and not (check_cpu_version(w_int4x8.device))
and not (check_xpu_version(w_int4x8.device))
):
data = w_int4x8.to(torch.int32)
high_bits = data >> 4
Expand All @@ -501,8 +488,12 @@ def groupwise_affine_dequantize_tensor_from_qparams(
dtype=torch.int32,
device=w_int4x8.device,
)
w_int32[::, ::2] = high_bits
w_int32[::, 1::2] = low_bits
if not (check_xpu_version(w_int4x8.device)):
w_int32[::, ::2] = high_bits
w_int32[::, 1::2] = low_bits
else:
w_int32[::, ::2] = low_bits
w_int32[::, 1::2] = high_bits
else:
w_int32 = w_int4x8

Expand Down
Loading