[None][feat] Optimize 6KD fp8 blockscale gemm#11502
[None][feat] Optimize 6KD fp8 blockscale gemm#11502CarstyYou wants to merge 12 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR updates the SM120 block-scaled GEMM kernel with a new three-parameter template configuration (Stages parameter), introduces a scheduler for coordinated block allocation across groups, adds grouped layout tracking to kernel arguments, and refactors tile configurations and barrier management. Test imports are simplified. Changes
Sequence Diagram(s)sequenceDiagram
participant BlockIdx as Block Index
participant Scheduler as SM120BlockScaledScheduler
participant GroupLayout as grouped_layout
participant BlockState as Block State
BlockIdx->>Scheduler: get_next_block()
activate Scheduler
alt current_iter < total_blocks
Scheduler->>GroupLayout: access layout table
activate GroupLayout
GroupLayout-->>Scheduler: group assignment for current_iter
deactivate GroupLayout
Scheduler->>BlockState: update m_block_idx, n_block_idx
activate BlockState
Scheduler->>Scheduler: advance current_iter
BlockState-->>Scheduler: state updated
deactivate BlockState
Scheduler->>Scheduler: get_swizzle_block_idx()
Scheduler-->>BlockIdx: return true (next block available)
else no more blocks
Scheduler-->>BlockIdx: return false (iteration complete)
end
deactivate Scheduler
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/sm120_blockwise_gemm/sm120_utils.cuh`:
- Around line 55-57: The calculation of kNumStagePerSF uses integer division and
can silently drop K tiles when kNumTileKPerSF is not divisible by AB_Stages;
update the compile-time checks to enforce divisibility by adding a guard that
kNumTileKPerSF % AB_Stages == 0 and revise the static_assert on kNumStagePerSF
(and/or add a new static_assert) to require both that kNumStagePerSF > 0 &&
kNumStagePerSF <= 2 and that kNumTileKPerSF % AB_Stages == 0, referencing the
existing symbols kNumTileKPerSF, kNumStagePerSF, AB_Stages, and the current
static_assert to locate where to add the check and error message.
In `@tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py`:
- Line 190: The test calls per_block_cast_to_fp8_e8m0 on a 3D tensor b (shape
(num_groups,n,k)) which asserts for 2D inputs; update the SM 120 branch where
getSMVersion() == 120 so it handles the 3D case by iterating over groups and
invoking per_block_cast_to_fp8_e8m0 (or per_block_cast_to_fp8) on each 2D slice
b[i], assigning results into b_fp8[i] and b_scales[i] (mirror the loop used in
the else branch), ensuring b_fp8 and b_scales are preallocated with the correct
shape before the loop.
🧹 Nitpick comments (4)
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/sm120_blockwise_gemm/sm120_utils.cuh (3)
65-66: Remove commented-out code.Per coding guidelines, dead code should not be disabled with comments; use
#if/#endifwith a mnemonic condition if the code needs to be preserved, otherwise remove it. As per coding guidelines: "Do not use comments to disable code."
363-429: Scheduler: member naming and a minor style note.The
SM120BlockScaledSchedulerstruct members (current_iter,m_block_idx,num_groups, etc.) use snake_case, but the coding guidelines require class/struct member variables to use camelCase withmprefix (e.g.,mCurrentIter,mBlockIdx). That said, I see existing CUTLASS-style code in the codebase may differ—flagging for awareness. As per coding guidelines: "Use camelCase prefixed with 'm' for public, private and protected class member variables."
386-396:get_swizzle_block_idx: bindingconst&to temporaries is unusual.Lines 389–390 bind
const&to rvalue temporaries (num_m_blocks * kNum1DBlocksPerGroup,block_idx / num_blocks_per_group). While valid C++ (lifetime extension), this is non-idiomatic and can confuse readers. Prefer plainauto const.Suggested diff
- auto const& num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup; - auto const& group_idx = block_idx / num_blocks_per_group; + auto const num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup; + auto const group_idx = block_idx / num_blocks_per_group;cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh (1)
98-105:get_grid_shape: CUDA API return values are unchecked.
cudaGetDeviceandcudaDeviceGetAttributecan fail; if they do,sm_countremains uninitialized and produces an incorrect grid dimension. Check the return values or zero-initializesm_count.Note: This function is currently unused—the active callers in
fp8_blockscale_gemm_kernel.cuhcompute grid dimensions directly and bypass this function. However, this is a latent hazard if the function is called in future implementations.
...ensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/sm120_blockwise_gemm/sm120_utils.cuh
Show resolved
Hide resolved
209fc71 to
46a64dc
Compare
12123cc to
e9774d0
Compare
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
e9774d0 to
9e5c949
Compare
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
xxi-nv
left a comment
There was a problem hiding this comment.
The python MoE parts LGTM.
…ASS FP8 BlockScale Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
9f70c85 to
c1f4c29
Compare
|
/bot run |
|
PR_Github #37678 [ run ] triggered by Bot. Commit: |
|
PR_Github #37678 [ run ] completed with state
|
|
/bot run --stage-list "DGX_B200-PyTorch-1" |
|
PR_Github #37787 [ run ] triggered by Bot. Commit: |
|
PR_Github #37787 [ run ] completed with state |
Barry-Delaney
left a comment
There was a problem hiding this comment.
Overall LGTM. Thanks on the contribution! Left minor comments.
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu
Show resolved
Hide resolved
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
06a4c31 to
d4cf137
Compare
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
844a5e8 to
54cc308
Compare
|
/bot run |
|
PR_Github #38622 [ run ] triggered by Bot. Commit: |
|
PR_Github #38622 [ run ] completed with state
|
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
a53e8d6 to
867652c
Compare
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Summary by CodeRabbit
Description
This PR adds SM120 (Blackwell GeForce / RTX PRO 6000) support for FP8 block-scale MoE GEMM and optimizes the existing dense/BMM kernels.
Changes
Commit 1: Optimize 6KD fp8 blockscale gemm
SM120BlockScaledBuilderwith a configurableStagestemplate parameter (default 4) for flexible pipeline depth tuningSM120BlockScaledSchedulerfor persistent-thread block scheduling across groups, replacing per-problem kernel launchesPermMmaTileN/K,TiledMmalayout) and barrier management (store_full_mbar/store_empty_mbar) for better SM120 occupancydim3(num_device_sms, 1, 1)for persistent kernel styleCommit 2: Fix stride overflow
inttoint64_tin SM120 kernel dispatch and scheduler to prevent overflow on large problem sizesCommit 3: Support MoE GEMM for SM120
SM120BlockScaledMoeKernel— a dedicated MoE grouped GEMM kernel for SM120 with per-expert TMA addressing andSM120BlockScaledMoeSchedulerscale_1x128_kernel_sm120— an online BF16-to-FP8 quantization kernel with E8M0 scale generation, usingcompute_padded_offset(alignment=4) for MoE-aware scale layoutgrouped_gemm_dispatch_sm120to wire quantization + GEMM kernels togetherCutlassFp8BlockScaleGemmRunner::moeGemmandfp8_block_scaling_moe_gemm_blackwell_geforcethop entry pointCutlassFusedMoEviaDeepSeekFP8BlockScalesFusedMoEMethodDeepGemmwithresmooth_and_transform_fp8_scalefor weight preprocessingtest_fp8_block_scale_moe_gemm) covering variousnum_rows,num_experts,top_k, and matrix dimensionsKey design notes
internal_quantize_a=true): the BF16 activation is quantized to FP8 + E8M0 scale inside the kernel workspace, then fed to the CUTLASS grouped GEMMresmooth_to_fp8_e8m0+transform_sf_into_required_layoutinto int32 packed col-major TMA-aligned formatdeep_gemm::compute_padded_offset(alignment=32) which is always >=sm120_blockscaled_gemm::compute_padded_offset(alignment=4), ensuring no buffer overflowTest Coverage
tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py::test_fp8_block_scale_moe_gemm— 18 parametrized cases covering:num_rows: 7, 64, 128num_experts, top_k: (4, 3), (8, 4), (16, 5)k, n: (7168, 2112), (2048, 7168)PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.