Skip to content

Commit 4f11061

Browse files
committed
Enable AWQ on Intel GPU.
1 parent 45e545e commit 4f11061

File tree

5 files changed

+42
-38
lines changed

5 files changed

+42
-38
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _groupwise_affine_quantize_tensor_from_qparams(
135135
if TORCH_VERSION_AT_LEAST_2_5:
136136
if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))):
137137
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
138-
if (check_xpu_version(w.device)):
138+
if check_xpu_version(w.device):
139139
w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8)
140140

141141
return w_int4x8
@@ -732,7 +732,7 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
732732
not (check_xpu_version(input.device))
733733
):
734734
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
735-
if (check_xpu_version(input.device)):
735+
if check_xpu_version(input.device):
736736
input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8)
737737
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
738738
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain

torchao/dtypes/uintx/int4_xpu_layout.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def _linear_bf16_act_uint4_weight_float_zero_check(input_tensor, weight_tensor,
5050

5151

5252
def _linear_bf16_act_uint4_weight_float_zero_impl(input_tensor, weight_tensor, bias):
53-
assert weight_tensor.block_size[0] == 1, (
54-
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
55-
)
53+
assert (
54+
weight_tensor.block_size[0] == 1
55+
), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
5656
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
5757
f"need input_tensor shape: {input_tensor.shape} final"
5858
f"dim to match weight_tensor shape: {weight_tensor.shape} second dim "
@@ -105,9 +105,9 @@ def _linear_fp_act_uint4_weight_int8_zero_check(input_tensor, weight_tensor, bia
105105

106106

107107
def _linear_fp_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bias):
108-
assert weight_tensor.block_size[0] == 1, (
109-
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
110-
)
108+
assert (
109+
weight_tensor.block_size[0] == 1
110+
), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
111111
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
112112
f"need input_tensor shape: {input_tensor.shape} final"
113113
f"dim to match weight_tensor shape: {weight_tensor.shape} second dim "
@@ -243,9 +243,9 @@ def from_plain(
243243
assert isinstance(_layout, Int4XPULayout)
244244

245245
if TORCH_VERSION_AT_LEAST_2_8:
246-
assert int_data.dtype == torch.int32, (
247-
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
248-
)
246+
assert (
247+
int_data.dtype == torch.int32
248+
), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
249249
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(
250250
torch.uint8
251251
)

torchao/prototype/awq/api.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
AWQObserver,
3939
)
4040

41-
assert len(_DTYPE_TO_BIT_WIDTH) > 0, (
42-
"Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+"
43-
)
41+
assert (
42+
len(_DTYPE_TO_BIT_WIDTH) > 0
43+
), "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+"
4444

4545

4646
def insert_awq_observer_(
@@ -63,9 +63,9 @@ def insert_awq_observer_(
6363
group_size: Quantization granularity. Use -1 for channel wise quantization
6464
"""
6565
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
66-
assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, (
67-
"Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
68-
)
66+
assert (
67+
quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8
68+
), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
6969
# AQT config
7070
mapping_type = MappingType.ASYMMETRIC
7171
quantization_granularity = PerGroup(group_size)
@@ -137,10 +137,10 @@ def _awq_uintx_transform(
137137
torchao.quantization.utils.recommended_inductor_config_setter()
138138
observed_linear = module
139139

140-
assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, (
141-
"Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
142-
)
143-
140+
assert (
141+
quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8
142+
), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
143+
144144
equalization_scale = observed_linear.act_obs.calculate_qparams()
145145
# AQT config
146146
if quant_dtype == torch.uint4:

torchao/prototype/awq/example.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,9 @@ def wikitext2_ppl(
232232
use_hqq = "hqq" in quant
233233
print(f"running {quant_dtype} quantization")
234234
t0 = time.time()
235-
awq_uintx_config = awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq)
235+
awq_uintx_config = awq_uintx(
236+
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
237+
)
236238
if "xpu" in device:
237239
awq_uintx_config.layout = Int4XPULayout()
238240
quantize_(
@@ -248,7 +250,9 @@ def wikitext2_ppl(
248250
group_size = int(quant.split("-")[1])
249251
use_hqq = "hqq" in quant
250252
print(f"running {quant} quantization with group size {group_size}")
251-
int4_weight_only_config = int4_weight_only(group_size=group_size, use_hqq=use_hqq)
253+
int4_weight_only_config = int4_weight_only(
254+
group_size=group_size, use_hqq=use_hqq
255+
)
252256
if "xpu" in device:
253257
int4_weight_only_config.layout = Int4XPULayout()
254258
quantize_(model, int4_weight_only_config)

torchao/quantization/utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -231,19 +231,20 @@ def quant_int8_per_token_matmul(
231231
Y_i_j_fp32 = sx * sw dot(X_i, W_j)
232232
"""
233233

234-
assert x_vals_int8.dtype == torch.int8, (
235-
f"x dtype {x_vals_int8.dtype} not yet supported"
236-
)
237-
assert w_vals_int8_t.dtype == torch.int8, (
238-
f"w dtype {w_vals_int8_t.dtype} not yet supported"
239-
)
240-
241-
assert x_scales.dtype in [
242-
torch.float,
243-
torch.bfloat16,
244-
], (
245-
f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}"
246-
)
234+
assert (
235+
x_vals_int8.dtype == torch.int8
236+
), f"x dtype {x_vals_int8.dtype} not yet supported"
237+
assert (
238+
w_vals_int8_t.dtype == torch.int8
239+
), f"w dtype {w_vals_int8_t.dtype} not yet supported"
240+
241+
assert (
242+
x_scales.dtype
243+
in [
244+
torch.float,
245+
torch.bfloat16,
246+
]
247+
), f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}"
247248

248249
#
249250
# 1. do the matrix form of dot(X_i, W_j)
@@ -488,8 +489,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
488489
dtype=torch.int32,
489490
device=w_int4x8.device,
490491
)
491-
if (not (check_xpu_version(w_int4x8.device))
492-
):
492+
if not (check_xpu_version(w_int4x8.device)):
493493
w_int32[::, ::2] = high_bits
494494
w_int32[::, 1::2] = low_bits
495495
else:

0 commit comments

Comments
 (0)