Skip to content

Add float8_e8m0fnu support to type canonicalization for dot_scaled#10009

Open
GeisYaO wants to merge 9 commits intotriton-lang:mainfrom
GeisYaO:fix-dot-scaled-e8m0fnu-dtype
Open

Add float8_e8m0fnu support to type canonicalization for dot_scaled#10009
GeisYaO wants to merge 9 commits intotriton-lang:mainfrom
GeisYaO:fix-dot-scaled-e8m0fnu-dtype

Conversation

@GeisYaO
Copy link
Copy Markdown

@GeisYaO GeisYaO commented Apr 12, 2026

This PR fixes an issue in dot_scaled where the scale factor's handle was being used directly without ensuring the correct bitcast to tl.uint8. It also adds a comprehensive end-to-end test for dot_scaled with e8m0fnu data type on AMD.

Summary of changes:

  • In python/triton/_utils.py: Added float8_e8m0fnu to the type canonicalization dictionary.
    • In python/triton/language/semantic.py: Bitcast lhs_scale and rhs_scale to tl.uint8 before retrieving their handles in dot_scaled.

@GeisYaO GeisYaO requested a review from ptillet as a code owner April 12, 2026 16:03
Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the title doesn't seem aligned with the patch as this changes more than Gluon and it is not specific to AMD. Also I don't think we should need that


def test_dot_scaled_e8m0fnu():
@triton.jit
def kernel(lhs_ptr, lhs_scale_ptr, rhs_ptr, rhs_scale_ptr, out_ptr,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the kernel is not even called

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, the test is incomplete - the kernel function is defined but never invoked with proper grid/args. I'll rewrite it as a proper pytest that actually launches the kernel and validates results. Should I add it to the existing test_dot_scaled.py instead of a separate file?

Comment thread python/triton/language/semantic.py Outdated
Comment on lines +1598 to +1599
rhs_scale_handle = None if rhs_scale_is_none else self.bitcast(rhs_scale, tl.uint8).handle
lhs_scale_handle = None if lhs_scale_is_none else self.bitcast(lhs_scale, tl.uint8).handle
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to change the type passed by user

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. If the preferred approach is to not auto-bitcast, would option (A) - only adding float8_e8m0fnu to the type canonicalization dict - be acceptable? That way float8_e8m0fnu tensors can at least pass through the argument binding stage, and users would handle any necessary casting on their side.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only adding float8_e8m0fnu to the type canonicalization dict - be acceptable?

yes that sounds reasonable

@GeisYaO
Copy link
Copy Markdown
Author

GeisYaO commented Apr 13, 2026

Thank you for the review, @ThomasRaoux!

You're right - the [AMD][GLUON] prefix is misleading since this change is not AMD-specific. I'll update the title.

Regarding "I don't think we should need that" - could you clarify the preferred approach? The core issue is:

  1. _utils.py: type_canonicalisation_dict doesn't have float8_e8m0fnu, causing a KeyError at kernel arg binding when users pass torch.float8_e8m0fnu tensors (e.g., AITER quantization outputs e8m0 scales as this dtype).
  2. semantic.py: Even after the dict fix, the IR layer rejects non-uint8 scale handles in dot_scaled.
    Should the fix be:
  • (A) Only add the type mapping in _utils.py (treating float8_e8m0fnu as u8 everywhere) and let users handle the bitcast themselves?
    • (B) Add native float8_e8m0fnu support as a first-class type throughout the compiler?
    • (C) Some other approach you'd prefer?
      Happy to rework the PR in whatever direction you suggest.

@GeisYaO GeisYaO changed the title [AMD][GLUON] Fix scale factor bitcast in dot_scaled and add test Add float8_e8m0fnu support to type canonicalization for dot_scaled Apr 13, 2026
@GeisYaO
Copy link
Copy Markdown
Author

GeisYaO commented Apr 13, 2026

Changes updated per your feedback:
1.Kept only the _utils.py type mapping ("float8_e8m0fnu": "u8")
2.Reverted all semantic.py changes
3.Removed the test file
The PR now contains a single 1-line addition. Ready for re-review when you get a chance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants