Skip to content

Fix dot_scaled to accept float8e8m0fnu scales#10595

Open
KaitaoQiu wants to merge 1 commit into
triton-lang:mainfrom
KaitaoQiu:feature/dot-scaled-e8m0-firstclass
Open

Fix dot_scaled to accept float8e8m0fnu scales#10595
KaitaoQiu wants to merge 1 commit into
triton-lang:mainfrom
KaitaoQiu:feature/dot-scaled-e8m0-firstclass

Conversation

@KaitaoQiu

Copy link
Copy Markdown

Summary

tl.dot_scaled currently only accepts uint8 scale tensors. E8M0 (float8_e8m0fnu) is the native scale format defined by the OCP Microscaling (MX) spec for MXFP4/MXFP8, and is what PyTorch and AMD's AITER quantization helpers emit by default. As a result users have to reinterpret their scales with scale.view(torch.uint8) before every dot_scaled call, and passing a float8_e8m0fnu tensor fails outright during specialization because the dtype name is not recognized.

This PR makes float8_e8m0fnu a first-class dtype so dot_scaled accepts E8M0 scales directly. The approach reuses MLIR's existing Float8E8M0FNUType and carries the type through the frontend instead of bitcasting it to uint8 up front, so no special-casing is scattered across the stack. The existing FloatType path in DecomposeScaledBlocked (FpToFpOp for the value conversion, CmpFOp UNO for NaN masking) already handles a floating-point scale, so the type flows through the current lowering. uint8/int8 scales keep working unchanged for backward compatibility.

What changed

  • Recognize the torch.float8_e8m0fnu / triton float8e8m0fnu dtype strings and canonicalize them to fp8e8m0fnu.
  • Add the tl.float8e8m0fnu dtype and a builder binding that returns MLIR Float8E8M0FNUType.
  • Exclude E8M0 from is_fp8() (it is an exponent-only scale format, not a compute fp8 type) and gate the architecture supported_fp8_dtypes check on is_fp8() so the scale type is not rejected.
  • Accept float8e8m0fnu in the AMD scale-format deduction / conversion paths (gfx1250, CDNA3, CDNA4), keeping uint8/int8 accepted.
  • Map the dtype to uint8 storage in the interpreter.

Testing

  • Added test_compile_only_dot_scaled_e8m0_scale: compiles a tl.dot_scaled kernel whose scales are *fp8e8m0fnu for hip:gfx950.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

Comment thread python/triton/language/core.py Outdated
Comment on lines +455 to +457
# Note: e8m0 is an 8-bit float but is an exponent-only *scale* format,
# not a compute fp8 type, so it is intentionally excluded here.
return 'fp8' in self.name and not self.is_fp8e8m0()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it need to be false for e8m0?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, addressed it in the new reversion.

triton.compile(src, target=GPUTarget("hip", "gfx942", 64))


def test_compile_only_dot_scaled_e8m0_scale() -> None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make it an execution test to make sure the rest of the compiler is working as expected

@KaitaoQiu KaitaoQiu force-pushed the feature/dot-scaled-e8m0-firstclass branch from 7b587fb to b877e20 Compare June 15, 2026 17:57
@KaitaoQiu KaitaoQiu force-pushed the feature/dot-scaled-e8m0-firstclass branch from b877e20 to 69dbf78 Compare June 15, 2026 17:58
@KaitaoQiu KaitaoQiu requested a review from ThomasRaoux June 15, 2026 21:07
Comment on lines +421 to +422
if is_cuda() and torch.cuda.get_device_capability()[0] < 10:
pytest.skip("Requires compute capability >= 10")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we support dot_scaled on all targets

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants