Skip to content

Commit abb309a

Browse files
xiaowangintelpytorchmergebot
authored andcommitted
Enable AWQ on Intel GPU. (#2248)
Following pytorch/pytorch#153019 requests, we enable awq-uint4 for Intel GPU in pytorch/ao after RTN ready. How to run awq quantization model: ```markdown cd torchao/prototype/awq python example.py --device xpu huggingface-model(such as meta-llama/Llama-3.1-8B-Instruct) awq-uint4-128 ``` #Results of meta-llama/Llama-3.1-8B-Instruct on Intel GPU: {'perplexity': {'perplexity': 10.099576950073242, 'prediction_time': 0.20489671968780787}} #Results of meta-llama/Llama-3.1-8B-Instruct on NVIDIA-A100 GPU: Results: {'perplexity': {'perplexity': 10.160041809082031, 'prediction_time': 0.4466673863672577}} Pull Request resolved: #2248 Approved by: https://github.com/liangan1, https://github.com/jerryzh168
1 parent c3c3163 commit abb309a

File tree

6 files changed

+51
-44
lines changed

6 files changed

+51
-44
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
135135
if TORCH_VERSION_AT_LEAST_2_5:
136136
if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))):
137137
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
138+
if check_xpu_version(w.device):
139+
w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8)
138140

139141
return w_int4x8
140142

@@ -730,6 +732,8 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
730732
not (check_xpu_version(input.device))
731733
):
732734
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
735+
if check_xpu_version(input.device):
736+
input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8)
733737
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
734738
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain
735739
)

torchao/dtypes/uintx/int4_xpu_layout.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,14 +242,15 @@ def from_plain(
242242
):
243243
assert isinstance(_layout, Int4XPULayout)
244244

245-
from torchao.quantization.utils import convert_weight_to_int4pack_xpu
246-
247245
if TORCH_VERSION_AT_LEAST_2_8:
248246
assert int_data.dtype == torch.int32, (
249247
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
250248
)
251-
packed_weight = convert_weight_to_int4pack_xpu(
252-
int_data, zero_point.dtype != scale.dtype
249+
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(
250+
torch.uint8
251+
)
252+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
253+
packed_weight.contiguous(), 8
253254
)
254255
else:
255256
assert False, "INT4 not supported on XPU until 2.8"
@@ -370,8 +371,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
370371

371372
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
372373
from torchao.quantization.quant_primitives import (
373-
ZeroPointDomain,
374374
quantize_affine,
375+
quantize_affine_float_zero_point,
375376
)
376377
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros
377378

