Skip to content

Commit b3e3693

Browse files
authored
[SM121] Enable native block-scaled dot_scaled for DGX Spark (GB10) (#10010)
SM121 (GB10 DGX Spark) supports the same mma.sync block-scaled instructions as SM120 (RTX 5090) but was excluded from the native lowering path by exact compute capability checks. Without this fix, dot_scaled on SM121 falls through to DecomposeScaledBlocked which upcasts to bf16 — ~10 TFLOPS vs ~270 TFLOPS with native mma.sync block-scaled FP4. Tested on GB10 with both MXFP4 (scale_vec::2X, ue8m0) and NVFP4 (scale_vec::4X, ue4m3). # New contributor declaration - [ x] I am not making a trivial change, such as fixing a typo in a comment. - [ x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ x ] This PR does not need a test because current test paths cover the flow, though there are no GB10s in CI to verify AFAIK it does work for me. - Select one of the following. - [x ] I have not added any `lit` tests.
1 parent f7c1d69 commit b3e3693

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ class ScaledBlockedToMMA : public mlir::OpRewritePattern<triton::DotScaledOp> {
677677
mlir::LogicalResult
678678
matchAndRewrite(triton::DotScaledOp dotOp,
679679
mlir::PatternRewriter &rewriter) const override {
680-
if (computeCapability != 120)
680+
if (computeCapability / 10 != 12)
681681
return failure();
682682

683683
auto numCTAs = lookupNumCTAs(rewriter);
@@ -924,7 +924,7 @@ static bool mmav2SupportsFp8Operands(int computeCapability) {
924924
// although PTX instructions for mma v2 w/ fp8 operands exist for sm90 and
925925
// sm100, they are emulated as fp16 upcasts + fp16 HMMA in SASS. sm120 has
926926
// hardware support for fp8 operands w/ mmav2.
927-
return computeCapability == 89 || computeCapability == 120;
927+
return computeCapability == 89 || computeCapability / 10 == 12;
928928
}
929929

930930
// promote operands of dot op if the existing combination is not natively

0 commit comments

Comments
 (0)