[CuTeDSL] Add BF16 Grouped GEMM example for Hopper (SM90)#3059
Closed
vruga wants to merge 1 commit intoNVIDIA:mainfrom
Closed
[CuTeDSL] Add BF16 Grouped GEMM example for Hopper (SM90)#3059vruga wants to merge 1 commit intoNVIDIA:mainfrom
vruga wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Adds examples/python/CuTeDSL/hopper/grouped_gemm.py, a Python/CuTeDSL grouped GEMM kernel targeting the NVIDIA Hopper SM90 architecture. This is the CuTeDSL equivalent of the C++ example 57_hopper_grouped_gemm, which only supports FP8. The new example adds BF16 (and Float16) support and a full Python implementation. Key design points: - Uses SM90 WGMMA (warp group MMA) instead of SM100 tcgen05. - Register-based accumulators (make_rmem_tensor) — SM90 has no TMEM. - PipelineTmaAsync for the A/B mainloop pipeline. - Warp specialization: DMA warp group (TMA loads + tensormap A/B updates) and one or two MMA warp groups (WGMMA + epilogue + tensormap C updates). - TensorMapManager (arch-agnostic) for per-group TMA descriptor updates, supporting both SMEM and GMEM update modes. - StaticPersistentGroupTileScheduler for persistent multi-group scheduling. - Supports BF16/Float16 inputs; Float16/BFloat16/Float32 outputs. - Includes host-side helpers, reference-check, and benchmarking harness. Closes NVIDIA#3040 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
what happended to this PR? is there anything wrong with the implementation? |
Author
|
i put the same pr after verifying |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
examples/python/CuTeDSL/hopper/grouped_gemm.py, a Python/CuTeDSL grouped GEMM kernel for the NVIDIA Hopper SM90 architecture with BF16 (and Float16) input support.This is the CuTeDSL equivalent of
examples/57_hopper_grouped_gemm, which:This PR adds a Python implementation that also covers the BF16 use case — the precision most commonly required for accuracy-sensitive workloads like Mixture-of-Experts (MoE) serving on Hopper.