fix: extend moe alltoall top-k specializations#3021
fix: extend moe alltoall top-k specializations#3021bobboli wants to merge 3 commits intoflashinfer-ai:mainfrom
Conversation
Add specialized dispatch and combine paths for larger top-k values up to 22 while keeping the generic fallback for other valid cases. Remove the unused kMaxExperts limit and align the payload cap with the current four-payload moe alltoall path.
📝 WalkthroughWalkthroughThe PR extends MOE all-to-all kernels with compile-time specializations for additional top_k values (22, 16, 10, 6), adds a vectorized Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)
53-97:⚠️ Potential issue | 🔴 Critical
top_k <= 22is not actually supported yet.Lines 481 and 953 now accept any
top_kin[1, 22], but this switch still instantiates only1, 2, 4, 6, 8, 10, 16, 22. Values like3, 5, 7, 9, 11-15, 17-21will still fail at runtime in thedefaultarm, and the generic reduction fallback on Lines 749-755 is unreachable for them. Either enumerate the remaining cases so the generic path can run, or tighten validation back to the truly supported set until that fallback is wired through.Also applies to: 749-755
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu` around lines 53 - 97, The SWITCH_TOP_K macro currently only handles top_k values {1,2,4,6,8,10,16,22} but callers accept any 1..22, causing unexpected hits to the default error; either (A) update the call-site validation where top_k is accepted (the functions that currently permit any top_k) to restrict allowed values to exactly {1,2,4,6,8,10,16,22}, or (B) expand SWITCH_TOP_K to enumerate the remaining top_k values (3,5,7,9,11-15,17-21) and route them into the generic reduction fallback path present in this file (the generic reduction fallback block), so those unsupported sizes run the generic reduction instead of hitting the default error. Ensure you reference and update the SWITCH_TOP_K macro and the call-site validation (or add cases that invoke the generic fallback) so behavior matches the accepted top_k range.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h`:
- Around line 24-25: The header's formatting is out of sync (pre-commit rewrites
the lines defining kMaxTopK and kMaxPayloads), so run clang-format on the header
that declares the static constexpr symbols kMaxTopK and kMaxPayloads, save the
reformatted file, and re-stage the updated file so pre-commit no longer modifies
those lines during CI.
---
Outside diff comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`:
- Around line 53-97: The SWITCH_TOP_K macro currently only handles top_k values
{1,2,4,6,8,10,16,22} but callers accept any 1..22, causing unexpected hits to
the default error; either (A) update the call-site validation where top_k is
accepted (the functions that currently permit any top_k) to restrict allowed
values to exactly {1,2,4,6,8,10,16,22}, or (B) expand SWITCH_TOP_K to enumerate
the remaining top_k values (3,5,7,9,11-15,17-21) and route them into the generic
reduction fallback path present in this file (the generic reduction fallback
block), so those unsupported sizes run the generic reduction instead of hitting
the default error. Ensure you reference and update the SWITCH_TOP_K macro and
the call-site validation (or add cases that invoke the generic fallback) so
behavior matches the accepted top_k range.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 74bd405f-090c-4b55-bc9d-7528049e20d6
📒 Files selected for processing (2)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cucsrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
Show resolved
Hide resolved
There was a problem hiding this comment.
Code Review
This pull request increases the maximum supported top-k experts for MoE All-to-All kernels from 8 to 22. It introduces a new accumulate_vec helper function and refactors the reduction logic in vectorized_combine_impl to support additional TOP_K values (22, 16, 10, and 6) using optimized accumulation paths. Furthermore, the kMaxTopK constant was updated, kMaxPayloads was reduced, and kMaxExperts was removed from the configuration. I have no feedback to provide.
|
Hi @bobboli are these topk values tested? |
Hi @yzh119 @bobboli , I tested some cases local and it works well. Meanwhile I'm verifying the cases in LLM framework. Will give feedback if have results. We can add more tests then~ Update: veried here sgl-project/sglang#22669 |
Cover the combine path with parameter sets that mirror representative MoE model configurations and add focused Qwen coverage for dtype and workspace staging. This keeps the regression matrix closer to real-world routing shapes while preserving targeted edge-case checks.
Thanks, I have refined the test coverage. |
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/comm/test_trtllm_moe_alltoall.py (1)
475-478:max_world_size=16is currently unexercised by this test matrix.The assertion now allows 16, but
COMBINE_PARAMSonly goes up to 8. Consider either adding a 16-rank tuple (if feasible) or keeping the bound at 8 to match actual coverage.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_trtllm_moe_alltoall.py` around lines 475 - 478, The test sets max_world_size = 16 but the test matrix (COMBINE_PARAMS) only covers up to 8 ranks; either reduce the bound or expand the matrix. Fix by updating the max_world_size variable in tests/comm/test_trtllm_moe_alltoall.py to 8 to match COMBINE_PARAMS, or alternatively add a 16-rank entry into COMBINE_PARAMS so that world_size==16 is actually exercised; reference the max_world_size symbol and the COMBINE_PARAMS collection when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/comm/test_trtllm_moe_alltoall.py`:
- Around line 55-71: The test matrix COMBINE_PARAMS is missing explicit coverage
for top_k=16 and some high-top-k rows use skip-prone shapes (world_size=8,
num_tokens=16); add at least one tuple that exercises top_k=16 (e.g., same
dtype/payload as others) and either replace or add lower-resource variants for
the high top_k cases so they run on smaller GPUs (for example use world_size=4
and/or num_tokens=8 for the new top_k values) to ensure the new top-k codepaths
are exercised without requiring 128 SM hardware.
---
Nitpick comments:
In `@tests/comm/test_trtllm_moe_alltoall.py`:
- Around line 475-478: The test sets max_world_size = 16 but the test matrix
(COMBINE_PARAMS) only covers up to 8 ranks; either reduce the bound or expand
the matrix. Fix by updating the max_world_size variable in
tests/comm/test_trtllm_moe_alltoall.py to 8 to match COMBINE_PARAMS, or
alternatively add a 16-rank entry into COMBINE_PARAMS so that world_size==16 is
actually exercised; reference the max_world_size symbol and the COMBINE_PARAMS
collection when making the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 26e6f439-5130-413f-94ba-5777db29e72a
📒 Files selected for processing (1)
tests/comm/test_trtllm_moe_alltoall.py
| # (world_size, num_tokens, vector_dim, top_k, dtype, payload_in_workspace) | ||
| COMBINE_PARAMS = [ | ||
| (2, 64, 8, 2, torch.bfloat16, True), # Small input, 2 ranks | ||
| (4, 32, 32768, 4, torch.bfloat16, True), # Large input, 4 ranks | ||
| (8, 16, 2048, 8, torch.bfloat16, True), # Medium input, 8 ranks | ||
| (8, 16, 2048, 8, torch.bfloat16, False), # Medium input, 8 ranks | ||
| (2, 64, 8, 2, torch.float16, True), # Small input, 2 ranks | ||
| (4, 32, 32768, 4, torch.float16, True), # Large input, 4 ranks | ||
| (8, 16, 2048, 8, torch.float16, True), # Medium input, 8 ranks | ||
| (8, 16, 2048, 8, torch.float16, False), # Medium input, 8 ranks | ||
| # Coverage for popular model specifications | ||
| (4, 16, 4096, 2, torch.bfloat16, True), # Mixtral-8x7B | ||
| (4, 16, 2880, 4, torch.bfloat16, True), # GPT-OSS-120B | ||
| (8, 16, 5120, 6, torch.bfloat16, True), # DeepSeek-V2 | ||
| (8, 16, 7168, 8, torch.bfloat16, True), # DeepSeek-V3 | ||
| (8, 16, 4096, 8, torch.bfloat16, True), # Qwen3-235B-A22B | ||
| (8, 16, 4096, 10, torch.bfloat16, True), # Qwen3.5-397B-A17B | ||
| (8, 16, 4096, 22, torch.bfloat16, True), # Nemotron-3-Super-120B-A12B | ||
| # Coverage for num_tokens | ||
| (8, 1, 4096, 8, torch.bfloat16, True), | ||
| # Coverage for dtype | ||
| (8, 16, 4096, 8, torch.float16, True), | ||
| # Coverage for payload_in_workspace | ||
| (8, 16, 4096, 8, torch.bfloat16, False), | ||
| ] |
There was a problem hiding this comment.
Add explicit top_k=16 coverage and reduce skip-prone shapes for new top-k cases.
This matrix validates top_k 6/10/22 but misses top_k=16, even though that specialization is part of this PR. Also, all new high top-k cases use world_size=8, num_tokens=16 (128 SM requirement), which may be skipped on many GPUs and leave the new paths untested.
Proposed matrix adjustment
COMBINE_PARAMS = [
@@
- (8, 16, 5120, 6, torch.bfloat16, True), # DeepSeek-V2
+ (8, 4, 5120, 6, torch.bfloat16, True), # DeepSeek-V2 (lower SM requirement)
@@
- (8, 16, 4096, 10, torch.bfloat16, True), # Qwen3.5-397B-A17B
- (8, 16, 4096, 22, torch.bfloat16, True), # Nemotron-3-Super-120B-A12B
+ (8, 4, 4096, 10, torch.bfloat16, True), # Qwen3.5-397B-A17B (lower SM requirement)
+ (8, 4, 4096, 16, torch.bfloat16, True), # Explicit coverage for top_k=16 specialization
+ (8, 4, 4096, 22, torch.bfloat16, True), # Nemotron-3-Super-120B-A12B (lower SM requirement)🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 55-55: pre-commit hook 'clang-format' failed. Hook modified file formatting.
[error] 55-55: pre-commit hook 'ruff-format' failed. Hook reformatted file.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/comm/test_trtllm_moe_alltoall.py` around lines 55 - 71, The test matrix
COMBINE_PARAMS is missing explicit coverage for top_k=16 and some high-top-k
rows use skip-prone shapes (world_size=8, num_tokens=16); add at least one tuple
that exercises top_k=16 (e.g., same dtype/payload as others) and either replace
or add lower-resource variants for the high top_k cases so they run on smaller
GPUs (for example use world_size=4 and/or num_tokens=8 for the new top_k values)
to ensure the new top-k codepaths are exercised without requiring 128 SM
hardware.
samuellees
left a comment
There was a problem hiding this comment.
Please consider adding test cases for the new top_k values (e.g. (32, 7168, 512, 10) and (64, 4096, 256, 16)) to SINGLE_GPU_PARAMS — I verified these pass on B200 but they should be in CI~
Summary
top_kvalues 6, 10, 16, and 22top_kvalues 6, 10, 16, and 22 while keeping the generic fallback for other valid casesTest plan
python3 -m py_compile tests/comm/test_trtllm_moe_alltoall.pySummary by CodeRabbit
Performance & Optimization
Tests