Skip to content

Commit 890438a

Browse files
committed
Match QAT prepare and convert numerics exactly
**Summary:** Previously, `Int8DynActInt4QATQuantizer` had slightly diverging numerics between the prepare and convert steps. This is because the prepare step uses quantization primitives shared with AQT (specifically `quantize_affine` and `dequantize_affine`), while the convert step relies on old ops from the `torch.ops.quantized_decomposed` namespace. The diverging numerics is negligible for small models, but the quantization errors begin to compound for larger models with many linear layers. More specifically, there are three different places where the divergence occurs during activation quantization: 1. **Choose qparams.** The prepare step casts the qparams to `torch.float32`, whereas the convert step casts the scales to `torch.float64` and zero points to `torch.int64`. 2. **Quantize.** The prepare step performs round before adding zero points and uses torch functions, while the convert step adds before rounding and uses torch tensor methods. ``` x = torch.clamp( torch.round(x * (1.0 / scale)) + zero_point, qmin, qmax, ) x = ( x.mul(1.0 / scale) .add(zero_point) .round() .clamp(qmin, qmax) .to(quantize_dtype) ) ``` 3. **Dequantize.** The prepare step casts to `torch.int32` before adding the zero points, and casts back to the original dtype before multiplying the scale. The convert step only casts at the very end. ``` x = x.to(torch.int32) - zero_point.to(torch.int32) x = x.to(orig_dtype) x = x * scale x = x - zero_point x = x * scale x = x.to(orig_dtype) ``` This commit makes the convert path use the same torchao quantization primitives as the prepare path, thereby resolving the 3 above differences. Now, the prepare and convert steps match exactly in terms of numerics over many trials. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
1 parent dfbd681 commit 890438a

File tree

3 files changed

+109
-24
lines changed

3 files changed

+109
-24
lines changed

test/quantization/test_qat.py

+71
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,18 @@ def forward(self, x):
133133
return x
134134

135135

136+
class M4(torch.nn.Module):
137+
def __init__(self):
138+
super().__init__()
139+
self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float)
140+
141+
def example_inputs(self):
142+
return (torch.randn(1, 512).to(torch.float),)
143+
144+
def forward(self, x):
145+
return self.linear(x)
146+
147+
136148
class ModelWithLinearBias(torch.nn.Module):
137149
def __init__(self):
138150
super().__init__()
@@ -1389,6 +1401,65 @@ def test_qat_linear_bias(self):
13891401
example_inputs = m.example_inputs()
13901402
m(*example_inputs)
13911403

1404+
@unittest.skipIf(
1405+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1406+
)
1407+
def test_fake_quantize_per_token_vs_convert(self):
1408+
"""
1409+
Test that the following produce the exact same numerics:
1410+
1. FakeQuantizer with asymmetric per_token config
1411+
2. torchao.quantization.utils.per_token_dynamic_quant
1412+
"""
1413+
from torchao.quantization.utils import per_token_dynamic_quant
1414+
1415+
torch.manual_seed(self.SEED)
1416+
x = torch.randn(1, 235, 2048)
1417+
config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
1418+
fake_quantizer = FakeQuantizer(config)
1419+
fake_quantizer_out = fake_quantizer(x)
1420+
baseline_out = per_token_dynamic_quant(x)
1421+
torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0)
1422+
1423+
@unittest.skipIf(
1424+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1425+
)
1426+
def test_qat_8da4w_prepare_vs_convert(self):
1427+
"""
1428+
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
1429+
numerics that match exactly over N trials.
1430+
"""
1431+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
1432+
from torchao.quantization.utils import compute_error
1433+
1434+
num_trials = 1000
1435+
group_size = 16
1436+
non_inf_sqnr = []
1437+
1438+
for seed in range(self.SEED, self.SEED + num_trials):
1439+
torch.manual_seed(seed)
1440+
m = M4()
1441+
torch.manual_seed(seed)
1442+
x = m.example_inputs()
1443+
1444+
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
1445+
prepared = quantizer.prepare(m)
1446+
prepared_out = prepared(*x)
1447+
converted = quantizer.convert(prepared)
1448+
converted_out = converted(*x)
1449+
sqnr = compute_error(prepared_out, converted_out).item()
1450+
if sqnr != float("inf"):
1451+
non_inf_sqnr.append(sqnr)
1452+
1453+
avg_sqnr = (
1454+
sum(non_inf_sqnr) / len(non_inf_sqnr) if len(non_inf_sqnr) > 0 else -1
1455+
)
1456+
fail_message = "%s/%s trials did not match exactly, average sqnr = %s" % (
1457+
len(non_inf_sqnr),
1458+
num_trials,
1459+
avg_sqnr,
1460+
)
1461+
self.assertEqual(len(non_inf_sqnr), 0, fail_message)
1462+
13921463

13931464
if __name__ == "__main__":
13941465
unittest.main()

torchao/_executorch_ops.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77

8+
# TODO: delete these ops
9+
810

911
def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs):
1012
"""

torchao/quantization/utils.py

+36-24
Original file line numberDiff line numberDiff line change
@@ -539,36 +539,48 @@ def group_quantize_tensor_symmetric(
539539
return w_int8, scales, zeros
540540

541541

542-
def per_token_dynamic_quant(input: torch.Tensor) -> torch.Tensor:
543-
orig_dtype = input.dtype
544-
# TODO: we may need to make the choose_qparams op configurable
545-
from torchao._executorch_ops import (
546-
_quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper,
547-
)
548-
549-
(
550-
scales,
551-
zero_points,
552-
) = _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(
553-
input, torch.int8
554-
)
555-
556-
# TODO: get these from torch.int8
542+
def per_token_dynamic_quant(
543+
input: torch.Tensor,
544+
scale_dtype: torch.dtype = torch.float32,
545+
zero_point_dtype: torch.dtype = torch.float32,
546+
) -> torch.Tensor:
547+
mapping_type = MappingType.ASYMMETRIC
548+
block_size = _get_per_token_block_size(input)
557549
quant_min = -128
558550
quant_max = 127
559-
from torchao._executorch_ops import _quantized_decomposed_quantize_per_token_wrapper
551+
quant_dtype = torch.int8
552+
output_dtype = input.dtype
560553

561-
input = _quantized_decomposed_quantize_per_token_wrapper(
562-
input, scales, zero_points, quant_min, quant_max, torch.int8
554+
scales, zero_points = choose_qparams_affine(
555+
input,
556+
mapping_type,
557+
block_size,
558+
quant_dtype,
559+
quant_min,
560+
quant_max,
561+
scale_dtype=scale_dtype,
562+
zero_point_dtype=zero_point_dtype,
563563
)
564-
from torchao._executorch_ops import (
565-
_quantized_decomposed_dequantize_per_token_wrapper,
564+
q = quantize_affine(
565+
input,
566+
block_size,
567+
scales,
568+
zero_points,
569+
quant_dtype,
570+
quant_min,
571+
quant_max,
566572
)
567-
568-
input = _quantized_decomposed_dequantize_per_token_wrapper(
569-
input, scales, zero_points, quant_min, quant_max, torch.int8, orig_dtype
573+
dq = dequantize_affine(
574+
q,
575+
block_size,
576+
scales,
577+
zero_points,
578+
quant_dtype,
579+
quant_min,
580+
quant_max,
581+
output_dtype=output_dtype,
570582
)
571-
return input.to(orig_dtype)
583+
return dq
572584

573585

574586
def recommended_inductor_config_setter():

0 commit comments

Comments
 (0)