Fix dot_scaled to accept float8e8m0fnu scales#10595
Open
KaitaoQiu wants to merge 1 commit into
Open
Conversation
This was referenced Jun 13, 2026
ThomasRaoux
requested changes
Jun 15, 2026
Comment on lines
+455
to
+457
| # Note: e8m0 is an 8-bit float but is an exponent-only *scale* format, | ||
| # not a compute fp8 type, so it is intentionally excluded here. | ||
| return 'fp8' in self.name and not self.is_fp8e8m0() |
Collaborator
There was a problem hiding this comment.
why does it need to be false for e8m0?
Author
There was a problem hiding this comment.
good catch, addressed it in the new reversion.
| triton.compile(src, target=GPUTarget("hip", "gfx942", 64)) | ||
|
|
||
|
|
||
| def test_compile_only_dot_scaled_e8m0_scale() -> None: |
Collaborator
There was a problem hiding this comment.
can you make it an execution test to make sure the rest of the compiler is working as expected
7b587fb to
b877e20
Compare
b877e20 to
69dbf78
Compare
ThomasRaoux
reviewed
Jun 16, 2026
Comment on lines
+421
to
+422
| if is_cuda() and torch.cuda.get_device_capability()[0] < 10: | ||
| pytest.skip("Requires compute capability >= 10") |
Collaborator
There was a problem hiding this comment.
we support dot_scaled on all targets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
tl.dot_scaledcurrently only accepts uint8 scale tensors.E8M0(float8_e8m0fnu) is the native scale format defined by the OCP Microscaling (MX) spec for MXFP4/MXFP8, and is what PyTorch and AMD's AITER quantization helpers emit by default. As a result users have to reinterpret their scales with scale.view(torch.uint8) before every dot_scaled call, and passing a float8_e8m0fnu tensor fails outright during specialization because the dtype name is not recognized.This PR makes float8_e8m0fnu a first-class dtype so dot_scaled accepts E8M0 scales directly. The approach reuses MLIR's existing Float8E8M0FNUType and carries the type through the frontend instead of bitcasting it to uint8 up front, so no special-casing is scattered across the stack. The existing FloatType path in DecomposeScaledBlocked (FpToFpOp for the value conversion, CmpFOp UNO for NaN masking) already handles a floating-point scale, so the type flows through the current lowering. uint8/int8 scales keep working unchanged for backward compatibility.
What changed
Testing
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD.Select one of the following.
/testforlittests/unittestfor C++ tests/python/testfor end-to-end testsFILL THIS IN.Select one of the following.
littests.littests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)