perf(autotuner): replace power-of-2 token buckets with hybrid spacing & fix missing routing_replay_out arg#3115
Conversation
…le_moe Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughReplaces power-of-2 token-bucketing with a new four-phase hybrid bucketing across MoE and GEMM autotuning; updates callsites to use the new mapping utilities. Also threads an optional Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Client/Caller
participant Tuner as Autotuner/Tuner
participant Utils as Bucketing Utils
participant Runner as MoE Runner
participant Kernel as trtllm_fp8_per_tensor_scale_moe_op
Client->>Tuner: request tuning / forward(input with num_tokens)
Tuner->>Utils: get_hybrid_num_tokens_buckets(max_tokens)
Tuner->>Utils: map_to_hybrid_bucket(num_tokens, max_tokens)
Tuner->>Runner: select tactic / provide mapped bucket
Client->>Runner: forward(..., routing_replay_out=?)
Runner->>Kernel: call trtllm_fp8_per_tensor_scale_moe_op(..., routing_replay_out)
Kernel-->>Runner: result
Runner-->>Client: output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~28 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request replaces the power-of-2 token bucket generation logic with a hybrid approach across several modules to improve autotuning for MoE workloads. The new logic uses four phases with varying spacing, including power-of-2 and linear steps. Additionally, a routing_replay_out parameter is added to the MoE forward functions. A logic inconsistency was identified in get_hybrid_num_tokens_buckets where the generated buckets may not align with the mapping function when min_num_tokens is greater than one, which could lead to autotuner failures.
| m = max(min_num_tokens, 1) | ||
| while m <= min(max_num_tokens, _PHASE1_END): | ||
| buckets.append(m) | ||
| m *= 2 | ||
|
|
||
| # Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END] | ||
| m = _PHASE1_END + _PHASE2_STEP | ||
| while m <= min(max_num_tokens, _PHASE2_END): | ||
| buckets.append(m) | ||
| m += _PHASE2_STEP | ||
|
|
||
| # Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END] | ||
| m = _PHASE2_END + _PHASE3_STEP | ||
| while m <= min(max_num_tokens, _PHASE3_END): | ||
| buckets.append(m) | ||
| m += _PHASE3_STEP | ||
|
|
||
| # Phase 4: power-of-2 beyond _PHASE3_END | ||
| m = _PHASE3_END * 2 | ||
| while m <= max_num_tokens: | ||
| buckets.append(m) | ||
| m *= 2 | ||
|
|
||
| if not buckets or buckets[-1] != max_num_tokens: | ||
| buckets.append(max_num_tokens) | ||
|
|
||
| return tuple(sorted(set(buckets))) |
There was a problem hiding this comment.
The implementation of get_hybrid_num_tokens_buckets has a critical inconsistency with map_to_hybrid_bucket when min_num_tokens > 1.
- Phase 1 Mismatch: If
min_num_tokensis not a power of 2 (e.g., 10), Phase 1 currently generates buckets starting from that value (e.g.,[10, 20, 40, ...]). However,map_to_hybrid_bucketusesnext_positive_power_of_2(x)for Phase 1, which means an input of size 10 will map to bucket 16. Since 16 is not in the generated list, the autotuner will fail to find a tuned tactic for this size. - Phase 2-4 Filtering: The loops for subsequent phases use fixed starting points (e.g.,
_PHASE1_END + _PHASE2_STEP), which results in buckets smaller thanmin_num_tokensbeing added to the list ifmin_num_tokensis large.
The robust fix is to always generate the full set of potential buckets starting from 1 (to ensure consistency with the mapping logic) and then filter the final result to keep only those within the [min_num_tokens, max_num_tokens] range.
buckets: List[int] = []
# Phase 1: power-of-2 up to _PHASE1_END
m = 1
while m <= min(max_num_tokens, _PHASE1_END):
buckets.append(m)
m *= 2
# Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
m = _PHASE1_END + _PHASE2_STEP
while m <= min(max_num_tokens, _PHASE2_END):
buckets.append(m)
m += _PHASE2_STEP
# Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
m = _PHASE2_END + _PHASE3_STEP
while m <= min(max_num_tokens, _PHASE3_END):
buckets.append(m)
m += _PHASE3_STEP
# Phase 4: power-of-2 beyond _PHASE3_END
m = _PHASE3_END * 2
while m <= max_num_tokens:
buckets.append(m)
m *= 2
if not buckets or buckets[-1] != max_num_tokens:
buckets.append(max_num_tokens)
return tuple(sorted(set(b for b in buckets if b >= min_num_tokens and b <= max_num_tokens)))There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 217-223: The docstring in fused_moe/utils.py (around the function
that describes the four phases) contains Unicode multiplication characters "×"
which trigger Ruff; replace those with the ASCII letter "x" (e.g., change "step
×2" to "step x2") so the docstring uses plain ASCII and pre-commit passes;
update all occurrences in that docstring text accordingly.
- Around line 224-253: get_hybrid_num_tokens_buckets is not honoring
min_num_tokens across phases: phase1 starts at min_num_tokens without rounding
up to the next power-of-2, and phases 2/3 always start at fixed boundaries
(e.g., _PHASE1_END+_PHASE2_STEP) which can emit values below min_num_tokens. Fix
by computing phase starts relative to min_num_tokens: for phase1 set m to the
smallest power-of-2 >= min_num_tokens (use bit math or loop) and then multiply
by 2; for phase2 set m to the smallest value >= min_num_tokens and >=
(_PHASE1_END+_PHASE2_STEP) that aligns to the _PHASE2_STEP grid (ceil to next
multiple of _PHASE2_STEP); for phase3 do the same alignment with _PHASE3_STEP
and _PHASE2_END; and for phase4 start at max(min_num_tokens, _PHASE3_END*2) then
multiply by 2; ensure every appended bucket >= min_num_tokens and <=
max_num_tokens and keep the final sorting/unique logic intact (variables:
get_hybrid_num_tokens_buckets, min_num_tokens, max_num_tokens, _PHASE1_END,
_PHASE2_STEP, _PHASE2_END, _PHASE3_STEP, _PHASE3_END).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f83d9649-3a2d-43ac-98e8-a7f2b490a9f6
📒 Files selected for processing (5)
flashinfer/fused_moe/core.pyflashinfer/fused_moe/cute_dsl/tuner.pyflashinfer/fused_moe/utils.pyflashinfer/gemm/gemm_base.pyflashinfer/trtllm_low_latency_gemm.py
| This function uses four phases with progressively coarser spacing:: | ||
|
|
||
| Phase 1: [min .. 256] — power-of-2 (step ×2) | ||
| Phase 2: (256 .. 2048] — linear step 256 | ||
| Phase 3: (2048 .. 4096] — linear step 512 | ||
| Phase 4: (4096 .. max] — power-of-2 (step ×2) | ||
| """ |
There was a problem hiding this comment.
Replace ambiguous multiplication signs in the docstring.
Ruff flags the Unicode × characters here; use plain x to keep pre-commit clean.
Proposed fix
- Phase 1: [min .. 256] — power-of-2 (step ×2)
+ Phase 1: [min .. 256] — power-of-2 (step x2)
Phase 2: (256 .. 2048] — linear step 256
Phase 3: (2048 .. 4096] — linear step 512
- Phase 4: (4096 .. max] — power-of-2 (step ×2)
+ Phase 4: (4096 .. max] — power-of-2 (step x2)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| This function uses four phases with progressively coarser spacing:: | |
| Phase 1: [min .. 256] — power-of-2 (step ×2) | |
| Phase 2: (256 .. 2048] — linear step 256 | |
| Phase 3: (2048 .. 4096] — linear step 512 | |
| Phase 4: (4096 .. max] — power-of-2 (step ×2) | |
| """ | |
| This function uses four phases with progressively coarser spacing:: | |
| Phase 1: [min .. 256] — power-of-2 (step x2) | |
| Phase 2: (256 .. 2048] — linear step 256 | |
| Phase 3: (2048 .. 4096] — linear step 512 | |
| Phase 4: (4096 .. max] — power-of-2 (step x2) | |
| """ |
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 219-219: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
[warning] 222-222: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/utils.py` around lines 217 - 223, The docstring in
fused_moe/utils.py (around the function that describes the four phases) contains
Unicode multiplication characters "×" which trigger Ruff; replace those with the
ASCII letter "x" (e.g., change "step ×2" to "step x2") so the docstring uses
plain ASCII and pre-commit passes; update all occurrences in that docstring text
accordingly.
| buckets: List[int] = [] | ||
|
|
||
| # Phase 1: power-of-2 up to _PHASE1_END | ||
| m = max(min_num_tokens, 1) | ||
| while m <= min(max_num_tokens, _PHASE1_END): | ||
| buckets.append(m) | ||
| m *= 2 | ||
|
|
||
| # Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END] | ||
| m = _PHASE1_END + _PHASE2_STEP | ||
| while m <= min(max_num_tokens, _PHASE2_END): | ||
| buckets.append(m) | ||
| m += _PHASE2_STEP | ||
|
|
||
| # Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END] | ||
| m = _PHASE2_END + _PHASE3_STEP | ||
| while m <= min(max_num_tokens, _PHASE3_END): | ||
| buckets.append(m) | ||
| m += _PHASE3_STEP | ||
|
|
||
| # Phase 4: power-of-2 beyond _PHASE3_END | ||
| m = _PHASE3_END * 2 | ||
| while m <= max_num_tokens: | ||
| buckets.append(m) | ||
| m *= 2 | ||
|
|
||
| if not buckets or buckets[-1] != max_num_tokens: | ||
| buckets.append(max_num_tokens) | ||
|
|
||
| return tuple(sorted(set(buckets))) |
There was a problem hiding this comment.
Honor min_num_tokens across all phases.
Line 233 always starts phase 2 at 512, so get_hybrid_num_tokens_buckets(4096, min_num_tokens=1024) still emits 512 and 768. Line 227 also starts phase 1 from arbitrary non-powers like 129, breaking the documented power-of-2 phase.
Proposed fix
buckets: List[int] = []
+ min_num_tokens = max(min_num_tokens, 1)
+ if max_num_tokens < min_num_tokens:
+ raise ValueError("max_num_tokens must be >= min_num_tokens")
# Phase 1: power-of-2 up to _PHASE1_END
- m = max(min_num_tokens, 1)
+ m = next_positive_power_of_2(min_num_tokens)
while m <= min(max_num_tokens, _PHASE1_END):
buckets.append(m)
m *= 2
# Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
- m = _PHASE1_END + _PHASE2_STEP
+ m = _ceil_to_step(max(min_num_tokens, _PHASE1_END + 1), _PHASE2_STEP)
while m <= min(max_num_tokens, _PHASE2_END):
buckets.append(m)
m += _PHASE2_STEP
# Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
- m = _PHASE2_END + _PHASE3_STEP
+ m = _ceil_to_step(max(min_num_tokens, _PHASE2_END + 1), _PHASE3_STEP)
while m <= min(max_num_tokens, _PHASE3_END):
buckets.append(m)
m += _PHASE3_STEP
# Phase 4: power-of-2 beyond _PHASE3_END
- m = _PHASE3_END * 2
+ m = next_positive_power_of_2(max(min_num_tokens, _PHASE3_END + 1))
while m <= max_num_tokens:
buckets.append(m)
m *= 2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/utils.py` around lines 224 - 253,
get_hybrid_num_tokens_buckets is not honoring min_num_tokens across phases:
phase1 starts at min_num_tokens without rounding up to the next power-of-2, and
phases 2/3 always start at fixed boundaries (e.g., _PHASE1_END+_PHASE2_STEP)
which can emit values below min_num_tokens. Fix by computing phase starts
relative to min_num_tokens: for phase1 set m to the smallest power-of-2 >=
min_num_tokens (use bit math or loop) and then multiply by 2; for phase2 set m
to the smallest value >= min_num_tokens and >= (_PHASE1_END+_PHASE2_STEP) that
aligns to the _PHASE2_STEP grid (ceil to next multiple of _PHASE2_STEP); for
phase3 do the same alignment with _PHASE3_STEP and _PHASE2_END; and for phase4
start at max(min_num_tokens, _PHASE3_END*2) then multiply by 2; ensure every
appended bucket >= min_num_tokens and <= max_num_tokens and keep the final
sorting/unique logic intact (variables: get_hybrid_num_tokens_buckets,
min_num_tokens, max_num_tokens, _PHASE1_END, _PHASE2_STEP, _PHASE2_END,
_PHASE3_STEP, _PHASE3_END).
|
/bot run |
samuellees
left a comment
There was a problem hiding this comment.
LGTM, waiting for the CI pass
|
/bot run |
|
/bot run |
|
[FAILED] Pipeline #49156002: 1/20 passed |
|
Hi @StudyingShao , Could you please:
Thx! |
📌 Description
This PR includes two improvements:
perf(autotuner): Replace power-of-2 token buckets with hybrid spacing — Pure power-of-2 spacing creates huge gaps at large values (e.g. a jump from 1024 to 2048), forcing the autotuner to pick a kernel optimised for a very different workload size. The new hybrid scheme uses four phases with progressively coarser spacing:
[min .. 256]— power-of-2 (step ×2)(256 .. 2048]— linear step 256(2048 .. 4096]— linear step 512(4096 .. max]— power-of-2 (step ×2)All callsites in MoE, GEMM, and low-latency GEMM autotuners are updated to use the new
get_hybrid_num_tokens_buckets/map_to_hybrid_bucketAPI.fix: Pass missing
routing_replay_outarg totrtllm_fp8_per_tensor_scale_moe— Two call sites infused_moe/core.pywere missing therouting_replay_outargument, causing it to be silently dropped.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Changed files:
flashinfer/fused_moe/utils.py— Core implementation: newget_hybrid_num_tokens_buckets,map_to_hybrid_bucket,map_to_hybrid_bucket_uncapped; removed oldget_last_power_of_2_num_tokens_bucketsflashinfer/fused_moe/core.py— Updated all MoE autotuner callsites + added missingrouting_replay_outargflashinfer/fused_moe/cute_dsl/tuner.py— Updated CuTe DSL FP4 MoE tuner callsiteflashinfer/gemm/gemm_base.py— Updated GEMM (FP8, BF16, FP4, MXFP8, TGV) autotuner configsflashinfer/trtllm_low_latency_gemm.py— Updated low-latency GEMM autotuner configSummary by CodeRabbit
Improvements
New Features