Skip to content

[None][feat] Optimize 6KD fp8 blockscale gemm#11502

Open
CarstyYou wants to merge 12 commits intoNVIDIA:mainfrom
CarstyYou:user/xiy/6kd_kernel
Open

[None][feat] Optimize 6KD fp8 blockscale gemm#11502
CarstyYou wants to merge 12 commits intoNVIDIA:mainfrom
CarstyYou:user/xiy/6kd_kernel

Conversation

@CarstyYou
Copy link
Contributor

@CarstyYou CarstyYou commented Feb 13, 2026

Summary by CodeRabbit

Description

This PR adds SM120 (Blackwell GeForce / RTX PRO 6000) support for FP8 block-scale MoE GEMM and optimizes the existing dense/BMM kernels.

Changes

Commit 1: Optimize 6KD fp8 blockscale gemm

  • Refactor SM120BlockScaledBuilder with a configurable Stages template parameter (default 4) for flexible pipeline depth tuning
  • Introduce SM120BlockScaledScheduler for persistent-thread block scheduling across groups, replacing per-problem kernel launches
  • Rework tile configurations (PermMmaTileN/K, TiledMma layout) and barrier management (store_full_mbar/store_empty_mbar) for better SM120 occupancy
  • Update grid size computation from dynamic to explicit dim3(num_device_sms, 1, 1) for persistent kernel style

Commit 2: Fix stride overflow

  • Change stride variables from int to int64_t in SM120 kernel dispatch and scheduler to prevent overflow on large problem sizes

Commit 3: Support MoE GEMM for SM120

  • Add SM120BlockScaledMoeKernel — a dedicated MoE grouped GEMM kernel for SM120 with per-expert TMA addressing and SM120BlockScaledMoeScheduler
  • Add scale_1x128_kernel_sm120 — an online BF16-to-FP8 quantization kernel with E8M0 scale generation, using compute_padded_offset (alignment=4) for MoE-aware scale layout
  • Add grouped_gemm_dispatch_sm120 to wire quantization + GEMM kernels together
  • Integrate into CutlassFp8BlockScaleGemmRunner::moeGemm and fp8_block_scaling_moe_gemm_blackwell_geforce thop entry point
  • Wire SM120 MoE path into CutlassFusedMoE via DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm with resmooth_and_transform_fp8_scale for weight preprocessing
  • Add unit tests for SM120 MoE GEMM (test_fp8_block_scale_moe_gemm) covering various num_rows, num_experts, top_k, and matrix dimensions

Key design notes

  • SM120 MoE GEMM uses online A-matrix quantization (internal_quantize_a=true): the BF16 activation is quantized to FP8 + E8M0 scale inside the kernel workspace, then fed to the CUTLASS grouped GEMM
  • Weight scale (SFB) is preprocessed offline via resmooth_to_fp8_e8m0 + transform_sf_into_required_layout into int32 packed col-major TMA-aligned format
  • Workspace allocation uses deep_gemm::compute_padded_offset (alignment=32) which is always >= sm120_blockscaled_gemm::compute_padded_offset (alignment=4), ensuring no buffer overflow

Test Coverage

  • tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py::test_fp8_block_scale_moe_gemm — 18 parametrized cases covering:
    • num_rows: 7, 64, 128
    • num_experts, top_k: (4, 3), (8, 4), (16, 5)
    • k, n: (7168, 2112), (2048, 7168)
    • All cases pass with diff < 0.001 on SM120

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 13, 2026

📝 Walkthrough

Walkthrough

This PR updates the SM120 block-scaled GEMM kernel with a new three-parameter template configuration (Stages parameter), introduces a scheduler for coordinated block allocation across groups, adds grouped layout tracking to kernel arguments, and refactors tile configurations and barrier management. Test imports are simplified.

Changes

Cohort / File(s) Summary
SM120 GEMM Kernel Infrastructure
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh, cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh, cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/sm120_blockwise_gemm/sm120_utils.cuh
Introduced Stages template parameter to SM120BlockScaledBuilder (default 4), added grouped_layout pointer to kernel Params and Arguments, added SM120BlockScaledScheduler for block scheduling across groups, reworked tile configurations (PermMmaTileN/K), updated TiledMma layout, added barrier storage fields (store_full_mbar/store_empty_mbar), replaced tuple destructuring with explicit get<> calls, and changed grid size computation from dynamic to explicit dim3(num_device_sms, 1, 1).
Test Updates
tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py
Simplified function calls from qualified namespace (fp8_utils.per_block_cast_to_fp8_e8m0) to direct import (per_block_cast_to_fp8_e8m0).

Sequence Diagram(s)

sequenceDiagram
    participant BlockIdx as Block Index
    participant Scheduler as SM120BlockScaledScheduler
    participant GroupLayout as grouped_layout
    participant BlockState as Block State
    
    BlockIdx->>Scheduler: get_next_block()
    activate Scheduler
    alt current_iter < total_blocks
        Scheduler->>GroupLayout: access layout table
        activate GroupLayout
        GroupLayout-->>Scheduler: group assignment for current_iter
        deactivate GroupLayout
        
        Scheduler->>BlockState: update m_block_idx, n_block_idx
        activate BlockState
        Scheduler->>Scheduler: advance current_iter
        BlockState-->>Scheduler: state updated
        deactivate BlockState
        
        Scheduler->>Scheduler: get_swizzle_block_idx()
        Scheduler-->>BlockIdx: return true (next block available)
    else no more blocks
        Scheduler-->>BlockIdx: return false (iteration complete)
    end
    deactivate Scheduler
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main
Title check ✅ Passed The title mentions '6KD fp8 blockscale gemm' optimization, which directly aligns with the file changes targeting SM120 block-scaled GEMM kernels and related kernel optimizations.
Description check ✅ Passed PR description follows the template structure with all required sections (Description, Test Coverage, PR Checklist) completed. The author provides clear objectives for SM120 support and kernel optimizations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/sm120_blockwise_gemm/sm120_utils.cuh`:
- Around line 55-57: The calculation of kNumStagePerSF uses integer division and
can silently drop K tiles when kNumTileKPerSF is not divisible by AB_Stages;
update the compile-time checks to enforce divisibility by adding a guard that
kNumTileKPerSF % AB_Stages == 0 and revise the static_assert on kNumStagePerSF
(and/or add a new static_assert) to require both that kNumStagePerSF > 0 &&
kNumStagePerSF <= 2 and that kNumTileKPerSF % AB_Stages == 0, referencing the
existing symbols kNumTileKPerSF, kNumStagePerSF, AB_Stages, and the current
static_assert to locate where to add the check and error message.

In `@tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py`:
- Line 190: The test calls per_block_cast_to_fp8_e8m0 on a 3D tensor b (shape
(num_groups,n,k)) which asserts for 2D inputs; update the SM 120 branch where
getSMVersion() == 120 so it handles the 3D case by iterating over groups and
invoking per_block_cast_to_fp8_e8m0 (or per_block_cast_to_fp8) on each 2D slice
b[i], assigning results into b_fp8[i] and b_scales[i] (mirror the loop used in
the else branch), ensuring b_fp8 and b_scales are preallocated with the correct
shape before the loop.
🧹 Nitpick comments (4)
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/sm120_blockwise_gemm/sm120_utils.cuh (3)

65-66: Remove commented-out code.

Per coding guidelines, dead code should not be disabled with comments; use #if / #endif with a mnemonic condition if the code needs to be preserved, otherwise remove it. As per coding guidelines: "Do not use comments to disable code."


363-429: Scheduler: member naming and a minor style note.

The SM120BlockScaledScheduler struct members (current_iter, m_block_idx, num_groups, etc.) use snake_case, but the coding guidelines require class/struct member variables to use camelCase with m prefix (e.g., mCurrentIter, mBlockIdx). That said, I see existing CUTLASS-style code in the codebase may differ—flagging for awareness. As per coding guidelines: "Use camelCase prefixed with 'm' for public, private and protected class member variables."


386-396: get_swizzle_block_idx: binding const& to temporaries is unusual.

Lines 389–390 bind const& to rvalue temporaries (num_m_blocks * kNum1DBlocksPerGroup, block_idx / num_blocks_per_group). While valid C++ (lifetime extension), this is non-idiomatic and can confuse readers. Prefer plain auto const.

Suggested diff
-        auto const& num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
-        auto const& group_idx = block_idx / num_blocks_per_group;
+        auto const num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
+        auto const group_idx = block_idx / num_blocks_per_group;
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh (1)

98-105: get_grid_shape: CUDA API return values are unchecked.

cudaGetDevice and cudaDeviceGetAttribute can fail; if they do, sm_count remains uninitialized and produces an incorrect grid dimension. Check the return values or zero-initialize sm_count.

Note: This function is currently unused—the active callers in fp8_blockscale_gemm_kernel.cuh compute grid dimensions directly and bypass this function. However, this is a latent hazard if the function is called in future implementations.

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Feb 13, 2026
@CarstyYou CarstyYou changed the title [feat] update 6KD fp8 dense & bmm [None][feat] Optimize per_block_cast_to_fp8_e8m0 and update 6KD fp8 blockscale gemm Feb 14, 2026
@CarstyYou CarstyYou changed the title [None][feat] Optimize per_block_cast_to_fp8_e8m0 and update 6KD fp8 blockscale gemm [None][feat] Optimize 6KD fp8 blockscale gemm Feb 14, 2026
@CarstyYou CarstyYou requested a review from a team as a code owner March 3, 2026 11:54
@CarstyYou CarstyYou requested a review from HuiGao-NV March 3, 2026 11:54
@CarstyYou CarstyYou force-pushed the user/xiy/6kd_kernel branch from 12123cc to e9774d0 Compare March 3, 2026 12:09
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
@CarstyYou CarstyYou force-pushed the user/xiy/6kd_kernel branch from e9774d0 to 9e5c949 Compare March 3, 2026 12:10
@xxi-nv xxi-nv requested a review from Barry-Delaney March 4, 2026 03:10
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Copy link
Collaborator

@xxi-nv xxi-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The python MoE parts LGTM.

…ASS FP8 BlockScale

Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
@CarstyYou CarstyYou force-pushed the user/xiy/6kd_kernel branch from 9f70c85 to c1f4c29 Compare March 4, 2026 09:45
@CarstyYou
Copy link
Contributor Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37678 [ run ] triggered by Bot. Commit: a561198 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37678 [ run ] completed with state SUCCESS. Commit: a561198
/LLM/main/L0_MergeRequest_PR pipeline #29161 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@xxi-nv
Copy link
Collaborator

xxi-nv commented Mar 5, 2026

/bot run --stage-list "DGX_B200-PyTorch-1"

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37787 [ run ] triggered by Bot. Commit: a561198 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37787 [ run ] completed with state SUCCESS. Commit: a561198
/LLM/main/L0_MergeRequest_PR pipeline #29255 (Partly Tested) completed with status: 'SUCCESS'

Link to invocation

Copy link
Collaborator

@HuiGao-NV HuiGao-NV left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@Barry-Delaney Barry-Delaney left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Thanks on the contribution! Left minor comments.

Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
@CarstyYou CarstyYou force-pushed the user/xiy/6kd_kernel branch from 06a4c31 to d4cf137 Compare March 11, 2026 17:16
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
@CarstyYou CarstyYou force-pushed the user/xiy/6kd_kernel branch from 844a5e8 to 54cc308 Compare March 11, 2026 18:15
@CarstyYou
Copy link
Contributor Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38622 [ run ] triggered by Bot. Commit: a53e8d6 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38622 [ run ] completed with state SUCCESS. Commit: a53e8d6
/LLM/main/L0_MergeRequest_PR pipeline #29956 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
@CarstyYou CarstyYou force-pushed the user/xiy/6kd_kernel branch from a53e8d6 to 867652c Compare March 12, 2026 02:37
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants