Commit 04f4c0c
fix: MXFP4/MXFP8 failures in SM120 FAST_BUILD and expand all_tiles[] (#2994)
**Problem**
MXFP4 and MXFP8 GEMM operations were failing on SM120 because:
- The FAST_BUILD path returned a single hardcoded CtaShape128x128x64B
tile regardless of GROUPED_GEMM, and that tile is not valid for all
MXFP4/MXFP8 configurations
- The full-build all_tiles[] table was missing tiles needed by those
dtypes (128x128x128B, 128x128x64B, 256x128x64B),
leaving the autotuner with no viable candidate in some cases
**Fix**
- FAST_BUILD: differentiate grouped vs. non-grouped paths with tiles
known to work for MXFP4/MXFP8:
- Grouped: 128x128x128B + 128x128x64B
- Non-grouped: 128x128x256B + 128x128x64B
- Full-build all_tiles[]: add the three missing tiles so the autotuner
has a complete candidate set for MXFP4/MXFP8
workloads
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Performance & Optimizations**
* More predictable kernel candidate selection and expanded
tile/configuration options for SM120-class GPUs to improve tuning and
performance.
* Broadened handling of grouped computation patterns to enable
additional configuration choices.
* **Build/Compatibility**
* Refined CUDA 12.9+ architecture suffixing for more accurate build
targeting.
* **Chores**
* Added type annotations and minor signature clarifications (no runtime
behavior changes).
* **Bug Fixes**
* MoE fusion path now forwards additional tensors/parameters to improve
fused operation correctness.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Co-authored-by: samuellees <lsam@nvidia.com>
---------
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Co-authored-by: Andrii Skliar <askliar@nvidia.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Sam (Kesen Li) <lsam@nvidia.com>
Co-authored-by: Alex Yang <aleyang@nvidia.com>1 parent 19055a6 commit 04f4c0c
2 files changed
Lines changed: 34 additions & 18 deletions
File tree
- csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels
- flashinfer
Lines changed: 29 additions & 16 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
587 | 587 | | |
588 | 588 | | |
589 | 589 | | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
590 | 605 | | |
591 | 606 | | |
592 | 607 | | |
593 | 608 | | |
594 | 609 | | |
595 | 610 | | |
596 | | - | |
597 | | - | |
598 | | - | |
599 | | - | |
600 | | - | |
601 | | - | |
602 | | - | |
603 | | - | |
604 | | - | |
605 | | - | |
606 | | - | |
607 | | - | |
608 | | - | |
609 | | - | |
610 | | - | |
611 | | - | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
| 624 | + | |
612 | 625 | | |
613 | 626 | | |
614 | 627 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
36 | 36 | | |
37 | 37 | | |
38 | 38 | | |
39 | | - | |
| 39 | + | |
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
| |||
48 | 48 | | |
49 | 49 | | |
50 | 50 | | |
51 | | - | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
52 | 55 | | |
53 | 56 | | |
54 | 57 | | |
| |||
0 commit comments