@@ -394,7 +395,6 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
394395
quant_max = 15
395396
assert len(block_size) == 2 and block_size[0] == 1
396397
if self.scale_and_zero is None:
397-
zero_point_domain = ZeroPointDomain.INT
398398
dequantized = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
399399
torch.eye(eye_shape, device=device, dtype=original_dtype),
400400
self.packed_weight,
@@ -411,10 +411,8 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
411411
target_dtype,
412412
quant_min,
413413
quant_max,
414-
zero_point_domain,
415414
)
416415
else:
417-
zero_point_domain = ZeroPointDomain.FLOAT
418416
dequantized = torch.ops.aten._weight_int4pack_mm(
419417
torch.eye(eye_shape, device=device, dtype=original_dtype),
420418
self.packed_weight,
@@ -425,15 +423,14 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
425423
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
426424
scale = scale.reshape(scale.shape[:-1]).contiguous()
427425
zero = zero.reshape(zero.shape[:-1]).contiguous()
428-
int_data = quantize_affine(
426+
int_data = quantize_affine_float_zero_point(
429427
dequantized,
430428
block_size,
431429
scale,
432430
zero,
433431
target_dtype,
434432
quant_min,
435433
quant_max,
436-
zero_point_domain,
437434
)
438435
return int_data, scale, zero
439436

torchao/prototype/awq/api.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66
import types
77
from dataclasses import dataclass
8+
from typing import Optional
89

910
import torch
1011

1112
import torchao
1213
from torchao.core.config import AOBaseConfig
1314
from torchao.dtypes import (
15+
Int4XPULayout,
16+
Layout,
1417
TensorCoreTiledLayout,
1518
to_affine_quantized_intx,
1619
)
@@ -105,12 +108,14 @@ class AWQUIntXConfig(AOBaseConfig):
105108
106109
Args:
107110
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
111+
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
108112
group_size: Quantization granularity. Use -1 for channel wise quantization
109113
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
110114
set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
111115
"""
112116

113117
quant_dtype: torch.dtype = torch.uint4
118+
layout: Optional[Layout] = TensorCoreTiledLayout(inner_k_tiles=8)
114119
group_size: int = 64
115120
use_hqq: bool = False
116121
set_inductor_config: bool = True
@@ -142,9 +147,13 @@ def _awq_uintx_transform(
142147
target_dtype = torch.int32
143148
eps = 1e-6
144149
preserve_zero = False
145-
zero_point_dtype = torch.bfloat16
146-
zero_point_domain = ZeroPointDomain.FLOAT
147-
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
150+
_layout = config.layout
151+
if isinstance(_layout, Int4XPULayout):
152+
zero_point_dtype = torch.int8
153+
zero_point_domain = ZeroPointDomain.INT
154+
else:
155+
zero_point_dtype = torch.bfloat16
156+
zero_point_domain = ZeroPointDomain.FLOAT
148157
else:
149158
target_dtype = torch.uint8
150159
eps = torch.finfo(torch.float32).eps

torchao/prototype/awq/example.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tqdm import tqdm
1212
from transformers import AutoModelForCausalLM, AutoTokenizer
1313

14+
from torchao.dtypes import Int4XPULayout
1415
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
1516
from torchao.quantization import int4_weight_only, quantize_
1617

@@ -71,6 +72,8 @@ def wiki2_eval(
7172
log_likelihood = model(input_ids, labels=target_ids).loss * trg_len
7273
if device.startswith("cuda"):
7374
torch.cuda.synchronize()
75+
if device.startswith("xpu"):
76+
torch.xpu.synchronize()
7477
t2 = time.time()
7578
t.append((t2 - t1))
7679
lls.append(log_likelihood)
@@ -229,9 +232,14 @@ def wikitext2_ppl(
229232
use_hqq = "hqq" in quant
230233
print(f"running {quant_dtype} quantization")
231234
t0 = time.time()
235+
awq_uintx_config = awq_uintx(
236+
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
237+
)
238+
if "xpu" in device:
239+
awq_uintx_config.layout = Int4XPULayout()
232240
quantize_(
233241
model,
234-
awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq),
242+
awq_uintx_config,
235243
is_observed_linear,
236244
)
237245
print(f"time for quantization: {time.time() - t0:.02f} seconds")
@@ -242,7 +250,12 @@ def wikitext2_ppl(
242250
group_size = int(quant.split("-")[1])
243251
use_hqq = "hqq" in quant
244252
print(f"running {quant} quantization with group size {group_size}")
245-
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
253+
int4_weight_only_config = int4_weight_only(
254+
group_size=group_size, use_hqq=use_hqq
255+
)
256+
if "xpu" in device:
257+
int4_weight_only_config.layout = Int4XPULayout()
258+
quantize_(model, int4_weight_only_config)
246259
if compile:
247260
model = torch.compile(model)
248261

torchao/quantization/subclass.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -697,13 +697,6 @@ def to_qtensor_components(
697697
int_data = aten._convert_weight_to_int4pack_for_cpu(
698698
input_int4x8, inner_k_tiles
699699
)
700-
if check_xpu_version(input_float.device):
701-
from torchao.quantization.utils import convert_weight_to_int4pack_xpu
702-
703-
int_data = convert_weight_to_int4pack_xpu(
704-
input_int4x8,
705-
zero_point_domain_is_int=zero_point_domain == ZeroPointDomain.INT,
706-
)
707700
else:
708701
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
709702
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles

torchao/quantization/utils.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def cuda(self):
127127
val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values
128128
]
129129

130+
def xpu(self):
131+
self.values = [
132+
val.xpu() if isinstance(val, torch.Tensor) else val for val in self.values
133+
]
134+
130135

131136
def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
132137
if dtype is not None and tensor_arg.dtype != dtype:
@@ -415,25 +420,6 @@ def unpack_tinygemm_scales_and_zeros(scales_and_zeros):
415420
return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1)
416421

417422

418-
def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False):
419-
assert weight.device.type == "xpu"
420-
421-
if zero_point_domain_is_int:
422-
# int_data = weight.to(dtype=torch.uint8)
423-
int_data = (weight[::, 1::2] << 4 | weight[::, ::2]).to(torch.uint8)
424-
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
425-
int_data,
426-
8, # TODO:remove
427-
)
428-
else:
429-
out = weight.to(dtype=torch.uint8)
430-
out = (out[::, 1::2] << 4 | out[::, ::2]).to(torch.uint8)
431-
packed_weight = out.view(torch.int32)
432-
433-
# Second, N * K/2 uint8 -> N * K/8 int32
434-
return packed_weight
435-
436-
437423
def groupwise_affine_quantize_tensor_from_qparams(
438424
w, scales, zeros, n_bit=4, groupsize=128, zero_point_domain=ZeroPointDomain.FLOAT
439425
):
@@ -473,6 +459,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
473459
not (check_xpu_version(int_data.device))
474460
):
475461
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
462+
if check_xpu_version(int_data.device):
463+
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
476464
return int_data
477465

478466

@@ -491,7 +479,6 @@ def groupwise_affine_dequantize_tensor_from_qparams(
491479
TORCH_VERSION_AT_LEAST_2_5
492480
and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1)
493481
and not (check_cpu_version(w_int4x8.device))
494-
and not (check_xpu_version(w_int4x8.device))
495482
):
496483
data = w_int4x8.to(torch.int32)
497484
high_bits = data >> 4
@@ -501,8 +488,12 @@ def groupwise_affine_dequantize_tensor_from_qparams(
501488
dtype=torch.int32,
502489
device=w_int4x8.device,
503490
)
504-
w_int32[::, ::2] = high_bits
505-
w_int32[::, 1::2] = low_bits
491+
if not (check_xpu_version(w_int4x8.device)):
492+
w_int32[::, ::2] = high_bits
493+
w_int32[::, 1::2] = low_bits
494+
else:
495+
w_int32[::, ::2] = low_bits
496+
w_int32[::, 1::2] = high_bits
506497
else:
507498
w_int32 = w_int4x8
508499

0 commit comments

Comments
 (0)