Skip to content

fix: extend moe alltoall top-k specializations#3021

Open
bobboli wants to merge 3 commits intoflashinfer-ai:mainfrom
bobboli:moe-alltoall-topk-22
Open

fix: extend moe alltoall top-k specializations#3021
bobboli wants to merge 3 commits intoflashinfer-ai:mainfrom
bobboli:moe-alltoall-topk-22

Conversation

@bobboli
Copy link
Copy Markdown
Contributor

@bobboli bobboli commented Apr 9, 2026

Summary

  • extend moe alltoall dispatch launch specialization to top_k values 6, 10, 16, and 22
  • add explicit combine reduction trees for top_k values 6, 10, 16, and 22 while keeping the generic fallback for other valid cases
  • align combine regression coverage with representative MoE model parameter sets and targeted Qwen coverage for dtype and workspace staging

Test plan

  • python3 -m py_compile tests/comm/test_trtllm_moe_alltoall.py
  • GPU pytest not run in this environment

Summary by CodeRabbit

  • Performance & Optimization

    • Extended MOE all-to-all kernel limits to 22 (was 8).
    • Optimized communication kernel reduction logic for faster throughput.
    • Adjusted internal configuration sizing for larger scenarios.
  • Tests

    • Expanded test parameter coverage with new model-driven cases and increased max world size from 8 to 16.

Add specialized dispatch and combine paths for larger top-k values up to 22 while keeping the generic fallback for other valid cases. Remove the unused kMaxExperts limit and align the payload cap with the current four-payload moe alltoall path.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 9, 2026

📝 Walkthrough

Walkthrough

The PR extends MOE all-to-all kernels with compile-time specializations for additional top_k values (22, 16, 10, 6), adds a vectorized accumulate_vec device helper, refactors vectorized_combine_impl into structured else if constexpr branches using pairwise/vectorized accumulation, updates kMaxTopK to 22 and kMaxPayloads to 4, and expands related tests.

Changes

Cohort / File(s) Summary
MOE Kernel Implementation
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu, csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
Added compile-time specializations for top_k = 22,16,10,6; introduced accumulate_vec<T, ELEMS_PER_VEC>(dst, src) device helper; refactored vectorized_combine_impl to structured else if constexpr branches with pairwise/vectorized accumulation; changed config constants kMaxTopK 8→22 and kMaxPayloads 8→4.
Tests
tests/comm/test_trtllm_moe_alltoall.py
Expanded COMBINE_PARAMS to include model-driven parameter sets (Mixtral, GPT-OSS, DeepSeek-V2/V3, Qwen3, Nemotron), adjusted vector_dim cases, and raised max_world_size from 8 to 16.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested labels

op: moe, op: comm

Suggested reviewers

  • samuellees
  • yzh119
  • jiahanc
  • bkryu
  • aleozlx
  • nv-yunzheq
  • cyx-6

Poem

🐰 Hop-hop, I stitched the kernels bright,
Top-k grew taller, now twenty-two in sight.
Vectors fold neatly, sums intertwine,
A rabbit's small cheer for code well-aligned. 🥕✨

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description covers the summary of changes and test plan, but does not follow the template structure with required sections like Description, Related Issues, and Pre-commit Checks. Follow the repository's PR template by adding sections for Description, Related Issues, Pre-commit Checks, and Tests with proper checkboxes and formatting.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: extending MoE alltoall top-k specializations to additional values.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)

53-97: ⚠️ Potential issue | 🔴 Critical

top_k <= 22 is not actually supported yet.

Lines 481 and 953 now accept any top_k in [1, 22], but this switch still instantiates only 1, 2, 4, 6, 8, 10, 16, 22. Values like 3, 5, 7, 9, 11-15, 17-21 will still fail at runtime in the default arm, and the generic reduction fallback on Lines 749-755 is unreachable for them. Either enumerate the remaining cases so the generic path can run, or tighten validation back to the truly supported set until that fallback is wired through.

Also applies to: 749-755

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`
around lines 53 - 97, The SWITCH_TOP_K macro currently only handles top_k values
{1,2,4,6,8,10,16,22} but callers accept any 1..22, causing unexpected hits to
the default error; either (A) update the call-site validation where top_k is
accepted (the functions that currently permit any top_k) to restrict allowed
values to exactly {1,2,4,6,8,10,16,22}, or (B) expand SWITCH_TOP_K to enumerate
the remaining top_k values (3,5,7,9,11-15,17-21) and route them into the generic
reduction fallback path present in this file (the generic reduction fallback
block), so those unsupported sizes run the generic reduction instead of hitting
the default error. Ensure you reference and update the SWITCH_TOP_K macro and
the call-site validation (or add cases that invoke the generic fallback) so
behavior matches the accepted top_k range.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h`:
- Around line 24-25: The header's formatting is out of sync (pre-commit rewrites
the lines defining kMaxTopK and kMaxPayloads), so run clang-format on the header
that declares the static constexpr symbols kMaxTopK and kMaxPayloads, save the
reformatted file, and re-stage the updated file so pre-commit no longer modifies
those lines during CI.

---

Outside diff comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`:
- Around line 53-97: The SWITCH_TOP_K macro currently only handles top_k values
{1,2,4,6,8,10,16,22} but callers accept any 1..22, causing unexpected hits to
the default error; either (A) update the call-site validation where top_k is
accepted (the functions that currently permit any top_k) to restrict allowed
values to exactly {1,2,4,6,8,10,16,22}, or (B) expand SWITCH_TOP_K to enumerate
the remaining top_k values (3,5,7,9,11-15,17-21) and route them into the generic
reduction fallback path present in this file (the generic reduction fallback
block), so those unsupported sizes run the generic reduction instead of hitting
the default error. Ensure you reference and update the SWITCH_TOP_K macro and
the call-site validation (or add cases that invoke the generic fallback) so
behavior matches the accepted top_k range.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 74bd405f-090c-4b55-bc9d-7528049e20d6

📥 Commits

Reviewing files that changed from the base of the PR and between c2b4db2 and 7b1262c.

📒 Files selected for processing (2)
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request increases the maximum supported top-k experts for MoE All-to-All kernels from 8 to 22. It introduces a new accumulate_vec helper function and refactors the reduction logic in vectorized_combine_impl to support additional TOP_K values (22, 16, 10, and 6) using optimized accumulation paths. Furthermore, the kMaxTopK constant was updated, kMaxPayloads was reduced, and kMaxExperts was removed from the configuration. I have no feedback to provide.

@yzh119 yzh119 added the run-ci label Apr 9, 2026
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Apr 9, 2026

Hi @bobboli are these topk values tested?

@samuellees
Copy link
Copy Markdown
Collaborator

samuellees commented Apr 10, 2026

Hi @bobboli are these topk values tested?

Hi @yzh119 @bobboli , I tested some cases local and it works well. Meanwhile I'm verifying the cases in LLM framework. Will give feedback if have results. We can add more tests then~

Update: veried here sgl-project/sglang#22669

Cover the combine path with parameter sets that mirror representative MoE model configurations and add focused Qwen coverage for dtype and workspace staging. This keeps the regression matrix closer to real-world routing shapes while preserving targeted edge-case checks.
@bobboli
Copy link
Copy Markdown
Contributor Author

bobboli commented Apr 13, 2026

Hi @bobboli are these topk values tested?

Thanks, I have refined the test coverage.

@bobboli
Copy link
Copy Markdown
Contributor Author

bobboli commented Apr 13, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !536 has been created, and the CI pipeline #48383439 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/comm/test_trtllm_moe_alltoall.py (1)

475-478: max_world_size=16 is currently unexercised by this test matrix.

The assertion now allows 16, but COMBINE_PARAMS only goes up to 8. Consider either adding a 16-rank tuple (if feasible) or keeping the bound at 8 to match actual coverage.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_trtllm_moe_alltoall.py` around lines 475 - 478, The test sets
max_world_size = 16 but the test matrix (COMBINE_PARAMS) only covers up to 8
ranks; either reduce the bound or expand the matrix. Fix by updating the
max_world_size variable in tests/comm/test_trtllm_moe_alltoall.py to 8 to match
COMBINE_PARAMS, or alternatively add a 16-rank entry into COMBINE_PARAMS so that
world_size==16 is actually exercised; reference the max_world_size symbol and
the COMBINE_PARAMS collection when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/comm/test_trtllm_moe_alltoall.py`:
- Around line 55-71: The test matrix COMBINE_PARAMS is missing explicit coverage
for top_k=16 and some high-top-k rows use skip-prone shapes (world_size=8,
num_tokens=16); add at least one tuple that exercises top_k=16 (e.g., same
dtype/payload as others) and either replace or add lower-resource variants for
the high top_k cases so they run on smaller GPUs (for example use world_size=4
and/or num_tokens=8 for the new top_k values) to ensure the new top-k codepaths
are exercised without requiring 128 SM hardware.

---

Nitpick comments:
In `@tests/comm/test_trtllm_moe_alltoall.py`:
- Around line 475-478: The test sets max_world_size = 16 but the test matrix
(COMBINE_PARAMS) only covers up to 8 ranks; either reduce the bound or expand
the matrix. Fix by updating the max_world_size variable in
tests/comm/test_trtllm_moe_alltoall.py to 8 to match COMBINE_PARAMS, or
alternatively add a 16-rank entry into COMBINE_PARAMS so that world_size==16 is
actually exercised; reference the max_world_size symbol and the COMBINE_PARAMS
collection when making the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 26e6f439-5130-413f-94ba-5777db29e72a

📥 Commits

Reviewing files that changed from the base of the PR and between 7b1262c and 9299fdb.

📒 Files selected for processing (1)
  • tests/comm/test_trtllm_moe_alltoall.py

Comment on lines +55 to 71
# (world_size, num_tokens, vector_dim, top_k, dtype, payload_in_workspace)
COMBINE_PARAMS = [
(2, 64, 8, 2, torch.bfloat16, True), # Small input, 2 ranks
(4, 32, 32768, 4, torch.bfloat16, True), # Large input, 4 ranks
(8, 16, 2048, 8, torch.bfloat16, True), # Medium input, 8 ranks
(8, 16, 2048, 8, torch.bfloat16, False), # Medium input, 8 ranks
(2, 64, 8, 2, torch.float16, True), # Small input, 2 ranks
(4, 32, 32768, 4, torch.float16, True), # Large input, 4 ranks
(8, 16, 2048, 8, torch.float16, True), # Medium input, 8 ranks
(8, 16, 2048, 8, torch.float16, False), # Medium input, 8 ranks
# Coverage for popular model specifications
(4, 16, 4096, 2, torch.bfloat16, True), # Mixtral-8x7B
(4, 16, 2880, 4, torch.bfloat16, True), # GPT-OSS-120B
(8, 16, 5120, 6, torch.bfloat16, True), # DeepSeek-V2
(8, 16, 7168, 8, torch.bfloat16, True), # DeepSeek-V3
(8, 16, 4096, 8, torch.bfloat16, True), # Qwen3-235B-A22B
(8, 16, 4096, 10, torch.bfloat16, True), # Qwen3.5-397B-A17B
(8, 16, 4096, 22, torch.bfloat16, True), # Nemotron-3-Super-120B-A12B
# Coverage for num_tokens
(8, 1, 4096, 8, torch.bfloat16, True),
# Coverage for dtype
(8, 16, 4096, 8, torch.float16, True),
# Coverage for payload_in_workspace
(8, 16, 4096, 8, torch.bfloat16, False),
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add explicit top_k=16 coverage and reduce skip-prone shapes for new top-k cases.

This matrix validates top_k 6/10/22 but misses top_k=16, even though that specialization is part of this PR. Also, all new high top-k cases use world_size=8, num_tokens=16 (128 SM requirement), which may be skipped on many GPUs and leave the new paths untested.

Proposed matrix adjustment
 COMBINE_PARAMS = [
@@
-    (8, 16, 5120, 6, torch.bfloat16, True),     # DeepSeek-V2
+    (8, 4, 5120, 6, torch.bfloat16, True),      # DeepSeek-V2 (lower SM requirement)
@@
-    (8, 16, 4096, 10, torch.bfloat16, True),    # Qwen3.5-397B-A17B 
-    (8, 16, 4096, 22, torch.bfloat16, True),    # Nemotron-3-Super-120B-A12B
+    (8, 4, 4096, 10, torch.bfloat16, True),     # Qwen3.5-397B-A17B (lower SM requirement)
+    (8, 4, 4096, 16, torch.bfloat16, True),     # Explicit coverage for top_k=16 specialization
+    (8, 4, 4096, 22, torch.bfloat16, True),     # Nemotron-3-Super-120B-A12B (lower SM requirement)
🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 55-55: pre-commit hook 'clang-format' failed. Hook modified file formatting.


[error] 55-55: pre-commit hook 'ruff-format' failed. Hook reformatted file.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_trtllm_moe_alltoall.py` around lines 55 - 71, The test matrix
COMBINE_PARAMS is missing explicit coverage for top_k=16 and some high-top-k
rows use skip-prone shapes (world_size=8, num_tokens=16); add at least one tuple
that exercises top_k=16 (e.g., same dtype/payload as others) and either replace
or add lower-resource variants for the high top_k cases so they run on smaller
GPUs (for example use world_size=4 and/or num_tokens=8 for the new top_k values)
to ensure the new top-k codepaths are exercised without requiring 128 SM
hardware.

Copy link
Copy Markdown
Collaborator

@samuellees samuellees left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider adding test cases for the new top_k values (e.g. (32, 7168, 512, 10) and (64, 4096, 256, 16)) to SINGLE_GPU_PARAMS — I verified these pass on B200 but they should be in CI~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants