Skip to content

[TRTLLM-10929][feat] add fp8 combine in moe_a2a#11844

Open
dc3671 wants to merge 13 commits intoNVIDIA:mainfrom
dc3671:add-fp8-combine
Open

[TRTLLM-10929][feat] add fp8 combine in moe_a2a#11844
dc3671 wants to merge 13 commits intoNVIDIA:mainfrom
dc3671:add-fp8-combine

Conversation

@dc3671
Copy link
Collaborator

@dc3671 dc3671 commented Mar 3, 2026

Description

Summary

Adds an FP8 combine path to the MoE All-to-All communication kernel. When fp8_combine=True, the BF16 expert outputs are quantized to FP8 before transmission over NVLink and accumulated back to BF16 on the receiving rank. This reduces NVLink bandwidth by 2× on the combine payload at the cost of FP8 rounding error.

in payload fp8 bf16
quant/copy 2.656 μs 3.072 μs
combine 7.680 μs 15.712 μs
in workspace fp8 bf16
quant/copy 2.400 μs 1.440 μs
combine 7.456 μs 15.232 μs

E2E perf - DS-R1 nvfp4 DEP16 EPLB0 MTP0 8K1K samples=1024:

throughput 10796.54 10257.16 105.26%
tpot 85.92 91.1 106.03%

Both staging modes are supported:

  • payload_in_workspace=False (external payload): the prepare kernel reads from the caller's BF16 tensor, writes compact FP8 to the workspace at token_idx × EPT × 1 stride, then the combine kernel reads at that compact stride.
  • payload_in_workspace=True (in-place): the prepare kernel reads BF16 already resident in the workspace and overwrites it with FP8 at the same BF16-stride offset (token_idx × EPT × sizeof(BF16)), keeping adjacent-token byte ranges non-overlapping across CTAs. The combine kernel reads at the BF16 stride instead of compact FP8 stride.

Kernel changes (moeAlltoAllKernels.cu)

Unified prepare-combine kernel

moeA2APrepareCombineKernel<ThreadingPolicy, bool FP8_COMBINE, SrcT> replaces the two separate kernels (byte-copy and FP8-quant). The FP8_COMBINE boolean is a compile-time template parameter resolved via SWITCH_BOOL, so all branches dead-code-eliminate at compile time:

┌─────────────┬──────────┬─────────────────────────────────────────────────────────────────┐
│ FP8_COMBINE │ payload  │                             Action                              │
├─────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ false       │ non-null │ vectorized byte-copy external → workspace                       │
├─────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ false       │ null     │ flag increment only (no-op, 1-block grid)                       │
├─────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ true        │ non-null │ vectorized_quant BF16→FP8, compact FP8 stride dst               │
├─────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ true        │ null     │ vectorized_quant BF16→FP8 in-place, BF16 stride dst (race-free) │
└─────────────┴──────────┴─────────────────────────────────────────────────────────────────┘

vectorized_quant

Replaced the scalar element loop with vectorized_quant<ThreadingPolicy, SrcT, DstT> backed by vec_t<SrcT, N> wide loads and vec_t<DstT, N> wide stores. Vector width N is selected as the largest of {16, 8, 4, 2, 1} that evenly divides elements_per_token. For the in-place case, all SrcT values are loaded into registers before any DstT bytes
are written, making single-CTA in-place safe.

stride_per_token in the combine kernel

vectorized_combine_impl / vectorized_combine / moeA2ACombineKernel gain a stride_per_token parameter (byte distance between tokens in the recv buffer). This decouples the loop bound (size_per_token = EPT × sizeof(T)) from the address stride:

  • Non-FP8 or FP8 external: stride_per_token = EPT × sizeof(TKernelType) (same as before)
  • FP8 in-place: stride_per_token = EPT × sizeof(payload_dtype) (e.g. EPT × 2 for BF16 payload)

moe_a2a_combine_launch computes fp8_in_place = (fp8_combine && prepare_payload == nullptr) and selects the correct stride before calling the kernel.

Launch-function cleanup

  • Removed TLLM_CHECK_WITH_INFO that hard-blocked fp8_combine + payload_in_workspace=True.
  • global_token_num is always ep_size × max_tokens_per_rank when fp8_combine=True (even in-place tokens need quantizing); falls back to 1 (flag-increment only) for the BF16 no-op path.
  • moe_a2a_prepare_combine_launch now uses SWITCH_BOOL(fp8_combine) + SWITCH_DTYPE to dispatch a single templated kernel for both BF16 copy and FP8 quant paths.
  • moe_a2a_combine_launch uses effective_dtype = fp8_combine ? kFP8 : dtype so the FP8 accumulation kernel is dispatched regardless of the payload's declared dtype.

Test changes (test_moe_a2a.py)

  • test_combine: added payload_in_workspace parameter (was hardcoded False). Added 4 True cases: uniform, zero-token ranks, non-uniform tokens, top_k=4.
  • test_combine_fp8: added payload_in_workspace parameter. Added 4 True cases matching the False set. Increased hidden_size to 2880 (gpt-oss-20b) for realistic magnitude. Tightened tolerances to rtol=0.13 / atol=0.5 (FP8 e4m3fn 3-bit mantissa → ≤12.5% relative error per element).
  • run_moe_a2a_dispatch_moe_combine_fp8_single_rank: introduced _combine() helper that optionally stages the MoE output through the workspace buffer. Two independent dispatch_and_fake_moe() calls for rounds 1 (BF16 reference) and 2 (FP8 under test) avoid any tensor aliasing between rounds.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

  • New Features

    • Added FP8 quantization support for Mixture of Experts (MOE) all-to-all communication, reducing NVLink transfer bandwidth by half while maintaining BF16 output precision.
    • Introduced optional fp8_combine parameter across MOE communication APIs to enable FP8 quantization during communication operations.
  • Tests

    • Added comprehensive FP8 quantization test coverage for MOE all-to-all communication workflows.

@dc3671 dc3671 requested a review from bobboli March 3, 2026 01:32
@dc3671 dc3671 requested review from a team as code owners March 3, 2026 01:32
@dc3671 dc3671 requested review from dongxuy04 and yuxianq March 3, 2026 01:32
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 3, 2026

📝 Walkthrough

Walkthrough

This pull request adds FP8 quantization support to MOE AllToAll communication kernels. The changes introduce FP8-aware prepare and combine kernels, extend vectorized accumulation to handle multiple types, and add a fp8_combine parameter throughout the C++ operator and Python API layers to enable BF16→FP8 quantization during NVLink transfer, with corresponding test coverage.

Changes

Cohort / File(s) Summary
CUDA Kernel Implementation
cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
Extends dtype handling to include kFP8, replaces rigid vectorized accumulate with generic multi-template support (T, InT), introduces FP8-specific prepare/combine kernels with quantization utilities (vectorized_quant), updates kernel signatures to accept stride_per_token parameter, and adds FP8-aware kernel dispatch in launch wrappers while preserving non-FP8 backward compatibility.
Kernel Header & Operator Binding
cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h, cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp
Adds fp8_combine boolean flag to MoeA2ACombineParams struct; updates C++ operator to accept fp8Combine parameter, sets output tensor dtype to BF16 when FP8 combine is enabled, and updates PyTorch binding signature to expose the new flag.
Python Distributed & Communication APIs
tensorrt_llm/_torch/distributed/moe_alltoall.py, tensorrt_llm/_torch/modules/fused_moe/communication/base.py, tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py
Adds fp8_combine parameter to MoeAlltoAll.combine() and Communication.combine() methods with updated docstrings explaining FP8 quantization behavior; threads parameter through to underlying kernel calls.
MOE Module Integration
tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
Introduces combine_fp8 parameter to ConfigurableMoE constructor, stores as instance attribute, and propagates it to the combine() call in _forward_chunk_impl via the fp8_combine keyword argument.
Test Suite
tests/unittest/_torch/multi_gpu/test_moe_a2a.py
Adds payload_in_workspace parameter to existing test functions, introduces new test_combine_fp8() test path with corresponding FP8-specific worker function, adds verify_combine_fp8() verification routine to validate FP8 outputs against BF16 references with FP8 rounding tolerances, and extends test scaffolding with FP8-aware skips and multi-round verification.

Sequence Diagram

sequenceDiagram
    participant PyTorch as PyTorch Op
    participant Prepare as Prepare Kernel
    participant Recv as Recv Buffer
    participant Quant as Quantize
    participant Combine as Combine Kernel
    participant Output as Output BF16

    rect rgba(100, 150, 255, 0.5)
        Note over PyTorch,Output: Traditional BF16 Combine Flow
        PyTorch->>Prepare: launch prepare kernel
        Prepare->>Recv: write payload
        Recv->>Combine: read BF16 payload
        Combine->>Combine: accumulate & compute
        Combine->>Output: write BF16 result
    end

    rect rgba(150, 200, 100, 0.5)
        Note over PyTorch,Output: FP8 Combine Flow (fp8_combine=True)
        PyTorch->>Prepare: launch FP8 prepare kernel
        Prepare->>Quant: quantize BF16→FP8
        Quant->>Recv: store FP8 payload
        Recv->>Combine: read FP8 data, convert to BF16
        Combine->>Combine: accumulate float32, cast to BF16
        Combine->>Output: write BF16 result
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.54% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly summarizes the main change: adding FP8 combine support to the MoE All-to-All kernel. It follows the required format with ticket ID [TRTLLM-10929] and feature type [feat].
Description check ✅ Passed The PR description is mostly complete with detailed technical explanation of the FP8 combine feature, implementation details, kernel changes, and test coverage. However, the Test Coverage section is empty.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp (1)

