[TRTLLM-10929][feat] add fp8 combine in moe_a2a#11844
[TRTLLM-10929][feat] add fp8 combine in moe_a2a#11844dc3671 wants to merge 13 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis pull request adds FP8 quantization support to MOE AllToAll communication kernels. The changes introduce FP8-aware prepare and combine kernels, extend vectorized accumulation to handle multiple types, and add a Changes
Sequence DiagramsequenceDiagram
participant PyTorch as PyTorch Op
participant Prepare as Prepare Kernel
participant Recv as Recv Buffer
participant Quant as Quantize
participant Combine as Combine Kernel
participant Output as Output BF16
rect rgba(100, 150, 255, 0.5)
Note over PyTorch,Output: Traditional BF16 Combine Flow
PyTorch->>Prepare: launch prepare kernel
Prepare->>Recv: write payload
Recv->>Combine: read BF16 payload
Combine->>Combine: accumulate & compute
Combine->>Output: write BF16 result
end
rect rgba(150, 200, 100, 0.5)
Note over PyTorch,Output: FP8 Combine Flow (fp8_combine=True)
PyTorch->>Prepare: launch FP8 prepare kernel
Prepare->>Quant: quantize BF16→FP8
Quant->>Recv: store FP8 payload
Recv->>Combine: read FP8 data, convert to BF16
Combine->>Combine: accumulate float32, cast to BF16
Combine->>Output: write BF16 result
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp (1)
440-455:⚠️ Potential issue | 🟠 MajorValidate BF16 input when
fp8Combineis enabled.
fp8Combineforces BF16 output, but current input checks still accept FP16/FP32. That permits silent dtype behavior changes and a path outside the documented BF16→FP8→BF16 contract.✅ Suggested guard
else { TORCH_CHECK(false, "Unsupported data type for payload"); } + TORCH_CHECK(!fp8Combine || scalarType == at::kBFloat16, + "fp8_combine requires bfloat16 payload");Also applies to: 489-491
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp` around lines 440 - 455, When fp8Combine is true you must enforce BF16 input—update the dtype-checking logic around scalarType/nvDtype (the branch that currently maps at::kHalf, at::kBFloat16, at::kFloat and the TORCH_CHECK) to add a guard: if fp8Combine is enabled, require scalarType == at::kBFloat16 and fail with a clear TORCH_CHECK message otherwise; apply the same guard to the identical mapping block later in this file (the other scalarType→nvDtype conversion) so inputs cannot be FP16/FP32 when fp8Combine forces BF16 outputs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp`:
- Around line 440-455: When fp8Combine is true you must enforce BF16
input—update the dtype-checking logic around scalarType/nvDtype (the branch that
currently maps at::kHalf, at::kBFloat16, at::kFloat and the TORCH_CHECK) to add
a guard: if fp8Combine is enabled, require scalarType == at::kBFloat16 and fail
with a clear TORCH_CHECK message otherwise; apply the same guard to the
identical mapping block later in this file (the other scalarType→nvDtype
conversion) so inputs cannot be FP16/FP32 when fp8Combine forces BF16 outputs.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cucpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.hcpp/tensorrt_llm/thop/moeAlltoAllOp.cpptensorrt_llm/_torch/distributed/moe_alltoall.pytensorrt_llm/_torch/modules/fused_moe/communication/base.pytensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.pytensorrt_llm/_torch/modules/fused_moe/configurable_moe.pytests/unittest/_torch/multi_gpu/test_moe_a2a.py
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #37448 [ run ] triggered by Bot. Commit: |
|
PR_Github #37448 [ run ] completed with state
|
cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
Outdated
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
Outdated
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
Outdated
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
Outdated
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
Outdated
Show resolved
Hide resolved
4a06b46 to
11c93e1
Compare
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #38580 [ run ] triggered by Bot. Commit: |
|
PR_Github #38580 [ run ] completed with state
|
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #38599 [ run ] triggered by Bot. Commit: |
| // Convert SrcT → DstT. | ||
| vec_t<DstT, VEC_SIZE> out_vec; | ||
|
|
||
| #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) |
There was a problem hiding this comment.
| #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) | |
| #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) |
There was a problem hiding this comment.
I previously use 890, but CI failed on building:
"ptxas /tmp/tmpxft_00002857_00000000-9_moeAlltoAllKernels.compute_90a.ptx, line 50850; error : Feature 'cvt.e4m3x2.bf16x2' not supported on .target 'sm_90a'"
There was a problem hiding this comment.
bf16x2 only supported on Blackwell, but f16 can be supported on Hopper.
There was a problem hiding this comment.
Got it, thanks. Please denote in the comment.
There was a problem hiding this comment.
Seperate bf16 and fp16's logic with different arch requirement.
There was a problem hiding this comment.
I suggest using template specialization for vectorized_quant_impl, so that the code could be much cleaner without constexpr if.
|
PR_Github #38599 [ run ] completed with state
|
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #38661 [ run ] triggered by Bot. Commit: |
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
Description
Summary
Adds an FP8 combine path to the MoE All-to-All communication kernel. When fp8_combine=True, the BF16 expert outputs are quantized to FP8 before transmission over NVLink and accumulated back to BF16 on the receiving rank. This reduces NVLink bandwidth by 2× on the combine payload at the cost of FP8 rounding error.
E2E perf - DS-R1 nvfp4 DEP16 EPLB0 MTP0 8K1K samples=1024:
Both staging modes are supported:
Kernel changes (moeAlltoAllKernels.cu)
Unified prepare-combine kernel
moeA2APrepareCombineKernel<ThreadingPolicy, bool FP8_COMBINE, SrcT> replaces the two separate kernels (byte-copy and FP8-quant). The FP8_COMBINE boolean is a compile-time template parameter resolved via SWITCH_BOOL, so all branches dead-code-eliminate at compile time:
vectorized_quant
Replaced the scalar element loop with vectorized_quant<ThreadingPolicy, SrcT, DstT> backed by vec_t<SrcT, N> wide loads and vec_t<DstT, N> wide stores. Vector width N is selected as the largest of {16, 8, 4, 2, 1} that evenly divides elements_per_token. For the in-place case, all SrcT values are loaded into registers before any DstT bytes
are written, making single-CTA in-place safe.
stride_per_token in the combine kernel
vectorized_combine_impl / vectorized_combine / moeA2ACombineKernel gain a stride_per_token parameter (byte distance between tokens in the recv buffer). This decouples the loop bound (size_per_token = EPT × sizeof(T)) from the address stride:
moe_a2a_combine_launch computes fp8_in_place = (fp8_combine && prepare_payload == nullptr) and selects the correct stride before calling the kernel.
Launch-function cleanup
Test changes (test_moe_a2a.py)
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
New Features
fp8_combineparameter across MOE communication APIs to enable FP8 quantization during communication operations.Tests