[SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM#3280
Open
tgmerritt wants to merge 2 commits into
Open
[SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM#3280tgmerritt wants to merge 2 commits into
tgmerritt wants to merge 2 commits into
Conversation
…ped GEMM Adds CollectiveMma and CollectiveBuilder specializations for MainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM (MoE expert dispatch) with tensor- and token-level FP8 scaling on SM_120/SM_121 consumer Blackwell (RTX 5090/5080/5070, DGX Spark GB10). New files: - include/cutlass/gemm/collective/sm120_mma_array_tma.hpp CollectiveMma specialization for MainloopSm120ArrayTmaWarpSpecialized. Handles both Cooperative (4x2 atom layout) and Pingpong (2x2) schedules. Grouped GEMM via pointer-array indirection through params.ptr_A / ptr_B. Supports F8F6F4 MMA with TMA loads for both A and B operands. - include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inl CollectiveBuilder specialization for KernelPtrArrayTmaWarpSpecialized Cooperative/PingpongSm120<N> schedule tags. Computes tile/stage counts from smem capacity, routes to MainloopSm120ArrayTmaWarpSpecialized dispatch policy, produces correctly-typed CollectiveOp. Modified files: - collective_mma.hpp: include sm120_mma_array_tma.hpp - collective_builder.hpp: include sm120_array_mma_builder.inl - sm120_mma_builder.inl: remove ptr-array schedules from enable_if (they now route to sm120_array_mma_builder.inl) and drop the IsPtrArrayKernel static_assert that enforced the restriction Validated on real SM_121 hardware (DGX Spark, 128 GB LPDDR5X) running vLLM with RedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic (Gemma 4 MoE, 26B total / 4B active). Previously fell back to a non-CUTLASS Triton path; with this patch, the SM120 CUTLASS grouped GEMM collective activates and produces correct outputs. Short-sequence throughput improved ~7% vs the fallback baseline (76.3 → 81.9 tok/s). Closes NVIDIA#3263 Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>
Collaborator
|
@depaulmillz for review |
Contributor
|
Thanks. Would you be able to also add some unit tests for this under |
Adds 6 device-level tests for the CollectiveMma/CollectiveBuilder specializations introduced for MainloopSm120ArrayTmaWarpSpecialized, covering both KernelPtrArrayTmaWarpSpecializedPingpongSm120<2> and KernelPtrArrayTmaWarpSpecializedCooperativeSm120<2> schedule tags across e4m3×e4m3 (symmetric), e4m3×e5m2 (mixed), float and bfloat16 outputs, and two tile shapes. Tests land in test/unit/gemm/device/sm120_tensorop_gemm/ under the new cutlass_test_unit_sm120_grouped_gemm_device_tensorop CMake target, per reviewer request in PR NVIDIA#3280. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements
CollectiveMmaandCollectiveBuilderspecializations forMainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM (MoE expert dispatch) with tensor- and token-level FP8 scaling on SM_120/SM_121 consumer Blackwell hardware.Closes #3263.
What this adds
New file:
include/cutlass/gemm/collective/sm120_mma_array_tma.hppCollectiveMmaspecialization forMainloopSm120ArrayTmaWarpSpecialized. Key design choices:Cooperative(4×2 UMMA atom layout) andPingpong(2×2) schedules via a single templateptr_A[problem_idx]/ptr_B[problem_idx]inload_init, so each problem in the grouped batch can have a distinct A/B matrix pointerSM_1xx_TMA_LOAD_IM2COLorSM_1xx_TMA_LOADdescriptors depending on smem layoutrr_op_selector_sm120— the non-blockscaled path (scale factors in epilogue, not mainloop)PipelineTmaUmmaAsync, stage count auto-computed from smem capacityNew file:
include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inlCollectiveBuilderspecialization forKernelPtrArrayTmaWarpSpecializedCooperativeSm120<N>andKernelPtrArrayTmaWarpSpecializedPingpongSm120<N>schedule tags. Mirrors the structure ofsm120_mma_builder.inlbut:enable_ifrequiresis_base_of<KernelPtrArrayTmaWarpSpecializedCooperative>orPingpongGmemLayoutATagis already a pointer type (RowMajor*) —TagToStrideA_tpropagates the pointer, so no extra*neededSchedulerPipelineStageCountextracted from the schedule tag's non-type parameterModified:
sm120_mma_builder.inlRemoves
KernelPtrArrayTmaWarpSpecialized{Cooperative,Pingpong}from the dense GEMM builder'senable_ifcondition. Previously these were explicitly excluded with astatic_assert(!IsPtrArrayKernel, ...). With the new array builder present, they now route correctly.Modified:
collective_builder.hpp+collective_mma.hppOne
#includeadded to each — the new array builder and collective, respectively.Hardware validation
Validated on real SM_121 hardware (NVIDIA DGX Spark, GB10, 128 GB LPDDR5X unified memory) running:
TORCH_CUDA_ARCH_LIST=12.0a;12.1a;12.1+PTXRedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic(Gemma 4 MoE, 26B total / 4B active params, FP8-Dynamic quantization)Throughput comparison (wall-clock, single-stream, MTP speculative decoding, 3 iterations each)
Short-sequence improvement is the clearest signal: decode-dominated workloads where grouped GEMM runs every forward pass. Medium/long results are within the variance of 3-run speculative-decoding benchmarks (accept-rate variance dominates at longer outputs).
Why not duplicate
gh pr list --repo NVIDIA/cutlass --state open --search "SM120 array"— no resultsgh pr list --repo NVIDIA/cutlass --state open --search "MainloopSm120Array"— no resultsAI assistance disclosure
This implementation was developed with Claude (Anthropic) AI assistance, including analysis of the SM_100 array collective as a reference pattern, iterative compile debugging on real SM_121 hardware (four full Docker builds), and identification of the double-pointer bug in an early version of the
StrideA/StrideBtypes. All changed lines have been reviewed by the human submitter (Tyler Merritt). Build and inference validation ran on physical DGX Spark hardware.Related vLLM issues: #43507