Skip to content

Match QAT prepare and convert numerics exactly #1964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2025

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Mar 26, 2025

Summary: Previously, Int8DynActInt4QATQuantizer had slightly diverging numerics between the prepare and convert steps. This is because the prepare step uses quantization primitives shared with AQT (specifically quantize_affine and dequantize_affine), while the convert step relies on old ops from the torch.ops.quantized_decomposed namespace. The diverging numerics is negligible for small models, but the quantization errors begin to compound for larger models with many linear layers.

More specifically, there are three different places where the divergence occurs during activation quantization:

  1. Choose qparams. The prepare step casts the qparams to torch.float32, whereas the convert step casts the scales to torch.float64 and zero points to torch.int64.

  2. Quantize. The prepare step performs round before adding zero points and uses torch functions, while the convert step adds before rounding and uses torch tensor methods.

# Prepare
x = torch.clamp(
    torch.round(x * (1.0 / scale)) + zero_point, qmin, qmax,
)

# Convert
x = ( 
    x.mul(1.0 / scale)
    .add(zero_point)
    .round()
    .clamp(qmin, qmax)
    .to(quantize_dtype)
)
  1. Dequantize. The prepare step casts to torch.int32 before adding the zero points, and casts back to the original dtype before multiplying the scale. The convert step only casts at the very end.
# Prepare
x = x.to(torch.int32) - zero_point.to(torch.int32)
x = x.to(orig_dtype)
x = x * scale

# Convert
x = x - zero_point
x = x * scale 
x = x.to(orig_dtype)

This commit makes the convert path use the same torchao quantization primitives as the prepare path, thereby resolving the 3 above differences. Now, the prepare and convert steps match exactly in terms of numerics over many trials.

Test Plan:

python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert

@andrewor14 andrewor14 requested a review from jerryzh168 March 26, 2025 20:47
Copy link

pytorch-bot bot commented Mar 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1964

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 4c45344 with merge base dfbd681 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2025
@andrewor14 andrewor14 force-pushed the qat-8da4w-match-prepare-convert branch from 1fe118a to df409ed Compare March 26, 2025 21:04
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Mar 26, 2025
@andrewor14 andrewor14 force-pushed the qat-8da4w-match-prepare-convert branch 3 times, most recently from cb9942c to 709feab Compare March 27, 2025 14:41
@andrewor14 andrewor14 requested a review from jerryzh168 March 27, 2025 14:42
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jerryzh168
Copy link
Contributor

Thanks for the updates, will be good to setup a deprecation plan for quantized_decomposed.choose_qparams op then

@andrewor14 andrewor14 force-pushed the qat-8da4w-match-prepare-convert branch from 709feab to 890438a Compare March 27, 2025 20:26
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@andrewor14 andrewor14 force-pushed the qat-8da4w-match-prepare-convert branch from 890438a to d9af870 Compare March 31, 2025 17:03
**Summary:** Previously, `Int8DynActInt4QATQuantizer` had
slightly diverging numerics between the prepare and convert
steps. This is because the prepare step uses quantization
primitives shared with AQT (specifically `quantize_affine`
and `dequantize_affine`), while the convert step relies on
old ops from the `torch.ops.quantized_decomposed` namespace.
The diverging numerics is negligible for small models, but
the quantization errors begin to compound for larger models
with many linear layers.

More specifically, there are three different places where the
divergence occurs during activation quantization:

1. **Choose qparams.** The prepare step casts the qparams to
`torch.float32`, whereas the convert step casts the scales to
`torch.float64` and zero points to `torch.int64`.

2. **Quantize.** The prepare step performs round before adding
zero points and uses torch functions, while the convert step
adds before rounding and uses torch tensor methods.
```
x = torch.clamp(
    torch.round(x * (1.0 / scale)) + zero_point, qmin, qmax,
)

x = (
    x.mul(1.0 / scale)
    .add(zero_point)
    .round()
    .clamp(qmin, qmax)
    .to(quantize_dtype)
)
```

3. **Dequantize.** The prepare step casts to `torch.int32`
before adding the zero points, and casts back to the original
dtype before multiplying the scale. The convert step only casts
at the very end.
```
x = x.to(torch.int32) - zero_point.to(torch.int32)
x = x.to(orig_dtype)
x = x * scale

x = x - zero_point
x = x * scale
x = x.to(orig_dtype)
```

This commit makes the convert path use the same torchao
quantization primitives as the prepare path, thereby resolving
the 3 above differences. Now, the prepare and convert steps match
exactly in terms of numerics over many trials.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
@andrewor14 andrewor14 force-pushed the qat-8da4w-match-prepare-convert branch from d9af870 to 4c45344 Compare March 31, 2025 17:04
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@andrewor14 andrewor14 merged commit 6922733 into main Apr 4, 2025
19 of 20 checks passed
jainapurva pushed a commit that referenced this pull request Apr 8, 2025
**Summary:** Previously, `Int8DynActInt4QATQuantizer` had
slightly diverging numerics between the prepare and convert
steps. This is because the prepare step uses quantization
primitives shared with AQT (specifically `quantize_affine`
and `dequantize_affine`), while the convert step relies on
old ops from the `torch.ops.quantized_decomposed` namespace.
The diverging numerics is negligible for small models, but
the quantization errors begin to compound for larger models
with many linear layers.

More specifically, there are three different places where the
divergence occurs during activation quantization:

1. **Choose qparams.** The prepare step casts the qparams to
`torch.float32`, whereas the convert step casts the scales to
`torch.float64` and zero points to `torch.int64`.

2. **Quantize.** The prepare step performs round before adding
zero points and uses torch functions, while the convert step
adds before rounding and uses torch tensor methods.
```
x = torch.clamp(
    torch.round(x * (1.0 / scale)) + zero_point, qmin, qmax,
)

x = (
    x.mul(1.0 / scale)
    .add(zero_point)
    .round()
    .clamp(qmin, qmax)
    .to(quantize_dtype)
)
```

3. **Dequantize.** The prepare step casts to `torch.int32`
before adding the zero points, and casts back to the original
dtype before multiplying the scale. The convert step only casts
at the very end.
```
x = x.to(torch.int32) - zero_point.to(torch.int32)
x = x.to(orig_dtype)
x = x * scale

x = x - zero_point
x = x * scale
x = x.to(orig_dtype)
```

This commit makes the convert path use the same torchao
quantization primitives as the prepare path, thereby resolving
the 3 above differences. Now, the prepare and convert steps match
exactly in terms of numerics over many trials.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
andrewor14 added a commit that referenced this pull request Apr 15, 2025
**Summary:** The previous PR #1964 got this to match for fp32,
but there were two additional sources of numerical discrepancies
with bf16:

1. QAT asymmetric per token choose qparams diverged from
  `choose_qparams_affine`, which had simpler logic
2. QAT per token fake quantize cast the input to fp32 before
  fake quantizing them

These are both resolved in this commit: (1) QAT now uses
`choose_qparams_affine` instead of the custom function for
asymmetric per token, which is now deleted, and (2) QAT no
longer casts the input to fp32. The result is exact match
in numerics between the prepare and convert steps for both
fp32 and bf16.

**Test Plan:**

python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert_fp32

python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert_bf16

python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert_fp32

python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert_bf16
andrewor14 added a commit that referenced this pull request Apr 15, 2025
**Summary:** The previous PR #1964 got this to match for fp32,
but there were two additional sources of numerical discrepancies
with bf16:

1. QAT asymmetric per token choose qparams diverged from
  `choose_qparams_affine`, which had simpler logic
2. QAT per token fake quantize cast the input to fp32 before
  fake quantizing them

These are both resolved in this commit: (1) QAT now uses
`choose_qparams_affine` instead of the custom function for
asymmetric per token, which is now deleted, and (2) QAT no
longer casts the input to fp32. The result is exact match
in numerics between the prepare and convert steps for both
fp32 and bf16.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert_fp32
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert_bf16
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert_fp32
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert_bf16
andrewor14 added a commit that referenced this pull request Apr 15, 2025
**Summary:** The previous PR #1964 got this to match for fp32,
but there were two additional sources of numerical discrepancies
with bf16:

1. QAT asymmetric per token choose qparams diverged from
  `choose_qparams_affine`, which had simpler logic
2. QAT per token fake quantize cast the input to fp32 before
  fake quantizing them

These are both resolved in this commit: (1) QAT now uses
`choose_qparams_affine` instead of the custom function for
asymmetric per token, which is now deleted, and (2) QAT no
longer casts the input to fp32. The result is exact match
in numerics between the prepare and convert steps for both
fp32 and bf16.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert_fp32
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert_bf16
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert_fp32
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert_bf16
andrewor14 added a commit that referenced this pull request Apr 16, 2025
**Summary:** The previous PR #1964 got this to match for fp32,
but there were two additional sources of numerical discrepancies
with bf16:

1. QAT asymmetric per token choose qparams diverged from
  `choose_qparams_affine`, which had simpler logic
2. QAT per token fake quantize cast the input to fp32 before
  fake quantizing them
3. QAT symmetric per group choose qparams used a hardcoded
  eps value that did not match `choose_qparams_affine`

These are both resolved in this commit: (1) QAT now uses
`choose_qparams_affine` instead of the custom function for
asymmetric per token, which is now deleted, (2) QAT no
longer casts the input to fp32, and (3) QAT now uses
an eps value that corresponds to the input dtype. The result
is exact match in numerics between the prepare and convert
steps for both fp32 and bf16.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert_fp32
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert_bf16
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert_fp32
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert_bf16
andrewor14 added a commit that referenced this pull request Apr 17, 2025
**Summary:** The previous PR #1964 got this to match for fp32,
but there were two additional sources of numerical discrepancies
with bf16:

1. QAT asymmetric per token choose qparams diverged from
  `choose_qparams_affine`, which had simpler logic
2. QAT per token fake quantize cast the input to fp32 before
  fake quantizing them
3. QAT symmetric per group choose qparams used a hardcoded
  eps value that did not match `choose_qparams_affine`

These are both resolved in this commit: (1) QAT now uses
`choose_qparams_affine` instead of the custom function for
asymmetric per token, which is now deleted, (2) QAT no
longer casts the input to fp32, and (3) QAT now uses
an eps value that corresponds to the input dtype. The result
is exact match in numerics between the prepare and convert
steps for both fp32, bf16, and fp16.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
andrewor14 added a commit that referenced this pull request Apr 21, 2025
**Summary:** The previous PR #1964 got this to match for fp32,
but there were two additional sources of numerical discrepancies
with bf16:

1. QAT asymmetric per token choose qparams diverged from
  `choose_qparams_affine`, which had simpler logic
2. QAT per token fake quantize cast the input to fp32 before
  fake quantizing them
3. QAT symmetric per group choose qparams used a hardcoded
  eps value that did not match `choose_qparams_affine`

These are both resolved in this commit: (1) QAT now uses
`choose_qparams_affine` instead of the custom function for
asymmetric per token, which is now deleted, (2) QAT no
longer casts the input to fp32, and (3) QAT now uses
an eps value that corresponds to the input dtype. The result
is exact match in numerics between the prepare and convert
steps for both fp32, bf16, and fp16.

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert
python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants