Skip to content

[SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM#3280

Open
tgmerritt wants to merge 2 commits into
NVIDIA:mainfrom
tgmerritt:feat/sm120-array-tma-collective
Open

[SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM#3280
tgmerritt wants to merge 2 commits into
NVIDIA:mainfrom
tgmerritt:feat/sm120-array-tma-collective

Conversation

@tgmerritt
Copy link
Copy Markdown

Summary

Implements CollectiveMma and CollectiveBuilder specializations for MainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM (MoE expert dispatch) with tensor- and token-level FP8 scaling on SM_120/SM_121 consumer Blackwell hardware.

Closes #3263.


What this adds

New file: include/cutlass/gemm/collective/sm120_mma_array_tma.hpp

CollectiveMma specialization for MainloopSm120ArrayTmaWarpSpecialized. Key design choices:

  • Handles both Cooperative (4×2 UMMA atom layout) and Pingpong (2×2) schedules via a single template
  • Pointer-array indirection: loads ptr_A[problem_idx] / ptr_B[problem_idx] in load_init, so each problem in the grouped batch can have a distinct A/B matrix pointer
  • TMA loads for both operands (same as SM_100 array path), using SM_1xx_TMA_LOAD_IM2COL or SM_1xx_TMA_LOAD descriptors depending on smem layout
  • F8F6F4 MMA via rr_op_selector_sm120 — the non-blockscaled path (scale factors in epilogue, not mainloop)
  • Pipeline: PipelineTmaUmmaAsync, stage count auto-computed from smem capacity

New file: include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inl

CollectiveBuilder specialization for KernelPtrArrayTmaWarpSpecializedCooperativeSm120<N> and KernelPtrArrayTmaWarpSpecializedPingpongSm120<N> schedule tags. Mirrors the structure of sm120_mma_builder.inl but:

  • enable_if requires is_base_of<KernelPtrArrayTmaWarpSpecializedCooperative> or Pingpong
  • GmemLayoutATag is already a pointer type (RowMajor*) — TagToStrideA_t propagates the pointer, so no extra * needed
  • SchedulerPipelineStageCount extracted from the schedule tag's non-type parameter

Modified: sm120_mma_builder.inl

Removes KernelPtrArrayTmaWarpSpecialized{Cooperative,Pingpong} from the dense GEMM builder's enable_if condition. Previously these were explicitly excluded with a static_assert(!IsPtrArrayKernel, ...). With the new array builder present, they now route correctly.

Modified: collective_builder.hpp + collective_mma.hpp

One #include added to each — the new array builder and collective, respectively.


Hardware validation

Validated on real SM_121 hardware (NVIDIA DGX Spark, GB10, 128 GB LPDDR5X unified memory) running:

  • Container: vLLM source build, CUDA 13.0.2, TORCH_CUDA_ARCH_LIST=12.0a;12.1a;12.1+PTX
  • Model: RedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic (Gemma 4 MoE, 26B total / 4B active params, FP8-Dynamic quantization)
  • Previous behavior: SM120 path gated off; vLLM fell back to Triton MoE backend for all forward passes
  • With this patch: SM120 CUTLASS grouped GEMM collective activates and produces numerically correct outputs

Throughput comparison (wall-clock, single-stream, MTP speculative decoding, 3 iterations each)

Prompt size Baseline (Triton fallback) This patch (SM120 CUTLASS) Delta
Short (128 tok output) 76.3 tok/s 81.9 tok/s +7.3%
Medium (512 tok output) 89.8 tok/s 89.8 tok/s ≈0% (within noise)
Long (1024 tok output) 87.6 tok/s 85.5 tok/s -2.4% (within noise)

Short-sequence improvement is the clearest signal: decode-dominated workloads where grouped GEMM runs every forward pass. Medium/long results are within the variance of 3-run speculative-decoding benchmarks (accept-rate variance dominates at longer outputs).


Why not duplicate


AI assistance disclosure

This implementation was developed with Claude (Anthropic) AI assistance, including analysis of the SM_100 array collective as a reference pattern, iterative compile debugging on real SM_121 hardware (four full Docker builds), and identification of the double-pointer bug in an early version of the StrideA/StrideB types. All changed lines have been reviewed by the human submitter (Tyler Merritt). Build and inference validation ran on physical DGX Spark hardware.

Related vLLM issues: #43507

…ped GEMM

Adds CollectiveMma and CollectiveBuilder specializations for
MainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM
(MoE expert dispatch) with tensor- and token-level FP8 scaling on
SM_120/SM_121 consumer Blackwell (RTX 5090/5080/5070, DGX Spark GB10).

New files:
- include/cutlass/gemm/collective/sm120_mma_array_tma.hpp
  CollectiveMma specialization for MainloopSm120ArrayTmaWarpSpecialized.
  Handles both Cooperative (4x2 atom layout) and Pingpong (2x2) schedules.
  Grouped GEMM via pointer-array indirection through params.ptr_A / ptr_B.
  Supports F8F6F4 MMA with TMA loads for both A and B operands.

- include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inl
  CollectiveBuilder specialization for KernelPtrArrayTmaWarpSpecialized
  Cooperative/PingpongSm120<N> schedule tags. Computes tile/stage counts
  from smem capacity, routes to MainloopSm120ArrayTmaWarpSpecialized
  dispatch policy, produces correctly-typed CollectiveOp.

Modified files:
- collective_mma.hpp: include sm120_mma_array_tma.hpp
- collective_builder.hpp: include sm120_array_mma_builder.inl
- sm120_mma_builder.inl: remove ptr-array schedules from enable_if
  (they now route to sm120_array_mma_builder.inl) and drop the
  IsPtrArrayKernel static_assert that enforced the restriction

Validated on real SM_121 hardware (DGX Spark, 128 GB LPDDR5X) running
vLLM with RedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic (Gemma 4 MoE, 26B
total / 4B active). Previously fell back to a non-CUTLASS Triton path;
with this patch, the SM120 CUTLASS grouped GEMM collective activates and
produces correct outputs. Short-sequence throughput improved ~7% vs the
fallback baseline (76.3 → 81.9 tok/s).

Closes NVIDIA#3263

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>
@hwu36
Copy link
Copy Markdown
Collaborator

hwu36 commented May 28, 2026

@depaulmillz for review

@depaulmillz
Copy link
Copy Markdown
Contributor

Thanks. Would you be able to also add some unit tests for this under test/unit/gemm/device/sm120_tensorop_gemm? Here is an example of how we structure the tests https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_nvf4_group_gemm_fusion.cu

Adds 6 device-level tests for the CollectiveMma/CollectiveBuilder
specializations introduced for MainloopSm120ArrayTmaWarpSpecialized,
covering both KernelPtrArrayTmaWarpSpecializedPingpongSm120<2> and
KernelPtrArrayTmaWarpSpecializedCooperativeSm120<2> schedule tags across
e4m3×e4m3 (symmetric), e4m3×e5m2 (mixed), float and bfloat16 outputs,
and two tile shapes.

Tests land in test/unit/gemm/device/sm120_tensorop_gemm/ under the new
cutlass_test_unit_sm120_grouped_gemm_device_tensorop CMake target, per
reviewer request in PR NVIDIA#3280.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] SM_120/SM_121 CollectiveBuilder specialization for tensor/token-scaled FP8 grouped GEMM (ptr-array)

3 participants