Add float8_e8m0fnu support to type canonicalization for dot_scaled#10009
Add float8_e8m0fnu support to type canonicalization for dot_scaled#10009GeisYaO wants to merge 9 commits intotriton-lang:mainfrom
Conversation
…ith e8m0fnu on AMD
ThomasRaoux
left a comment
There was a problem hiding this comment.
the title doesn't seem aligned with the patch as this changes more than Gluon and it is not specific to AMD. Also I don't think we should need that
|
|
||
| def test_dot_scaled_e8m0fnu(): | ||
| @triton.jit | ||
| def kernel(lhs_ptr, lhs_scale_ptr, rhs_ptr, rhs_scale_ptr, out_ptr, |
There was a problem hiding this comment.
the kernel is not even called
There was a problem hiding this comment.
You're right, the test is incomplete - the kernel function is defined but never invoked with proper grid/args. I'll rewrite it as a proper pytest that actually launches the kernel and validates results. Should I add it to the existing test_dot_scaled.py instead of a separate file?
| rhs_scale_handle = None if rhs_scale_is_none else self.bitcast(rhs_scale, tl.uint8).handle | ||
| lhs_scale_handle = None if lhs_scale_is_none else self.bitcast(lhs_scale, tl.uint8).handle |
There was a problem hiding this comment.
I don't think we want to change the type passed by user
There was a problem hiding this comment.
Understood. If the preferred approach is to not auto-bitcast, would option (A) - only adding float8_e8m0fnu to the type canonicalization dict - be acceptable? That way float8_e8m0fnu tensors can at least pass through the argument binding stage, and users would handle any necessary casting on their side.
There was a problem hiding this comment.
only adding float8_e8m0fnu to the type canonicalization dict - be acceptable?
yes that sounds reasonable
|
Thank you for the review, @ThomasRaoux! You're right - the [AMD][GLUON] prefix is misleading since this change is not AMD-specific. I'll update the title. Regarding "I don't think we should need that" - could you clarify the preferred approach? The core issue is:
|
…nd restore typed signature.ed method
|
Changes updated per your feedback: |
This PR fixes an issue in
dot_scaledwhere the scale factor's handle was being used directly without ensuring the correct bitcast totl.uint8. It also adds a comprehensive end-to-end test fordot_scaledwithe8m0fnudata type on AMD.Summary of changes:
python/triton/_utils.py: Addedfloat8_e8m0fnuto the type canonicalization dictionary.python/triton/language/semantic.py: Bitcastlhs_scaleandrhs_scaletotl.uint8before retrieving their handles indot_scaled.