440-455: ⚠️ Potential issue | 🟠 Major

Validate BF16 input when fp8Combine is enabled.

fp8Combine forces BF16 output, but current input checks still accept FP16/FP32. That permits silent dtype behavior changes and a path outside the documented BF16→FP8→BF16 contract.

✅ Suggested guard
     else
     {
         TORCH_CHECK(false, "Unsupported data type for payload");
     }
+    TORCH_CHECK(!fp8Combine || scalarType == at::kBFloat16,
+        "fp8_combine requires bfloat16 payload");

Also applies to: 489-491

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp` around lines 440 - 455, When
fp8Combine is true you must enforce BF16 input—update the dtype-checking logic
around scalarType/nvDtype (the branch that currently maps at::kHalf,
at::kBFloat16, at::kFloat and the TORCH_CHECK) to add a guard: if fp8Combine is
enabled, require scalarType == at::kBFloat16 and fail with a clear TORCH_CHECK
message otherwise; apply the same guard to the identical mapping block later in
this file (the other scalarType→nvDtype conversion) so inputs cannot be
FP16/FP32 when fp8Combine forces BF16 outputs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp`:
- Around line 440-455: When fp8Combine is true you must enforce BF16
input—update the dtype-checking logic around scalarType/nvDtype (the branch that
currently maps at::kHalf, at::kBFloat16, at::kFloat and the TORCH_CHECK) to add
a guard: if fp8Combine is enabled, require scalarType == at::kBFloat16 and fail
with a clear TORCH_CHECK message otherwise; apply the same guard to the
identical mapping block later in this file (the other scalarType→nvDtype
conversion) so inputs cannot be FP16/FP32 when fp8Combine forces BF16 outputs.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a632f0f and 3f23655.

📒 Files selected for processing (8)
  • cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
  • cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
  • cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp
  • tensorrt_llm/_torch/distributed/moe_alltoall.py
  • tensorrt_llm/_torch/modules/fused_moe/communication/base.py
  • tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py
  • tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
  • tests/unittest/_torch/multi_gpu/test_moe_a2a.py

@dc3671 dc3671 requested a review from xxi-nv March 3, 2026 04:05
@dc3671 dc3671 requested a review from a team as a code owner March 3, 2026 06:36
@dc3671
Copy link
Collaborator Author

dc3671 commented Mar 3, 2026

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37448 [ run ] triggered by Bot. Commit: 1e184dc Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37448 [ run ] completed with state SUCCESS. Commit: 1e184dc
/LLM/main/L0_MergeRequest_PR pipeline #28987 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@dc3671 dc3671 requested a review from a team as a code owner March 11, 2026 03:19
@dc3671 dc3671 requested a review from yizhang-nv March 11, 2026 03:19
@dc3671 dc3671 force-pushed the add-fp8-combine branch 2 times, most recently from 4a06b46 to 11c93e1 Compare March 11, 2026 09:34
@dc3671
Copy link
Collaborator Author

dc3671 commented Mar 11, 2026

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38580 [ run ] triggered by Bot. Commit: 11c93e1 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38580 [ run ] completed with state FAILURE. Commit: 11c93e1
/LLM/main/L0_MergeRequest_PR pipeline #29917 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@dc3671
Copy link
Collaborator Author

dc3671 commented Mar 11, 2026

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38599 [ run ] triggered by Bot. Commit: d912556 Link to invocation

// Convert SrcT → DstT.
vec_t<DstT, VEC_SIZE> out_vec;

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I previously use 890, but CI failed on building:

"ptxas /tmp/tmpxft_00002857_00000000-9_moeAlltoAllKernels.compute_90a.ptx, line 50850; error : Feature 'cvt.e4m3x2.bf16x2' not supported on .target 'sm_90a'"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bf16x2 only supported on Blackwell, but f16 can be supported on Hopper.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks. Please denote in the comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seperate bf16 and fp16's logic with different arch requirement.

Copy link
Collaborator

@bobboli bobboli Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest using template specialization for vectorized_quant_impl, so that the code could be much cleaner without constexpr if.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38599 [ run ] completed with state SUCCESS. Commit: d912556
/LLM/main/L0_MergeRequest_PR pipeline #29935 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

dc3671 added 12 commits March 11, 2026 20:09
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
@dc3671
Copy link
Collaborator Author

dc3671 commented Mar 12, 2026

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38661 [ run ] triggered by Bot. Commit: ebe54b4 Link to invocation

Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants