Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate ZeroPointDomain.NONE & None zero point domains #1556

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

sanchitintel
Copy link
Contributor

@sanchitintel sanchitintel commented Jan 13, 2025

Summary

  • Both ZeroPointDomain.NONE & None zero point domains were being used. The latter was being used for float8. This PR consolidates both & retains ZeroPointDomain.NONE
  • Using ZeroPointDomain.NONE would now produce None for zero-point
  • int8_dynamic_activation_int8_weight now uses ZeroPointDomain.NONE as weight zero point domain (as weight is quantized symmetrically to int8).

Some of the older changes in this PR (such as supporting torch.compile with optional zero_point) were rendered redundant by more recent changes in the main branch, so I removed them & modified the description accordingly. Thanks!

Copy link

pytorch-bot bot commented Jan 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1556

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 13, 2025
@sanchitintel sanchitintel force-pushed the zero_point_domain_none branch from 473c42f to c53a9d5 Compare January 13, 2025 23:26
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sanchitintel
Copy link
Contributor Author

Thanks for reviewing, @jerryzh168! I'll fix the lint error once other CI checks would complete.

@sanchitintel

This comment was marked as outdated.

@sanchitintel
Copy link
Contributor Author

The UTs that failed are failing because of some Inductor codegen issues that have since been fixed, but the CI jobs used PyTorch v2.4. I'll skip some of those UTs, so that they'd run with PyTorch v2.5 & beyond instead.

Thanks!

@jerryzh168
Copy link
Contributor

for https://github.com/pytorch/ao/actions/runs/12760810579/job/35567540271?pr=1556 you can install pre-commit

pip install pre-commit

and run it:

pre-commit run

then it will run the formatting before every git commit

@jerryzh168 jerryzh168 added topic: bug fix Use this tag for PRs that fix bugs topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) and removed topic: bug fix Use this tag for PRs that fix bugs labels Jan 15, 2025
@sanchitintel sanchitintel force-pushed the zero_point_domain_none branch from 837ee22 to c44df1f Compare January 15, 2025 06:09
@sanchitintel sanchitintel marked this pull request as ready for review January 15, 2025 06:26
@sanchitintel sanchitintel force-pushed the zero_point_domain_none branch from c83f470 to da2e9e0 Compare January 15, 2025 06:48
@sanchitintel sanchitintel changed the title Fix ZeroPointDomain.NONE support & make it default for da8w8 weights Consolidate ZeroPointDomain.NONE & None zero point domains Jan 15, 2025
test/integration/test_integration.py Show resolved Hide resolved
test/integration/test_integration.py Outdated Show resolved Hide resolved
torchao/quantization/quant_primitives.py Outdated Show resolved Hide resolved
@@ -199,7 +200,6 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNotNone(input_zero_point)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a change of behavior when you change zero_point_domain for None to ZeroPointDomain.NONE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, input_zero_point would now be None. So, instead of removing that line, I now added self.assertIsNone(input_zero_point). Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so what is the meaning of zero_point_domain == None before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some APIs were creating a None zero_point when zero_point_domain ZeroPointDomain.NONE or None was used, while choose_qparams_affine was not.

@@ -838,6 +838,32 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

# ZeroPointDomain.NONE should work
def test_none_zero_point_domain(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also have a test for zero_point_domain being None and throw an IllegalArgumentError exception I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for reviewing! I added code for raising ValueError if zero_point_domain would be None, but only for choose_qparams_affine and choose_qparams_affine_with_min_max.

For other public-facing APIs that have zero_point_domain as a parameter, I added asserts instead.

Please advise if this is fine, or if other places in the code should also throw exceptions instead of asserting.

Thanks!

@@ -302,10 +303,8 @@ def from_hp_to_intx_static(
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
):
assert zero_point_domain is not None, "zero_point_domain must not be None"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zero_point_domain should not be Optional in L303 I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also maybe raise ValueError here, that might be more user friendly than assertion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Made those modifications

@@ -85,6 +85,7 @@ def __new__(
dtype=None,
strides=None,
):
assert zero_point_domain is not None, "zero_point_domain must not be None"
Copy link
Contributor

@jerryzh168 jerryzh168 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the error message, I think we should include: "please use ZeroPointDomain.NONE" instead

and after some point it should be OK to remove these asserts when we are confident that no misuse in the codebase I think

@sanchitintel sanchitintel force-pushed the zero_point_domain_none branch from cd7abef to 8116c0c Compare January 17, 2025 21:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants