-
Notifications
You must be signed in to change notification settings - Fork 204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consolidate ZeroPointDomain.NONE
& None
zero point domains
#1556
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1556
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
473c42f
to
c53a9d5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks for reviewing, @jerryzh168! I'll fix the lint error once other CI checks would complete. |
This comment was marked as outdated.
This comment was marked as outdated.
The UTs that failed are failing because of some Inductor codegen issues that have since been fixed, but the CI jobs used PyTorch v2.4. I'll skip some of those UTs, so that they'd run with PyTorch v2.5 & beyond instead. Thanks! |
for https://github.com/pytorch/ao/actions/runs/12760810579/job/35567540271?pr=1556 you can install pre-commit
and run it:
then it will run the formatting before every |
837ee22
to
c44df1f
Compare
c83f470
to
da2e9e0
Compare
ZeroPointDomain.NONE
& None
zero point domains
@@ -199,7 +200,6 @@ def test_linear_observer_tensor(self, observe_weight: bool): | |||
input_scale.item(), | |||
max_val / max_fp8, | |||
) | |||
self.assertIsNotNone(input_zero_point) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a change of behavior when you change zero_point_domain for None to ZeroPointDomain.NONE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, input_zero_point
would now be None. So, instead of removing that line, I now added self.assertIsNone(input_zero_point)
. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, so what is the meaning of zero_point_domain
== None before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some APIs were creating a None
zero_point
when zero_point_domain
ZeroPointDomain.NONE
or None
was used, while choose_qparams_affine
was not.
@@ -838,6 +838,32 @@ def test_fake_quantize_affine_cachemask(self): | |||
torch.testing.assert_close(dequantized, fake_quantized) | |||
torch.testing.assert_close(expected_mask, mask) | |||
|
|||
# ZeroPointDomain.NONE should work | |||
def test_none_zero_point_domain(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could also have a test for zero_point_domain being None and throw an IllegalArgumentError exception I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for reviewing! I added code for raising ValueError
if zero_point_domain
would be None
, but only for choose_qparams_affine
and choose_qparams_affine_with_min_max
.
For other public-facing APIs that have zero_point_domain
as a parameter, I added asserts instead.
Please advise if this is fine, or if other places in the code should also throw exceptions instead of asserting.
Thanks!
@@ -302,10 +303,8 @@ def from_hp_to_intx_static( | |||
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, | |||
_layout: Layout = PlainLayout(), | |||
): | |||
assert zero_point_domain is not None, "zero_point_domain must not be None" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zero_point_domain should not be Optional in L303 I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also maybe raise ValueError here, that might be more user friendly than assertion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Made those modifications
@@ -85,6 +85,7 @@ def __new__( | |||
dtype=None, | |||
strides=None, | |||
): | |||
assert zero_point_domain is not None, "zero_point_domain must not be None" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the error message, I think we should include: "please use ZeroPointDomain.NONE" instead
and after some point it should be OK to remove these asserts when we are confident that no misuse in the codebase I think
cd7abef
to
8116c0c
Compare
Summary
ZeroPointDomain.NONE
&None
zero point domains were being used. The latter was being used for float8. This PR consolidates both & retainsZeroPointDomain.NONE
ZeroPointDomain.NONE
would now produceNone
for zero-pointint8_dynamic_activation_int8_weight
now usesZeroPointDomain.NONE
as weight zero point domain (as weight is quantized symmetrically to int8).Some of the older changes in this PR (such as supporting torch.compile with optional zero_point) were rendered redundant by more recent changes in the main branch, so I removed them & modified the description accordingly. Thanks!